summaryrefslogtreecommitdiff
path: root/lua/99/extensions/completions.lua
blob: 742f6e3ee72beca6b906a446123b605881483fd1 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
--- A provider for completion tokens (#rules, @files) used in the prompt.
--- @class _99.CompletionProvider
--- @field trigger string
--- @field name string
--- @field get_items fun(): CompletionItem[]
--- @field is_valid fun(token: string): boolean
--- @field resolve fun(token: string): string|nil

--- @class _99.Reference
--- @field content string

--- @type _99.CompletionProvider[]
local providers = {}

local M = {}

--- Escape special Lua pattern characters in a string
--- @param str string The string to escape
--- @return string The escaped string safe for use in Lua patterns
function M.escape_pattern(str)
    return str:gsub("([%%%^%$%(%)%.%[%]%*%+%-%?])", "%%%1")
end

--- @param provider _99.CompletionProvider
function M.register(provider)
    for i, p in ipairs(providers) do
        if p.trigger == provider.trigger then
            providers[i] = provider
            return
        end
    end
    table.insert(providers, provider)
end

--- @return string[]
function M.get_trigger_characters()
    local chars = {}
    for _, p in ipairs(providers) do
        table.insert(chars, p.trigger)
    end
    return chars
end

--- @return string
function M.get_keyword_pattern()
    local triggers = {}
    for _, p in ipairs(providers) do
        table.insert(triggers, p.trigger)
    end
    return "[" .. table.concat(triggers) .. "]\\k*"
end

--- @param prompt_text string|boolean|nil
--- @return _99.Reference[]
function M.parse(prompt_text)
    local refs = {}
    if type(prompt_text) ~= "string" then
        return refs
    end

    for _, provider in ipairs(providers) do
        local pattern = M.escape_pattern(provider.trigger) .. "%S+"
        for word in prompt_text:gmatch(pattern) do
            local token = word:sub(#provider.trigger + 1)
            if provider.is_valid(token) then
                local content = provider.resolve(token)
                if content then
                    table.insert(refs, { content = content })
                end
            end
        end
    end
    return refs
end

--- @param trigger_char string
--- @return CompletionItem[]
function M.get_completions(trigger_char)
    for _, provider in ipairs(providers) do
        if provider.trigger == trigger_char then
            return provider.get_items()
        end
    end
    return {}
end

function M._reset()
    providers = {}
end

return M