diff options
| author | theprimeagain <the.primeagen@gmail.com> | 2026-02-27 16:25:58 -0700 |
|---|---|---|
| committer | theprimeagain <the.primeagen@gmail.com> | 2026-02-27 16:25:58 -0700 |
| commit | ad07a4dd1f00b651874e9e6d3a249731d223134d (patch) | |
| tree | 71ac8b4088cf87c228dbd2b3a5c114c8f996cd6f /lua | |
| parent | e16d69eedd3ed77e8998a908d21bc0d653102a27 (diff) | |
| download | a4-ad07a4dd1f00b651874e9e6d3a249731d223134d.tar.xz a4-ad07a4dd1f00b651874e9e6d3a249731d223134d.zip | |
tracking upgrade, rest of program not updated
Diffstat (limited to 'lua')
| -rw-r--r-- | lua/99/init.lua | 7 | ||||
| -rw-r--r-- | lua/99/prompt.lua | 31 | ||||
| -rw-r--r-- | lua/99/state.lua | 94 | ||||
| -rw-r--r-- | lua/99/state/tracking.lua | 96 | ||||
| -rw-r--r-- | lua/99/test/prompt_spec.lua | 27 | ||||
| -rw-r--r-- | lua/99/utils.lua | 17 |
6 files changed, 181 insertions, 91 deletions
diff --git a/lua/99/init.lua b/lua/99/init.lua index 2310784..895b369 100644 --- a/lua/99/init.lua +++ b/lua/99/init.lua @@ -431,13 +431,6 @@ function _99.stop_all_requests() end end -function _99.clear_all_marks() - for _, mark in ipairs(_99_state.__active_marks or {}) do - mark:delete() - end - _99_state.__active_marks = {} -end - function _99.clear_previous_requests() _99_state:clear_history() end diff --git a/lua/99/prompt.lua b/lua/99/prompt.lua index 23b2309..4f9fe38 100644 --- a/lua/99/prompt.lua +++ b/lua/99/prompt.lua @@ -28,6 +28,10 @@ local filetype_map = { --- @alias _99.Prompt.State "ready" | "requesting" | _99.Prompt.EndingState --- @alias _99.Prompt.Cleanup fun(): nil +--- @class _99.Prompt.Serialized +--- @field data _99.Prompt.Data +--- @field user_prompt string + --- @class _99.Prompt.Data.Search --- @field type "search" --- @field qfix_items _99.Search.Result[] @@ -51,7 +55,7 @@ local filetype_map = { --- @field xid number TODO: we should probably get rid of this. The request pattern is not quite correct --- @field tutorial string[] ---- @class _99.Prompt.Properties +--- @class _99.Prompt --- @field md_file_names string[] --- @field model string --- @field user_prompt string @@ -65,8 +69,6 @@ local filetype_map = { --- @field marks table<string, _99.Mark> --- @field logger _99.Logger --- @field xid number - ---- @class _99.Prompt : _99.Prompt.Properties --- @field clean_ups (fun(): nil)[] --- @field _99 _99.State ---@diagnostic disable-next-line: undefined-doc-name @@ -108,6 +110,29 @@ function Prompt.todo(_99) end --- @param _99 _99.State +--- @param data _99.Prompt.Serialized +--- @return _99.Prompt +function Prompt.deserialize(_99, data) + local prompt = setmetatable({ + _99 = _99, + data = data.data, + operation = data.data.type, + user_prompt = data.user_prompt, + xid = get_id(), + }, Prompt) + assert(prompt:valid(), "prompt is not valid from data") + return prompt +end + +--- @return _99.Prompt.Serialized +function Prompt:serialize() + return { + data = self.data, + user_prompt = self.user_prompt, + } +end + +--- @param _99 _99.State --- @return _99.Prompt function Prompt.vibe(_99) _99:refresh_rules() diff --git a/lua/99/state.lua b/lua/99/state.lua index 050a5c4..b3ff5fe 100644 --- a/lua/99/state.lua +++ b/lua/99/state.lua @@ -1,7 +1,9 @@ local utils = require("99.utils") local Agents = require("99.extensions.agents") local Extensions = require("99.extensions") +local Tracking = require("99.state.tracking") +local _99_STATE_FILE = "99-state" local function default_completion() return { source = nil, custom_rules = {} } end @@ -28,9 +30,7 @@ end --- @field display_errors boolean --- @field provider_override _99.Providers.BaseProvider? --- @field rules _99.Agents.Rules ---- @field __request_history _99.Prompt[] ---- @field __request_by_id table<number, _99.Prompt> ---- @field __active_marks _99.Mark[] +--- @field tracking _99.State.Tracking --- @field __tmp_dir string | nil local State = {} State.__index = State @@ -43,8 +43,6 @@ local function create() ai_stdout_rows = 3, display_errors = false, provider_override = nil, - __request_history = {}, - __request_by_id = {}, tmp_dir = nil, } end @@ -63,21 +61,16 @@ end --- @param opts _99.Options --- @return _99.StateProps | nil local function read_state_from_tmp(opts) - local state_file = utils.named_tmp_file(get_tmp_dir(opts), "99-state") - local fd = vim.uv.fs_open(state_file, "r", 438) - if not fd then - return nil - end + local state_file = utils.named_tmp_file(get_tmp_dir(opts), _99_STATE_FILE) return utils.read_file_json_safe(state_file) --[[@as _99.StateProps]] end --- @param opts _99.Options --- @return _99.State function State.new(opts) - local props = read_state_from_tmp(opts) or create() + local props = create() local _99_state = setmetatable(props, State) --[[@as _99.State]] - _99_state.in_flight_options = opts.in_flight_options or { enable = true } _99_state.provider_override = opts.provider _99_state.completion = opts.completion or default_completion() _99_state.completion.custom_rules = _99_state.completion.custom_rules or {} @@ -86,10 +79,18 @@ function State.new(opts) --- TODO: Prompt overrides would be a great thing, we just have to get there --- for now, i am going to have this as just a hardcoded ... thing _99_state.prompts = require("99.prompt-settings") + _99_state.tracking = Tracking.new(_99_state) return _99_state end +function State:sync() + local tracking = self.tracking:serialize() + local tmp = self:tmp_dir() + local file = utils.named_tmp_file(tmp, _99_STATE_FILE) + utils.write_file_json_safe(tracking, file) +end + --- @return string function State:tmp_dir() return get_tmp_dir(self) @@ -109,73 +110,4 @@ function State:refresh_rules() Extensions.refresh(self) end ---- @param context _99.Prompt -function State:track_prompt_request(context) - assert(context:valid(), "context is not valid") - table.insert(self.__request_history, context) - self.__request_by_id[context.xid] = context -end - ---- @return number -function State:completed_prompts() - local count = 0 - for _, entry in ipairs(self.__request_history) do - if entry.state ~= "requesting" then - count = count + 1 - end - end - return count -end - ---- @return _99.Prompt[] -function State:requests() - return self.__request_history -end - -function State:clear_history() - local keep = {} - for _, entry in ipairs(self.__request_history) do - if entry.state == "requesting" then - table.insert(keep, entry) - else - self.__request_by_id[entry.xid] = nil - end - end - self.__request_history = keep -end - ---- @param mark _99.Mark -function State:add_mark(mark) - table.insert(self.__active_marks, mark) -end - -function State:clear_marks() - for _, active_mark in ipairs(self.__active_marks or {}) do - active_mark:delete() - end - self.__active_marks = {} -end - -function State:active_request_count() - local count = 0 - for _, r in pairs(self.__request_history) do - if r.state == "requesting" then - count = count + 1 - end - end - return count -end - ---- @param type "search" | "visual" | "tutorial" ---- @return _99.Prompt[] -function State:request_by_type(type) - local out = {} --[[ @as _99.Prompt[] ]] - for _, r in ipairs(self.__request_history) do - if r.operation == type then - table.insert(out, r) - end - end - return out -end - return State diff --git a/lua/99/state/tracking.lua b/lua/99/state/tracking.lua new file mode 100644 index 0000000..a0420e8 --- /dev/null +++ b/lua/99/state/tracking.lua @@ -0,0 +1,96 @@ +local Prompt = require("99.prompt") + +--- @class _99.State.Tracking.Serialized +--- @field requests _99.Prompt.Serialized[] + +--- @class _99.State.Tracking +--- @field history _99.Prompt[] +--- @field id_to_request table<number, _99.Prompt> +local Tracking = {} + +--- @param _99 _99.State +--- @param previous_state _99.State.Tracking.Serialized | nil +--- @return _99.State.Tracking +function Tracking.new(_99, previous_state) + local tracking = setmetatable({}, Tracking) --[[ @as _99.State.Tracking]] + + tracking.history = {} + tracking.id_to_request = {} + + if not previous_state then + return tracking + end + + for _, d in ipairs(previous_state.requests or {}) do + local prompt = Prompt.deserialize(_99, d) + table.insert(tracking.history, prompt) + tracking.id_to_request[prompt.xid] = prompt + end + + return tracking +end + +--- @param context _99.Prompt +function Tracking:track_prompt_request(context) + assert(context:valid(), "context is not valid") + table.insert(self.history, context) + self.id_to_request[context.xid] = context +end + +--- @return number +function Tracking:completed_requests() + local count = 0 + for _, entry in ipairs(self.history) do + if entry.state ~= "requesting" then + count = count + 1 + end + end + return count +end + +function Tracking:clear_history() + local keep = {} + for _, entry in ipairs(self.history) do + if entry.state == "requesting" then + table.insert(keep, entry) + else + self.id_to_request[entry.xid] = nil + end + end + self.history = keep +end + +function Tracking:active_request_count() + local count = 0 + for _, r in pairs(self.history) do + if r.state == "requesting" then + count = count + 1 + end + end + return count +end + +--- @param type "search" | "visual" | "tutorial" +--- @return _99.Prompt[] +function Tracking:request_by_type(type) + local out = {} --[[ @as _99.Prompt[] ]] + for _, r in ipairs(self.history) do + if r.operation == type then + table.insert(out, r) + end + end + return out +end + +--- @return _99.State.Tracking.Serialized +function Tracking:serialize() + local requests = {} + for _, r in ipairs(self.history) do + table.insert(requests, r:serialize()) + end + return { + requests = requests, + } +end + +return Tracking diff --git a/lua/99/test/prompt_spec.lua b/lua/99/test/prompt_spec.lua new file mode 100644 index 0000000..09de83e --- /dev/null +++ b/lua/99/test/prompt_spec.lua @@ -0,0 +1,27 @@ +-- luacheck: globals describe it assert +local _99 = require("99") +local test_utils = require("99.test.test_utils") +local Prompt = require("99.prompt") +local eq = assert.are.same + +describe("prompt", function() + it("should deserialize a serialized prompt", function() + local provider = test_utils.TestProvider.new() + _99.setup(test_utils.get_test_setup_options({}, provider)) + + local state = _99.__get_state() + local prompt = Prompt.deserialize(state, { + user_prompt = "find important changes", + data = { + type = "search", + qfix_items = {}, + response = "", + }, + }) + + eq("search", prompt.operation) + eq("search", prompt.data.type) + eq("find important changes", prompt.user_prompt) + eq("search: find important changes", prompt:summary()) + end) +end) diff --git a/lua/99/utils.lua b/lua/99/utils.lua index 2a51793..88abcea 100644 --- a/lua/99/utils.lua +++ b/lua/99/utils.lua @@ -57,4 +57,21 @@ function M.read_file_json_safe(path) end end +--- @param obj table +---@param path string +function M.write_file_json_safe(obj, path) + local ok, fh = pcall(io.open, path, "w") + if not ok or not fh then + return + end + + local obj_str = "" + ok, obj_str = pcall(vim.json.encode, obj) + if not ok then + return + end + + pcall(fh.write, fh, obj_str) +end + return M |
