diff options
Diffstat (limited to 'lua')
| -rw-r--r-- | lua/99/init.lua | 8 | ||||
| -rw-r--r-- | lua/99/providers.lua | 6 | ||||
| -rw-r--r-- | lua/99/state.lua | 2 | ||||
| -rw-r--r-- | lua/99/test/providers_spec.lua | 18 |
4 files changed, 34 insertions, 0 deletions
diff --git a/lua/99/init.lua b/lua/99/init.lua index 0364c6c..884250f 100644 --- a/lua/99/init.lua +++ b/lua/99/init.lua @@ -54,6 +54,7 @@ end --- @field in_flight_options? _99.StatusWindow.Opts --- @field md_files? string[] --- @field provider? _99.Providers.BaseProvider +--- @field provider_extra_args? string[] --- @field display_errors? boolean --- @field auto_add_skills? boolean --- @field completion? _99.Completion @@ -438,6 +439,13 @@ function _99.setup(opts) end end + if opts.provider_extra_args then + assert( + type(opts.provider_extra_args) == "table", + "opts.provider_extra_args must be a table" + ) + end + if opts.md_files then assert(type(opts.md_files) == "table", "opts.md_files is not a table") for _, md in ipairs(opts.md_files) do diff --git a/lua/99/providers.lua b/lua/99/providers.lua index 5eab070..1ea9cda 100644 --- a/lua/99/providers.lua +++ b/lua/99/providers.lua @@ -71,6 +71,12 @@ function BaseProvider:make_request(query, context, observer) ) local command = self:_build_command(query, context) + local extra_args = context._99 and context._99.provider_extra_args or {} + if #extra_args > 0 then + local query_arg = table.remove(command) + vim.list_extend(command, extra_args) + table.insert(command, query_arg) + end logger:debug("make_request", "command", command) local proc = vim.system( diff --git a/lua/99/state.lua b/lua/99/state.lua index a67acee..b4ac692 100644 --- a/lua/99/state.lua +++ b/lua/99/state.lua @@ -30,6 +30,7 @@ end --- @field ai_stdout_rows number --- @field display_errors boolean --- @field provider_override _99.Providers.BaseProvider? +--- @field provider_extra_args string[] --- @field rules _99.Agents.Rules --- @field tracking _99.State.Tracking --- @field __tmp_dir string | nil @@ -73,6 +74,7 @@ function State.new(opts) local _99_state = setmetatable(props, State) --[[@as _99.State]] _99_state.provider_override = opts.provider + _99_state.provider_extra_args = opts.provider_extra_args or {} _99_state.completion = opts.completion or default_completion() _99_state.completion.custom_rules = _99_state.completion.custom_rules or {} _99_state.completion.files = _99_state.completion.files or {} diff --git a/lua/99/test/providers_spec.lua b/lua/99/test/providers_spec.lua index 5d0436c..803cd05 100644 --- a/lua/99/test/providers_spec.lua +++ b/lua/99/test/providers_spec.lua @@ -152,6 +152,24 @@ describe("providers", function() end) end) + describe("provider_extra_args", function() + it("stores provider_extra_args on state", function() + local _99 = require("99") + _99.setup({ + provider_extra_args = { "--no-session-persistence" }, + }) + local state = _99.__get_state() + eq({ "--no-session-persistence" }, state.provider_extra_args) + end) + + it("defaults provider_extra_args to empty table", function() + local _99 = require("99") + _99.setup({}) + local state = _99.__get_state() + eq({}, state.provider_extra_args) + end) + end) + describe("BaseProvider", function() it("all providers have make_request", function() eq("function", type(Providers.OpenCodeProvider.make_request)) |
