summaryrefslogtreecommitdiff
path: root/lua
diff options
context:
space:
mode:
authortheprimeagain <the.primeagen@gmail.com>2026-02-27 16:25:58 -0700
committertheprimeagain <the.primeagen@gmail.com>2026-02-27 16:25:58 -0700
commitad07a4dd1f00b651874e9e6d3a249731d223134d (patch)
tree71ac8b4088cf87c228dbd2b3a5c114c8f996cd6f /lua
parente16d69eedd3ed77e8998a908d21bc0d653102a27 (diff)
downloada4-ad07a4dd1f00b651874e9e6d3a249731d223134d.tar.xz
a4-ad07a4dd1f00b651874e9e6d3a249731d223134d.zip
tracking upgrade, rest of program not updated
Diffstat (limited to 'lua')
-rw-r--r--lua/99/init.lua7
-rw-r--r--lua/99/prompt.lua31
-rw-r--r--lua/99/state.lua94
-rw-r--r--lua/99/state/tracking.lua96
-rw-r--r--lua/99/test/prompt_spec.lua27
-rw-r--r--lua/99/utils.lua17
6 files changed, 181 insertions, 91 deletions
diff --git a/lua/99/init.lua b/lua/99/init.lua
index 2310784..895b369 100644
--- a/lua/99/init.lua
+++ b/lua/99/init.lua
@@ -431,13 +431,6 @@ function _99.stop_all_requests()
end
end
-function _99.clear_all_marks()
- for _, mark in ipairs(_99_state.__active_marks or {}) do
- mark:delete()
- end
- _99_state.__active_marks = {}
-end
-
function _99.clear_previous_requests()
_99_state:clear_history()
end
diff --git a/lua/99/prompt.lua b/lua/99/prompt.lua
index 23b2309..4f9fe38 100644
--- a/lua/99/prompt.lua
+++ b/lua/99/prompt.lua
@@ -28,6 +28,10 @@ local filetype_map = {
--- @alias _99.Prompt.State "ready" | "requesting" | _99.Prompt.EndingState
--- @alias _99.Prompt.Cleanup fun(): nil
+--- @class _99.Prompt.Serialized
+--- @field data _99.Prompt.Data
+--- @field user_prompt string
+
--- @class _99.Prompt.Data.Search
--- @field type "search"
--- @field qfix_items _99.Search.Result[]
@@ -51,7 +55,7 @@ local filetype_map = {
--- @field xid number TODO: we should probably get rid of this. The request pattern is not quite correct
--- @field tutorial string[]
---- @class _99.Prompt.Properties
+--- @class _99.Prompt
--- @field md_file_names string[]
--- @field model string
--- @field user_prompt string
@@ -65,8 +69,6 @@ local filetype_map = {
--- @field marks table<string, _99.Mark>
--- @field logger _99.Logger
--- @field xid number
-
---- @class _99.Prompt : _99.Prompt.Properties
--- @field clean_ups (fun(): nil)[]
--- @field _99 _99.State
---@diagnostic disable-next-line: undefined-doc-name
@@ -108,6 +110,29 @@ function Prompt.todo(_99)
end
--- @param _99 _99.State
+--- @param data _99.Prompt.Serialized
+--- @return _99.Prompt
+function Prompt.deserialize(_99, data)
+ local prompt = setmetatable({
+ _99 = _99,
+ data = data.data,
+ operation = data.data.type,
+ user_prompt = data.user_prompt,
+ xid = get_id(),
+ }, Prompt)
+ assert(prompt:valid(), "prompt is not valid from data")
+ return prompt
+end
+
+--- @return _99.Prompt.Serialized
+function Prompt:serialize()
+ return {
+ data = self.data,
+ user_prompt = self.user_prompt,
+ }
+end
+
+--- @param _99 _99.State
--- @return _99.Prompt
function Prompt.vibe(_99)
_99:refresh_rules()
diff --git a/lua/99/state.lua b/lua/99/state.lua
index 050a5c4..b3ff5fe 100644
--- a/lua/99/state.lua
+++ b/lua/99/state.lua
@@ -1,7 +1,9 @@
local utils = require("99.utils")
local Agents = require("99.extensions.agents")
local Extensions = require("99.extensions")
+local Tracking = require("99.state.tracking")
+local _99_STATE_FILE = "99-state"
local function default_completion()
return { source = nil, custom_rules = {} }
end
@@ -28,9 +30,7 @@ end
--- @field display_errors boolean
--- @field provider_override _99.Providers.BaseProvider?
--- @field rules _99.Agents.Rules
---- @field __request_history _99.Prompt[]
---- @field __request_by_id table<number, _99.Prompt>
---- @field __active_marks _99.Mark[]
+--- @field tracking _99.State.Tracking
--- @field __tmp_dir string | nil
local State = {}
State.__index = State
@@ -43,8 +43,6 @@ local function create()
ai_stdout_rows = 3,
display_errors = false,
provider_override = nil,
- __request_history = {},
- __request_by_id = {},
tmp_dir = nil,
}
end
@@ -63,21 +61,16 @@ end
--- @param opts _99.Options
--- @return _99.StateProps | nil
local function read_state_from_tmp(opts)
- local state_file = utils.named_tmp_file(get_tmp_dir(opts), "99-state")
- local fd = vim.uv.fs_open(state_file, "r", 438)
- if not fd then
- return nil
- end
+ local state_file = utils.named_tmp_file(get_tmp_dir(opts), _99_STATE_FILE)
return utils.read_file_json_safe(state_file) --[[@as _99.StateProps]]
end
--- @param opts _99.Options
--- @return _99.State
function State.new(opts)
- local props = read_state_from_tmp(opts) or create()
+ local props = create()
local _99_state = setmetatable(props, State) --[[@as _99.State]]
- _99_state.in_flight_options = opts.in_flight_options or { enable = true }
_99_state.provider_override = opts.provider
_99_state.completion = opts.completion or default_completion()
_99_state.completion.custom_rules = _99_state.completion.custom_rules or {}
@@ -86,10 +79,18 @@ function State.new(opts)
--- TODO: Prompt overrides would be a great thing, we just have to get there
--- for now, i am going to have this as just a hardcoded ... thing
_99_state.prompts = require("99.prompt-settings")
+ _99_state.tracking = Tracking.new(_99_state)
return _99_state
end
+function State:sync()
+ local tracking = self.tracking:serialize()
+ local tmp = self:tmp_dir()
+ local file = utils.named_tmp_file(tmp, _99_STATE_FILE)
+ utils.write_file_json_safe(tracking, file)
+end
+
--- @return string
function State:tmp_dir()
return get_tmp_dir(self)
@@ -109,73 +110,4 @@ function State:refresh_rules()
Extensions.refresh(self)
end
---- @param context _99.Prompt
-function State:track_prompt_request(context)
- assert(context:valid(), "context is not valid")
- table.insert(self.__request_history, context)
- self.__request_by_id[context.xid] = context
-end
-
---- @return number
-function State:completed_prompts()
- local count = 0
- for _, entry in ipairs(self.__request_history) do
- if entry.state ~= "requesting" then
- count = count + 1
- end
- end
- return count
-end
-
---- @return _99.Prompt[]
-function State:requests()
- return self.__request_history
-end
-
-function State:clear_history()
- local keep = {}
- for _, entry in ipairs(self.__request_history) do
- if entry.state == "requesting" then
- table.insert(keep, entry)
- else
- self.__request_by_id[entry.xid] = nil
- end
- end
- self.__request_history = keep
-end
-
---- @param mark _99.Mark
-function State:add_mark(mark)
- table.insert(self.__active_marks, mark)
-end
-
-function State:clear_marks()
- for _, active_mark in ipairs(self.__active_marks or {}) do
- active_mark:delete()
- end
- self.__active_marks = {}
-end
-
-function State:active_request_count()
- local count = 0
- for _, r in pairs(self.__request_history) do
- if r.state == "requesting" then
- count = count + 1
- end
- end
- return count
-end
-
---- @param type "search" | "visual" | "tutorial"
---- @return _99.Prompt[]
-function State:request_by_type(type)
- local out = {} --[[ @as _99.Prompt[] ]]
- for _, r in ipairs(self.__request_history) do
- if r.operation == type then
- table.insert(out, r)
- end
- end
- return out
-end
-
return State
diff --git a/lua/99/state/tracking.lua b/lua/99/state/tracking.lua
new file mode 100644
index 0000000..a0420e8
--- /dev/null
+++ b/lua/99/state/tracking.lua
@@ -0,0 +1,96 @@
+local Prompt = require("99.prompt")
+
+--- @class _99.State.Tracking.Serialized
+--- @field requests _99.Prompt.Serialized[]
+
+--- @class _99.State.Tracking
+--- @field history _99.Prompt[]
+--- @field id_to_request table<number, _99.Prompt>
+local Tracking = {}
+
+--- @param _99 _99.State
+--- @param previous_state _99.State.Tracking.Serialized | nil
+--- @return _99.State.Tracking
+function Tracking.new(_99, previous_state)
+ local tracking = setmetatable({}, Tracking) --[[ @as _99.State.Tracking]]
+
+ tracking.history = {}
+ tracking.id_to_request = {}
+
+ if not previous_state then
+ return tracking
+ end
+
+ for _, d in ipairs(previous_state.requests or {}) do
+ local prompt = Prompt.deserialize(_99, d)
+ table.insert(tracking.history, prompt)
+ tracking.id_to_request[prompt.xid] = prompt
+ end
+
+ return tracking
+end
+
+--- @param context _99.Prompt
+function Tracking:track_prompt_request(context)
+ assert(context:valid(), "context is not valid")
+ table.insert(self.history, context)
+ self.id_to_request[context.xid] = context
+end
+
+--- @return number
+function Tracking:completed_requests()
+ local count = 0
+ for _, entry in ipairs(self.history) do
+ if entry.state ~= "requesting" then
+ count = count + 1
+ end
+ end
+ return count
+end
+
+function Tracking:clear_history()
+ local keep = {}
+ for _, entry in ipairs(self.history) do
+ if entry.state == "requesting" then
+ table.insert(keep, entry)
+ else
+ self.id_to_request[entry.xid] = nil
+ end
+ end
+ self.history = keep
+end
+
+function Tracking:active_request_count()
+ local count = 0
+ for _, r in pairs(self.history) do
+ if r.state == "requesting" then
+ count = count + 1
+ end
+ end
+ return count
+end
+
+--- @param type "search" | "visual" | "tutorial"
+--- @return _99.Prompt[]
+function Tracking:request_by_type(type)
+ local out = {} --[[ @as _99.Prompt[] ]]
+ for _, r in ipairs(self.history) do
+ if r.operation == type then
+ table.insert(out, r)
+ end
+ end
+ return out
+end
+
+--- @return _99.State.Tracking.Serialized
+function Tracking:serialize()
+ local requests = {}
+ for _, r in ipairs(self.history) do
+ table.insert(requests, r:serialize())
+ end
+ return {
+ requests = requests,
+ }
+end
+
+return Tracking
diff --git a/lua/99/test/prompt_spec.lua b/lua/99/test/prompt_spec.lua
new file mode 100644
index 0000000..09de83e
--- /dev/null
+++ b/lua/99/test/prompt_spec.lua
@@ -0,0 +1,27 @@
+-- luacheck: globals describe it assert
+local _99 = require("99")
+local test_utils = require("99.test.test_utils")
+local Prompt = require("99.prompt")
+local eq = assert.are.same
+
+describe("prompt", function()
+ it("should deserialize a serialized prompt", function()
+ local provider = test_utils.TestProvider.new()
+ _99.setup(test_utils.get_test_setup_options({}, provider))
+
+ local state = _99.__get_state()
+ local prompt = Prompt.deserialize(state, {
+ user_prompt = "find important changes",
+ data = {
+ type = "search",
+ qfix_items = {},
+ response = "",
+ },
+ })
+
+ eq("search", prompt.operation)
+ eq("search", prompt.data.type)
+ eq("find important changes", prompt.user_prompt)
+ eq("search: find important changes", prompt:summary())
+ end)
+end)
diff --git a/lua/99/utils.lua b/lua/99/utils.lua
index 2a51793..88abcea 100644
--- a/lua/99/utils.lua
+++ b/lua/99/utils.lua
@@ -57,4 +57,21 @@ function M.read_file_json_safe(path)
end
end
+--- @param obj table
+---@param path string
+function M.write_file_json_safe(obj, path)
+ local ok, fh = pcall(io.open, path, "w")
+ if not ok or not fh then
+ return
+ end
+
+ local obj_str = ""
+ ok, obj_str = pcall(vim.json.encode, obj)
+ if not ok then
+ return
+ end
+
+ pcall(fh.write, fh, obj_str)
+end
+
return M