diff options
Diffstat (limited to 'lua/99/state/tracking.lua')
| -rw-r--r-- | lua/99/state/tracking.lua | 69 |
1 files changed, 66 insertions, 3 deletions
diff --git a/lua/99/state/tracking.lua b/lua/99/state/tracking.lua index 0e3da08..83b23a0 100644 --- a/lua/99/state/tracking.lua +++ b/lua/99/state/tracking.lua @@ -3,10 +3,25 @@ local Prompt = require("99.prompt") --- @class _99.State.Tracking.Serialized --- @field requests _99.Prompt.Serialized[] +--- @class _99.State.Tracking.Config.Options.Counts +--- @field vibe number | nil +--- @field search number | nil +--- @field tutorial number | nil +--- @field visual number | nil +--- +--- @class _99.State.Tracking.Config.Options +--- @field serialize_counts _99.State.Tracking.Config.Options.Counts | nil + +--- @class _99.State.Tracking.Config +--- @field serialize_counts table<_99.Prompt.Operation, number> + --- @class _99.State.Tracking +--- @docs base --- @field history _99.Prompt[] --- @field id_to_request table<number, _99.Prompt> +--- @field setup fun(opts: _99.State.Tracking.Config.Options): nil local Tracking = {} +Tracking.__index = Tracking --- @param _99 _99.State --- @param previous_state _99.State.Tracking.Serialized | nil @@ -60,7 +75,7 @@ function Tracking:clear_history() self.history = keep end -function Tracking:active_request_count() +function Tracking:active() local count = 0 for _, r in pairs(self.history) do if r.state == "requesting" then @@ -84,13 +99,61 @@ end --- @return _99.State.Tracking.Serialized function Tracking:serialize() - local requests = {} + local sc = Tracking.__config.serialize_count + + --- @type table<_99.Prompt.Operation, _99.Prompt[]> + local all_requests = {} for _, r in ipairs(self.history) do - table.insert(requests, r:serialize()) + local op = r.operation + all_requests[op] = all_requests[op] or {} + if r.state == "success" and sc[op] > 0 then + table.insert(all_requests[op], r) + end + end + for op, _ in pairs(sc) do + local r = all_requests[op] + table.sort(r, function(a, b) + return a.started_at > b.started_at + end) + end + + local requests = {} + for op, max in pairs(sc) do + local count = 0 + for _, request in ipairs(all_requests[op] or {}) do + if count >= max then + break + end + table.insert(requests, request:serialize()) + count = count + 1 + end end + return { requests = requests, } end +Tracking.__config = { + serialize_count = { + vibe = 1, + search = 1, + tutorial = 3, + visual = 0, + }, +} + +--- @param opts _99.State.Tracking.Config.Options +function Tracking.setup(opts) + local config = Tracking.__config + local opts_sa = opts.serialize_counts + if opts_sa then + local sa = config.serialize_count + sa.vibe = opts_sa.vibe or sa.vibe + sa.search = opts_sa.search or sa.search + sa.tutorial = opts_sa.tutorial or sa.tutorial + sa.visual = opts_sa.visual or sa.visual + end +end + return Tracking |
