summaryrefslogtreecommitdiff
path: root/lua/99/state/tracking.lua
diff options
context:
space:
mode:
Diffstat (limited to 'lua/99/state/tracking.lua')
-rw-r--r--lua/99/state/tracking.lua69
1 files changed, 66 insertions, 3 deletions
diff --git a/lua/99/state/tracking.lua b/lua/99/state/tracking.lua
index 0e3da08..83b23a0 100644
--- a/lua/99/state/tracking.lua
+++ b/lua/99/state/tracking.lua
@@ -3,10 +3,25 @@ local Prompt = require("99.prompt")
--- @class _99.State.Tracking.Serialized
--- @field requests _99.Prompt.Serialized[]
+--- @class _99.State.Tracking.Config.Options.Counts
+--- @field vibe number | nil
+--- @field search number | nil
+--- @field tutorial number | nil
+--- @field visual number | nil
+---
+--- @class _99.State.Tracking.Config.Options
+--- @field serialize_counts _99.State.Tracking.Config.Options.Counts | nil
+
+--- @class _99.State.Tracking.Config
+--- @field serialize_counts table<_99.Prompt.Operation, number>
+
--- @class _99.State.Tracking
+--- @docs base
--- @field history _99.Prompt[]
--- @field id_to_request table<number, _99.Prompt>
+--- @field setup fun(opts: _99.State.Tracking.Config.Options): nil
local Tracking = {}
+Tracking.__index = Tracking
--- @param _99 _99.State
--- @param previous_state _99.State.Tracking.Serialized | nil
@@ -60,7 +75,7 @@ function Tracking:clear_history()
self.history = keep
end
-function Tracking:active_request_count()
+function Tracking:active()
local count = 0
for _, r in pairs(self.history) do
if r.state == "requesting" then
@@ -84,13 +99,61 @@ end
--- @return _99.State.Tracking.Serialized
function Tracking:serialize()
- local requests = {}
+ local sc = Tracking.__config.serialize_count
+
+ --- @type table<_99.Prompt.Operation, _99.Prompt[]>
+ local all_requests = {}
for _, r in ipairs(self.history) do
- table.insert(requests, r:serialize())
+ local op = r.operation
+ all_requests[op] = all_requests[op] or {}
+ if r.state == "success" and sc[op] > 0 then
+ table.insert(all_requests[op], r)
+ end
+ end
+ for op, _ in pairs(sc) do
+ local r = all_requests[op]
+ table.sort(r, function(a, b)
+ return a.started_at > b.started_at
+ end)
+ end
+
+ local requests = {}
+ for op, max in pairs(sc) do
+ local count = 0
+ for _, request in ipairs(all_requests[op] or {}) do
+ if count >= max then
+ break
+ end
+ table.insert(requests, request:serialize())
+ count = count + 1
+ end
end
+
return {
requests = requests,
}
end
+Tracking.__config = {
+ serialize_count = {
+ vibe = 1,
+ search = 1,
+ tutorial = 3,
+ visual = 0,
+ },
+}
+
+--- @param opts _99.State.Tracking.Config.Options
+function Tracking.setup(opts)
+ local config = Tracking.__config
+ local opts_sa = opts.serialize_counts
+ if opts_sa then
+ local sa = config.serialize_count
+ sa.vibe = opts_sa.vibe or sa.vibe
+ sa.search = opts_sa.search or sa.search
+ sa.tutorial = opts_sa.tutorial or sa.tutorial
+ sa.visual = opts_sa.visual or sa.visual
+ end
+end
+
return Tracking