summaryrefslogtreecommitdiff
path: root/lua/99
diff options
context:
space:
mode:
Diffstat (limited to 'lua/99')
-rw-r--r--lua/99/init.lua23
-rw-r--r--lua/99/prompt.lua35
-rw-r--r--lua/99/state/tracking.lua45
-rw-r--r--lua/99/test/open_spec.lua5
-rw-r--r--lua/99/test/request_spec.lua6
-rw-r--r--lua/99/test/test_utils.lua1
-rw-r--r--lua/99/test/tracking_spec.lua116
-rw-r--r--lua/99/test/visual_spec.lua5
-rw-r--r--lua/99/utils.lua2
-rw-r--r--lua/99/window/status-window.lua6
10 files changed, 132 insertions, 112 deletions
diff --git a/lua/99/init.lua b/lua/99/init.lua
index 743bc75..3935dd4 100644
--- a/lua/99/init.lua
+++ b/lua/99/init.lua
@@ -283,16 +283,12 @@ function _99.open_tutorial(context)
end
function _99.open()
- local requests = _99_state.tracking.history
+ local requests = _99_state.tracking:successful()
local str_requests = {}
- for _, r in ipairs(requests) do
- if r.state == "success" then
- table.insert(str_requests, r:summary())
- end
- end
- for i = 1, #requests do
- str_requests[i] = string.format("%d: %s", i, requests[i]:summary())
+ for i, r in ipairs(requests) do
+ table.insert(str_requests, string.format("%d: %s", i, r:summary()))
end
+
Window.capture_select_input("99", {
content = str_requests,
cb = function(success, result)
@@ -301,10 +297,13 @@ function _99.open()
end
local idx = tonumber(vim.fn.matchstr(result, "^\\d\\+"))
+ if idx == nil then
+ return
+ end
local r = requests[idx]
if not r then
print(
- "somehow we have had a successful callback, but no request context... i honestly have no idea how we got here"
+ "request not found... potentially report bug: " .. vim.inspect(idx)
)
return
end
@@ -426,11 +425,7 @@ function _99.open_qfix_for_request(request)
end
function _99.stop_all_requests()
- for _, c in ipairs(_99_state.tracking.history) do
- if c.state == "requesting" then
- c:stop()
- end
- end
+ _99_state.tracking:stop_all_requests()
end
function _99.clear_previous_requests()
diff --git a/lua/99/prompt.lua b/lua/99/prompt.lua
index 7691244..96d4579 100644
--- a/lua/99/prompt.lua
+++ b/lua/99/prompt.lua
@@ -113,24 +113,24 @@ end
--- @param data _99.Prompt.Serialized
--- @return _99.Prompt
function Prompt.deserialize(_99, data)
- local prompt = setmetatable({
- _99 = _99,
- data = data.data,
- operation = data.data.type,
- user_prompt = data.user_prompt,
- started_at = Time.now(),
- xid = get_id(),
- }, Prompt)
- assert(prompt:valid(), "prompt is not valid from data")
- return prompt
+ local prompt = setmetatable({
+ _99 = _99,
+ data = data.data,
+ operation = data.data.type,
+ user_prompt = data.user_prompt,
+ started_at = Time.now(),
+ xid = get_id(),
+ }, Prompt)
+ assert(prompt:valid(), "prompt is not valid from data")
+ return prompt
end
--- @return _99.Prompt.Serialized
function Prompt:serialize()
- return {
- data = self.data,
- user_prompt = self.user_prompt,
- }
+ return {
+ data = self.data,
+ user_prompt = self.user_prompt,
+ }
end
--- @param _99 _99.State
@@ -304,6 +304,10 @@ function Prompt:is_cancelled()
return self.state == "cancelled"
end
+function Prompt:is_completed()
+ return self.state == "success" or self.state == "failed"
+end
+
---@diagnostic disable-next-line: undefined-doc-name
--- @param proc vim.SystemObj?
function Prompt:_set_process(proc)
@@ -311,11 +315,10 @@ function Prompt:_set_process(proc)
end
function Prompt:cancel()
- if self:is_cancelled() then
+ if self:is_cancelled() or self:is_completed() then
return
end
- self.logger:debug("cancel")
self.state = "cancelled"
local proc = self._proc
---@diagnostic disable-next-line: undefined-field
diff --git a/lua/99/state/tracking.lua b/lua/99/state/tracking.lua
index 83b23a0..bfeb430 100644
--- a/lua/99/state/tracking.lua
+++ b/lua/99/state/tracking.lua
@@ -75,7 +75,24 @@ function Tracking:clear_history()
self.history = keep
end
+function Tracking:stop_all_requests()
+ for _, r in pairs(self:active()) do
+ r:stop()
+ end
+end
+
+--- @return _99.Prompt[]
function Tracking:active()
+ local out = {}
+ for _, r in pairs(self.history) do
+ if r.state == "requesting" then
+ table.insert(out, r)
+ end
+ end
+ return out
+end
+
+function Tracking:active_count()
local count = 0
for _, r in pairs(self.history) do
if r.state == "requesting" then
@@ -97,6 +114,17 @@ function Tracking:request_by_type(type)
return out
end
+--- @return _99.Prompt[]
+function Tracking:successful()
+ local out = {}
+ for _, r in ipairs(self.history) do
+ if r.state == "success" then
+ table.insert(out, r)
+ end
+ end
+ return out
+end
+
--- @return _99.State.Tracking.Serialized
function Tracking:serialize()
local sc = Tracking.__config.serialize_count
@@ -111,12 +139,14 @@ function Tracking:serialize()
end
end
for op, _ in pairs(sc) do
+ all_requests[op] = all_requests[op] or {}
local r = all_requests[op]
table.sort(r, function(a, b)
return a.started_at > b.started_at
end)
end
+ --- @type _99.Prompt[]
local requests = {}
for op, max in pairs(sc) do
local count = 0
@@ -124,13 +154,21 @@ function Tracking:serialize()
if count >= max then
break
end
- table.insert(requests, request:serialize())
+ table.insert(requests, request)
count = count + 1
end
end
+ table.sort(requests, function(a, b)
+ return a.started_at > b.started_at
+ end)
+ local serialized = {}
+ for _, r in ipairs(requests) do
+ table.insert(serialized, r:serialize())
+ end
+
return {
- requests = requests,
+ requests = serialized,
}
end
@@ -148,6 +186,9 @@ function Tracking.setup(opts)
local config = Tracking.__config
local opts_sa = opts.serialize_counts
if opts_sa then
+ --- note to Agents: in lua 0 or 1 will result in 0. or only operates on nil or false
+ --- note to prime: i hate llms sometimes... like seriously.. know the language ya dork
+ --- so opts_sa.vibe or sa.vibe if opts_sa.vibe = 0, will yield 0. this is how lua works
local sa = config.serialize_count
sa.vibe = opts_sa.vibe or sa.vibe
sa.search = opts_sa.search or sa.search
diff --git a/lua/99/test/open_spec.lua b/lua/99/test/open_spec.lua
index a65d8e8..2dff972 100644
--- a/lua/99/test/open_spec.lua
+++ b/lua/99/test/open_spec.lua
@@ -75,6 +75,7 @@ describe("open", function()
local function select_content(idx)
Window.capture_select_input = function(_, opts)
+ print("capture_select_input", vim.inspect(opts.content), idx)
opts.cb(true, opts.content[idx])
end
end
@@ -84,6 +85,10 @@ describe("open", function()
local v = vibe()
local t = tutorial()
+ local history = _99:__get_state().tracking.history
+ for _, r in ipairs(history) do
+ print("history", r.state, r:summary())
+ end
select_content(1)
_99.open()
eq(QFixHelpers.create_qfix_entries(s), qfix_items())
diff --git a/lua/99/test/request_spec.lua b/lua/99/test/request_spec.lua
index 975a570..4bb4142 100644
--- a/lua/99/test/request_spec.lua
+++ b/lua/99/test/request_spec.lua
@@ -23,7 +23,7 @@ describe("request test", function()
eq("ready", context.state)
- eq(0, state:active_request_count())
+ eq(0, state.tracking:active_count())
context:start_request({
on_start = function()
print("on_start")
@@ -36,14 +36,14 @@ describe("request test", function()
on_stderr = function() end,
})
test_utils.next_frame()
- eq(1, state:active_request_count())
+ eq(1, state.tracking:active_count())
eq("requesting", context.state)
p:resolve("success", " return 'implemented!'")
assert.is_true(finished_called)
- eq(0, state:active_request_count())
+ eq(0, state.tracking:active_count())
eq("success", context.state)
eq("success", finished_status)
end)
diff --git a/lua/99/test/test_utils.lua b/lua/99/test/test_utils.lua
index 28fe1e4..554d54b 100644
--- a/lua/99/test/test_utils.lua
+++ b/lua/99/test/test_utils.lua
@@ -113,6 +113,7 @@ end
--- @return _99.Options
function M.get_test_setup_options(opts, provider)
opts = opts or {}
+ opts.tmp_dir = opts.tmp_dir or vim.fn.tempname()
opts.provider = provider
opts.logger = {
error_cache_level = Levels.ERROR,
diff --git a/lua/99/test/tracking_spec.lua b/lua/99/test/tracking_spec.lua
index 535f203..8d64660 100644
--- a/lua/99/test/tracking_spec.lua
+++ b/lua/99/test/tracking_spec.lua
@@ -1,86 +1,60 @@
-- luacheck: globals describe it assert
-local Prompt = require("99.prompt")
+local _99 = require("99")
local Tracking = require("99.state.tracking")
+local test_utils = require("99.test.test_utils")
local eq = assert.are.same
-local function data_for(operation)
- if operation == "tutorial" then
- return {
- type = "tutorial",
- xid = 1,
- buffer = 0,
- window = 0,
- tutorial = {},
- }
- end
-
- if operation == "visual" then
- return {
- type = "visual",
- buffer = 0,
- file_type = "lua",
- range = {
- start_row = 1,
- start_col = 1,
- end_row = 1,
- end_col = 1,
- },
- }
- end
-
- if operation == "vibe" then
- return {
- type = "vibe",
- response = "",
- qfix_items = {},
- }
- end
-
- return {
- type = "search",
- response = "",
- qfix_items = {},
- }
-end
-
-local function track_request(state, operation, started_at, status)
- local prompt = Prompt.deserialize(state, {
- user_prompt = string.format("%s-%d", operation, started_at),
- data = data_for(operation),
- })
- prompt.started_at = started_at
- prompt.state = status
- state.tracking:track(prompt)
+local function run(provider, operation, status, prompt)
+ _99[operation]({ additional_prompt = prompt })
+ provider:resolve(status, "result")
end
describe("tracking", function()
- it("serialize respects Tracking.serialized_counts", function()
- local state = {}
- state.tracking = Tracking.new(state, nil)
+ it("serializes requests based on configured counts", function()
+ local previous_counts = vim.deepcopy(Tracking.__config.serialize_count)
+ Tracking.setup({
+ serialize_counts = {
+ vibe = 1,
+ search = 1,
+ tutorial = 3,
+ visual = 0,
+ },
+ })
- local expected_total = 0
- local started_at = 0
- for operation, count in pairs(Tracking.serialize_counts) do
- expected_total = expected_total + count
- for _ = 1, count + 2 do
- started_at = started_at + 1
- track_request(state, operation, started_at, "success")
- end
+ local provider = test_utils.TestProvider.new()
+ _99.setup(test_utils.get_test_setup_options({
+ in_flight_options = { enable = false },
+ }, provider))
+ test_utils.create_file({ "local value = 1" }, "lua", 1, 0)
- started_at = started_at + 1
- track_request(state, operation, started_at, "failed")
- end
+ run(provider, "search", "success", "search one")
+ run(provider, "search", "success", "search two")
+ run(provider, "vibe", "success", "vibe one")
+ run(provider, "vibe", "success", "vibe two")
+ run(provider, "tutorial", "success", "tutorial one")
+ run(provider, "tutorial", "success", "tutorial two")
+ run(provider, "tutorial", "success", "tutorial three")
+ run(provider, "tutorial", "success", "tutorial four")
+ run(provider, "search", "failed", "search failed")
+
+ local serialized = _99.__get_state().tracking:serialize()
+ local actual_counts = {
+ search = 0,
+ vibe = 0,
+ tutorial = 0,
+ visual = 0,
+ }
- local serialized = state.tracking:serialize()
- local actual_counts = {}
for _, request in ipairs(serialized.requests) do
- local operation = request.data.type
- actual_counts[operation] = (actual_counts[operation] or 0) + 1
+ actual_counts[request.data.type] = actual_counts[request.data.type] + 1
end
- eq(expected_total, #serialized.requests)
- for operation, count in pairs(Tracking.serialize_counts) do
- eq(count, actual_counts[operation] or 0)
- end
+ eq(1, actual_counts.search)
+ eq(1, actual_counts.vibe)
+ eq(3, actual_counts.tutorial)
+ eq(0, actual_counts.visual)
+ eq(5, #serialized.requests)
+
+ Tracking.__config.serialize_count = previous_counts
end)
end)
diff --git a/lua/99/test/visual_spec.lua b/lua/99/test/visual_spec.lua
index 498c47f..5dab250 100644
--- a/lua/99/test/visual_spec.lua
+++ b/lua/99/test/visual_spec.lua
@@ -61,7 +61,7 @@ describe("visual", function()
visual_call_with_range(context, range)
- eq(1, state:active_request_count())
+ eq(1, state.tracking:active_count())
eq(content, r(buffer))
p:resolve("success", " return 'implemented!'")
@@ -74,6 +74,7 @@ describe("visual", function()
}
eq(expected_state, r(buffer))
-- Note: Not checking active_request_count() == 0 due to logger bug with "id" key collision
+ -- TODO: validate if this is true..
end)
it("should handle multi-line replacement", function()
@@ -90,7 +91,7 @@ describe("visual", function()
visual_call_with_range(context, range)
- eq(1, state:active_request_count())
+ eq(1, state.tracking:active_count())
eq(multi_line_content, r(buffer))
p:resolve("success", " local x = 1\n local y = 2\n return x + y")
diff --git a/lua/99/utils.lua b/lua/99/utils.lua
index 88abcea..1dcd09e 100644
--- a/lua/99/utils.lua
+++ b/lua/99/utils.lua
@@ -65,7 +65,7 @@ function M.write_file_json_safe(obj, path)
return
end
- local obj_str = ""
+ local obj_str
ok, obj_str = pcall(vim.json.encode, obj)
if not ok then
return
diff --git a/lua/99/window/status-window.lua b/lua/99/window/status-window.lua
index bea4857..c815e4a 100644
--- a/lua/99/window/status-window.lua
+++ b/lua/99/window/status-window.lua
@@ -80,7 +80,7 @@ function StatusWindow:_run_loop()
local active_window = Window.has_active_status_window()
local active_other_window = Window.has_active_windows()
- local active_requests = self._99.tracking:active_request_count()
+ local active_requests = self._99.tracking:active_count()
if
active_window == false and active_other_window
or active_window and active_requests > 0
@@ -98,7 +98,7 @@ function StatusWindow:_run_loop()
end
local throb = Throbber.new(function(throb)
- local count = self._99.tracking:active_request_count()
+ local count = self._99.tracking:active_count()
local win_valid = Window.valid(win)
if count == 0 or not win_valid then
@@ -110,7 +110,7 @@ function StatusWindow:_run_loop()
throb .. " requests(" .. tostring(count) .. ") " .. throb,
}
- for _, c in ipairs(self._99.tracking.history) do
+ for _, c in ipairs(self._99.tracking:active()) do
if c.state == "requesting" then
table.insert(lines, c.operation)
end