summaryrefslogtreecommitdiff
path: root/lua/99/providers.lua
diff options
context:
space:
mode:
Diffstat (limited to 'lua/99/providers.lua')
-rw-r--r--lua/99/providers.lua147
1 files changed, 80 insertions, 67 deletions
diff --git a/lua/99/providers.lua b/lua/99/providers.lua
index 8875ca8..63e64aa 100644
--- a/lua/99/providers.lua
+++ b/lua/99/providers.lua
@@ -1,7 +1,7 @@
--- @class _99.Providers.Observer
--- @field on_stdout fun(line: string): nil
--- @field on_stderr fun(line: string): nil
---- @field on_complete fun(status: _99.Prompt.EndingState, res: string): nil
+--- @field on_complete fun(status: _99.Prompt.EndingState, res: string|table): nil
--- @field on_start fun(): nil
--- @param fn fun(...: any): nil
@@ -62,54 +62,6 @@ end
--- @field predictedPerTkMs number
--- @field predictedPerSec number
---- @param raw table
---- @return _A4.OpenAIResponse
-local function parse_openai_response(raw)
- local choice = raw.choices[1]
-
- --- @type _A4.OpenAIResponse
- local response = {
- completion = {
- finishReason = choice.finish_reason,
- index = choice.index,
- message = {
- role = choice.message.role,
- content = choice.message.content,
- },
- },
- metadata = {
- created = raw.created,
- model = raw.model,
- systemFingerPrint = raw.system_fingerprint,
- object = raw.object,
- usage = {
- completionTks = raw.usage.completion_tokens,
- promptTks = raw.usage.prompt_tokens,
- totalTks = raw.usage.total_tokens,
- promptTksDetails = {
- cachedTks = raw.usage.prompt_tokens_details.cached_tokens,
- },
- },
- },
- id = {
- id = raw.id,
- },
- timings = {
- cacheN = raw.timings.cache_n,
- promptN = raw.timings.prompt_n,
- promptMs = raw.timings.prompt_ms,
- promptPerTkMs = raw.timings.prompt_per_token_ms,
- promptPerSec = raw.timings.prompt_per_second,
- predictedN = raw.timings.predicted_n,
- predictedMs = raw.timings.predicted_ms,
- predictedPerTkMs = raw.timings.predicted_per_token_ms,
- predictedPerSec = raw.timings.predicted_per_second,
- },
- }
-
- return response
-end
-
--- @class _99.Providers.BaseProvider
--- @field _build_command fun(self: _99.Providers.BaseProvider, query: string, context: _99.Prompt): string[]
--- @field _get_provider_name fun(self: _99.Providers.BaseProvider): string
@@ -123,8 +75,8 @@ end
--- @param data string
--- @return string
-function BaseProvider:_clean_response(data)
- local _data, _ = data:gsub("```[a-z]*\n", ""):gsub(" ```", "")
+function BaseProvider._clean_response(data)
+ local _data, _ = data:gsub("```[a-z]*\n", ""):gsub("```", "")
return _data
end
@@ -164,9 +116,6 @@ function BaseProvider:_retrieve_response(context)
end
end
--- TODO: Remember that we are ditching the tmp_file; likely 90%
--- You can just grab the json data and dump it in logs honestly
--- files only make sense for agentic flows
--- @param query string
--- @param context _99.Prompt
--- @param observer _99.Providers.Observer
@@ -192,21 +141,25 @@ function BaseProvider:make_request(query, context, observer)
command,
{ text = true },
vim.schedule_wrap(function(obj)
- if obj.code ~= 0 then
- once_complete("failed", obj.stderr)
+ logger:debug("exit callback fired", "code", tostring(obj.code))
+ logger:debug("schedule_wrap", "stdout", tostring(obj.stdout))
+ logger:debug("schedule_wrap", "stderr", tostring(obj.stderr))
+
+ if context:is_cancelled() then
+ once_complete("cancelled", "")
return
end
- local data = parse_openai_response(vim.json.decode(obj.stdout))
- if self._get_provider_name == "BareMetalProvider" then
- data.completion.message.content =
- self:_clean_response(data.completion.message.content)
- context.req_response = data.completion.message.content
+ if obj.code ~= 0 then
+ once_complete("failed", obj.stderr or "")
+ return
end
- once_complete("success", data.completion.message.content)
+
+ once_complete("success", obj.stdout)
end)
)
+ logger:debug("proc spawned", "proc_id", tostring(proc))
context:_set_process(proc)
end
@@ -400,7 +353,8 @@ function GeminiCLIProvider._get_default_model()
end
--- @class BareMetalProvider : _99.Providers.BaseProvider
---- @field _clean_response fun(self: _99.Providers.BaseProvider, data : string): string
+--- @field _clean_response fun(data : string): string
+--- @field _parse_openai_response fun(raw: table): _A4.OpenAIResponse
local BareMetalProvider = setmetatable({}, { __index = BaseProvider })
--- @param query string
@@ -412,11 +366,22 @@ function BareMetalProvider._build_command(_, query, context)
"-s",
context.endpoint,
"-H",
- '"Content-Type: applications/json"',
+ "Content-Type: application/json",
"-d",
- '"{"messages":[{"role":"system","content":"You are a strict code completion backend for Neovim. Your input will be a code snippet, function signature, or a comment requesting code. CRITICAL DIRECTIONS: 1. Output ONLY valid, executable programming code. 2. Do NOT wrap your response in markdown code blocks. 3. Do NOT include any conversational filler, explanations, greetings, or sign-offs. 4. If you cannot fulfill the request, output nothing or a code comment explaining why."},{"role":"user","content":"'
- .. query
- .. '"}],"temperature":0.0,"stream":false}"',
+ vim.json.encode({
+ messages = {
+ {
+ role = "system",
+ content = "You are a strict code completion backend for Neovim. Your input will be a code snippet, function signature, or a comment requesting code. CRITICAL DIRECTIONS: 1. Output ONLY valid, executable programming code. 2. Do NOT wrap your response in markdown code blocks. 3. Do NOT include any conversational filler, explanations, greetings, or sign-offs. 4. If you cannot fulfill the request, output nothing or a code comment explaining why.",
+ },
+ {
+ role = "user",
+ content = query,
+ },
+ },
+ temperature = 0.0,
+ stream = false,
+ }),
}
end
@@ -436,6 +401,54 @@ function BareMetalProvider.fetch_models(callback)
}, nil)
end
+--- @param raw table
+--- @return _A4.OpenAIResponse
+function BareMetalProvider._parse_openai_response(raw)
+ local choice = raw.choices[1]
+
+ --- @type _A4.OpenAIResponse
+ local response = {
+ completion = {
+ finishReason = choice.finish_reason,
+ index = choice.index,
+ message = {
+ role = choice.message.role,
+ content = choice.message.content,
+ },
+ },
+ metadata = {
+ created = raw.created,
+ model = raw.model,
+ systemFingerPrint = raw.system_fingerprint,
+ object = raw.object,
+ usage = {
+ completionTks = raw.usage.completion_tokens,
+ promptTks = raw.usage.prompt_tokens,
+ totalTks = raw.usage.total_tokens,
+ promptTksDetails = {
+ cachedTks = raw.usage.prompt_tokens_details.cached_tokens,
+ },
+ },
+ },
+ id = {
+ id = raw.id,
+ },
+ timings = {
+ cacheN = raw.timings.cache_n,
+ promptN = raw.timings.prompt_n,
+ promptMs = raw.timings.prompt_ms,
+ promptPerTkMs = raw.timings.prompt_per_token_ms,
+ promptPerSec = raw.timings.prompt_per_second,
+ predictedN = raw.timings.predicted_n,
+ predictedMs = raw.timings.predicted_ms,
+ predictedPerTkMs = raw.timings.predicted_per_token_ms,
+ predictedPerSec = raw.timings.predicted_per_second,
+ },
+ }
+
+ return response
+end
+
return {
BaseProvider = BaseProvider,
OpenCodeProvider = OpenCodeProvider,