summaryrefslogtreecommitdiff
path: root/lua/99/ops/over-range.lua
blob: c1fca2af36e02294c95f6aa281d9c3a439cd570e (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
-- TODO: YOU HAVE TO FIX THIS FILE BROOOOO....
-- actually important file!!!!

local RequestStatus = require("99.ops.request_status")
local Mark = require("99.ops.marks")
local BaseProvider = require("99.providers")
local geo = require("99.geo")
local make_prompt = require("99.ops.make-prompt")
local CleanUp = require("99.ops.clean-up")

local make_clean_up = CleanUp.make_clean_up
local make_observer = CleanUp.make_observer

local Range = geo.Range
local Point = geo.Point

--- @param context _99.Prompt
--- @param opts? _99.ops.Opts
local function over_range(context, opts)
    opts = opts or {}
    local logger = context.logger:set_area("visual")

    local data = context:visual_data()
    local range = data.range
    local top_mark = Mark.mark_above_range(range)
    local bottom_mark = Mark.mark_point(range.buffer, range.end_)
    context.marks.top_mark = top_mark
    context.marks.bottom_mark = bottom_mark

    logger:debug(
        "visual request start",
        "start",
        Point.from_mark(top_mark),
        "end",
        Point.from_mark(bottom_mark)
    )

    local display_ai_status = context._99.ai_stdout_rows > 1
    local top_status = RequestStatus.new(
        250,
        context._99.ai_stdout_rows or 1,
        "Implementing",
        top_mark
    )
    local bottom_status = RequestStatus.new(250, 1, "Implementing", bottom_mark)
    local clean_up = make_clean_up(function()
        top_status:stop()
        bottom_status:stop()
    end)

    local system_cmd = context._99.prompts.prompts.visual_selection(range)
    local prompt = make_prompt(context, system_cmd, opts)

    context:add_prompt_content(prompt)
    context:add_clean_up(clean_up)

    top_status:start()
    bottom_status:start()
    context:start_request(make_observer(context, {
        on_complete = function(status, response)
            if status == "cancelled" then
                logger:debug(
                    "request cancelled for visual selection, removing marks"
                )
            elseif status == "failed" then
                logger:error(
                    "request failed for visual selection",
                    "error response",
                    response or "no response provided"
                )
            elseif status == "success" then
                local valid = top_mark:is_valid() and bottom_mark:is_valid()
                if not valid then
                    logger:fatal(
                        -- luacheck: ignore 631
                        "the original visual_selection has been destroyed.  You cannot delete the original visual selection during a request"
                    )
                    return
                end

                local dt =
                    BaseProvider.BareMetalProvider._parse_openai_response(
                        vim.json.decode(response)
                    ) --[[ @as _A4.OpenAIResponse ]]

                dt.completion.message.content =
                    BaseProvider.BareMetalProvider._clean_response(
                        dt.completion.message.content
                    )

                if vim.trim(dt.completion.message.content) == "" then
                    logger:debug(
                        "response was empty, visual replacement aborted"
                    )
                    return
                end

                context.req_response = dt.completion.message.content

                local new_range = Range.from_marks(top_mark, bottom_mark)
                local lines = vim.split(dt.completion.message.content, "\n")

                --- HACK: i am adding a new line here because above range will add a mark to the line above.
                --- that way this appears to be added to "the same line" as the visual selection was
                --- originally take from
                table.insert(lines, 1, "")

                new_range:replace_text(lines)
                context._99:sync()
            end
        end,
        on_stdout = function(line)
            if display_ai_status then
                top_status:push(line)
            end
        end,
    }))
end

return over_range