summaryrefslogtreecommitdiff
path: root/lua/99/test/test_utils.lua
blob: 554d54b4a521510a660fa9e615b36496f1fd7c6f (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
local Levels = require("99.logger.level")
local M = {}

--- @type _99.Providers.Observer
local DevNullObserver = {
  on_start = function() end,
  on_complete = function() end,
  on_stderr = function() end,
  on_stdout = function() end,
}

function M.next_frame()
  local next = false
  vim.schedule(function()
    next = true
  end)

  vim.wait(1000, function()
    return next
  end)
end

M.created_files = {}

--- @class _99.test.ProviderRequest
--- @field query string
--- @field prompt _99.Prompt
--- @field observer _99.Providers.Observer
--- @field logger _99.Logger

--- @class _99.test.Provider : _99.Providers.BaseProvider
--- @field request _99.test.ProviderRequest?
local TestProvider = {}
TestProvider.__index = TestProvider

function TestProvider.new()
  return setmetatable({}, TestProvider)
end

--- @param query string
---@param prompt _99.Prompt
---@param observer _99.Providers.Observer?
function TestProvider:make_request(query, prompt, observer)
  local logger = prompt.logger:set_area("TestProvider")
  logger:debug("make_request", "tmp_file", prompt.tmp_file)

  observer = observer or DevNullObserver
  observer.on_start()

  self.request = {
    query = query,
    prompt = prompt,
    observer = observer,
    logger = logger,
  }
end

--- @param status _99.Prompt.EndingState
--- @param result string
function TestProvider:resolve(status, result)
  assert(self.request, "you cannot call resolve until make_request is called")

  if self.request.prompt:is_cancelled() then
    self.request.observer.on_complete("cancelled", result)
  else
    self.request.observer.on_complete(status, result)
  end

  self.request = nil
end

--- @param line string
function TestProvider:stdout(line)
  assert(self.request, "you cannot call stdout until make_request is called")
  self.request.observer.on_stdout(line)
end

--- @param line string
function TestProvider:stderr(line)
  assert(self.request, "you cannot call stderr until make_request is called")
  self.request.observer.on_stderr(line)
end

M.TestProvider = TestProvider

function M.clean_files()
  for _, bufnr in ipairs(M.created_files) do
    vim.api.nvim_buf_delete(bufnr, { force = true })
  end
  M.created_files = {}
end

---@param contents string[]
---@param file_type string?
---@param row number?
---@param col number?
function M.create_file(contents, file_type, row, col)
  assert(type(contents) == "table", "contents must be a table of strings")
  file_type = file_type or "lua"
  local bufnr = vim.api.nvim_create_buf(false, false)

  vim.api.nvim_set_current_buf(bufnr)
  vim.bo[bufnr].ft = file_type
  vim.api.nvim_buf_set_lines(bufnr, 0, -1, false, contents)
  vim.api.nvim_win_set_cursor(0, { row or 1, col or 0 })

  table.insert(M.created_files, bufnr)
  return bufnr
end

--- @param opts _99.Options | nil
--- @param provider _99.Providers.BaseProvider
--- @return _99.Options
function M.get_test_setup_options(opts, provider)
  opts = opts or {}
  opts.tmp_dir = opts.tmp_dir or vim.fn.tempname()
  opts.provider = provider
  opts.logger = {
    error_cache_level = Levels.ERROR,
    type = "print",
  }
  opts.in_flight_options = opts.in_flight_options
    or {
      throbber_opts = {
        tick_time = 10,
        throb_time = 1000,
        cooldown_time = 500,
      },
      in_flight_interval = 10,
      enable = true,
    }
  return opts
end

--- @param content string[]
--- @param row number
--- @param col number
--- @param lang string?
--- @param opts _99.Options | nil
--- @return _99.test.Provider, number
function M.test_setup(content, row, col, lang, opts)
  lang = lang or "lua"
  local provider = M.TestProvider.new()
  require("99").setup(M.get_test_setup_options(opts, provider))

  local buffer = M.create_file(content, lang, row, col)
  return provider, buffer
end

--- @param buffer number
--- @return string[]
function M.r(buffer)
  return vim.api.nvim_buf_get_lines(buffer, 0, -1, false)
end

return M