summaryrefslogtreecommitdiff
path: root/lua/99/state/tracking.lua
blob: bfeb430ce97dcd812d5d89097a3b006405d9b3b3 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
local Prompt = require("99.prompt")

--- @class _99.State.Tracking.Serialized
--- @field requests _99.Prompt.Serialized[]

--- @class _99.State.Tracking.Config.Options.Counts
--- @field vibe number | nil
--- @field search number | nil
--- @field tutorial number | nil
--- @field visual number | nil
---
--- @class _99.State.Tracking.Config.Options
--- @field serialize_counts _99.State.Tracking.Config.Options.Counts | nil

--- @class _99.State.Tracking.Config
--- @field serialize_counts table<_99.Prompt.Operation, number>

--- @class _99.State.Tracking
--- @docs base
--- @field history _99.Prompt[]
--- @field id_to_request table<number, _99.Prompt>
--- @field setup fun(opts: _99.State.Tracking.Config.Options): nil
local Tracking = {}
Tracking.__index = Tracking

--- @param _99 _99.State
--- @param previous_state _99.State.Tracking.Serialized | nil
--- @return _99.State.Tracking
function Tracking.new(_99, previous_state)
  local tracking = setmetatable({}, Tracking) --[[ @as _99.State.Tracking]]

  tracking.history = {}
  tracking.id_to_request = {}

  if not previous_state then
    return tracking
  end

  for _, d in ipairs(previous_state.requests or {}) do
    local prompt = Prompt.deserialize(_99, d)
    table.insert(tracking.history, prompt)
    tracking.id_to_request[prompt.xid] = prompt
  end

  return tracking
end

--- @param context _99.Prompt
function Tracking:track(context)
  assert(context:valid(), "context is not valid")
  table.insert(self.history, context)
  self.id_to_request[context.xid] = context
end

--- @return number
function Tracking:completed()
  local count = 0
  for _, entry in ipairs(self.history) do
    if entry.state ~= "requesting" then
      count = count + 1
    end
  end
  return count
end

function Tracking:clear_history()
  local keep = {}
  for _, entry in ipairs(self.history) do
    if entry.state == "requesting" then
      table.insert(keep, entry)
    else
      self.id_to_request[entry.xid] = nil
    end
  end
  self.history = keep
end

function Tracking:stop_all_requests()
  for _, r in pairs(self:active()) do
    r:stop()
  end
end

--- @return _99.Prompt[]
function Tracking:active()
  local out = {}
  for _, r in pairs(self.history) do
    if r.state == "requesting" then
      table.insert(out, r)
    end
  end
  return out
end

function Tracking:active_count()
  local count = 0
  for _, r in pairs(self.history) do
    if r.state == "requesting" then
      count = count + 1
    end
  end
  return count
end

--- @param type "search" | "visual" | "tutorial"
--- @return _99.Prompt[]
function Tracking:request_by_type(type)
  local out = {} --[[ @as _99.Prompt[] ]]
  for _, r in ipairs(self.history) do
    if r.operation == type then
      table.insert(out, r)
    end
  end
  return out
end

--- @return _99.Prompt[]
function Tracking:successful()
  local out = {}
  for _, r in ipairs(self.history) do
    if r.state == "success" then
      table.insert(out, r)
    end
  end
  return out
end

--- @return _99.State.Tracking.Serialized
function Tracking:serialize()
  local sc = Tracking.__config.serialize_count

  --- @type table<_99.Prompt.Operation, _99.Prompt[]>
  local all_requests = {}
  for _, r in ipairs(self.history) do
    local op = r.operation
    all_requests[op] = all_requests[op] or {}
    if r.state == "success" and sc[op] > 0 then
      table.insert(all_requests[op], r)
    end
  end
  for op, _ in pairs(sc) do
    all_requests[op] = all_requests[op] or {}
    local r = all_requests[op]
    table.sort(r, function(a, b)
      return a.started_at > b.started_at
    end)
  end

  --- @type _99.Prompt[]
  local requests = {}
  for op, max in pairs(sc) do
    local count = 0
    for _, request in ipairs(all_requests[op] or {}) do
      if count >= max then
        break
      end
      table.insert(requests, request)
      count = count + 1
    end
  end

  table.sort(requests, function(a, b)
    return a.started_at > b.started_at
  end)
  local serialized = {}
  for _, r in ipairs(requests) do
    table.insert(serialized, r:serialize())
  end

  return {
    requests = serialized,
  }
end

Tracking.__config = {
  serialize_count = {
    vibe = 1,
    search = 1,
    tutorial = 3,
    visual = 0,
  },
}

--- @param opts _99.State.Tracking.Config.Options
function Tracking.setup(opts)
  local config = Tracking.__config
  local opts_sa = opts.serialize_counts
  if opts_sa then
    --- note to Agents: in lua 0 or 1 will result in 0.  or only operates on nil or false
    --- note to prime: i hate llms sometimes... like seriously.. know the language ya dork
    --- so opts_sa.vibe or sa.vibe if opts_sa.vibe = 0, will yield 0. this is how lua works
    local sa = config.serialize_count
    sa.vibe = opts_sa.vibe or sa.vibe
    sa.search = opts_sa.search or sa.search
    sa.tutorial = opts_sa.tutorial or sa.tutorial
    sa.visual = opts_sa.visual or sa.visual
  end
end

return Tracking