summaryrefslogtreecommitdiff
path: root/lua/99/extensions/agents/init.lua
blob: b0af7f86f2e856b03ae5fe2001b447d341553aa6 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
local helpers = require("99.extensions.agents.helpers")
local M = {}

--- @class _99.Agents.Rule
--- @docs included
--- @field name string
--- @field path string
--- @field absolute_path string?

--- @class _99.Agents.Rules
--- @field custom _99.Agents.Rule[]
--- @field by_name table<string, _99.Agents.Rule[]>

--- @class _99.Agents.Agent
--- @field rules _99.Agents.Rules

--- @param map table<string, _99.Agents.Rule[]>
--- @param rules _99.Agents.Rule[]
local function add_rule_by_name(map, rules)
  for _, r in ipairs(rules) do
    if map[r.name] == nil then
      map[r.name] = {}
    end
    table.insert(map[r.name], r)
  end
end

--- @param _99 _99.State
--- @return _99.Agents.Rules
function M.rules(_99)
  local custom = {}
  for _, path in ipairs(_99.completion.custom_rules or {}) do
    local custom_rules = helpers.ls(path)
    for _, r in ipairs(custom_rules) do
      table.insert(custom, r)
    end
  end

  local by_name = {}
  add_rule_by_name(by_name, custom)
  return {
    by_name = by_name,
    custom = custom,
  }
end

--- @param rules _99.Agents.Rules
--- @return _99.Agents.Rule[]
function M.rules_to_items(rules)
  local items = {}
  for _, rule in ipairs(rules.custom or {}) do
    table.insert(items, rule)
  end
  return items
end

--- @param rules _99.Agents.Rules
--- @param path string
--- @return _99.Agents.Rule | nil
function M.get_rule_by_path(rules, path)
  for _, rule in ipairs(rules.custom or {}) do
    if rule.path == path then
      return rule
    end
  end
  return nil
end

--- @param rules _99.Agents.Rules
--- @param token string
--- @return boolean
function M.is_rule(rules, token)
  for _, rule in ipairs(rules.custom or {}) do
    if rule.path == token or rule.name == token then
      return true
    end
  end
  return false
end

--- @param rules _99.Agents.Rules
--- @param prompt string
--- @return {names: string[], rules: _99.Agents.Rules[]}
function M.by_name(rules, prompt)
  --- @type table<string, boolean>
  local found = {}

  --- @type string[]
  local names = {}

  --- @type _99.Agents.Rule[]
  local out_rules = {}
  for word in prompt:gmatch("%S+") do
    if word:sub(1, 1) == "#" then
      local w = word:sub(2)
      local rules_by_name = rules.by_name[w]
      if rules_by_name and found[w] == nil then
        for _, r in ipairs(rules_by_name) do
          table.insert(out_rules, r)
        end
        table.insert(names, w)
        found[w] = true
      end
    end
  end

  return {
    names = names,
    rules = out_rules,
  }
end

--- @param rule _99.Agents.Rule
--- @return string | nil
function M.get_rule_content(rule)
  local file_path = rule.absolute_path or rule.path
  local ok, file = pcall(io.open, file_path, "r")
  if not ok or not file then
    return nil
  end
  local ok_read, content = pcall(file.read, file, "*a")
  if not ok_read then
    return nil
  end

  pcall(file.close, file)
  return string.format("<%s>\n%s\n</%s>", rule.name, content, rule.name)
end

--- @param _99 _99.State
--- @return _99.CompletionProvider
function M.completion_provider(_99)
  return {
    trigger = "#",
    name = "rules",
    get_items = function()
      local agent_rules = M.rules_to_items(_99.rules)
      local items = {}
      for _, rule in ipairs(agent_rules) do
        local docs = helpers.head(rule.absolute_path or rule.path)
        table.insert(items, {
          label = rule.name,
          insertText = "#" .. rule.path,
          filterText = "#" .. rule.name,
          kind = 12, -- LSP CompletionItemKind.Value
          documentation = { kind = "markdown", value = docs },
          detail = "Rule: " .. rule.path,
        })
      end
      return items
    end,
    is_valid = function(token)
      return M.is_rule(_99.rules, token)
    end,
    resolve = function(token)
      local rule = M.get_rule_by_path(_99.rules, token)
      if not rule then
        return nil
      end
      return M.get_rule_content(rule)
    end,
  }
end

return M