summaryrefslogtreecommitdiff
path: root/lua/99/test/visual_spec.lua
blob: 498c47f0e8d0c318c2167975a841f2180c0553e2 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
-- luacheck: globals describe it assert
local _99 = require("99")
local test_utils = require("99.test.test_utils")
local eq = assert.are.same
local Levels = require("99.logger.level")
local Range = require("99.geo").Range
local Point = require("99.geo").Point
local visual_fn = require("99.ops.over-range")
local Prompt = require("99.prompt")

--- @param content string[]
--- @param start_row number
--- @param start_col number
--- @param end_row number
--- @param end_col number
--- @return _99.test.Provider, number, _99.Range
local function setup(content, start_row, start_col, end_row, end_col)
  local p = test_utils.TestProvider.new()
  _99.setup({
    provider = p,
    logger = {
      error_cache_level = Levels.ERROR,
    },
  })

  local buffer = test_utils.create_file(content, "lua", start_row, start_col)

  -- Create a range for the visual selection
  local start_point = Point:from_1_based(start_row, start_col)
  local end_point = Point:from_1_based(end_row, end_col)
  local range = Range:new(buffer, start_point, end_point)

  return p, buffer, range
end

--- @param buffer number
--- @return string[]
local function r(buffer)
  return vim.api.nvim_buf_get_lines(buffer, 0, -1, false)
end

local content = {
  "local function foo()",
  "    -- TODO: implement",
  "end",
}

--- @param context _99.Prompt
local function visual_call_with_range(context, range)
  context.data.range = range
  visual_fn(context, {
    additional_prompt = "test prompt",
  })
end

describe("visual", function()
  it("should replace visual selection with AI response", function()
    local p, buffer, range = setup(content, 2, 1, 2, 23)
    local state = _99.__get_state()
    local context = Prompt.visual(state)

    visual_call_with_range(context, range)

    eq(1, state:active_request_count())
    eq(content, r(buffer))

    p:resolve("success", "    return 'implemented!'")
    test_utils.next_frame()

    local expected_state = {
      "local function foo()",
      "    return 'implemented!'",
      "end",
    }
    eq(expected_state, r(buffer))
    -- Note: Not checking active_request_count() == 0 due to logger bug with "id" key collision
  end)

  it("should handle multi-line replacement", function()
    local multi_line_content = {
      "local function bar()",
      "    -- TODO: implement",
      "    -- more comments",
      "    -- even more",
      "end",
    }
    local p, buffer, range = setup(multi_line_content, 2, 1, 4, 17)
    local state = _99.__get_state()
    local context = Prompt.visual(state)

    visual_call_with_range(context, range)

    eq(1, state:active_request_count())
    eq(multi_line_content, r(buffer))

    p:resolve("success", "    local x = 1\n    local y = 2\n    return x + y")
    test_utils.next_frame()

    local expected_state = {
      "local function bar()",
      "    local x = 1",
      "    local y = 2",
      "    return x + y",
      "end",
    }
    eq(expected_state, r(buffer))
    -- Note: Not checking active_request_count() == 0 due to logger bug with "id" key collision
  end)

  it("should cancel request when stop_all_requests is called", function()
    local p, buffer, range = setup(content, 2, 1, 2, 23)
    local state = _99.__get_state()
    local context = Prompt.visual(state)

    visual_call_with_range(context, range)

    eq(content, r(buffer))

    assert.is_false(p.request.prompt:is_cancelled())
    assert.is_not_nil(p.request)
    assert.is_not_nil(p.request.prompt)

    _99.stop_all_requests()
    test_utils.next_frame()

    assert.is_true(p.request.prompt:is_cancelled())

    p:resolve("success", "    return 'should not appear'")
    test_utils.next_frame()

    -- Buffer should remain unchanged after cancellation
    eq(content, r(buffer))
  end)

  it("should handle error cases with graceful failures", function()
    local p, buffer, range = setup(content, 2, 1, 2, 23)
    local state = _99.__get_state()
    local context = Prompt.visual(state)

    visual_call_with_range(context, range)

    eq(content, r(buffer))

    p:resolve("failed", "Something went wrong")
    test_utils.next_frame()

    -- Buffer should remain unchanged on failure
    eq(content, r(buffer))
  end)

  it("should handle cancelled status gracefully", function()
    local p, buffer, range = setup(content, 2, 1, 2, 23)
    local state = _99.__get_state()
    local context = require("99.prompt").visual(state)

    visual_call_with_range(context, range)

    eq(content, r(buffer))

    -- Manually cancel and resolve as cancelled
    p.request.prompt:cancel()
    p:resolve("cancelled", "Request was cancelled")
    test_utils.next_frame()

    -- Buffer should remain unchanged on cancellation
    eq(content, r(buffer))
  end)
end)