diff options
| author | theprimeagain <the.primeagen@gmail.com> | 2026-02-21 09:28:39 -0700 |
|---|---|---|
| committer | theprimeagain <the.primeagen@gmail.com> | 2026-02-21 11:44:12 -0700 |
| commit | 3787c3dc34a1a9b818b3e71afa02823f5bec96c3 (patch) | |
| tree | 87068473862e70b6ad749a2d9782b069fb7d4e70 | |
| parent | f03187fb3c28a33ae85f587c56d73fd4d56f2a6a (diff) | |
| download | a4-3787c3dc34a1a9b818b3e71afa02823f5bec96c3.tar.xz a4-3787c3dc34a1a9b818b3e71afa02823f5bec96c3.zip | |
Refactoring of context and request entry into a single thing prompt
| -rw-r--r-- | AGENTS.md | 4 | ||||
| -rw-r--r-- | TODO.md | 29 | ||||
| -rw-r--r-- | lua/99/editor/treesitter.lua | 6 | ||||
| -rw-r--r-- | lua/99/extensions/work/worker.lua | 87 | ||||
| -rw-r--r-- | lua/99/geo.lua | 9 | ||||
| -rw-r--r-- | lua/99/init.lua | 196 | ||||
| -rw-r--r-- | lua/99/ops/clean-up.lua | 2 | ||||
| -rw-r--r-- | lua/99/ops/make-prompt.lua | 2 | ||||
| -rw-r--r-- | lua/99/ops/over-range.lua | 11 | ||||
| -rw-r--r-- | lua/99/ops/search.lua | 23 | ||||
| -rw-r--r-- | lua/99/ops/tutorial.lua | 26 | ||||
| -rw-r--r-- | lua/99/prompt-settings.lua | 13 | ||||
| -rw-r--r-- | lua/99/prompt.lua | 415 | ||||
| -rw-r--r-- | lua/99/providers.lua | 63 | ||||
| -rw-r--r-- | lua/99/request-context.lua | 181 | ||||
| -rw-r--r-- | lua/99/request/init.lua | 138 | ||||
| -rw-r--r-- | lua/99/test/providers_spec.lua | 8 | ||||
| -rw-r--r-- | lua/99/test/request_spec.lua | 16 | ||||
| -rw-r--r-- | lua/99/test/test_utils.lua | 16 | ||||
| -rw-r--r-- | lua/99/test/visual_spec.lua | 28 | ||||
| -rw-r--r-- | lua/99/utils.lua | 27 | ||||
| -rw-r--r-- | scripts/tests/minimal.vim | 5 |
22 files changed, 712 insertions, 593 deletions
@@ -1,3 +1,7 @@ * always use neovim provided functions * this is not a standard lua project. package resolution and all things related to lua and std should be ignored in favor of neovim and its utitlites. +## Testing +* make lua_test +* make pr_ready + @@ -0,0 +1,29 @@ +* Prompt should generate its prompt via prompt() instead of passing it into the provider as a string... +* Search Items should be editable. That way i can mark them off as finished + * use capture input style to "mark" them as done. [x] as done, or delete line + * this should mean that when we revive the work menu list, it reflects the new reality +* Search item navigation. We should just be able to next("search") to navigate the searches + * tutorials, searches, and visuals should all have their own history + * clean history should be on a vertical as well +* Vibe Work + * takes the search results and asks the AI to implement those changes. + * this should use the new "vibe" interface i want to make + * something i have ran into, maybe its useful, but being able to do the following + * search -> partial select -> vibe +* vibe interface + * makes changes, and then describes each edit in a tmp file such that it can be loaded into memory and transfered to quickfixlist + * be able to have a diff view? live view toggle? +* state of state + * maybe this needs to be persisted as json in a tmp file such that we can restore it upon opening. I could see this being super useful +* some sort of interface that i can peruse the types of requests made + * filter by type + * display all + * enter opens up the request + * delete removes the request from history +* search qfix notes should be added as marks + * there will be a need for smarter mark management. +* stop all requests do not seem to stop active requests... +* add an add_data method to context in which when you set the data it: + * asserts if you included a type + * initialized with the proper type + * adds the fields one at a time diff --git a/lua/99/editor/treesitter.lua b/lua/99/editor/treesitter.lua index b0db2e9..5cb9b1c 100644 --- a/lua/99/editor/treesitter.lua +++ b/lua/99/editor/treesitter.lua @@ -31,7 +31,7 @@ local function tree_root(buffer, lang) return tree:root() end ---- @param context _99.RequestContext +--- @param context _99.Prompt --- @param cursor _99.Point --- @return _99.treesitter.TSNode | nil function M.fn_call(context, cursor) @@ -103,7 +103,7 @@ end --- @param ts_node _99.treesitter.TSNode ---@param cursor _99.Point ----@param context _99.RequestContext +---@param context _99.Prompt ---@return _99.treesitter.Function function Function.from_ts_node(ts_node, cursor, context) local ok, query = @@ -138,7 +138,7 @@ function Function.from_ts_node(ts_node, cursor, context) return setmetatable(func, Function) end ---- @param context _99.RequestContext +--- @param context _99.Prompt --- @param cursor _99.Point --- @return _99.treesitter.Function? function M.containing_function(context, cursor) diff --git a/lua/99/extensions/work/worker.lua b/lua/99/extensions/work/worker.lua index 5cce21c..780b2f2 100644 --- a/lua/99/extensions/work/worker.lua +++ b/lua/99/extensions/work/worker.lua @@ -1,4 +1,5 @@ local Window = require("99.window") +local utils = require("99.utils") --- @class _99.Extension.Worker local M = {} @@ -6,6 +7,57 @@ local M = {} --- @class _99.WorkOpts --- @field description string | nil +--- @return string +local function get_work_item_file() + local _99 = require("99") + local state = _99.__get_state() + local tmp = state:tmp_dir() + return utils.named_tmp_file(tmp, "work-item") +end + +--- @return string | nil +local function read_work_item() + local ok, file = pcall(io.open, get_work_item_file(), "r") + if not ok or not file then + return nil + end + --- @type string + local contents + ok, contents = pcall(file.read, file, "*a") + pcall(file.close, file) + + if not ok then + return nil + end + return contents +end + +--- @param success boolean +---@param result string +local function set_work_item_cb(success, result) + if not success then + return + end + M.current_work_item = result + + local file = io.open(get_work_item_file(), "w") + if file then + file:write(result) + file:close() + else + error("unable to save work item") + end +end + +function M.updated_work() + local work = M.current_work_item + or "Put in the description of the work you want to complete" + Window.capture_input(" Work ", { + cb = set_work_item_cb, + content = vim.split(work, "\n"), + }) +end + --- @param opts _99.WorkOpts | nil function M.set_work(opts) opts = opts or {} @@ -14,13 +66,7 @@ function M.set_work(opts) M.current_work_item = description else Window.capture_input(" Work ", { - cb = function(success, result) - if not success then - return - end - M.current_work_item = result - end, - + cb = set_work_item_cb, content = { "Put in the description of the work you want to complete" }, }) end @@ -36,13 +82,29 @@ function M.craft_prompt(worker) return string.format( [[ <YourGoal> -You are to take the current git diff and git diff --staged and figure out what is +<OrderedSteps> +<Step> +Inspect and understand all changed code +* git diff +* git diff --staged +* commits that have not been pushed to remote +</Step> + +<Step> +Take the current pending and commited changes and figure out what is left to change to complete the work item. The work item is described in <Description> -Carefully review everything in git diff and git diff --staged and <Description> before you respond. +Carefully review all the changes and <Description> before you respond. respond with proper Search Format described in <Rule> and an example in <Output> If you see bugs, also report those +</Step> + +<Step> +if there are steps to test the project. run the tests and add to the list the failures +and how to fix them +</Step> +</OrderedSteps> </YourGoal> <Description> %s @@ -53,11 +115,16 @@ If you see bugs, also report those end function M.work() + local _99 = require("99") + if M.current_work_item == nil then + M.current_work_item = read_work_item() + end + assert( M.current_work_item, 'you must call "set_work" and set your current work item before calling this' ) - local _99 = require("99") + M.last_work_search = _99.search({ additional_prompt = M.craft_prompt(M), }) diff --git a/lua/99/geo.lua b/lua/99/geo.lua index 33e0cb1..48e60ad 100644 --- a/lua/99/geo.lua +++ b/lua/99/geo.lua @@ -204,6 +204,10 @@ function Point.from_mark(mark) }, Point) end +function Point.zero() + return Point.from_0_based(0, 0) +end + --- @class _99.Range --- @field start _99.Point --- @field end_ _99.Point @@ -353,6 +357,11 @@ function Range:to_string() ) end +--- @return _99.Range +function Range.zero() + return Range:new(0, Point.zero(), Point.zero()) +end + return { Point = Point, Range = Range, diff --git a/lua/99/init.lua b/lua/99/init.lua index a512ab6..7356d7b 100644 --- a/lua/99/init.lua +++ b/lua/99/init.lua @@ -3,15 +3,12 @@ local Level = require("99.logger.level") local ops = require("99.ops") local Languages = require("99.language") local Window = require("99.window") -local get_id = require("99.id") -local RequestContext = require("99.request-context") +local Prompt = require("99.prompt") local geo = require("99.geo") local Range = geo.Range -local Point = geo.Point local Extensions = require("99.extensions") local Agents = require("99.extensions.agents") local Providers = require("99.providers") -local time = require("99.time") local Throbber = require("99.ops.throbber") ---@param path_or_rule string | _99.Agents.Rule @@ -41,25 +38,6 @@ local function process_opts(opts) return opts end ---- @alias _99.Cleanup fun(): nil - ---- @class _99.RequestEntry.Data.Search ---- @field type "search" ---- @field qfix_items _99.Search.Result[] - ---- @class _99.RequestEntry.Data.Visual ---- @field type "visual" - --- luacheck: ignore ---- @alias _99.RequestEntry.Data _99.RequestEntry.Data.Search | _99.RequestEntry.Data.Tutorial | _99.RequestEntry.Data.Visual - ---- @class _99.RequestEntry ---- @field context _99.RequestContext ---- @field status _99.Request.State ---- @field point _99.Point ---- @field started_at number ---- @field operation_data _99.RequestEntry.Data | nil - --- @class _99.StateProps --- @field model string --- @field md_files string[] @@ -71,9 +49,7 @@ end --- @field auto_add_skills boolean --- @field provider_override _99.Providers.BaseProvider? --- @field __view_log_idx number ---- @field __request_history _99.RequestEntry[] ---- @field __request_by_id table<number, _99.RequestEntry> ---- @field tmp_dir string | nil +--- @field __tmp_dir string | nil --- @return _99.StateProps local function create_99_state() @@ -128,10 +104,10 @@ end --- @field auto_add_skills boolean --- @field rules _99.Agents.Rules --- @field __view_log_idx number ---- @field __request_history _99.RequestEntry[] ---- @field __request_by_id table<number, _99.RequestEntry> +--- @field __request_history _99.Prompt[] +--- @field __request_by_id table<number, _99.Prompt> --- @field __active_marks _99.Mark[] ---- @field tmp_dir string | nil +--- @field __tmp_dir string | nil local _99_State = {} _99_State.__index = _99_State @@ -142,6 +118,15 @@ function _99_State.new() return setmetatable(props, _99_State) end +--- @return string +function _99_State:tmp_dir() + local tmp_dir = self.__tmp_dir or "./tmp" + if tmp_dir then + tmp_dir = vim.fn.expand(tmp_dir) + end + return tmp_dir +end + --- TODO: This is something to understand. I bet that this is going to need --- a lot of performance tuning. I am just reading every file, and this could --- take a decent amount of time if there are lots of rules. @@ -156,63 +141,25 @@ function _99_State:refresh_rules() Extensions.refresh(self) end ---- @param tutorial _99.RequestEntry.Data.Tutorial -function _99_State:open_tutorial(tutorial) end - ---- @param context _99.RequestContext ---- @return _99.RequestEntry -function _99_State:track_request(context) - assert( - context.operation, - "must have an operation defined to track the request" - ) - - local point = context.range and context.range.start or Point:from_cursor() - local entry = { - context = context, - status = "requesting", - point = point, - started_at = time.now(), - operation_data = nil, - } - table.insert(self.__request_history, entry) - self.__request_by_id[context.xid] = entry - return entry -end - ---- @param context _99.RequestContext ---- @param status _99.Request.ResponseState -function _99_State:finish_request(context, status) - local id = context.xid - local entry = self.__request_by_id[id] - if not entry then - return - end - - entry.status = status +--- @param tutorial _99.Prompt.Data.Tutorial +function _99_State:open_tutorial(tutorial) + -- TODO: this is a task item for when i start the work item "Tutorial Navigation" + _ = self + _ = tutorial end ---- @param context _99.RequestContext ----@param data _99.RequestEntry.Data -function _99_State:add_data(context, data) - local id = context.xid - local entry = self.__request_by_id[id] - if not entry then - return - end - local logger = Logger:set_id(id) - logger:assert( - entry.context.operation == data.type, - "the data type is not the same as the operation" - ) - entry.operation_data = data +--- @param context _99.Prompt +function _99_State:track_request(context) + assert(context:valid(), "context is not valid") + table.insert(self.__request_history, context) + self.__request_by_id[context.xid] = context end --- @return number function _99_State:previous_request_count() local count = 0 for _, entry in ipairs(self.__request_history) do - if entry.status ~= "requesting" then + if entry.state ~= "requesting" then count = count + 1 end end @@ -222,10 +169,10 @@ end function _99_State:clear_previous_requests() local keep = {} for _, entry in ipairs(self.__request_history) do - if entry.status == "requesting" then + if entry.state == "requesting" then table.insert(keep, entry) else - self.__request_by_id[entry.context.xid] = nil + self.__request_by_id[entry.xid] = nil end end self.__request_history = keep @@ -239,7 +186,7 @@ end function _99_State:active_request_count() local count = 0 for _, r in pairs(self.__request_history) do - if r.status == "requesting" then + if r.state == "requesting" then count = count + 1 end end @@ -247,11 +194,11 @@ function _99_State:active_request_count() end --- @param type "search" | "visual" | "tutorial" ---- @return _99.RequestEntry.Data +--- @return _99.Prompt.Data function _99_State:get_request_data_by_type(type) local out = {} for _, r in ipairs(self.__request_history) do - local data = r.operation_data + local data = r.data if data and data.type == type then table.insert(out, data) end @@ -279,9 +226,9 @@ local function set_selection_marks() ) end ---- @param cb fun(context: _99.RequestContext, o: _99.ops.Opts?): nil +--- @param cb fun(context: _99.Prompt, o: _99.ops.Opts?): nil --- @param name string ---- @param context _99.RequestContext +--- @param context _99.Prompt --- @param opts _99.ops.Opts --- @param capture_content string[] | nil local function capture_prompt(cb, name, context, opts, capture_content) @@ -316,17 +263,6 @@ local function capture_prompt(cb, name, context, opts, capture_content) }) end ---- @param operation_name string ---- @return _99.RequestContext -local function get_context(operation_name) - _99_state:refresh_rules() - local trace_id = get_id() - local context = RequestContext.from_current_buffer(_99_state, trace_id) - context.operation = operation_name - context.logger:debug("99 Request", "method", operation_name) - return context -end - function _99.info() local info = {} _99_state:refresh_rules() @@ -344,7 +280,7 @@ function _99.info() Window.display_centered_message(info) end ---- @param tutorials _99.RequestEntry.Data.Tutorial[] +--- @param tutorials _99.Prompt.Data.Tutorial[] --- @return string[] local function tutorial_to_string(tutorials) local out = {} @@ -359,39 +295,34 @@ end function _99.open_tutorial(xid, opts) opts = opts or { split_direction = "vertical" } if xid == nil then + --- @type _99.Prompt.Data.Tutorial[] local tutorials = _99_state:get_request_data_by_type("tutorial") if #tutorials == 0 then print("no tutorials available") + return elseif #tutorials == 1 then - local data = tutorials[1].operation_data + local data = tutorials[1] assert(data, "tutorial is malformed") Window.create_split(data.tutorial, data.buffer, opts) + return else - local context = get_context("tutorial-lookup") - capture_prompt(function(_, o) - local response = o.additional_prompt - local lines = vim.split(response, "\n") - for _, l in ipairs(lines) do - local id = tonumber(vim.split(l, ":")[1]) - if not id then - error( - "do not alter the tutoria lines, just delete the ones you dont want" - ) - end - local tut = _99_state.__request_by_id[id] - local data = tut and tut.operation_data - assert(data and data.type == "tutorial", "invalid tutorial selected") - Window.create_split(data.tutorial, data.buffer, opts) - end - end, "Select Tutorial", context, {}, tutorial_to_string(tutorials)) + --- TODO: Complete this task when i work through tutorials + error([[not implemented. right now tutorials are not sccrollable. +This is a later change required. I want a next/prev tutorial navigation +much like qfix list. then i to have a capture input style window where you +can press enter +]]) end return end - local tutorial = _99_state.__request_by_id[xid] - local data = tutorial and tutorial.operation_data - assert(data and data.type == "tutorial", "cannot open a non tutorial") - Window.create_split(data.tutorial, data.buffer, opts) + --- @type _99.Prompt | nil + local context = _99_state.__request_by_id[xid] + assert(context, "could not find request") + assert(context.state == "success", "tutorial found had a non success state") + + local tutorial = context:tutorial_data() + Window.create_split(tutorial.tutorial, tutorial.buffer, opts) end --- @param path string @@ -405,7 +336,7 @@ end --- @return number function _99.search(opts) local o = process_opts(opts) --[[ @as _99.ops.SearchOpts ]] - local context = get_context("search") + local context = Prompt.search(_99_state) if o.additional_prompt then ops.search(context, o) else @@ -417,7 +348,7 @@ end --- @param opts _99.ops.Opts function _99.tutorial(opts) opts = process_opts(opts) - local context = get_context("tutorial") + local context = Prompt.tutorial(_99_state) if opts.additional_prompt then ops.tutorial(context, opts) else @@ -428,7 +359,7 @@ end --- @param opts _99.ops.Opts? function _99.visual(opts) opts = process_opts(opts) - local context = get_context("visual") + local context = Prompt.visual(_99_state) local function perform_range() set_selection_marks() local range = Range.from_visual_selection() @@ -480,9 +411,9 @@ end --- @field text string function _99.stop_all_requests() - for _, request in pairs(_99_state.__request_by_id) do - if request.status == "requesting" then - request.context:stop() + for _, c in pairs(_99_state.__request_by_id) do + if c.state == "requesting" then + c:stop() end end end @@ -496,14 +427,11 @@ end --- @param xid number | nil function _99.qfix_search_results(xid) - --- @type _99.RequestEntry + --- @type _99.Prompt local entry = _99_state.__request_by_id[xid] assert(entry, "qfix_search_results could not find id: " .. xid) - local data = entry.operation_data - assert(data, "there must be data associated with request entry") - assert(data.type == "search", "the operation_data must be type search") - + local data = entry:search_data() local items = data.qfix_items vim.fn.setqflist({}, "r", { title = "99 Search Results", items = items }) vim.cmd("copen") @@ -561,9 +489,9 @@ local function show_in_flight_requests() throb .. " requests(" .. tostring(count) .. ") " .. throb, } - for _, r in pairs(_99_state.__request_by_id) do - if r.status == "requesting" then - table.insert(lines, r.context.operation) + for _, c in pairs(_99_state.__request_by_id) do + if c.state == "requesting" then + table.insert(lines, c.operation) end end @@ -628,7 +556,7 @@ function _99.setup(opts) if opts.tmp_dir then assert(type(opts.tmp_dir) == "string", "opts.tmp_dir must be a string") end - _99_state.tmp_dir = opts.tmp_dir + _99_state.__tmp_dir = opts.tmp_dir _99_state.display_errors = opts.display_errors or false _99_state:refresh_rules() diff --git a/lua/99/ops/clean-up.lua b/lua/99/ops/clean-up.lua index cbf077a..56a7e11 100644 --- a/lua/99/ops/clean-up.lua +++ b/lua/99/ops/clean-up.lua @@ -1,6 +1,6 @@ local M = {} ---- @alias _99.Providers.on_complete fun(status: _99.Request.ResponseState, response: string): nil +--- @alias _99.Providers.on_complete fun(status: _99.Prompt.EndingState, response: string): nil --- @class _99.Providers.PartialObserver --- @field on_complete _99.Providers.on_complete --- @field on_stdout? fun(line: string): nil diff --git a/lua/99/ops/make-prompt.lua b/lua/99/ops/make-prompt.lua index 4a17687..091eade 100644 --- a/lua/99/ops/make-prompt.lua +++ b/lua/99/ops/make-prompt.lua @@ -1,7 +1,7 @@ local Completions = require("99.extensions.completions") local Agents = require("99.extensions.agents") ---- @param context _99.RequestContext +--- @param context _99.Prompt --- @param prompt string --- @param opts _99.ops.Opts --- @return string, _99.Reference[] diff --git a/lua/99/ops/over-range.lua b/lua/99/ops/over-range.lua index 7070857..1738270 100644 --- a/lua/99/ops/over-range.lua +++ b/lua/99/ops/over-range.lua @@ -1,4 +1,3 @@ -local Request = require("99.request") local RequestStatus = require("99.ops.request_status") local Mark = require("99.ops.marks") local geo = require("99.geo") @@ -11,18 +10,18 @@ local make_observer = CleanUp.make_observer local Range = geo.Range local Point = geo.Point ---- @param context _99.RequestContext +--- @param context _99.Prompt --- @param range _99.Range --- @param opts? _99.ops.Opts local function over_range(context, range, opts) opts = opts or {} local logger = context.logger:set_area("visual") - local request = Request.new(context) local top_mark = Mark.mark_above_range(range) local bottom_mark = Mark.mark_point(range.buffer, range.end_) context.marks.top_mark = top_mark context.marks.bottom_mark = bottom_mark + context.data.range = range logger:debug( "visual request start", @@ -44,19 +43,19 @@ local function over_range(context, range, opts) top_status:stop() bottom_status:stop() context:clear_marks() - request:cancel() + context:stop() end) local system_cmd = context._99.prompts.prompts.visual_selection(range) local prompt, refs = make_prompt(context, system_cmd, opts) - request:add_prompt_content(prompt) + context:add_prompt_content(prompt) context:add_references(refs) context:add_clean_up(clean_up) top_status:start() bottom_status:start() - request:start(make_observer(clean_up, { + context:start_request(make_observer(clean_up, { on_complete = function(status, response) if status == "cancelled" then logger:debug("request cancelled for visual selection, removing marks") diff --git a/lua/99/ops/search.lua b/lua/99/ops/search.lua index d6e5428..dda70d1 100644 --- a/lua/99/ops/search.lua +++ b/lua/99/ops/search.lua @@ -1,4 +1,3 @@ -local Request = require("99.request") local make_prompt = require("99.ops.make-prompt") local CleanUp = require("99.ops.clean-up") @@ -36,7 +35,7 @@ local function parse_line(line) } end ---- @param context _99.RequestContext +--- @param context _99.Prompt --- @param response string local function create_search_locations(context, response) local lines = vim.split(response, "\n") @@ -48,10 +47,10 @@ local function create_search_locations(context, response) table.insert(qf_list, res) end end - context._99:add_data(context, { + context.data = { type = "search", qfix_items = qf_list, - }) + } if #qf_list > 0 then require("99").qfix_search_results(context.xid) @@ -60,28 +59,32 @@ local function create_search_locations(context, response) end end ---- @param context _99.RequestContext +--- @param context _99.Prompt ---@param opts _99.ops.SearchOpts local function search(context, opts) opts = opts or {} local logger = context.logger:set_area("search") - local request = Request.new(context) - logger:debug("search", "with opts", opts.additional_prompt) local clean_up = make_clean_up(function() - request:cancel() + context:stop() end) local prompt, refs = make_prompt(context, context._99.prompts.prompts.semantic_search(), opts) - request:add_prompt_content(prompt) + context:add_prompt_content(prompt) context:add_references(refs) context:add_clean_up(clean_up) - request:start(make_observer(clean_up, function(status, response) + --- TODO: part of the context request clean up there needs to be a refactoring of + --- make observer... it really should just be within the context observer creation. + --- same with cleanup.. that should just be clean_ups from context, instead of a + --- once cleanup function wrapper. + --- + --- i think an interface, CleanUpI could be something that is worth it :) + context:start_request(make_observer(clean_up, function(status, response) if status == "cancelled" then logger:debug("request cancelled for search") elseif status == "failed" then diff --git a/lua/99/ops/tutorial.lua b/lua/99/ops/tutorial.lua index 320366f..9889801 100644 --- a/lua/99/ops/tutorial.lua +++ b/lua/99/ops/tutorial.lua @@ -1,4 +1,3 @@ -local Request = require("99.request") local CleanUp = require("99.ops.clean-up") local Window = require("99.window") local make_prompt = require("99.ops.make-prompt") @@ -6,21 +5,14 @@ local make_prompt = require("99.ops.make-prompt") local make_clean_up = CleanUp.make_clean_up local make_observer = CleanUp.make_observer ---- @class _99.RequestEntry.Data.Tutorial ---- @field type "tutorial" ---- @field buffer number ---- @field window number ---- @field xid number ---- @field tutorial string[] - ---- @param context _99.RequestContext +--- @param context _99.Prompt ---@param response string ----@return _99.RequestEntry.Data.Tutorial +---@return _99.Prompt.Data.Tutorial local function open_tutorial(context, response) local content = vim.split(response, "\n") local win = Window.create_split(content) - --- @type _99.RequestEntry.Data.Tutorial + --- @type _99.Prompt.Data.Tutorial local data = { type = "tutorial", buffer = win.buffer, @@ -28,11 +20,11 @@ local function open_tutorial(context, response) xid = context.xid, tutorial = content, } - context._99:add_data(context, data) + context.data = data return data end ---- @param context _99.RequestContext +--- @param context _99.Prompt ---@param opts _99.ops.Opts local function tutorial(context, opts) opts = opts or {} @@ -40,20 +32,18 @@ local function tutorial(context, opts) local logger = context.logger:set_area("tutorial") logger:debug("starting", "with opts", opts) - local request = Request.new(context) - local clean_up = make_clean_up(function() - request:cancel() + context:stop() end) local prompt, refs = make_prompt(context, context._99.prompts.prompts.tutorial(), opts) context:add_references(refs) - request:add_prompt_content(prompt) + context:add_prompt_content(prompt) context:add_clean_up(clean_up) - request:start(make_observer(clean_up, function(status, response) + context:start_request(make_observer(clean_up, function(status, response) vim.schedule(clean_up) if status == "cancelled" then logger:debug("cancelled") diff --git a/lua/99/prompt-settings.lua b/lua/99/prompt-settings.lua index 03f82a2..f079fc1 100644 --- a/lua/99/prompt-settings.lua +++ b/lua/99/prompt-settings.lua @@ -136,17 +136,14 @@ local prompt_settings = { ) end, - ---@param context _99.RequestContext + ---@param full_path string + ---@param range _99.Range ---@return string - get_file_location = function(context) - context.logger:assert( - context.range, - "get_file_location requires range specified" - ) + get_file_location = function(full_path, range) return string.format( "<Location><File>%s</File><Function>%s</Function></Location>", - context.full_path, - context.range:to_string() + full_path, + range:to_string() ) end, diff --git a/lua/99/prompt.lua b/lua/99/prompt.lua new file mode 100644 index 0000000..d4f5095 --- /dev/null +++ b/lua/99/prompt.lua @@ -0,0 +1,415 @@ +local BaseProvider = require("99.providers") +local Logger = require("99.logger.logger") +local utils = require("99.utils") +local random_file = utils.random_file +local copy = utils.copy +local get_id = require("99.id") +local Range = require("99.geo").Range +local Time = require("99.time") + +local filetype_map = { + typescriptreact = "typescript", +} + +-- luacheck: ignore +--- @alias _99.Prompt.Data _99.Prompt.Data.Search | _99.Prompt.Data.Tutorial | _99.Prompt.Data.Visual +--- @alias _99.Prompt.Operation "visual" | "tutorial" | "search" +--- @alias _99.Prompt.EndingState "failed" | "success" | "cancelled" +--- @alias _99.Prompt.State "ready" | "requesting" | _99.Prompt.EndingState +--- @alias _99.Prompt.Cleanup fun(): nil + +--- @class _99.Prompt.Data.Search +--- @field type "search" +--- @field qfix_items _99.Search.Result[] + +--- @class _99.Prompt.Data.Visual +--- @field type "visual" +--- @field buffer number +--- @field file_type string +--- @field range _99.Range + +--- @class _99.Prompt.Data.Tutorial +--- @field type "tutorial" +--- @field buffer number +--- @field window number +--- @field xid number TODO: we should probably get rid of this. The request pattern is not quite correct +--- @field tutorial string[] + +--- @class _99.Prompt +--- @field md_file_names string[] +--- @field model string +--- @field operation _99.Prompt.Operation +--- @field state _99.Prompt.State +--- @field full_path string +--- @field started_at number +--- @field data _99.Prompt.Data +--- @field agent_context string[] +--- @field tmp_file string +--- @field marks table<string, _99.Mark> +--- @field logger _99.Logger +--- @field xid number +--- @field clean_ups (fun(): nil)[] +--- @field _99 _99.State +---@diagnostic disable-next-line: undefined-doc-name +--- @field _proc vim.SystemObj? +local Prompt = {} +Prompt.__index = Prompt + +--- @type _99.Prompt[] +Prompt.__previous_contexts = {} + +--- @type table<number, _99.Prompt> +Prompt.__context_by_id = {} + +--- @param context _99.Prompt +--- @param _99 _99.State +local function set_defaults(context, _99) + local xid = get_id() + local full_path = vim.api.nvim_buf_get_name(0) + + context.state = "ready" + context._99 = _99 + context.clean_ups = {} + context.md_file_names = copy(_99.md_files) + context.model = _99.model + context.agent_context = {} + context.tmp_file = random_file(_99:tmp_dir()) + context.logger = Logger:set_id(xid) + context.xid = xid + context.full_path = full_path + context.marks = {} + context.started_at = Time.now() +end + +--- TODO: Work item for "TODO implementation" +function Prompt.todo(_99) + _ = _99 + assert(false, "not implemented") +end + +function Prompt.vibe(_99, opts) + _ = _99 + _ = opts + assert(false, "not implemented") +end + +--- @param _99 _99.State +--- @return _99.Prompt +function Prompt.visual(_99) + _99:refresh_rules() + + local file_type = vim.bo[0].ft + local buffer = vim.api.nvim_get_current_buf() + file_type = filetype_map[file_type] or file_type + + local mds = {} + for _, md in ipairs(_99.md_files) do + table.insert(mds, md) + end + + --- @type _99.Prompt + local context = setmetatable({}, Prompt) + set_defaults(context, _99) + context.operation = "visual" + context.data = { + type = "visual", + buffer = buffer, + file_type = file_type, + range = Range.zero(), + } + context.logger:debug("99 Request", "method", "visual") + + return context +end + +--- @param _99 _99.State +--- @return _99.Prompt +function Prompt.tutorial(_99) + _99:refresh_rules() + + --- @type _99.Prompt + local context = setmetatable({}, Prompt) + set_defaults(context, _99) + context.operation = "tutorial" + context.data = { + type = "tutorial", + xid = context.xid, -- TODO: i want to get rid of this when i implement rehydration of the data. + buffer = 0, + window = 0, + tutorial = {}, + } + context.logger:debug("99 Request", "method", "tutorial") + + return context +end + +--- @param _99 _99.State +--- @return _99.Prompt +function Prompt.search(_99) + _99:refresh_rules() + + --- @type _99.Prompt + local context = setmetatable({}, Prompt) + set_defaults(context, _99) + context.operation = "search" + context.data = { + type = "search", + qfix_items = {}, + } + context.logger:debug("99 Request", "method", "search") + + return context +end + +--- @param obs _99.Providers.Observer | nil +function Prompt:_observer(obs) + return { + on_start = function() + self.state = "requesting" + self._99:track_request(self) + + if obs then + obs.on_start() + end + end, + on_complete = function(status, res) + self.state = status + if obs then + obs.on_complete(status, res) + end + end, + on_stderr = function(line) + if obs then + obs.on_stderr(line) + end + end, + on_stdout = function(line) + if obs then + obs.on_stdout(line) + end + end, + } +end + +--- @return boolean +function Prompt:valid() + local t = self.data.type + return t == "visual" or t == "search" or t == "tutorial" +end + +--- @param observer _99.Providers.Observer? +function Prompt:start_request(observer) + local l = self.logger + l:assert( + self.state == "ready", + 'state is not "ready" when attempting to start a request' + ) + + local ok = self:finalize() + l:assert(ok, "context failed to finalize") + + --- TODO: create a prompt context class that can actually organize. + --- do not do this during the request context refactoring, but next + local prompt = table.concat(self.agent_context, "\n") + local obs = self:_observer(observer) + local provider = self._99.provider_override or BaseProvider.OpenCodeProvider + + self:save_prompt(prompt) + l:debug("start", "prompt", prompt) + + provider:make_request(prompt, self, obs) +end + +function Prompt:is_cancelled() + return self.state == "cancelled" +end + +---@diagnostic disable-next-line: undefined-doc-name +--- @param proc vim.SystemObj? +function Prompt:_set_process(proc) + self._proc = proc +end + +function Prompt:cancel() + if self:is_cancelled() then + return + end + + self.logger:debug("cancel") + self.state = "cancelled" + local proc = self._proc + ---@diagnostic disable-next-line: undefined-field + if proc and proc.pid then + self._proc = nil + pcall(function() + local sigterm = (vim.uv and vim.uv.constants and vim.uv.constants.SIGTERM) + or 15 + ---@diagnostic disable-next-line: undefined-field + proc:kill(sigterm) + end) + end +end + +--- @return _99.Prompt.Data.Visual +function Prompt:visual_data() + assert( + self.data.type == "visual", + "you cannot get visual data if its not type visual" + ) + return self.data --[[@as _99.Prompt.Data.Visual]] +end + +--- @return _99.Prompt.Data.Tutorial +function Prompt:tutorial_data() + assert( + self.data.type == "tutorial", + "you cannot get tutorial data if its not type tutorial" + ) + return self.data --[[@as _99.Prompt.Data.Tutorial]] +end + +--- @return _99.Prompt.Data.Search +function Prompt:search_data() + assert( + self.data.type == "search", + "you cannot get search data if its not type search" + ) + return self.data --[[@as _99.Prompt.Data.Search]] +end + +function Prompt:stop() + self:cancel() + for _, cb in ipairs(self.clean_ups) do + cb() + end +end + +--- @param clean_up fun(): nil +function Prompt:add_clean_up(clean_up) + table.insert(self.clean_ups, clean_up) +end + +--- @param md_file_name string +--- @return self +function Prompt:add_md_file_name(md_file_name) + table.insert(self.md_file_names, md_file_name) + return self +end + +--- @param content string +--- @return self +function Prompt:add_prompt_content(content) + table.insert(self.agent_context, content) + return self +end + +--- @param refs _99.Reference[] +function Prompt:add_references(refs) + for _, ref in ipairs(refs) do + self.logger:debug("adding reference to context") + table.insert(self.agent_context, ref.content) + end +end + +function Prompt:_read_md_files() + local cwd = vim.uv.cwd() + local dir = vim.fn.fnamemodify(self.full_path, ":h") + + while dir:find(cwd, 1, true) == 1 do + for _, md_file_name in ipairs(self.md_file_names) do + local md_path = dir .. "/" .. md_file_name + local file = io.open(md_path, "r") + if file then + local content = file:read("*a") + file:close() + self.logger:info( + "Context#adding md file to the context", + "md_path", + md_path + ) + table.insert(self.agent_context, content) + end + end + + if dir == cwd then + break + end + + dir = vim.fn.fnamemodify(dir, ":h") + end +end + +--- @return string[] +function Prompt:content() + return self.agent_context +end + +--- @return boolean +function Prompt:_ready_request_files() + local response_file = self.tmp_file + local prompt_file = self.tmp_file .. "-prompt" + + local dir = vim.fs.dirname(prompt_file) + + if dir and not vim.uv.fs_stat(dir) then + vim.fn.mkdir(dir, "p") + end + + local files = { prompt_file, response_file } + for _, f in ipairs(files) do + local file = io.open(f, "w") + if file then + file:write("") + file:close() + else + self.logger:error("unable to create prompt file") + return false + end + end + return true +end + +--- @param prompt string +function Prompt:save_prompt(prompt) + local prompt_file = self.tmp_file .. "-prompt" + local file = io.open(prompt_file, "w") + if file then + file:write(prompt) + file:close() + self.logger:debug("saved prompt to file", "path", prompt_file) + else + self.logger:error("failed to save prompt", "path", prompt_file) + end +end + +--- @return boolean, self +function Prompt:finalize() + if self:_ready_request_files() == false then + return false, self + end + self:_read_md_files() + + local ok, visual_data = pcall(self.visual_data, self) + if ok then + local f_loc = + self._99.prompts.get_file_location(self.full_path, visual_data.range) + table.insert(self.agent_context, f_loc) + table.insert( + self.agent_context, + self._99.prompts.get_range_text(visual_data.range) + ) + end + table.insert( + self.agent_context, + self._99.prompts.tmp_file_location(self.tmp_file) + ) + return true, self +end + +function Prompt:clear_marks() + for _, mark in pairs(self.marks) do + mark:delete() + end +end + +return Prompt diff --git a/lua/99/providers.lua b/lua/99/providers.lua index 3a2c3ec..24c3d28 100644 --- a/lua/99/providers.lua +++ b/lua/99/providers.lua @@ -1,7 +1,7 @@ --- @class _99.Providers.Observer --- @field on_stdout fun(line: string): nil --- @field on_stderr fun(line: string): nil ---- @field on_complete fun(status: _99.Request.ResponseState, res: string): nil +--- @field on_complete fun(status: _99.Prompt.EndingState, res: string): nil --- @field on_start fun(): nil --- @param fn fun(...: any): nil @@ -18,7 +18,7 @@ local function once(fn) end --- @class _99.Providers.BaseProvider ---- @field _build_command fun(self: _99.Providers.BaseProvider, query: string, request: _99.Request): string[] +--- @field _build_command fun(self: _99.Providers.BaseProvider, query: string, context: _99.Prompt): string[] --- @field _get_provider_name fun(self: _99.Providers.BaseProvider): string local BaseProvider = {} @@ -27,10 +27,10 @@ function BaseProvider.fetch_models(callback) callback(nil, "This provider does not support listing models") end ---- @param request _99.Request -function BaseProvider:_retrieve_response(request) - local logger = request.logger:set_area(self:_get_provider_name()) - local tmp = request.context.tmp_file +--- @param context _99.Prompt +function BaseProvider:_retrieve_response(context) + local logger = context.logger:set_area(self:_get_provider_name()) + local tmp = context.tmp_file local success, result = pcall(function() return vim.fn.readfile(tmp) end) @@ -53,24 +53,23 @@ function BaseProvider:_retrieve_response(request) end --- @param query string ---- @param request _99.Request +--- @param context _99.Prompt --- @param observer _99.Providers.Observer -function BaseProvider:make_request(query, request, observer) +function BaseProvider:make_request(query, context, observer) observer.on_start() - local logger = request.logger:set_area(self:_get_provider_name()) - logger:debug("make_request", "tmp_file", request.context.tmp_file) + local logger = context.logger:set_area(self:_get_provider_name()) + logger:debug("make_request", "tmp_file", context.tmp_file) local once_complete = once( --- @param status "success" | "failed" | "cancelled" ---@param text string function(status, text) - request.state = status observer.on_complete(status, text) end ) - local command = self:_build_command(query, request) + local command = self:_build_command(query, context) logger:debug("make_request", "command", command) local proc = vim.system( @@ -79,7 +78,7 @@ function BaseProvider:make_request(query, request, observer) text = true, stdout = vim.schedule_wrap(function(err, data) logger:debug("stdout", "data", data) - if request:is_cancelled() then + if context:is_cancelled() then once_complete("cancelled", "") return end @@ -92,7 +91,7 @@ function BaseProvider:make_request(query, request, observer) end), stderr = vim.schedule_wrap(function(err, data) logger:debug("stderr", "data", data) - if request:is_cancelled() then + if context:is_cancelled() then once_complete("cancelled", "") return end @@ -105,7 +104,7 @@ function BaseProvider:make_request(query, request, observer) end), }, vim.schedule_wrap(function(obj) - if request:is_cancelled() then + if context:is_cancelled() then once_complete("cancelled", "") logger:debug("on_complete: request has been cancelled") return @@ -121,7 +120,7 @@ function BaseProvider:make_request(query, request, observer) ) else vim.schedule(function() - local ok, res = self:_retrieve_response(request) + local ok, res = self:_retrieve_response(context) if ok then once_complete("success", res) else @@ -135,23 +134,23 @@ function BaseProvider:make_request(query, request, observer) end) ) - request:_set_process(proc) + context:_set_process(proc) end --- @class OpenCodeProvider : _99.Providers.BaseProvider local OpenCodeProvider = setmetatable({}, { __index = BaseProvider }) --- @param query string ---- @param request _99.Request +--- @param context _99.Prompt --- @return string[] -function OpenCodeProvider._build_command(_, query, request) +function OpenCodeProvider._build_command(_, query, context) return { "opencode", "run", "--agent", "build", "-m", - request.context.model, + context.model, query, } end @@ -183,14 +182,14 @@ end local ClaudeCodeProvider = setmetatable({}, { __index = BaseProvider }) --- @param query string ---- @param request _99.Request +--- @param context _99.Prompt --- @return string[] -function ClaudeCodeProvider._build_command(_, query, request) +function ClaudeCodeProvider._build_command(_, query, context) return { "claude", "--dangerously-skip-permissions", "--model", - request.context.model, + context.model, "--print", query, } @@ -228,10 +227,10 @@ end local CursorAgentProvider = setmetatable({}, { __index = BaseProvider }) --- @param query string ---- @param request _99.Request +--- @param context _99.Prompt --- @return string[] -function CursorAgentProvider._build_command(_, query, request) - return { "cursor-agent", "--model", request.context.model, "--print", query } +function CursorAgentProvider._build_command(_, query, context) + return { "cursor-agent", "--model", context.model, "--print", query } end --- @return string @@ -269,15 +268,15 @@ end local KiroProvider = setmetatable({}, { __index = BaseProvider }) --- @param query string ---- @param request _99.Request +--- @param context _99.Prompt --- @return string[] -function KiroProvider._build_command(_, query, request) +function KiroProvider._build_command(_, query, context) return { "kiro-cli", "chat", "--no-interactive", "--model", - request.context.model, + context.model, "--trust-all-tools", query, } @@ -297,9 +296,9 @@ end local GeminiCLIProvider = setmetatable({}, { __index = BaseProvider }) --- @param query string ---- @param request _99.Request +--- @param context _99.Prompt --- @return string[] -function GeminiCLIProvider._build_command(_, query, request) +function GeminiCLIProvider._build_command(_, query, context) return { "gemini", "--approval-mode", @@ -307,7 +306,7 @@ function GeminiCLIProvider._build_command(_, query, request) -- https://geminicli.com/docs/core/policy-engine/#default-policies "auto_edit", "--model", - request.context.model, + context.model, "--prompt", query, } diff --git a/lua/99/request-context.lua b/lua/99/request-context.lua deleted file mode 100644 index 48d4489..0000000 --- a/lua/99/request-context.lua +++ /dev/null @@ -1,181 +0,0 @@ -local Logger = require("99.logger.logger") -local utils = require("99.utils") -local random_file = utils.random_file - ---- @class _99.RequestContext ---- @field md_file_names string[] ---- @field ai_context string[] ---- @field model string ---- @field tmp_file string ---- @field full_path string ---- @field buffer number ---- @field file_type string ---- @field marks table<string, _99.Mark> ---- @field logger _99.Logger ---- @field xid number ---- @field range _99.Range? ---- @field operation string? ---- @field clean_ups (fun(): nil)[] ---- @field _99 _99.State -local RequestContext = {} -RequestContext.__index = RequestContext - ---- @param _99 _99.State ---- @param xid number ---- @return _99.RequestContext -function RequestContext.from_current_buffer(_99, xid) - local buffer = vim.api.nvim_get_current_buf() - local full_path = vim.api.nvim_buf_get_name(buffer) - local file_type = vim.bo[buffer].ft - - if file_type == "typescriptreact" then - file_type = "typescript" - end - - local mds = {} - for _, md in ipairs(_99.md_files) do - table.insert(mds, md) - end - - local tmp_dir = _99.tmp_dir - if tmp_dir then - tmp_dir = vim.fn.expand(tmp_dir) - end - - return setmetatable({ - _99 = _99, - clean_ups = {}, - md_file_names = mds, - ai_context = {}, - tmp_file = random_file(tmp_dir), - buffer = buffer, - full_path = full_path, - file_type = file_type, - logger = Logger:set_id(xid), - xid = xid, - model = _99.model, - marks = {}, - }, RequestContext) -end - -function RequestContext:stop() - for _, cb in ipairs(self.clean_ups) do - cb() - end -end - ---- @param clean_up fun(): nil -function RequestContext:add_clean_up(clean_up) - table.insert(self.clean_ups, clean_up) -end - ---- @param md_file_name string ---- @return self -function RequestContext:add_md_file_name(md_file_name) - table.insert(self.md_file_names, md_file_name) - return self -end - ---- @param refs _99.Reference[] -function RequestContext:add_references(refs) - for _, ref in ipairs(refs) do - self.logger:debug("adding reference to context") - table.insert(self.ai_context, ref.content) - end -end - -function RequestContext:_read_md_files() - local cwd = vim.uv.cwd() - local dir = vim.fn.fnamemodify(self.full_path, ":h") - - while dir:find(cwd, 1, true) == 1 do - for _, md_file_name in ipairs(self.md_file_names) do - local md_path = dir .. "/" .. md_file_name - local file = io.open(md_path, "r") - if file then - local content = file:read("*a") - file:close() - self.logger:info( - "Context#adding md file to the context", - "md_path", - md_path - ) - table.insert(self.ai_context, content) - end - end - - if dir == cwd then - break - end - - dir = vim.fn.fnamemodify(dir, ":h") - end -end - ---- @return string[] -function RequestContext:content() - return self.ai_context -end - ---- @return boolean -function RequestContext:_ready_request_files() - local response_file = self.tmp_file - local prompt_file = self.tmp_file .. "-prompt" - - local dir = vim.fs.dirname(prompt_file) - - if dir and not vim.uv.fs_stat(dir) then - vim.fn.mkdir(dir, "p") - end - - local files = { prompt_file, response_file } - for _, f in ipairs(files) do - local file = io.open(f, "w") - if file then - file:write("") - file:close() - else - self.logger:error("unable to create prompt file") - return false - end - end - return true -end - ---- @param prompt string -function RequestContext:save_prompt(prompt) - local prompt_file = self.tmp_file .. "-prompt" - local file = io.open(prompt_file, "w") - if file then - file:write(prompt) - file:close() - self.logger:debug("saved prompt to file", "path", prompt_file) - else - self.logger:error("failed to save prompt", "path", prompt_file) - end -end - ---- @return boolean, self -function RequestContext:finalize() - if self:_ready_request_files() == false then - return false, self - end - self:_read_md_files() - if self.range then - table.insert(self.ai_context, self._99.prompts.get_file_location(self)) - table.insert(self.ai_context, self._99.prompts.get_range_text(self.range)) - end - table.insert( - self.ai_context, - self._99.prompts.tmp_file_location(self.tmp_file) - ) - return true, self -end - -function RequestContext:clear_marks() - for _, mark in pairs(self.marks) do - mark:delete() - end -end - -return RequestContext diff --git a/lua/99/request/init.lua b/lua/99/request/init.lua deleted file mode 100644 index 8080211..0000000 --- a/lua/99/request/init.lua +++ /dev/null @@ -1,138 +0,0 @@ ---- @alias _99.Request.ResponseState "failed" | "success" | "cancelled" ---- @alias _99.Request.State "ready" | "requesting" | _99.Request.ResponseState - -local Providers = require("99.providers") - ---- @class _99.Request.Opts ---- @field model string ---- @field tmp_file string ---- @field provider _99.Providers.BaseProvider? ---- @field xid number - ---- @class _99.Request.Config ---- @field model string ---- @field tmp_file string ---- @field provider _99.Providers.BaseProvider ---- @field xid number - ---- @class _99.Request ---- @field context _99.RequestContext ---- @field state _99.Request.State ---- @field provider _99.Providers.BaseProvider ---- @field logger _99.Logger ---- @field _content string[] ----@diagnostic disable-next-line: undefined-doc-name ---- @field _proc vim.SystemObj? -local Request = {} -Request.__index = Request - ---- @param context _99.RequestContext ---- @return _99.Request -function Request.new(context) - local provider = context._99.provider_override or Providers.OpenCodeProvider - return setmetatable({ - context = context, - provider = provider, - state = "ready", - logger = context.logger:set_area("Request"), - _content = {}, - _proc = nil, - }, Request) -end - ----@diagnostic disable-next-line: undefined-doc-name ---- @param proc vim.SystemObj? -function Request:_set_process(proc) - self._proc = proc -end - -function Request:cancel() - if self.state == "success" or self.state == "failed" then - return - end - - self.logger:debug("cancel") - self.state = "cancelled" - local proc = self._proc - ---@diagnostic disable-next-line: undefined-field - if proc and proc.pid then - self._proc = nil - pcall(function() - local sigterm = (vim.uv and vim.uv.constants and vim.uv.constants.SIGTERM) - or 15 - ---@diagnostic disable-next-line: undefined-field - proc:kill(sigterm) - end) - end -end - -function Request:is_cancelled() - return self.state == "cancelled" -end - ---- @param content string ---- @return self -function Request:add_prompt_content(content) - table.insert(self._content, content) - return self -end - ---- @param r _99.Request ---- @param obs _99.Providers.Observer | nil -local function observer_from_request(r, obs) - local context = r.context - return { - on_start = function() - r.state = "requesting" - context._99:track_request(context) - if obs then - obs.on_start() - end - end, - on_complete = function(status, res) - r.state = status - context._99:finish_request(context, status) - if obs then - obs.on_complete(status, res) - end - end, - on_stderr = function(line) - if obs then - obs.on_stderr(line) - end - end, - on_stdout = function(line) - if obs then - obs.on_stdout(line) - end - end, - } -end - ---- @param observer _99.Providers.Observer? -function Request:start(observer) - self.logger:assert( - self.state == "ready", - "request is not in state ready when attempting to start a request" - ) - local ok = self.context:finalize() - self.logger:assert( - ok, - "request has failed due to context finalization: check logs for more details" - ) - - for _, content in ipairs(self.context.ai_context) do - self:add_prompt_content(content) - end - local prompt = table.concat(self._content, "\n") - - self.context:save_prompt(prompt) - self.logger:debug("start", "prompt", prompt) - self.provider:make_request( - prompt, - self, - observer_from_request(self, observer) - ) -end - -return Request diff --git a/lua/99/test/providers_spec.lua b/lua/99/test/providers_spec.lua index 9c72d0e..5d0436c 100644 --- a/lua/99/test/providers_spec.lua +++ b/lua/99/test/providers_spec.lua @@ -5,7 +5,7 @@ local Providers = require("99.providers") describe("providers", function() describe("OpenCodeProvider", function() it("builds correct command with model", function() - local request = { context = { model = "anthropic/claude-sonnet-4-5" } } + local request = { model = "anthropic/claude-sonnet-4-5" } local cmd = Providers.OpenCodeProvider._build_command(nil, "test query", request) eq({ @@ -29,7 +29,7 @@ describe("providers", function() describe("ClaudeCodeProvider", function() it("builds correct command with model", function() - local request = { context = { model = "anthropic/claude-sonnet-4-5" } } + local request = { model = "anthropic/claude-sonnet-4-5" } local cmd = Providers.ClaudeCodeProvider._build_command(nil, "test query", request) eq({ @@ -49,7 +49,7 @@ describe("providers", function() describe("CursorAgentProvider", function() it("builds correct command with model", function() - local request = { context = { model = "anthropic/claude-sonnet-4-5" } } + local request = { model = "anthropic/claude-sonnet-4-5" } local cmd = Providers.CursorAgentProvider._build_command(nil, "test query", request) eq({ @@ -68,7 +68,7 @@ describe("providers", function() describe("GeminiCLIProvider", function() it("builds correct command with model", function() - local request = { context = { model = "gemini-2.5-pro" } } + local request = { model = "gemini-2.5-pro" } local cmd = Providers.GeminiCLIProvider._build_command(nil, "test query", request) eq({ diff --git a/lua/99/test/request_spec.lua b/lua/99/test/request_spec.lua index 3c58ecd..0565b51 100644 --- a/lua/99/test/request_spec.lua +++ b/lua/99/test/request_spec.lua @@ -13,22 +13,18 @@ describe("request test", function() it("should replace visual selection with AI response", function() local p = test_utils.test_setup(content, 2, 1, "lua") local state = _99.__get_state() - local Request = require("99.request") - local RequestContext = require("99.request-context") + local Prompt = require("99.prompt") - local context = RequestContext.from_current_buffer(state, 100) - context.operation = "test_request" + local context = Prompt.visual(state) context:finalize() - local request = Request.new(context) - local finished_called = false local finished_status = nil - eq("ready", request.state) + eq("ready", context.state) eq(0, state:active_request_count()) - request:start({ + context:start_request({ on_start = function() print("on_start") end, @@ -42,13 +38,13 @@ describe("request test", function() test_utils.next_frame() eq(1, state:active_request_count()) - eq("requesting", request.state) + eq("requesting", context.state) p:resolve("success", " return 'implemented!'") assert.is_true(finished_called) eq(0, state:active_request_count()) - eq("success", request.state) + eq("success", context.state) eq("success", finished_status) end) end) diff --git a/lua/99/test/test_utils.lua b/lua/99/test/test_utils.lua index a1459a4..e5484f1 100644 --- a/lua/99/test/test_utils.lua +++ b/lua/99/test/test_utils.lua @@ -24,7 +24,7 @@ M.created_files = {} --- @class _99.test.ProviderRequest --- @field query string ---- @field request _99.Request +--- @field prompt _99.Prompt --- @field observer _99.Providers.Observer --- @field logger _99.Logger @@ -38,29 +38,29 @@ function TestProvider.new() end --- @param query string ----@param request _99.Request +---@param prompt _99.Prompt ---@param observer _99.Providers.Observer? -function TestProvider:make_request(query, request, observer) - local logger = request.context.logger:set_area("TestProvider") - logger:debug("make_request", "tmp_file", request.context.tmp_file) +function TestProvider:make_request(query, prompt, observer) + local logger = prompt.logger:set_area("TestProvider") + logger:debug("make_request", "tmp_file", prompt.tmp_file) observer = observer or DevNullObserver observer.on_start() self.request = { query = query, - request = request, + prompt = prompt, observer = observer, logger = logger, } end ---- @param status _99.Request.ResponseState +--- @param status _99.Prompt.EndingState --- @param result string function TestProvider:resolve(status, result) assert(self.request, "you cannot call resolve until make_request is called") - if self.request.request:is_cancelled() then + if self.request.prompt:is_cancelled() then self.request.observer.on_complete("cancelled", result) else self.request.observer.on_complete(status, result) diff --git a/lua/99/test/visual_spec.lua b/lua/99/test/visual_spec.lua index 2100539..7f17458 100644 --- a/lua/99/test/visual_spec.lua +++ b/lua/99/test/visual_spec.lua @@ -49,9 +49,7 @@ describe("visual", function() local state = _99.__get_state() local visual_fn = require("99.ops.over-range") - local context = - require("99.request-context").from_current_buffer(state, 100) - context.operation = "test_op" + local context = require("99.prompt").visual(state) visual_fn(context, range, { additional_prompt = "test prompt", @@ -84,9 +82,7 @@ describe("visual", function() local state = _99.__get_state() local visual_fn = require("99.ops.over-range") - local context = - require("99.request-context").from_current_buffer(state, 200) - context.operation = "test_op" + local context = require("99.prompt").visual(state) visual_fn(context, range, { additional_prompt = "test prompt", }) @@ -112,9 +108,7 @@ describe("visual", function() local p, buffer, range = setup(content, 2, 1, 2, 23) local visual_fn = require("99.ops.over-range") local state = _99.__get_state() - local context = - require("99.request-context").from_current_buffer(state, 300) - context.operation = "test_op" + local context = require("99.prompt").visual(state, 300) visual_fn(context, range, { additional_prompt = "test prompt", @@ -122,14 +116,14 @@ describe("visual", function() eq(content, r(buffer)) - assert.is_false(p.request.request:is_cancelled()) + assert.is_false(p.request.prompt:is_cancelled()) assert.is_not_nil(p.request) - assert.is_not_nil(p.request.request) + assert.is_not_nil(p.request.prompt) _99.stop_all_requests() test_utils.next_frame() - assert.is_true(p.request.request:is_cancelled()) + assert.is_true(p.request.prompt:is_cancelled()) p:resolve("success", " return 'should not appear'") test_utils.next_frame() @@ -142,9 +136,7 @@ describe("visual", function() local p, buffer, range = setup(content, 2, 1, 2, 23) local visual_fn = require("99.ops.over-range") local state = _99.__get_state() - local context = - require("99.request-context").from_current_buffer(state, 400) - context.operation = "test_op" + local context = require("99.prompt").visual(state) visual_fn(context, range, { additional_prompt = "test prompt", @@ -163,9 +155,7 @@ describe("visual", function() local p, buffer, range = setup(content, 2, 1, 2, 23) local visual_fn = require("99.ops.over-range") local state = _99.__get_state() - local context = - require("99.request-context").from_current_buffer(state, 500) - context.operation = "test_op" + local context = require("99.prompt").visual(state) visual_fn(context, range, { additional_prompt = "test prompt", @@ -174,7 +164,7 @@ describe("visual", function() eq(content, r(buffer)) -- Manually cancel and resolve as cancelled - p.request.request:cancel() + p.request.prompt:cancel() p:resolve("cancelled", "Request was cancelled") test_utils.next_frame() diff --git a/lua/99/utils.lua b/lua/99/utils.lua index 0882489..dfe31eb 100644 --- a/lua/99/utils.lua +++ b/lua/99/utils.lua @@ -1,11 +1,28 @@ local M = {} ---- TODO: some people change their current working directory as they open new ---- directories. if this is still the case in neovim land, then we will need ---- to make the _99_state have the project directory. + +function M.copy(t) + assert(type(t) == "table", "passed in non table into table") + local out = {} + for k, v in pairs(t) do + out[k] = v + end + for i, v in ipairs(t) do + out[i] = v + end + return out +end + +--- @param dir string --- @return string function M.random_file(dir) - local directory = dir or (vim.uv.cwd() .. "/tmp") - return string.format("%s/99-%d", directory, math.floor(math.random() * 10000)) + return string.format("%s/99-%d", dir, math.floor(math.random() * 10000)) +end + +--- @param dir string +--- @param name string +--- @return string +function M.named_tmp_file(dir, name) + return string.format("%s/99-%s", dir, name) end return M diff --git a/scripts/tests/minimal.vim b/scripts/tests/minimal.vim index b960417..f7675ec 100644 --- a/scripts/tests/minimal.vim +++ b/scripts/tests/minimal.vim @@ -37,15 +37,10 @@ local required_parsers = { 'lua', 'typescript' } local function missing_parsers(parsers) local missing = {} local buf = vim.api.nvim_create_buf(false, true) - print('[minimal.vim] Checking for missing parsers...') for _, lang in ipairs(parsers) do - print('[minimal.vim] Checking parser for: ' .. lang) local ok, err = pcall(vim.treesitter.get_parser, buf, lang) if not ok then - print('[minimal.vim] Parser NOT found for ' .. lang .. ': ' .. tostring(err)) table.insert(missing, lang) - else - print('[minimal.vim] Parser FOUND for ' .. lang) end end vim.api.nvim_buf_delete(buf, { force = true }) |
