diff options
Diffstat (limited to 'lua/99')
| -rw-r--r-- | lua/99/init.lua | 23 | ||||
| -rw-r--r-- | lua/99/prompt.lua | 35 | ||||
| -rw-r--r-- | lua/99/state/tracking.lua | 45 | ||||
| -rw-r--r-- | lua/99/test/open_spec.lua | 5 | ||||
| -rw-r--r-- | lua/99/test/request_spec.lua | 6 | ||||
| -rw-r--r-- | lua/99/test/test_utils.lua | 1 | ||||
| -rw-r--r-- | lua/99/test/tracking_spec.lua | 116 | ||||
| -rw-r--r-- | lua/99/test/visual_spec.lua | 5 | ||||
| -rw-r--r-- | lua/99/utils.lua | 2 | ||||
| -rw-r--r-- | lua/99/window/status-window.lua | 6 |
10 files changed, 132 insertions, 112 deletions
diff --git a/lua/99/init.lua b/lua/99/init.lua index 743bc75..3935dd4 100644 --- a/lua/99/init.lua +++ b/lua/99/init.lua @@ -283,16 +283,12 @@ function _99.open_tutorial(context) end function _99.open() - local requests = _99_state.tracking.history + local requests = _99_state.tracking:successful() local str_requests = {} - for _, r in ipairs(requests) do - if r.state == "success" then - table.insert(str_requests, r:summary()) - end - end - for i = 1, #requests do - str_requests[i] = string.format("%d: %s", i, requests[i]:summary()) + for i, r in ipairs(requests) do + table.insert(str_requests, string.format("%d: %s", i, r:summary())) end + Window.capture_select_input("99", { content = str_requests, cb = function(success, result) @@ -301,10 +297,13 @@ function _99.open() end local idx = tonumber(vim.fn.matchstr(result, "^\\d\\+")) + if idx == nil then + return + end local r = requests[idx] if not r then print( - "somehow we have had a successful callback, but no request context... i honestly have no idea how we got here" + "request not found... potentially report bug: " .. vim.inspect(idx) ) return end @@ -426,11 +425,7 @@ function _99.open_qfix_for_request(request) end function _99.stop_all_requests() - for _, c in ipairs(_99_state.tracking.history) do - if c.state == "requesting" then - c:stop() - end - end + _99_state.tracking:stop_all_requests() end function _99.clear_previous_requests() diff --git a/lua/99/prompt.lua b/lua/99/prompt.lua index 7691244..96d4579 100644 --- a/lua/99/prompt.lua +++ b/lua/99/prompt.lua @@ -113,24 +113,24 @@ end --- @param data _99.Prompt.Serialized --- @return _99.Prompt function Prompt.deserialize(_99, data) - local prompt = setmetatable({ - _99 = _99, - data = data.data, - operation = data.data.type, - user_prompt = data.user_prompt, - started_at = Time.now(), - xid = get_id(), - }, Prompt) - assert(prompt:valid(), "prompt is not valid from data") - return prompt + local prompt = setmetatable({ + _99 = _99, + data = data.data, + operation = data.data.type, + user_prompt = data.user_prompt, + started_at = Time.now(), + xid = get_id(), + }, Prompt) + assert(prompt:valid(), "prompt is not valid from data") + return prompt end --- @return _99.Prompt.Serialized function Prompt:serialize() - return { - data = self.data, - user_prompt = self.user_prompt, - } + return { + data = self.data, + user_prompt = self.user_prompt, + } end --- @param _99 _99.State @@ -304,6 +304,10 @@ function Prompt:is_cancelled() return self.state == "cancelled" end +function Prompt:is_completed() + return self.state == "success" or self.state == "failed" +end + ---@diagnostic disable-next-line: undefined-doc-name --- @param proc vim.SystemObj? function Prompt:_set_process(proc) @@ -311,11 +315,10 @@ function Prompt:_set_process(proc) end function Prompt:cancel() - if self:is_cancelled() then + if self:is_cancelled() or self:is_completed() then return end - self.logger:debug("cancel") self.state = "cancelled" local proc = self._proc ---@diagnostic disable-next-line: undefined-field diff --git a/lua/99/state/tracking.lua b/lua/99/state/tracking.lua index 83b23a0..bfeb430 100644 --- a/lua/99/state/tracking.lua +++ b/lua/99/state/tracking.lua @@ -75,7 +75,24 @@ function Tracking:clear_history() self.history = keep end +function Tracking:stop_all_requests() + for _, r in pairs(self:active()) do + r:stop() + end +end + +--- @return _99.Prompt[] function Tracking:active() + local out = {} + for _, r in pairs(self.history) do + if r.state == "requesting" then + table.insert(out, r) + end + end + return out +end + +function Tracking:active_count() local count = 0 for _, r in pairs(self.history) do if r.state == "requesting" then @@ -97,6 +114,17 @@ function Tracking:request_by_type(type) return out end +--- @return _99.Prompt[] +function Tracking:successful() + local out = {} + for _, r in ipairs(self.history) do + if r.state == "success" then + table.insert(out, r) + end + end + return out +end + --- @return _99.State.Tracking.Serialized function Tracking:serialize() local sc = Tracking.__config.serialize_count @@ -111,12 +139,14 @@ function Tracking:serialize() end end for op, _ in pairs(sc) do + all_requests[op] = all_requests[op] or {} local r = all_requests[op] table.sort(r, function(a, b) return a.started_at > b.started_at end) end + --- @type _99.Prompt[] local requests = {} for op, max in pairs(sc) do local count = 0 @@ -124,13 +154,21 @@ function Tracking:serialize() if count >= max then break end - table.insert(requests, request:serialize()) + table.insert(requests, request) count = count + 1 end end + table.sort(requests, function(a, b) + return a.started_at > b.started_at + end) + local serialized = {} + for _, r in ipairs(requests) do + table.insert(serialized, r:serialize()) + end + return { - requests = requests, + requests = serialized, } end @@ -148,6 +186,9 @@ function Tracking.setup(opts) local config = Tracking.__config local opts_sa = opts.serialize_counts if opts_sa then + --- note to Agents: in lua 0 or 1 will result in 0. or only operates on nil or false + --- note to prime: i hate llms sometimes... like seriously.. know the language ya dork + --- so opts_sa.vibe or sa.vibe if opts_sa.vibe = 0, will yield 0. this is how lua works local sa = config.serialize_count sa.vibe = opts_sa.vibe or sa.vibe sa.search = opts_sa.search or sa.search diff --git a/lua/99/test/open_spec.lua b/lua/99/test/open_spec.lua index a65d8e8..2dff972 100644 --- a/lua/99/test/open_spec.lua +++ b/lua/99/test/open_spec.lua @@ -75,6 +75,7 @@ describe("open", function() local function select_content(idx) Window.capture_select_input = function(_, opts) + print("capture_select_input", vim.inspect(opts.content), idx) opts.cb(true, opts.content[idx]) end end @@ -84,6 +85,10 @@ describe("open", function() local v = vibe() local t = tutorial() + local history = _99:__get_state().tracking.history + for _, r in ipairs(history) do + print("history", r.state, r:summary()) + end select_content(1) _99.open() eq(QFixHelpers.create_qfix_entries(s), qfix_items()) diff --git a/lua/99/test/request_spec.lua b/lua/99/test/request_spec.lua index 975a570..4bb4142 100644 --- a/lua/99/test/request_spec.lua +++ b/lua/99/test/request_spec.lua @@ -23,7 +23,7 @@ describe("request test", function() eq("ready", context.state) - eq(0, state:active_request_count()) + eq(0, state.tracking:active_count()) context:start_request({ on_start = function() print("on_start") @@ -36,14 +36,14 @@ describe("request test", function() on_stderr = function() end, }) test_utils.next_frame() - eq(1, state:active_request_count()) + eq(1, state.tracking:active_count()) eq("requesting", context.state) p:resolve("success", " return 'implemented!'") assert.is_true(finished_called) - eq(0, state:active_request_count()) + eq(0, state.tracking:active_count()) eq("success", context.state) eq("success", finished_status) end) diff --git a/lua/99/test/test_utils.lua b/lua/99/test/test_utils.lua index 28fe1e4..554d54b 100644 --- a/lua/99/test/test_utils.lua +++ b/lua/99/test/test_utils.lua @@ -113,6 +113,7 @@ end --- @return _99.Options function M.get_test_setup_options(opts, provider) opts = opts or {} + opts.tmp_dir = opts.tmp_dir or vim.fn.tempname() opts.provider = provider opts.logger = { error_cache_level = Levels.ERROR, diff --git a/lua/99/test/tracking_spec.lua b/lua/99/test/tracking_spec.lua index 535f203..8d64660 100644 --- a/lua/99/test/tracking_spec.lua +++ b/lua/99/test/tracking_spec.lua @@ -1,86 +1,60 @@ -- luacheck: globals describe it assert -local Prompt = require("99.prompt") +local _99 = require("99") local Tracking = require("99.state.tracking") +local test_utils = require("99.test.test_utils") local eq = assert.are.same -local function data_for(operation) - if operation == "tutorial" then - return { - type = "tutorial", - xid = 1, - buffer = 0, - window = 0, - tutorial = {}, - } - end - - if operation == "visual" then - return { - type = "visual", - buffer = 0, - file_type = "lua", - range = { - start_row = 1, - start_col = 1, - end_row = 1, - end_col = 1, - }, - } - end - - if operation == "vibe" then - return { - type = "vibe", - response = "", - qfix_items = {}, - } - end - - return { - type = "search", - response = "", - qfix_items = {}, - } -end - -local function track_request(state, operation, started_at, status) - local prompt = Prompt.deserialize(state, { - user_prompt = string.format("%s-%d", operation, started_at), - data = data_for(operation), - }) - prompt.started_at = started_at - prompt.state = status - state.tracking:track(prompt) +local function run(provider, operation, status, prompt) + _99[operation]({ additional_prompt = prompt }) + provider:resolve(status, "result") end describe("tracking", function() - it("serialize respects Tracking.serialized_counts", function() - local state = {} - state.tracking = Tracking.new(state, nil) + it("serializes requests based on configured counts", function() + local previous_counts = vim.deepcopy(Tracking.__config.serialize_count) + Tracking.setup({ + serialize_counts = { + vibe = 1, + search = 1, + tutorial = 3, + visual = 0, + }, + }) - local expected_total = 0 - local started_at = 0 - for operation, count in pairs(Tracking.serialize_counts) do - expected_total = expected_total + count - for _ = 1, count + 2 do - started_at = started_at + 1 - track_request(state, operation, started_at, "success") - end + local provider = test_utils.TestProvider.new() + _99.setup(test_utils.get_test_setup_options({ + in_flight_options = { enable = false }, + }, provider)) + test_utils.create_file({ "local value = 1" }, "lua", 1, 0) - started_at = started_at + 1 - track_request(state, operation, started_at, "failed") - end + run(provider, "search", "success", "search one") + run(provider, "search", "success", "search two") + run(provider, "vibe", "success", "vibe one") + run(provider, "vibe", "success", "vibe two") + run(provider, "tutorial", "success", "tutorial one") + run(provider, "tutorial", "success", "tutorial two") + run(provider, "tutorial", "success", "tutorial three") + run(provider, "tutorial", "success", "tutorial four") + run(provider, "search", "failed", "search failed") + + local serialized = _99.__get_state().tracking:serialize() + local actual_counts = { + search = 0, + vibe = 0, + tutorial = 0, + visual = 0, + } - local serialized = state.tracking:serialize() - local actual_counts = {} for _, request in ipairs(serialized.requests) do - local operation = request.data.type - actual_counts[operation] = (actual_counts[operation] or 0) + 1 + actual_counts[request.data.type] = actual_counts[request.data.type] + 1 end - eq(expected_total, #serialized.requests) - for operation, count in pairs(Tracking.serialize_counts) do - eq(count, actual_counts[operation] or 0) - end + eq(1, actual_counts.search) + eq(1, actual_counts.vibe) + eq(3, actual_counts.tutorial) + eq(0, actual_counts.visual) + eq(5, #serialized.requests) + + Tracking.__config.serialize_count = previous_counts end) end) diff --git a/lua/99/test/visual_spec.lua b/lua/99/test/visual_spec.lua index 498c47f..5dab250 100644 --- a/lua/99/test/visual_spec.lua +++ b/lua/99/test/visual_spec.lua @@ -61,7 +61,7 @@ describe("visual", function() visual_call_with_range(context, range) - eq(1, state:active_request_count()) + eq(1, state.tracking:active_count()) eq(content, r(buffer)) p:resolve("success", " return 'implemented!'") @@ -74,6 +74,7 @@ describe("visual", function() } eq(expected_state, r(buffer)) -- Note: Not checking active_request_count() == 0 due to logger bug with "id" key collision + -- TODO: validate if this is true.. end) it("should handle multi-line replacement", function() @@ -90,7 +91,7 @@ describe("visual", function() visual_call_with_range(context, range) - eq(1, state:active_request_count()) + eq(1, state.tracking:active_count()) eq(multi_line_content, r(buffer)) p:resolve("success", " local x = 1\n local y = 2\n return x + y") diff --git a/lua/99/utils.lua b/lua/99/utils.lua index 88abcea..1dcd09e 100644 --- a/lua/99/utils.lua +++ b/lua/99/utils.lua @@ -65,7 +65,7 @@ function M.write_file_json_safe(obj, path) return end - local obj_str = "" + local obj_str ok, obj_str = pcall(vim.json.encode, obj) if not ok then return diff --git a/lua/99/window/status-window.lua b/lua/99/window/status-window.lua index bea4857..c815e4a 100644 --- a/lua/99/window/status-window.lua +++ b/lua/99/window/status-window.lua @@ -80,7 +80,7 @@ function StatusWindow:_run_loop() local active_window = Window.has_active_status_window() local active_other_window = Window.has_active_windows() - local active_requests = self._99.tracking:active_request_count() + local active_requests = self._99.tracking:active_count() if active_window == false and active_other_window or active_window and active_requests > 0 @@ -98,7 +98,7 @@ function StatusWindow:_run_loop() end local throb = Throbber.new(function(throb) - local count = self._99.tracking:active_request_count() + local count = self._99.tracking:active_count() local win_valid = Window.valid(win) if count == 0 or not win_valid then @@ -110,7 +110,7 @@ function StatusWindow:_run_loop() throb .. " requests(" .. tostring(count) .. ") " .. throb, } - for _, c in ipairs(self._99.tracking.history) do + for _, c in ipairs(self._99.tracking:active()) do if c.state == "requesting" then table.insert(lines, c.operation) end |
