summaryrefslogtreecommitdiff
path: root/lua/99/ops/clean-up.lua
blob: ce7cc747efac731878d63cd7bea9e9956ebaed67 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
local M = {}

--- @alias _99.Providers.on_complete fun(status: _99.Prompt.EndingState, response: string): nil
--- @class _99.Providers.PartialObserver
--- @field on_complete _99.Providers.on_complete
--- @field on_stdout? fun(line: string): nil
--- @field on_stderr? fun(line: string): nil
--- @field on_start? fun(): nil

--- @param context _99.Prompt
--- @param obs_or_fn _99.Providers.PartialObserver | _99.Providers.on_complete
--- @return _99.Providers.Observer
M.make_observer = function(context, obs_or_fn)
    --- @type _99.Providers.PartialObserver
    local obs = type(obs_or_fn) == "table" and obs_or_fn
        or {
            on_complete = obs_or_fn,
        }
    return {
        on_start = function()
            if obs.on_start then
                obs.on_start()
            end
        end,
        on_complete = function(status, res)
            pcall(obs.on_complete, status, res)
            vim.schedule(function()
                context:stop()
                context._99:sync()
            end)
        end,
        on_stderr = function(line)
            if obs.on_stderr then
                obs.on_stderr(line)
            end
        end,
        on_stdout = function(line)
            if obs.on_stdout then
                obs.on_stdout(line)
            end
        end,
    } --[[@as _99.Providers.Observer ]]
end

---@param clean_up_fn fun(): nil
---@return fun(): nil
M.make_clean_up = function(clean_up_fn)
    local called = false
    local function clean_up()
        if called then
            return
        end
        called = true
        clean_up_fn()
    end
    return clean_up
end

return M