diff options
55 files changed, 3386 insertions, 2586 deletions
diff --git a/.github/workflows/makefile.yml b/.github/workflows/makefile.yml index 9d2b733..116ae1a 100644 --- a/.github/workflows/makefile.yml +++ b/.github/workflows/makefile.yml @@ -11,14 +11,20 @@ jobs: runs-on: ubuntu-latest + env: + NVIM_VERSION: v0.11.5 + NVIM_TREESITTER_SHA: 2ba5ec184609a96b513bf4c53a20512d64e27f39 + XDG_DATA_HOME: ${{ github.workspace }}/.xdg/data + XDG_CACHE_HOME: ${{ github.workspace }}/.xdg/cache + steps: - uses: actions/checkout@v4 - name: Install Neovim run: | - wget -q https://github.com/neovim/neovim/releases/download/v0.10.3/nvim-linux64.tar.gz - tar xzf nvim-linux64.tar.gz - sudo mv nvim-linux64 /opt/nvim + wget -q https://github.com/neovim/neovim/releases/download/${NVIM_VERSION}/nvim-linux-x86_64.tar.gz + tar xzf nvim-linux-x86_64.tar.gz + sudo mv nvim-linux-x86_64 /opt/nvim sudo ln -s /opt/nvim/bin/nvim /usr/local/bin/nvim - name: Install luarocks @@ -36,9 +42,34 @@ jobs: chmod +x stylua sudo mv stylua /usr/local/bin/ + - name: Cache Neovim data (tree-sitter) + uses: actions/cache@v4 + with: + path: .xdg + key: ${{ runner.os }}-nvim-${{ env.NVIM_VERSION }}-ts-${{ env.NVIM_TREESITTER_SHA }} + + - name: Install tree-sitter CLI + run: | + node --version + npm --version + sudo npm install -g tree-sitter-cli + tree-sitter --version + - name: Install plenary.nvim run: | git clone https://github.com/nvim-lua/plenary.nvim.git ../plenary.nvim + - name: Install nvim-treesitter (pinned) + run: | + git clone https://github.com/nvim-treesitter/nvim-treesitter.git ../nvim-treesitter + git -C ../nvim-treesitter checkout ${NVIM_TREESITTER_SHA} + + - name: Install treesitter parsers + run: | + nvim --headless -u NONE -i NONE \ + -c "set rtp+=../nvim-treesitter" \ + -c "lua dofile('scripts/ci/install_treesitter_parsers.lua')" \ + -c "qa" + - name: Run pr_ready run: make pr_ready diff --git a/.stylua.toml b/.stylua.toml index a1cdc12..de815a2 100644 --- a/.stylua.toml +++ b/.stylua.toml @@ -1,6 +1,6 @@ column_width = 80 line_endings = "Unix" indent_type = "Spaces" -indent_width = 4 +indent_width = 2 quote_style = "AutoPreferDouble" @@ -31,6 +31,22 @@ I make the assumption you are using Lazy print_on_error = true, }, + --- A new feature that is centered around tags + completion = { + --- Defaults to .cursor/rules + cursor_rules = "<custom path to cursor rules>" + + --- A list of folders where you have your own agents + custom_rules = { + "scratch/custom_rules/", + }, + + --- What autocomplete do you use. We currently only + --- support cmp right now + source = "cmp", + + } + --- WARNING: if you change cwd then this is likely broken --- ill likely fix this in a later change --- @@ -64,10 +80,23 @@ I make the assumption you are using Lazy vim.keymap.set("v", "<leader>9s", function() _99.stop_all_requests() end) + + --- Example: Using rules + actions for custom behaviors + --- Create a rule file like ~/.rules/debug.md that defines custom behavior. + --- For instance, a "debug" rule could automatically add printf statements + --- throughout a function to help debug its execution flow. + vim.keymap.set("n", "<leader>9fd", function() + _99.fill_in_function() + end) end, }, ``` +## Completion +When prompting, if you have cmp installed as your autocomplete you can use an autocomplete for rule inclusion in your prompt. + +You can also specify a directory for rule inclusion, and this plugin auto looks into `.cursor/rules` unless you override it + ## API You can see the full api at [99 API](./lua/99/init.lua) diff --git a/lua/99/editor/init.lua b/lua/99/editor/init.lua index 70c24a3..ff0e8ab 100644 --- a/lua/99/editor/init.lua +++ b/lua/99/editor/init.lua @@ -1,4 +1,4 @@ return { - treesitter = require("99.editor.treesitter"), - -- lsp = require("99.editor.lsp"), + treesitter = require("99.editor.treesitter"), + -- lsp = require("99.editor.lsp"), } diff --git a/lua/99/editor/lsp.lua b/lua/99/editor/lsp.lua index 7f8e7d3..caf4b94 100644 --- a/lua/99/editor/lsp.lua +++ b/lua/99/editor/lsp.lua @@ -23,8 +23,8 @@ --- @param node _99.treesitter.Node The treesitter node to convert --- @return LspPosition The LSP-compatible position (0-based line and character) local function ts_node_to_lsp_position(node) - local start_row, start_col, _, _ = node:range() - return { line = start_row, character = start_col } + local start_row, start_col, _, _ = node:range() + return { line = start_row, character = start_col } end --- Makes an LSP textDocument/definition request for a given position. @@ -33,17 +33,17 @@ end --- @param position LspPosition The position in the document to get definitions for --- @param cb fun(res: LspDefinitionResult[] | nil): nil Callback receiving the definition results local function get_lsp_definitions(buffer, position, cb) - local params = vim.lsp.util.make_position_params() - params.position = position + local params = vim.lsp.util.make_position_params() + params.position = position - vim.lsp.buf_request( - buffer, - "textDocument/definition", - params, - function(_, result, _, _) - cb(result) - end - ) + vim.lsp.buf_request( + buffer, + "textDocument/definition", + params, + function(_, result, _, _) + cb(result) + end + ) end --- Resolves a Lua require path to an absolute file path using Neovim's runtime. @@ -51,22 +51,22 @@ end --- @param require_path string The Lua require path (e.g., "99.logger.logger") --- @return string|nil The absolute file path, or nil if it can't be resolved local function resolve_require_path(require_path) - local relative_path = "lua/" .. require_path:gsub("%.", "/") .. ".lua" - local results = vim.api.nvim_get_runtime_file(relative_path, false) + local relative_path = "lua/" .. require_path:gsub("%.", "/") .. ".lua" + local results = vim.api.nvim_get_runtime_file(relative_path, false) - if results and #results > 0 then - return results[1] - end + if results and #results > 0 then + return results[1] + end - -- Also try init.lua for module directories - local init_path = "lua/" .. require_path:gsub("%.", "/") .. "/init.lua" - results = vim.api.nvim_get_runtime_file(init_path, false) + -- Also try init.lua for module directories + local init_path = "lua/" .. require_path:gsub("%.", "/") .. "/init.lua" + results = vim.api.nvim_get_runtime_file(init_path, false) - if results and #results > 0 then - return results[1] - end + if results and #results > 0 then + return results[1] + end - return nil + return nil end --- Ensures a buffer is loaded and has LSP attached, then calls the callback. @@ -74,23 +74,23 @@ end --- @param filepath string The file path to load --- @param cb fun(bufnr: number|nil, err: string|nil): nil Callback with buffer number or error local function ensure_buffer_with_lsp(filepath, cb) - local bufnr = vim.fn.bufnr(filepath) - if bufnr == -1 then - bufnr = vim.fn.bufadd(filepath) - end + local bufnr = vim.fn.bufnr(filepath) + if bufnr == -1 then + bufnr = vim.fn.bufadd(filepath) + end - if not vim.api.nvim_buf_is_loaded(bufnr) then - vim.fn.bufload(bufnr) - end + if not vim.api.nvim_buf_is_loaded(bufnr) then + vim.fn.bufload(bufnr) + end - vim.schedule(function() - local clients = vim.lsp.get_clients({ bufnr = bufnr }) - if #clients == 0 then - cb(nil, "No LSP client attached to buffer for: " .. filepath) - return - end - cb(bufnr, nil) - end) + vim.schedule(function() + local clients = vim.lsp.get_clients({ bufnr = bufnr }) + if #clients == 0 then + cb(nil, "No LSP client attached to buffer for: " .. filepath) + return + end + cb(bufnr, nil) + end) end --- Makes an LSP textDocument/hover request for a given position. @@ -99,23 +99,23 @@ end --- @param position LspPosition The position to hover at --- @param cb fun(result: table|nil, err: string|nil): nil Callback with hover result local function get_lsp_hover(bufnr, position, cb) - local params = { - textDocument = { uri = vim.uri_from_bufnr(bufnr) }, - position = position, - } + local params = { + textDocument = { uri = vim.uri_from_bufnr(bufnr) }, + position = position, + } - vim.lsp.buf_request( - bufnr, - "textDocument/hover", - params, - function(err, result, _, _) - if err then - cb(nil, vim.inspect(err)) - return - end - cb(result, nil) - end - ) + vim.lsp.buf_request( + bufnr, + "textDocument/hover", + params, + function(err, result, _, _) + if err then + cb(nil, vim.inspect(err)) + return + end + cb(result, nil) + end + ) end --- Finds the return statement in a Lua file and extracts the exported keys. @@ -123,51 +123,51 @@ end --- @param bufnr number The buffer number --- @return { name: string, line: number, col: number }[] List of exported names with positions local function find_export_keys(bufnr) - local lines = vim.api.nvim_buf_get_lines(bufnr, 0, -1, false) - local exports = {} + local lines = vim.api.nvim_buf_get_lines(bufnr, 0, -1, false) + local exports = {} - -- Find the last return statement - local return_line_idx = nil - for i = #lines, 1, -1 do - if lines[i]:match("^%s*return%s+") then - return_line_idx = i - break - end + -- Find the last return statement + local return_line_idx = nil + for i = #lines, 1, -1 do + if lines[i]:match("^%s*return%s+") then + return_line_idx = i + break end + end - if not return_line_idx then - return exports - end + if not return_line_idx then + return exports + end - -- Check if it's a simple `return M` style - local simple_return = - lines[return_line_idx]:match("^%s*return%s+([%w_]+)%s*$") - if simple_return then - local col = lines[return_line_idx]:find(simple_return) + -- Check if it's a simple `return M` style + local simple_return = + lines[return_line_idx]:match("^%s*return%s+([%w_]+)%s*$") + if simple_return then + local col = lines[return_line_idx]:find(simple_return) + table.insert(exports, { + name = simple_return, + line = return_line_idx - 1, + col = col - 1, + }) + return exports + end + + -- Parse `return { Key = Value, ... }` style + for i = return_line_idx, #lines do + local line = lines[i] + for key, col_start in line:gmatch("()([%w_]+)%s*=") do + key, col_start = col_start, key + if key ~= "" and not key:match("^%d") then table.insert(exports, { - name = simple_return, - line = return_line_idx - 1, - col = col - 1, + name = key, + line = i - 1, + col = col_start - 1, }) - return exports - end - - -- Parse `return { Key = Value, ... }` style - for i = return_line_idx, #lines do - local line = lines[i] - for key, col_start in line:gmatch("()([%w_]+)%s*=") do - key, col_start = col_start, key - if key ~= "" and not key:match("^%d") then - table.insert(exports, { - name = key, - line = i - 1, - col = col_start - 1, - }) - end - end + end end + end - return exports + return exports end --- Gets the hover information for each exported symbol using LSP. @@ -176,59 +176,55 @@ end --- @param export_keys { name: string, line: number, col: number }[] The export positions --- @param cb fun(results: table<string, string>): nil Callback with name -> hover info map local function get_exports_hover_info(bufnr, export_keys, cb) - if #export_keys == 0 then - cb({}) - return - end + if #export_keys == 0 then + cb({}) + return + end - local results = {} - local pending = #export_keys + local results = {} + local pending = #export_keys - for _, export in ipairs(export_keys) do - local line_text = vim.api.nvim_buf_get_lines( - bufnr, - export.line, - export.line + 1, - false - )[1] + for _, export in ipairs(export_keys) do + local line_text = + vim.api.nvim_buf_get_lines(bufnr, export.line, export.line + 1, false)[1] - local pattern = export.name .. "%s*=%s*()" - local value_start = line_text:match(pattern) - local hover_col = value_start and (value_start - 1) or export.col - local position = { line = export.line, character = hover_col } + local pattern = export.name .. "%s*=%s*()" + local value_start = line_text:match(pattern) + local hover_col = value_start and (value_start - 1) or export.col + local position = { line = export.line, character = hover_col } - get_lsp_hover(bufnr, position, function(result, _) - if result and result.contents then - local content = result.contents - if type(content) == "table" then - if content.value then - results[export.name] = content.value - elseif content.kind == "markdown" then - results[export.name] = content.value - else - local parts = {} - for _, part in ipairs(content) do - if type(part) == "string" then - table.insert(parts, part) - elseif part.value then - table.insert(parts, part.value) - end - end - results[export.name] = table.concat(parts, "\n") - end - else - results[export.name] = tostring(content) - end - else - results[export.name] = "unknown" + get_lsp_hover(bufnr, position, function(result, _) + if result and result.contents then + local content = result.contents + if type(content) == "table" then + if content.value then + results[export.name] = content.value + elseif content.kind == "markdown" then + results[export.name] = content.value + else + local parts = {} + for _, part in ipairs(content) do + if type(part) == "string" then + table.insert(parts, part) + elseif part.value then + table.insert(parts, part.value) + end end + results[export.name] = table.concat(parts, "\n") + end + else + results[export.name] = tostring(content) + end + else + results[export.name] = "unknown" + end - pending = pending - 1 - if pending == 0 then - cb(results) - end - end) - end + pending = pending - 1 + if pending == 0 then + cb(results) + end + end) + end end --- Finds all method/field definitions for a class in the source file. @@ -237,32 +233,32 @@ end --- @param class_name string The name of the class (e.g., "Lsp") --- @return { name: string, line: number, col: number }[] List of member positions local function find_class_member_positions(file_lines, class_name) - local members = {} + local members = {} - for i, line in ipairs(file_lines) do - local method_name = - line:match("^%s*function%s+" .. class_name .. "[%.:]([%w_]+)%s*%(") - if method_name then - local col = line:find(method_name, 1, true) - table.insert(members, { - name = method_name, - line = i - 1, - col = col and (col - 1) or 0, - }) - end + for i, line in ipairs(file_lines) do + local method_name = + line:match("^%s*function%s+" .. class_name .. "[%.:]([%w_]+)%s*%(") + if method_name then + local col = line:find(method_name, 1, true) + table.insert(members, { + name = method_name, + line = i - 1, + col = col and (col - 1) or 0, + }) + end - local field_name = line:match("^%s*" .. class_name .. "%.([%w_]+)%s*=") - if field_name and not line:match("^%s*function") then - local col = line:find(field_name, 1, true) - table.insert(members, { - name = field_name, - line = i - 1, - col = col and (col - 1) or 0, - }) - end + local field_name = line:match("^%s*" .. class_name .. "%.([%w_]+)%s*=") + if field_name and not line:match("^%s*function") then + local col = line:find(field_name, 1, true) + table.insert(members, { + name = field_name, + line = i - 1, + col = col and (col - 1) or 0, + }) end + end - return members + return members end --- Gets hover information for each class member using LSP. @@ -271,52 +267,52 @@ end --- @param member_positions { name: string, line: number, col: number }[] Member positions --- @param cb fun(results: table<string, string>): nil Callback with name -> type info map local function get_class_members_hover(bufnr, member_positions, cb) - if #member_positions == 0 then - cb({}) - return - end + if #member_positions == 0 then + cb({}) + return + end - local results = {} - local pending = #member_positions + local results = {} + local pending = #member_positions - for _, member in ipairs(member_positions) do - local position = { line = member.line, character = member.col } + for _, member in ipairs(member_positions) do + local position = { line = member.line, character = member.col } - get_lsp_hover(bufnr, position, function(result, _) - local hover_text = "unknown" + get_lsp_hover(bufnr, position, function(result, _) + local hover_text = "unknown" - if result and result.contents then - local content = result.contents + if result and result.contents then + local content = result.contents - if type(content) == "table" then - if content.value then - hover_text = content.value - elseif content.kind then - hover_text = content.value or "" - else - local parts = {} - for _, part in ipairs(content) do - if type(part) == "string" then - table.insert(parts, part) - elseif part.value then - table.insert(parts, part.value) - end - end - hover_text = table.concat(parts, "\n") - end - else - hover_text = tostring(content) - end + if type(content) == "table" then + if content.value then + hover_text = content.value + elseif content.kind then + hover_text = content.value or "" + else + local parts = {} + for _, part in ipairs(content) do + if type(part) == "string" then + table.insert(parts, part) + elseif part.value then + table.insert(parts, part.value) + end end + hover_text = table.concat(parts, "\n") + end + else + hover_text = tostring(content) + end + end - results[member.name] = hover_text + results[member.name] = hover_text - pending = pending - 1 - if pending == 0 then - cb(results) - end - end) - end + pending = pending - 1 + if pending == 0 then + cb(results) + end + end) + end end --- Removes markdown fencing and cleans hover output. @@ -324,24 +320,24 @@ end --- @param hover_text string The raw hover text from LSP --- @return string The cleaned type information local function format_hover_output(hover_text) - if not hover_text or hover_text == "unknown" then - return "unknown" - end + if not hover_text or hover_text == "unknown" then + return "unknown" + end - local lines = {} + local lines = {} - for line in hover_text:gmatch("[^\n]+") do - if not line:match("^```") then - local cleaned = line - cleaned = cleaned:gsub("^local%s+", "") - cleaned = cleaned:gsub("^[%w_]+:%s*", "") - if cleaned ~= "" then - table.insert(lines, cleaned) - end - end + for line in hover_text:gmatch("[^\n]+") do + if not line:match("^```") then + local cleaned = line + cleaned = cleaned:gsub("^local%s+", "") + cleaned = cleaned:gsub("^[%w_]+:%s*", "") + if cleaned ~= "" then + table.insert(lines, cleaned) + end end + end - return table.concat(lines, "\n") + return table.concat(lines, "\n") end --- Formats a function hover result into TypeScript-style signature. @@ -349,21 +345,21 @@ end --- @param hover_text string The hover text from LSP --- @return string The formatted signature like "(a: number, b: string): boolean" local function format_function_signature(hover_text) - local clean = hover_text:gsub("```%w*\n?", ""):gsub("```", "") - clean = clean:gsub("^%s*", ""):gsub("%s*$", "") + local clean = hover_text:gsub("```%w*\n?", ""):gsub("```", "") + clean = clean:gsub("^%s*", ""):gsub("%s*$", "") - local params, ret = - clean:match("function%s*[%w_%.%:]*%((.-)%)%s*:%s*([^\n]+)") + local params, ret = + clean:match("function%s*[%w_%.%:]*%((.-)%)%s*:%s*([^\n]+)") + if params then + return string.format("(%s): %s", params, ret or "nil") + else + params = clean:match("function%s*[%w_%.%:]*%((.-)%)") if params then - return string.format("(%s): %s", params, ret or "nil") - else - params = clean:match("function%s*[%w_%.%:]*%((.-)%)") - if params then - return string.format("(%s): nil", params) - end + return string.format("(%s): nil", params) end + end - return clean + return clean end --- Extracts all enum values from source (not truncated like hover). @@ -372,34 +368,34 @@ end --- @param symbol_name string The name of the enum symbol --- @return string[] Array of enum entries like "Key = value" local function expand_enum_values(file_lines, symbol_name) - local values = {} - - for i, line in ipairs(file_lines) do - if - line:match("local%s+" .. symbol_name .. "%s*=") - or line:match(symbol_name .. "%s*=%s*{") - then - local j = i - while j <= #file_lines do - local enum_line = file_lines[j] + local values = {} - if enum_line:match("^%s*}") then - break - end + for i, line in ipairs(file_lines) do + if + line:match("local%s+" .. symbol_name .. "%s*=") + or line:match(symbol_name .. "%s*=%s*{") + then + local j = i + while j <= #file_lines do + local enum_line = file_lines[j] - local key, value = enum_line:match("^%s*([%w_]+)%s*=%s*([^,]+)") - if key and value then - value = value:match("^%s*(.-)%s*,?%s*$") - table.insert(values, key .. " = " .. value) - end + if enum_line:match("^%s*}") then + break + end - j = j + 1 - end - break + local key, value = enum_line:match("^%s*([%w_]+)%s*=%s*([^,]+)") + if key and value then + value = value:match("^%s*(.-)%s*,?%s*$") + table.insert(values, key .. " = " .. value) end + + j = j + 1 + end + break end + end - return values + return values end -------------------------------------------------------------------------------- @@ -416,9 +412,9 @@ Lsp.__index = Lsp --- @param config _99.Options The configuration options --- @return Lsp A new Lsp instance function Lsp.new(config) - return setmetatable({ - config = config, - }, Lsp) + return setmetatable({ + config = config, + }, Lsp) end -------------------------------------------------------------------------------- @@ -431,97 +427,97 @@ end --- @param require_path string The Lua require path (e.g., "99", "99.logger.logger") --- @param cb fun(result: string, err: string|nil): nil Callback with formatted string or error function Lsp.stringify_module_exports(require_path, cb) - local resolved_path = resolve_require_path(require_path) + local resolved_path = resolve_require_path(require_path) - if not resolved_path then - cb( - "", - "Could not resolve module path: " - .. require_path - .. ". The module may not be in runtimepath." - ) - return - end + if not resolved_path then + cb( + "", + "Could not resolve module path: " + .. require_path + .. ". The module may not be in runtimepath." + ) + return + end - local uri = vim.uri_from_fname(resolved_path) + local uri = vim.uri_from_fname(resolved_path) - ensure_buffer_with_lsp(resolved_path, function(bufnr, err) - if err then - cb("", err) - return - end + ensure_buffer_with_lsp(resolved_path, function(bufnr, err) + if err then + cb("", err) + return + end - local export_keys = find_export_keys(bufnr) + local export_keys = find_export_keys(bufnr) - if #export_keys == 0 then - cb("", "No exports found in return statement") - return - end + if #export_keys == 0 then + cb("", "No exports found in return statement") + return + end - get_exports_hover_info(bufnr, export_keys, function(hover_results) - local file_lines = vim.fn.readfile(resolved_path) + get_exports_hover_info(bufnr, export_keys, function(hover_results) + local file_lines = vim.fn.readfile(resolved_path) - -- Collect classes that need member expansion - local classes_to_expand = {} - for _, export in ipairs(export_keys) do - local hover = hover_results[export.name] or "unknown" - local is_class = hover:match("__index") ~= nil - or hover:match(":%s*[%w_]+%s*{") ~= nil + -- Collect classes that need member expansion + local classes_to_expand = {} + for _, export in ipairs(export_keys) do + local hover = hover_results[export.name] or "unknown" + local is_class = hover:match("__index") ~= nil + or hover:match(":%s*[%w_]+%s*{") ~= nil - if is_class then - local member_positions = - find_class_member_positions(file_lines, export.name) - if #member_positions > 0 then - table.insert(classes_to_expand, { - name = export.name, - positions = member_positions, - }) - end - end - end + if is_class then + local member_positions = + find_class_member_positions(file_lines, export.name) + if #member_positions > 0 then + table.insert(classes_to_expand, { + name = export.name, + positions = member_positions, + }) + end + end + end - -- If no classes, format immediately - if #classes_to_expand == 0 then - local result = Lsp._format_exports( - require_path, - uri, - export_keys, - hover_results, - file_lines, - {} - ) - cb(result, nil) - return - end + -- If no classes, format immediately + if #classes_to_expand == 0 then + local result = Lsp._format_exports( + require_path, + uri, + export_keys, + hover_results, + file_lines, + {} + ) + cb(result, nil) + return + end - -- Get hover for class members - local pending = #classes_to_expand - local all_member_hovers = {} + -- Get hover for class members + local pending = #classes_to_expand + local all_member_hovers = {} - for _, class_info in ipairs(classes_to_expand) do - get_class_members_hover( - bufnr, - class_info.positions, - function(member_hovers) - all_member_hovers[class_info.name] = member_hovers - pending = pending - 1 + for _, class_info in ipairs(classes_to_expand) do + get_class_members_hover( + bufnr, + class_info.positions, + function(member_hovers) + all_member_hovers[class_info.name] = member_hovers + pending = pending - 1 - if pending == 0 then - local result = Lsp._format_exports( - require_path, - uri, - export_keys, - hover_results, - file_lines, - all_member_hovers - ) - cb(result, nil) - end - end - ) + if pending == 0 then + local result = Lsp._format_exports( + require_path, + uri, + export_keys, + hover_results, + file_lines, + all_member_hovers + ) + cb(result, nil) end - end) + end + ) + end end) + end) end --- Internal function to format exports into a string. @@ -534,88 +530,84 @@ end --- @param class_member_hovers table<string, table<string, string>> Class name -> member hovers --- @return string The formatted export string function Lsp._format_exports( - module_path, - uri, - export_keys, - hover_results, - file_lines, - class_member_hovers + module_path, + uri, + export_keys, + hover_results, + file_lines, + class_member_hovers ) - local out = {} + local out = {} - table.insert(out, "Module: " .. module_path) - table.insert(out, "URI: " .. uri) - table.insert(out, string.rep("-", 60)) + table.insert(out, "Module: " .. module_path) + table.insert(out, "URI: " .. uri) + table.insert(out, string.rep("-", 60)) - for _, export in ipairs(export_keys) do - table.insert(out, "") - - local hover = hover_results[export.name] or "unknown" + for _, export in ipairs(export_keys) do + table.insert(out, "") - local is_enum = hover:match("enum%s+") ~= nil - local is_class = hover:match("__index") ~= nil - or hover:match(":%s*[%w_]+%s*{") ~= nil + local hover = hover_results[export.name] or "unknown" - if is_enum then - local values = expand_enum_values(file_lines, export.name) - if #values > 0 then - table.insert(out, export.name .. " = {") - for _, v in ipairs(values) do - table.insert(out, " " .. v) - end - table.insert(out, "}") - else - table.insert( - out, - export.name .. ": " .. format_hover_output(hover) - ) - end - elseif is_class then - local member_hovers = class_member_hovers[export.name] or {} - table.insert(out, export.name .. " {") + local is_enum = hover:match("enum%s+") ~= nil + local is_class = hover:match("__index") ~= nil + or hover:match(":%s*[%w_]+%s*{") ~= nil - -- Extract fields from class hover - local class_fields = {} - for line in hover:gmatch("[^\n]+") do - local field_name, field_type = - line:match("^%s*([%w_]+):%s*([^,}]+)") - if field_name and field_type then - field_type = field_type:match("^%s*(.-)%s*,?$") - if field_type ~= "function" then - class_fields[field_name] = field_type - end - end - end + if is_enum then + local values = expand_enum_values(file_lines, export.name) + if #values > 0 then + table.insert(out, export.name .. " = {") + for _, v in ipairs(values) do + table.insert(out, " " .. v) + end + table.insert(out, "}") + else + table.insert(out, export.name .. ": " .. format_hover_output(hover)) + end + elseif is_class then + local member_hovers = class_member_hovers[export.name] or {} + table.insert(out, export.name .. " {") - -- Print fields - for field_name, field_type in pairs(class_fields) do - if field_name ~= "__index" then - table.insert(out, " " .. field_name .. ": " .. field_type) - end - end + -- Extract fields from class hover + local class_fields = {} + for line in hover:gmatch("[^\n]+") do + local field_name, field_type = line:match("^%s*([%w_]+):%s*([^,}]+)") + if field_name and field_type then + field_type = field_type:match("^%s*(.-)%s*,?$") + if field_type ~= "function" then + class_fields[field_name] = field_type + end + end + end - -- Print methods with full signatures - for method_name, method_hover in pairs(member_hovers) do - if method_name ~= "__index" then - local sig = format_function_signature(method_hover) - table.insert(out, " " .. method_name .. sig) - end - end + -- Print fields + for field_name, field_type in pairs(class_fields) do + if field_name ~= "__index" then + table.insert(out, " " .. field_name .. ": " .. field_type) + end + end - table.insert(out, "}") - else - local formatted = format_hover_output(hover) - table.insert(out, export.name .. ": " .. formatted) + -- Print methods with full signatures + for method_name, method_hover in pairs(member_hovers) do + if method_name ~= "__index" then + local sig = format_function_signature(method_hover) + table.insert(out, " " .. method_name .. sig) end + end + + table.insert(out, "}") + else + local formatted = format_hover_output(hover) + table.insert(out, export.name .. ": " .. formatted) end + end - return table.concat(out, "\n") + return table.concat(out, "\n") end Lsp.stringify_module_exports("99.editor.lsp", function(res) - print(res) + print(res) end) return { - Lsp = Lsp, + Lsp = Lsp, } diff --git a/lua/99/editor/treesitter.lua b/lua/99/editor/treesitter.lua index b758266..d219ffb 100644 --- a/lua/99/editor/treesitter.lua +++ b/lua/99/editor/treesitter.lua @@ -22,69 +22,69 @@ local fn_call_query = "99-fn-call" --- @param buffer number ---@param lang string local function tree_root(buffer, lang) - -- Load the parser and the query. - local ok, parser = pcall(vim.treesitter.get_parser, buffer, lang) - if not ok then - return nil - end + -- Load the parser and the query. + local ok, parser = pcall(vim.treesitter.get_parser, buffer, lang) + if not ok then + return nil + end - local tree = parser:parse()[1] - return tree:root() + local tree = parser:parse()[1] + return tree:root() end --- @param context _99.RequestContext --- @param cursor _99.Point --- @return _99.treesitter.TSNode | nil function M.fn_call(context, cursor) - local buffer = context.buffer - local lang = context.file_type - local logger = context.logger:set_area("treesitter") - local root = tree_root(buffer, lang) - if not root then - Logger:error( - "unable to find treeroot, this should never happen", - "buffer", - buffer, - "lang", - lang - ) - return nil - end + local buffer = context.buffer + local lang = context.file_type + local logger = context.logger:set_area("treesitter") + local root = tree_root(buffer, lang) + if not root then + Logger:error( + "unable to find treeroot, this should never happen", + "buffer", + buffer, + "lang", + lang + ) + return nil + end - local ok, query = pcall(vim.treesitter.query.get, lang, fn_call_query) - if not ok or query == nil then - logger:error( - "unable to get the fn_call_query", - "lang", - lang, - "buffer", - buffer, - "ok", - type(ok), - "query", - type(query) - ) - return nil - end + local ok, query = pcall(vim.treesitter.query.get, lang, fn_call_query) + if not ok or query == nil then + logger:error( + "unable to get the fn_call_query", + "lang", + lang, + "buffer", + buffer, + "ok", + type(ok), + "query", + type(query) + ) + return nil + end - --- likely something that needs to be done with treesitter#get_node - local found = nil - for _, match, _ in query:iter_matches(root, buffer, 0, -1, { all = true }) do - for _, nodes in pairs(match) do - for _, node in ipairs(nodes) do - local range = Range:from_ts_node(node, buffer) - if range:contains(cursor) then - found = node - goto end_of_loops - end - end + --- likely something that needs to be done with treesitter#get_node + local found = nil + for _, match, _ in query:iter_matches(root, buffer, 0, -1, { all = true }) do + for _, nodes in pairs(match) do + for _, node in ipairs(nodes) do + local range = Range:from_ts_node(node, buffer) + if range:contains(cursor) then + found = node + goto end_of_loops end + end end - ::end_of_loops:: + end + ::end_of_loops:: - logger:debug("treesitter#fn_call", "found", found ~= nil) + logger:debug("treesitter#fn_call", "found", found ~= nil) - return found + return found end --- @class _99.treesitter.Function @@ -99,7 +99,7 @@ Function.__index = Function --- to replace at the exact function begin / end --- @param replace_with string[] function Function:replace_text(replace_with) - self.function_range:replace_text(replace_with) + self.function_range:replace_text(replace_with) end --- @param ts_node _99.treesitter.TSNode @@ -107,149 +107,149 @@ end ---@param context _99.RequestContext ---@return _99.treesitter.Function function Function.from_ts_node(ts_node, cursor, context) - local ok, query = - pcall(vim.treesitter.query.get, context.file_type, function_query) - local logger = context.logger:set_area("Function") - if not ok or query == nil then - logger:fatal("not query or not ok") - error("failed") - end + local ok, query = + pcall(vim.treesitter.query.get, context.file_type, function_query) + local logger = context.logger:set_area("Function") + if not ok or query == nil then + logger:fatal("not query or not ok") + error("failed") + end - local func = {} - for id, node, _ in - query:iter_captures(ts_node, context.buffer, 0, -1, { all = true }) - do - local range = Range:from_ts_node(node, context.buffer) - local name = query.captures[id] - if range:contains(cursor) then - if name == "context.function" then - func.function_node = node - func.function_range = range - elseif name == "context.body" then - func.body_node = node - func.body_range = range - end - end + local func = {} + for id, node, _ in + query:iter_captures(ts_node, context.buffer, 0, -1, { all = true }) + do + local range = Range:from_ts_node(node, context.buffer) + local name = query.captures[id] + if range:contains(cursor) then + if name == "context.function" then + func.function_node = node + func.function_range = range + elseif name == "context.body" then + func.body_node = node + func.body_range = range + end end + end - --- NOTE: not all functions have bodies... (lua: local function foo() end) - logger:assert(func.function_node ~= nil, "function_node not found") - logger:assert(func.function_range ~= nil, "function_range not found") + --- NOTE: not all functions have bodies... (lua: local function foo() end) + logger:assert(func.function_node ~= nil, "function_node not found") + logger:assert(func.function_range ~= nil, "function_range not found") - return setmetatable(func, Function) + return setmetatable(func, Function) end --- @param context _99.RequestContext --- @param cursor _99.Point --- @return _99.treesitter.Function? function M.containing_function(context, cursor) - local buffer = context.buffer - local lang = context.file_type - local logger = context and context.logger:set_area("treesitter") or Logger - - logger:error("loading lang", "buffer", buffer, "lang", lang) - local root = tree_root(buffer, lang) - if not root then - logger:debug("LSP: could not find tree root") - return nil - end + local buffer = context.buffer + local lang = context.file_type + local logger = context and context.logger:set_area("treesitter") or Logger - local ok, query = pcall(vim.treesitter.query.get, lang, function_query) - if not ok or query == nil then - logger:debug( - "LSP: not ok or query", - "query", - vim.inspect(query), - "lang", - lang, - "ok", - vim.inspect(ok) - ) - return nil - end - - --- @type _99.Range - local found_range = nil - --- @type _99.treesitter.TSNode - local found_node = nil - for id, node, _ in query:iter_captures(root, buffer, 0, -1, { all = true }) do - local range = Range:from_ts_node(node, buffer) - local name = query.captures[id] - if name == "context.function" and range:contains(cursor) then - if not found_range then - found_range = range - found_node = node - elseif found_range:area() > range:area() then - found_range = range - found_node = node - end - end - end + logger:error("loading lang", "buffer", buffer, "lang", lang) + local root = tree_root(buffer, lang) + if not root then + logger:debug("LSP: could not find tree root") + return nil + end + local ok, query = pcall(vim.treesitter.query.get, lang, function_query) + if not ok or query == nil then logger:debug( - "treesitter#containing_function", - "found_range", - found_range and found_range:to_string() or "found_range is nil" + "LSP: not ok or query", + "query", + vim.inspect(query), + "lang", + lang, + "ok", + vim.inspect(ok) ) + return nil + end - if not found_range then - return nil + --- @type _99.Range + local found_range = nil + --- @type _99.treesitter.TSNode + local found_node = nil + for id, node, _ in query:iter_captures(root, buffer, 0, -1, { all = true }) do + local range = Range:from_ts_node(node, buffer) + local name = query.captures[id] + if name == "context.function" and range:contains(cursor) then + if not found_range then + found_range = range + found_node = node + elseif found_range:area() > range:area() then + found_range = range + found_node = node + end end - logger:assert( - found_node, - "INVARIANT: found_range is not nil but found node is" - ) + end - ok, query = pcall(vim.treesitter.query.get, lang, function_query) - if not ok or query == nil then - logger:fatal("INVARIANT: found_range ", "range", found_range:to_text()) - return - end + logger:debug( + "treesitter#containing_function", + "found_range", + found_range and found_range:to_string() or "found_range is nil" + ) + + if not found_range then + return nil + end + logger:assert( + found_node, + "INVARIANT: found_range is not nil but found node is" + ) + + ok, query = pcall(vim.treesitter.query.get, lang, function_query) + if not ok or query == nil then + logger:fatal("INVARIANT: found_range ", "range", found_range:to_text()) + return + end - --- TODO: we need some language specific things here. - --- that is because comments above the function needs to considered - return Function.from_ts_node(found_node, cursor, context) + --- TODO: we need some language specific things here. + --- that is because comments above the function needs to considered + return Function.from_ts_node(found_node, cursor, context) end --- @param buffer number --- @return _99.treesitter.Node[] function M.imports(buffer) - Logger:assert(false, "not implemented yet", "id", 69420) - local lang = vim.bo[buffer].ft - local root = tree_root(buffer, lang) - if not root then - Logger:debug("imports: could not find tree root") - return {} - end + Logger:assert(false, "not implemented yet", "id", 69420) + local lang = vim.bo[buffer].ft + local root = tree_root(buffer, lang) + if not root then + Logger:debug("imports: could not find tree root") + return {} + end - local ok, query = pcall(vim.treesitter.query.get, lang, imports_query) + local ok, query = pcall(vim.treesitter.query.get, lang, imports_query) - if not ok or query == nil then - Logger:debug( - "imports: not ok or query", - "query", - vim.inspect(query), - "lang", - lang, - "ok", - vim.inspect(ok) - ) - return {} - end + if not ok or query == nil then + Logger:debug( + "imports: not ok or query", + "query", + vim.inspect(query), + "lang", + lang, + "ok", + vim.inspect(ok) + ) + return {} + end - local imports = {} - for _, match, _ in query:iter_matches(root, buffer, 0, -1, { all = true }) do - for id, nodes in pairs(match) do - local name = query.captures[id] - if name == "import.name" then - for _, node in ipairs(nodes) do - table.insert(imports, node) - end - end + local imports = {} + for _, match, _ in query:iter_matches(root, buffer, 0, -1, { all = true }) do + for id, nodes in pairs(match) do + local name = query.captures[id] + if name == "import.name" then + for _, node in ipairs(nodes) do + table.insert(imports, node) end + end end + end - return imports + return imports end return M diff --git a/lua/99/extensions/agents/helpers.lua b/lua/99/extensions/agents/helpers.lua new file mode 100644 index 0000000..963c9f8 --- /dev/null +++ b/lua/99/extensions/agents/helpers.lua @@ -0,0 +1,67 @@ +local M = {} + +--- @param path string +--- @return string +local function normalize_path(path) + if path:sub(1, 1) == "/" then + return path + end + local cwd = vim.fs.joinpath(vim.uv.cwd(), path) + return cwd +end + +--- @param dir string +--- @return _99.Agents.Rule[] +function M.ls(dir) + local current_dir = normalize_path(dir) + local glob = vim.fs.joinpath(current_dir, "/*.{mdc,md}") + local files = vim.fn.glob(glob, false, true) + local rules = {} + + for _, file in ipairs(files) do + local filename = vim.fn.fnamemodify(file, ":t:r") + table.insert(rules, { + name = filename, + path = file, + }) + end + + return rules +end + +--- @param file string +--- @param count? number +--- @return string +function M.head(file, count) + count = count or 5 + local fd = vim.uv.fs_open(file, "r", 438) + if not fd then + return "" + end + + local stat = vim.uv.fs_fstat(fd) + if not stat then + vim.uv.fs_close(fd) + return "" + end + + local data = vim.uv.fs_read(fd, stat.size, 0) + vim.uv.fs_close(fd) + + if not data then + return "" + end + + local lines = {} + for line in data:gmatch("([^\n]*)\n?") do + if count == 0 then + break + end + count = count - 1 + table.insert(lines, line) + end + + return table.concat(lines, "\n") +end + +return M diff --git a/lua/99/extensions/agents/init.lua b/lua/99/extensions/agents/init.lua new file mode 100644 index 0000000..1b4a4f2 --- /dev/null +++ b/lua/99/extensions/agents/init.lua @@ -0,0 +1,97 @@ +local helpers = require("99.extensions.agents.helpers") +local M = {} + +--- @class _99.Agents.Rule +--- @field name string +--- @field path string + +--- @class _99.Agents.Rules +--- @field cursor _99.Agents.Rule[] +--- @field custom _99.Agents.Rule[] + +--- @class _99.Agents.Agent +--- @field rules _99.Agents.Rules + +---@param _99 _99.State +---@return _99.Agents.Rules +function M.rules(_99) + local cursor = helpers.ls(_99.completion.cursor_rules) + local custom = {} + for _, path in ipairs(_99.completion.custom_rules or {}) do + local custom_rule = helpers.ls(path) + for _, c in ipairs(custom_rule) do + table.insert(custom, c) + end + end + return { + cursor = cursor, + custom = custom, + } +end + +--- @param rules _99.Agents.Rules +--- @return _99.Agents.Rule[] +function M.rules_to_items(rules) + local items = {} + for _, rule in ipairs(rules.cursor or {}) do + table.insert(items, rule) + end + for _, rule in ipairs(rules.custom or {}) do + table.insert(items, rule) + end + return items +end + +--- @param rules _99.Agents.Rules +---@param path string +---@return _99.Agents.Rule | nil +function M.get_rule_by_path(rules, path) + for _, rule in ipairs(rules.cursor or {}) do + if rule.path == path then + return rule + end + end + for _, rule in ipairs(rules.custom or {}) do + if rule.path == path then + return rule + end + end + return nil +end + +--- @param rules _99.Agents.Rules +---@param token string +---@return boolean +function M.is_rule(rules, token) + for _, rule in ipairs(rules.cursor or {}) do + if rule.path == token then + return true + end + end + for _, rule in ipairs(rules.custom or {}) do + if rule.path == token then + return true + end + end + return false +end + +--- @param rules _99.Agents.Rules +--- @param haystack string +--- @return _99.Agents.Rule[] +function M.find_rules(rules, haystack) + --- @type _99.Agents.Rule[] + local out = {} + + for word in haystack:gmatch("@%S+") do + local rule_string = word:sub(2) + local rule = M.get_rule_by_path(rules, rule_string) + if rule then + table.insert(out, rule) + end + end + + return out +end + +return M diff --git a/lua/99/extensions/cmp.lua b/lua/99/extensions/cmp.lua new file mode 100644 index 0000000..d43ee8b --- /dev/null +++ b/lua/99/extensions/cmp.lua @@ -0,0 +1,153 @@ +local Agents = require("99.extensions.agents") +local Helpers = require("99.extensions.agents.helpers") +local SOURCE = "99" + +--- @class _99.Extensions.CmpItem +--- @field rule _99.Agents.Rule +--- @field docs string + +--- @param _99 _99.State +--- @return _99.Extensions.CmpItem[] +local function rules(_99) + local agent_rules = Agents.rules_to_items(_99.rules) + local out = {} + for _, rule in ipairs(agent_rules) do + table.insert(out, { + rule = rule, + docs = Helpers.head(rule.path), + }) + end + return out +end + +--- @class CmpSource +--- @field _99 _99.State +--- @field items _99.Extensions.CmpItem[] +local CmpSource = {} +CmpSource.__index = CmpSource + +--- @param _99 _99.State +function CmpSource.new(_99) + return setmetatable({ + _99 = _99, + items = rules(_99), + }, CmpSource) +end + +function CmpSource.is_available() + return true +end + +function CmpSource.get_debug_name() + return SOURCE +end + +function CmpSource.get_keyword_pattern() + return [[@\k\+]] +end + +function CmpSource.get_trigger_characters() + return { "@" } +end + +--- @class CompletionItem +--- @field label string +--- @field kind number kind is optional but gives icons / categories +--- @field documentation string can be a string or markdown table +--- @field detail string detail shows a right-side hint + +--- @class Completion +--- @field items CompletionItem[] +--- @field isIncomplete boolean - +-- true: I might return more if user types more +-- false: this result set is complete +function CmpSource:complete(params, callback) + local before = params.context.cursor_before_line or "" + local items = {} --[[ @as CompletionItem[] ]] + + if #before > 1 and before:sub(#before - 1) ~= " @" then + callback({ + items = {}, + isIncomplete = false, + }) + return + end + + for _, item in ipairs(self.items) do + table.insert(items, { + label = item.rule.name, + insertText = item.rule.path, + filterText = item.rule.name, + kind = 17, -- file + documentation = { + kind = "markdown", + value = item.docs, + }, + detail = item.rule.path, + }) + end + + callback({ + items = items, + isIncomplete = false, + }) +end + +--- TODO: Look into what this could be +function CmpSource.resolve(completion_item, callback) + callback(completion_item) +end + +function CmpSource.execute(completion_item, callback) + callback(completion_item) +end + +--- @type CmpSource | nil +local source = nil + +--- @param _ _99.State +local function init_for_buffer(_) + local cmp = require("cmp") + cmp.setup.buffer({ + sources = { + { name = SOURCE }, + }, + window = { + completion = { + zindex = 1001, + }, + documentation = { + zindex = 1001, + }, + }, + }) +end + +--- @param _99 _99.State +local function init(_99) + assert( + source == nil, + "the source must be nil when calling init on an completer" + ) + + local cmp = require("cmp") + source = CmpSource.new(_99) + source.items = rules(_99) + cmp.register_source(SOURCE, source) +end + +--- @param _99 _99.State +local function refresh_state(_99) + if not source then + return + end + source.items = rules(_99) +end + +--- @type _99.Extensions.Source +local source_wrapper = { + init_for_buffer = init_for_buffer, + init = init, + refresh_state = refresh_state, +} +return source_wrapper diff --git a/lua/99/extensions/init.lua b/lua/99/extensions/init.lua new file mode 100644 index 0000000..a6e2bd6 --- /dev/null +++ b/lua/99/extensions/init.lua @@ -0,0 +1,47 @@ +local cmp = require("99.extensions.cmp") + +--- @class _99.Extensions.Source +--- @field init_for_buffer fun(_99: _99.State): nil +--- @field init fun(_99: _99.State): nil +--- @field refresh_state fun(_99: _99.State): nil + +--- @param completion _99.Completion | nil +--- @return _99.Extensions.Source | nil +local function get_source(completion) + if not completion or not completion.source then + return + end + local source = completion.source + if source == "cmp" then + return cmp + end +end + +return { + --- @param _99 _99.State + init = function(_99) + local source = get_source(_99.completion) + if not source then + return + end + source.init(_99) + end, + + --- @param _99 _99.State + setup_buffer = function(_99) + local source = get_source(_99.completion) + if not source then + return + end + source.init_for_buffer(_99) + end, + + --- @param _99 _99.State + refresh = function(_99) + local source = get_source(_99.completion) + if not source then + return + end + source.refresh_state(_99) + end, +} diff --git a/lua/99/geo.lua b/lua/99/geo.lua index f025659..b4d89ad 100644 --- a/lua/99/geo.lua +++ b/lua/99/geo.lua @@ -4,10 +4,10 @@ local project_row = 100000000 --- @param col number | nil --- @return number local function project(point_or_row, col) - if type(point_or_row) == "number" then - return point_or_row * project_row + col - end - return point_or_row.row * project_row + point_or_row.col + if type(point_or_row) == "number" then + return point_or_row * project_row + col + end + return point_or_row.row * project_row + point_or_row.col end --- stores all values as 1 based @@ -18,27 +18,27 @@ local Point = {} Point.__index = Point function Point:to_string() - return string.format("point(%d,%d)", self.row, self.col) + return string.format("point(%d,%d)", self.row, self.col) end --- @param buffer number --- @return string function Point:get_text_line(buffer) - local r, _ = self:to_vim() - return vim.api.nvim_buf_get_lines(buffer, r, r + 1, true)[1] + local r, _ = self:to_vim() + return vim.api.nvim_buf_get_lines(buffer, r, r + 1, true)[1] end --- @param buffer number --- @param text string function Point:set_text_line(buffer, text) - local r, _ = self:to_vim() - vim.api.nvim_buf_set_lines(buffer, r, r + 1, false, { text }) + local r, _ = self:to_vim() + vim.api.nvim_buf_set_lines(buffer, r, r + 1, false, { text }) end function Point:update_to_end_of_line() - self.col = vim.fn.col("$") + 1 - local r, c = self:to_one_zero_index() - vim.api.nvim_win_set_cursor(0, { r, c }) + self.col = vim.fn.col("$") + 1 + local r, c = self:to_one_zero_index() + vim.api.nvim_win_set_cursor(0, { r, c }) end --- 1 based point @@ -46,138 +46,138 @@ end --- @param col number --- @return _99.Point function Point:new(row, col) - assert(type(row) == "number", "expected row to be a number") - assert(type(col) == "number", "expected col to be a number") - return setmetatable({ - row = row, - col = col, - }, self) + assert(type(row) == "number", "expected row to be a number") + assert(type(col) == "number", "expected col to be a number") + return setmetatable({ + row = row, + col = col, + }, self) end function Point:from_cursor() - local point = setmetatable({ - row = 0, - col = 0, - }, self) + local point = setmetatable({ + row = 0, + col = 0, + }, self) - --- NOTE: win_get_cursor 1, 0 based return - local cursor = vim.api.nvim_win_get_cursor(0) - local cursor_row, cursor_col = cursor[1], cursor[2] - point.row = cursor_row - point.col = cursor_col + 1 - return point + --- NOTE: win_get_cursor 1, 0 based return + local cursor = vim.api.nvim_win_get_cursor(0) + local cursor_row, cursor_col = cursor[1], cursor[2] + point.row = cursor_row + point.col = cursor_col + 1 + return point end --- Point from nvim_buf_get_extmark_by_id returns which is 0 based --- @param mark _99.Mark function Point.from_extmark(mark) - local buffer = mark.buffer - local ns_id = mark.nsid - local mark_id = mark.id - local row, col = vim.api.nvim_buf_get_extmark_by_id(buffer, ns_id, mark_id) - return setmetatable({ - row = row + 1, - col = col + 1, - }, Point) + local buffer = mark.buffer + local ns_id = mark.nsid + local mark_id = mark.id + local row, col = vim.api.nvim_buf_get_extmark_by_id(buffer, ns_id, mark_id) + return setmetatable({ + row = row + 1, + col = col + 1, + }, Point) end --- @param row number ---@param col number --- @return _99.Point function Point:from_ts_point(row, col) - return setmetatable({ - row = row + 1, - col = col + 1, - }, self) + return setmetatable({ + row = row + 1, + col = col + 1, + }, self) end --- stores all 2 points --- @param range _99.Range --- @return boolean function Point:in_ts_range(range) - return range:contains(self) + return range:contains(self) end --- vim.api.nvim_buf_get_text uses 0 based row and col --- @return number, number function Point:to_lua() - return self.row, self.col + return self.row, self.col end --- @return number, number function Point:to_lsp() - return self.row - 1, self.col - 1 + return self.row - 1, self.col - 1 end --- vim.api.nvim_buf_get_text uses 0 based row and col --- @return number, number function Point:to_vim() - return self.row - 1, self.col - 1 + return self.row - 1, self.col - 1 end function Point:to_one_zero_index() - return self.row, self.col - 1 + return self.row, self.col - 1 end --- treesitter uses 0 based row and col --- @return number, number function Point:to_ts() - return self.row - 1, self.col - 1 + return self.row - 1, self.col - 1 end --- @param point _99.Point --- @return boolean function Point:gt(point) - return project(self) > project(point) + return project(self) > project(point) end --- @param point _99.Point --- @return boolean function Point:lt(point) - return project(self) < project(point) + return project(self) < project(point) end --- @param point _99.Point --- @return boolean function Point:lte(point) - return project(self) <= project(point) + return project(self) <= project(point) end --- @param point _99.Point --- @return boolean function Point:gte(point) - return project(self) >= project(point) + return project(self) >= project(point) end --- @param point _99.Point --- @return boolean function Point:eq(point) - return project(self) == project(point) + return project(self) == project(point) end --- @param point _99.Point --- @return _99.Point function Point:add(point) - return Point:new(self.row + point.row, self.col + point.col) + return Point:new(self.row + point.row, self.col + point.col) end --- @param point _99.Point --- @return _99.Point function Point:sub(point) - return Point:new(self.row - point.row, self.col - point.col) + return Point:new(self.row - point.row, self.col - point.col) end --- @param mark _99.Mark --- @return _99.Point function Point.from_mark(mark) - --- buf extmark by id is a 0 based api - local pos = - vim.api.nvim_buf_get_extmark_by_id(mark.buffer, mark.nsid, mark.id, {}) + --- buf extmark by id is a 0 based api + local pos = + vim.api.nvim_buf_get_extmark_by_id(mark.buffer, mark.nsid, mark.id, {}) - return setmetatable({ - row = pos[1] + 1, - col = pos[2] + 1, - }, Point) + return setmetatable({ + row = pos[1] + 1, + col = pos[2] + 1, + }, Point) end --- @return _99.Point @@ -185,8 +185,8 @@ function Point.from_visual_start() end --- @return _99.Point function Point.from_visual_end() - --- make sure you dont allow visual line extend beyond the end of the - --- actual text line + --- make sure you dont allow visual line extend beyond the end of the + --- actual text line end --- @class _99.Range @@ -200,115 +200,114 @@ Range.__index = Range --- @param start _99.Point ---@param end_ _99.Point function Range:new(buffer, start, end_) - return setmetatable({ - start = start, - end_ = end_, - buffer = buffer, - }, self) + return setmetatable({ + start = start, + end_ = end_, + buffer = buffer, + }, self) end function Range.from_visual_selection() - local buffer = vim.api.nvim_get_current_buf() - local start_pos = vim.fn.getpos("'<") - local end_pos = vim.fn.getpos("'>") - local start = Point:new(start_pos[2], start_pos[3]) - local end_ = Point:new(end_pos[2], end_pos[3]) + local buffer = vim.api.nvim_get_current_buf() + local start_pos = vim.fn.getpos("'<") + local end_pos = vim.fn.getpos("'>") + local start = Point:new(start_pos[2], start_pos[3]) + local end_ = Point:new(end_pos[2], end_pos[3]) - --- visual line mode will select the end point for each row to be int max - --- which will cause marks to fail. so we have to correct it to the literal - --- row length - local end_r, _ = end_:to_vim() - local end_line = - vim.api.nvim_buf_get_lines(buffer, end_r, end_r + 1, false)[1] - local actual_end = - Point:new(end_pos[2], math.min(end_pos[3], #end_line + 1)) + --- visual line mode will select the end point for each row to be int max + --- which will cause marks to fail. so we have to correct it to the literal + --- row length + local end_r, _ = end_:to_vim() + local end_line = + vim.api.nvim_buf_get_lines(buffer, end_r, end_r + 1, false)[1] + local actual_end = Point:new(end_pos[2], math.min(end_pos[3], #end_line + 1)) - return Range:new(buffer, start, actual_end) + return Range:new(buffer, start, actual_end) end ---@param node _99.treesitter.TSNode ---@param buffer number ---@return _99.Range function Range:from_ts_node(node, buffer) - -- ts is zero based - local start_row, start_col, _ = node:start() - local end_row, end_col, _ = node:end_() - local range = { - start = Point:from_ts_point(start_row, start_col), - end_ = Point:from_ts_point(end_row, end_col), - buffer = buffer, - } + -- ts is zero based + local start_row, start_col, _ = node:start() + local end_row, end_col, _ = node:end_() + local range = { + start = Point:from_ts_point(start_row, start_col), + end_ = Point:from_ts_point(end_row, end_col), + buffer = buffer, + } - return setmetatable(range, self) + return setmetatable(range, self) end ---@param start _99.Mark ---@param end_ _99.Mark ---@return _99.Range function Range.from_marks(start, end_) - local start_point = Point.from_mark(start) - local end_point = Point.from_mark(end_) - return Range:new(start.buffer, start_point, end_point) + local start_point = Point.from_mark(start) + local end_point = Point.from_mark(end_) + return Range:new(start.buffer, start_point, end_point) end --- @param replace_with string[] function Range:replace_text(replace_with) - local s_row, s_col = self.start:to_vim() - local e_row, e_col = self.end_:to_vim() - vim.api.nvim_buf_set_text( - self.buffer, - s_row, - s_col, - e_row, - e_col, - replace_with - ) + local s_row, s_col = self.start:to_vim() + local e_row, e_col = self.end_:to_vim() + vim.api.nvim_buf_set_text( + self.buffer, + s_row, + s_col, + e_row, + e_col, + replace_with + ) end --- @param point _99.Point --- @return boolean function Range:contains(point) - local start = project(self.start) - local stop = project(self.end_) - local p = project(point) - return start <= p and p <= stop + local start = project(self.start) + local stop = project(self.end_) + local p = project(point) + return start <= p and p <= stop end --- @return string function Range:to_text() - local sr, sc = self.start:to_vim() - local er, ec = self.end_:to_vim() + local sr, sc = self.start:to_vim() + local er, ec = self.end_:to_vim() - --- blank line vis selection - if sr == er and sc == ec then - ec = ec + 1 - end + --- blank line vis selection + if sr == er and sc == ec then + ec = ec + 1 + end - local text = vim.api.nvim_buf_get_text(self.buffer, sr, sc, er, ec, {}) - return table.concat(text, "\n") + local text = vim.api.nvim_buf_get_text(self.buffer, sr, sc, er, ec, {}) + return table.concat(text, "\n") end --- @param range _99.Range --- @return boolean function Range:contains_range(range) - return self.start:lte(range.start) and self.end_:gte(range.end_) + return self.start:lte(range.start) and self.end_:gte(range.end_) end function Range:area() - local start = project(self.start) - local end_ = project(self.end_) - return end_ - start + local start = project(self.start) + local end_ = project(self.end_) + return end_ - start end function Range:to_string() - return string.format( - "range(%s,%s)", - self.start:to_string(), - self.end_:to_string() - ) + return string.format( + "range(%s,%s)", + self.start:to_string(), + self.end_:to_string() + ) end return { - Point = Point, - Range = Range, + Point = Point, + Range = Range, } diff --git a/lua/99/id.lua b/lua/99/id.lua index 08bb729..45a444f 100644 --- a/lua/99/id.lua +++ b/lua/99/id.lua @@ -1,5 +1,5 @@ local _id = 0 return function() - _id = _id + 1 - return _id + _id = _id + 1 + return _id end diff --git a/lua/99/init.lua b/lua/99/init.lua index fce2d36..1446acb 100644 --- a/lua/99/init.lua +++ b/lua/99/init.lua @@ -6,6 +6,30 @@ local Window = require("99.window") local get_id = require("99.id") local RequestContext = require("99.request-context") local Range = require("99.geo").Range +local Extensions = require("99.extensions") +local Agents = require("99.extensions.agents") + +---@param path_or_rule string | _99.Agents.Rule +---@return _99.Agents.Rule | string +local function expand(path_or_rule) + if type(path_or_rule) == "string" then + return vim.fn.expand(path_or_rule) + end + return { + name = path_or_rule.name, + path = vim.fn.expand(path_or_rule.path), + } +end + +--- @param opts _99.ops.Opts? +--- @return _99.ops.Opts +local function process_opts(opts) + opts = opts or {} + for i, rule in ipairs(opts.additional_rules or {}) do + opts.additional_rules[i] = expand(rule) + end + return opts +end --- @alias _99.Cleanup fun(): nil @@ -22,18 +46,23 @@ local Range = require("99.geo").Range --- @return _99.StateProps local function create_99_state() - return { - model = "opencode/claude-sonnet-4-5", - md_files = {}, - prompts = require("99.prompt-settings"), - ai_stdout_rows = 3, - languages = { "lua", "go", "java", "elixir", "cpp" }, - display_errors = false, - __active_requests = {}, - __view_log_idx = 1, - } + return { + model = "opencode/claude-sonnet-4-5", + md_files = {}, + prompts = require("99.prompt-settings"), + ai_stdout_rows = 3, + languages = { "lua", "go", "java", "elixir", "cpp" }, + display_errors = false, + __active_requests = {}, + __view_log_idx = 1, + } end +--- @class _99.Completion +--- @field source "cmp" | nil +--- @field custom_rules string[] +--- @field cursor_rules string | nil defaults to .cursor/rules + --- @class _99.Options --- @field logger _99.Logger.Options? --- @field model string? @@ -41,10 +70,12 @@ end --- @field provider _99.Provider? --- @field debug_log_prefix string? --- @field display_errors? boolean +--- @field completion _99.Completion? --- unanswered question -- will i need to queue messages one at a time or --- just send them all... So to prepare ill be sending around this state object --- @class _99.State +--- @field completion _99.Completion --- @field model string --- @field md_files string[] --- @field prompts _99.Prompts @@ -52,6 +83,7 @@ end --- @field languages string[] --- @field display_errors boolean --- @field provider_override _99.Provider? +--- @field rules _99.Agents.Rules --- @field __active_requests _99.Cleanup[] --- @field __view_log_idx number local _99_State = {} @@ -59,255 +91,305 @@ _99_State.__index = _99_State --- @return _99.State function _99_State.new() - local props = create_99_state() - ---@diagnostic disable-next-line: return-type-mismatch - return setmetatable(props, _99_State) + local props = create_99_state() + ---@diagnostic disable-next-line: return-type-mismatch + return setmetatable(props, _99_State) +end + +--- TODO: This is something to understand. I bet that this is going to need +--- a lot of performance tuning. I am just reading every file, and this could +--- take a decent amount of time if there are lots of rules. +--- +--- Simple perfs: +--- 1. read 4096 bytes at a tiem instead of whole file and parse out lines +--- 2. don't show the docs +--- 3. do the operation once at setup instead of every time. +--- likely not needed to do this all the time. +function _99_State:refresh_rules() + self.rules = Agents.rules(self) + Extensions.refresh(self) end local _active_request_id = 0 ---@param clean_up _99.Cleanup ---@return number function _99_State:add_active_request(clean_up) - _active_request_id = _active_request_id + 1 - Logger:debug("adding active request", "id", _active_request_id) - self.__active_requests[_active_request_id] = clean_up - return _active_request_id + _active_request_id = _active_request_id + 1 + Logger:debug("adding active request", "id", _active_request_id) + self.__active_requests[_active_request_id] = clean_up + return _active_request_id end function _99_State:active_request_count() - local count = 0 - for _ in pairs(self.__active_requests) do - count = count + 1 - end - return count + local count = 0 + for _ in pairs(self.__active_requests) do + count = count + 1 + end + return count end ---@param id number function _99_State:remove_active_request(id) - local logger = Logger:set_id(id) - local r = self.__active_requests[id] - logger:assert( - r, - "there is no active request for id. implementation broken" - ) - logger:debug("removing active request") - self.__active_requests[id] = nil + local logger = Logger:set_id(id) + local r = self.__active_requests[id] + logger:assert(r, "there is no active request for id. implementation broken") + logger:debug("removing active request") + self.__active_requests[id] = nil end local _99_state = _99_State.new() --- @class _99 local _99 = { - DEBUG = Level.DEBUG, - INFO = Level.INFO, - WARN = Level.WARN, - ERROR = Level.ERROR, - FATAL = Level.FATAL, + DEBUG = Level.DEBUG, + INFO = Level.INFO, + WARN = Level.WARN, + ERROR = Level.ERROR, + FATAL = Level.FATAL, } --- you can only set those marks after the visual selection is removed local function set_selection_marks() - vim.api.nvim_feedkeys( - vim.api.nvim_replace_termcodes("<Esc>", true, false, true), - "x", - false - ) + vim.api.nvim_feedkeys( + vim.api.nvim_replace_termcodes("<Esc>", true, false, true), + "x", + false + ) end --- @param operation_name string --- @return _99.RequestContext local function get_context(operation_name) - local trace_id = get_id() - local context = RequestContext.from_current_buffer(_99_state, trace_id) - context.logger:debug("99 Request", "method", operation_name) - return context + _99_state:refresh_rules() + local trace_id = get_id() + local context = RequestContext.from_current_buffer(_99_state, trace_id) + context.logger:debug("99 Request", "method", operation_name) + return context end function _99.info() - local info = {} - table.insert( - info, - string.format("Agent Files: %s", table.concat(_99_state.md_files, ", ")) - ) - table.insert(info, string.format("Model: %s", _99_state.model)) - table.insert( - info, - string.format("AI Stdout Rows: %d", _99_state.ai_stdout_rows) - ) - table.insert( - info, - string.format("Display Errors: %s", tostring(_99_state.display_errors)) - ) - table.insert( - info, - string.format("Active Requests: %d", _99_state:active_request_count()) - ) - Window.display_centered_message(info) + local info = {} + table.insert( + info, + string.format("Agent Files: %s", table.concat(_99_state.md_files, ", ")) + ) + table.insert(info, string.format("Model: %s", _99_state.model)) + table.insert( + info, + string.format("AI Stdout Rows: %d", _99_state.ai_stdout_rows) + ) + table.insert( + info, + string.format("Display Errors: %s", tostring(_99_state.display_errors)) + ) + table.insert( + info, + string.format("Active Requests: %d", _99_state:active_request_count()) + ) + Window.display_centered_message(info) end -function _99.fill_in_function_prompt() - local context = get_context("fill-in-function-with-prompt") - context.logger:debug("start") - Window.capture_input(function(success, response) - context.logger:debug( - "capture_prompt", - "success", - success, - "response", - response - ) - if success then - ops.fill_in_function(context, response) - end - end, {}) +--- @param path string +function _99:rule_from_path(path) + _ = self + path = expand(path) --[[ @as string]] + return Agents.get_rule_by_path(_99_state.rules, path) end -function _99.fill_in_function() - ops.fill_in_function(get_context("fill_in_function")) +--- @param opts? _99.ops.Opts +function _99.fill_in_function_prompt(opts) + opts = process_opts(opts) + local context = get_context("fill-in-function-with-prompt") + + context.logger:debug("start") + Window.capture_input({ + cb = function(success, response) + context.logger:debug( + "capture_prompt", + "success", + success, + "response", + response + ) + if success then + opts.additional_prompt = response + ops.fill_in_function(context, opts) + end + end, + on_load = function() + Extensions.setup_buffer(_99_state) + end, + }) end -function _99.visual_prompt() - local context = get_context("over-range-with-prompt") - context.logger:debug("start") - Window.capture_input(function(success, response) - context.logger:debug( - "capture_prompt", - "success", - success, - "response", - response - ) - if success then - _99.visual(response) - end - end, {}) +--- @param opts? _99.ops.Opts +function _99.fill_in_function(opts) + opts = process_opts(opts) + ops.fill_in_function(get_context("fill_in_function"), opts) +end + +--- @param opts _99.ops.Opts +function _99.visual_prompt(opts) + opts = process_opts(opts) + local context = get_context("over-range-with-prompt") + context.logger:debug("start") + Window.capture_input({ + cb = function(success, response) + context.logger:debug( + "capture_prompt", + "success", + success, + "response", + response + ) + if success then + opts.additional_prompt = response + _99.visual(context, opts) + end + end, + on_load = function() + Extensions.setup_buffer(_99_state) + end, + }) end ---- @param prompt string? --- @param context _99.RequestContext? -function _99.visual(prompt, context) - --- TODO: Talk to teej about this. - --- Visual selection marks are only set in place post visual selection. - --- that means for this function to work i must escape out of visual mode - --- which i dislike very much. because maybe you dont want this - set_selection_marks() +--- @param opts _99.ops.Opts? +function _99.visual(context, opts) + opts = process_opts(opts) + --- TODO: Talk to teej about this. + --- Visual selection marks are only set in place post visual selection. + --- that means for this function to work i must escape out of visual mode + --- which i dislike very much. because maybe you dont want this + set_selection_marks() - context = context or get_context("over-range") - local range = Range.from_visual_selection() - ops.over_range(context, range, prompt) + context = context or get_context("over-range") + local range = Range.from_visual_selection() + ops.over_range(context, range, opts) end --- View all the logs that are currently cached. Cached log count is determined --- by _99.Logger.Options that are passed in. function _99.view_logs() - _99_state.__view_log_idx = 1 - local logs = Logger.logs() - if #logs == 0 then - print("no logs to display") - return - end - Window.display_full_screen_message(logs[1]) + _99_state.__view_log_idx = 1 + local logs = Logger.logs() + if #logs == 0 then + print("no logs to display") + return + end + Window.display_full_screen_message(logs[1]) end function _99.prev_request_logs() - local logs = Logger.logs() - if #logs == 0 then - print("no logs to display") - return - end - _99_state.__view_log_idx = math.min(_99_state.__view_log_idx + 1, #logs) - Window.display_full_screen_message(logs[_99_state.__view_log_idx]) + local logs = Logger.logs() + if #logs == 0 then + print("no logs to display") + return + end + _99_state.__view_log_idx = math.min(_99_state.__view_log_idx + 1, #logs) + Window.display_full_screen_message(logs[_99_state.__view_log_idx]) end function _99.next_request_logs() - local logs = Logger.logs() - if #logs == 0 then - print("no logs to display") - return - end - _99_state.__view_log_idx = math.max(_99_state.__view_log_idx - 1, 1) - Window.display_full_screen_message(logs[_99_state.__view_log_idx]) -end - -function _99.__debug_ident() - ops.debug_ident(_99_state) + local logs = Logger.logs() + if #logs == 0 then + print("no logs to display") + return + end + _99_state.__view_log_idx = math.max(_99_state.__view_log_idx - 1, 1) + Window.display_full_screen_message(logs[_99_state.__view_log_idx]) end function _99.stop_all_requests() - for _, clean_up in pairs(_99_state.__active_requests) do - clean_up() - end - _99_state.__active_requests = {} + for _, clean_up in pairs(_99_state.__active_requests) do + clean_up() + end + _99_state.__active_requests = {} end --- if you touch this function you will be fired --- @return _99.State function _99.__get_state() - return _99_state + return _99_state end --- @param opts _99.Options? function _99.setup(opts) - opts = opts or {} - _99_state = _99_State.new() - _99_state.provider_override = opts.provider + opts = opts or {} + _99_state = _99_State.new() + _99_state.provider_override = opts.provider + _99_state.completion = opts.completion + or { + source = nil, + custom_rules = {}, + } + _99_state.completion.cursor_rules = _99_state.completion.cursor_rules + or ".cursor/rules/" + _99_state.completion.custom_rules = _99_state.completion.custom_rules or {} - vim.api.nvim_create_autocmd("VimLeavePre", { - callback = function() - _99.stop_all_requests() - end, - }) + local crules = _99_state.completion.custom_rules + for i, rule in ipairs(crules) do + crules[i] = expand(rule) + end - Logger:configure(opts.logger) + vim.api.nvim_create_autocmd("VimLeavePre", { + callback = function() + _99.stop_all_requests() + end, + }) - if opts.model then - assert(type(opts.model) == "string", "opts.model is not a string") - _99_state.model = opts.model - end + Logger:configure(opts.logger) + + if opts.model then + assert(type(opts.model) == "string", "opts.model is not a string") + _99_state.model = opts.model + end - if opts.md_files then - assert(type(opts.md_files) == "table", "opts.md_files is not a table") - for _, md in ipairs(opts.md_files) do - _99.add_md_file(md) - end + if opts.md_files then + assert(type(opts.md_files) == "table", "opts.md_files is not a table") + for _, md in ipairs(opts.md_files) do + _99.add_md_file(md) end + end - _99_state.display_errors = opts.display_errors or false + _99_state.display_errors = opts.display_errors or false - Languages.initialize(_99_state) + _99_state:refresh_rules() + Languages.initialize(_99_state) + Extensions.init(_99_state) end --- @param md string --- @return _99 function _99.add_md_file(md) - table.insert(_99_state.md_files, md) - return _99 + table.insert(_99_state.md_files, md) + return _99 end --- @param md string --- @return _99 function _99.rm_md_file(md) - for i, name in ipairs(_99_state.md_files) do - if name == md then - table.remove(_99_state.md_files, i) - break - end + for i, name in ipairs(_99_state.md_files) do + if name == md then + table.remove(_99_state.md_files, i) + break end - return _99 + end + return _99 end --- @param model string --- @return _99 function _99.set_model(model) - _99_state.model = model - return _99 + _99_state.model = model + return _99 end function _99.__debug() - Logger:configure({ - path = nil, - level = Level.DEBUG, - }) + Logger:configure({ + path = nil, + level = Level.DEBUG, + }) end return _99 diff --git a/lua/99/language/cpp.lua b/lua/99/language/cpp.lua index 5c96c8c..e81e219 100644 --- a/lua/99/language/cpp.lua +++ b/lua/99/language/cpp.lua @@ -5,7 +5,7 @@ M.names = {} --- @param item_name string --- @return string function M.log_item(item_name) - return string.format("std::println({}, %s)", item_name) + return string.format("std::println({}, %s)", item_name) end return M diff --git a/lua/99/language/go.lua b/lua/99/language/go.lua index 1f70561..0634b9d 100644 --- a/lua/99/language/go.lua +++ b/lua/99/language/go.lua @@ -5,7 +5,7 @@ M.names = {} --- @param item_name string --- @return string function M.log_item(item_name) - return string.format('fmt.Printf("%%+v\\n", %s)', item_name) + return string.format('fmt.Printf("%%+v\\n", %s)', item_name) end return M diff --git a/lua/99/language/init.lua b/lua/99/language/init.lua index 275ac93..a941abe 100644 --- a/lua/99/language/init.lua +++ b/lua/99/language/init.lua @@ -7,7 +7,7 @@ local Logger = require("99.logger.logger") --- @class _99.Langauges --- @field languages table<string, _99.LanguageOps> local M = { - languages = {}, + languages = {}, } --- @alias _99.langauge.GetLangParam _99.Location | number? @@ -17,37 +17,37 @@ local M = { --- @return string --- @return number local function get_langauge(bufferOrLoc) - if type(bufferOrLoc) == "number" or not bufferOrLoc then - local buffer = bufferOrLoc or vim.api.nvim_get_current_buf() - local file_type = - vim.api.nvim_get_option_value("filetype", { buf = buffer }) - local lang = M.languages[file_type] - if not lang then - Logger:fatal("language currently not supported", "lang", file_type) - end - return lang, file_type, buffer - end - - local file_type = bufferOrLoc.file_type + if type(bufferOrLoc) == "number" or not bufferOrLoc then + local buffer = bufferOrLoc or vim.api.nvim_get_current_buf() + local file_type = + vim.api.nvim_get_option_value("filetype", { buf = buffer }) local lang = M.languages[file_type] if not lang then - Logger:fatal("language currently not supported", "lang", file_type) + Logger:fatal("language currently not supported", "lang", file_type) end - return lang, file_type, bufferOrLoc.buffer + return lang, file_type, buffer + end + + local file_type = bufferOrLoc.file_type + local lang = M.languages[file_type] + if not lang then + Logger:fatal("language currently not supported", "lang", file_type) + end + return lang, file_type, bufferOrLoc.buffer end local function validate_function(fn, file_type) - if type(fn) ~= "function" then - Logger:fatal("language does not support log_item", "lang", file_type) - end + if type(fn) ~= "function" then + Logger:fatal("language does not support log_item", "lang", file_type) + end end --- @param _99 _99.State function M.initialize(_99) - M.languages = {} - for _, lang in ipairs(_99.languages) do - M.languages[lang] = require("99.language." .. lang) - end + M.languages = {} + for _, lang in ipairs(_99.languages) do + M.languages[lang] = require("99.language." .. lang) + end end --- @param _ _99.State @@ -55,10 +55,10 @@ end --- @param buffer number? --- @return string function M.log_item(_, item_name, buffer) - local lang, file_type = get_langauge(buffer) - validate_function(lang.log_item, file_type) + local lang, file_type = get_langauge(buffer) + validate_function(lang.log_item, file_type) - return lang.log_item(item_name) + return lang.log_item(item_name) end --[[ diff --git a/lua/99/language/java.lua b/lua/99/language/java.lua index f60a396..55109ee 100644 --- a/lua/99/language/java.lua +++ b/lua/99/language/java.lua @@ -5,7 +5,7 @@ M.names = {} --- @param item_name string --- @return string function M.log_item(item_name) - return string.format("System.out.println(%s)", item_name) + return string.format("System.out.println(%s)", item_name) end return M diff --git a/lua/99/language/lua.lua b/lua/99/language/lua.lua index 9419434..ad0aa3f 100644 --- a/lua/99/language/lua.lua +++ b/lua/99/language/lua.lua @@ -5,7 +5,7 @@ M.names = {} --- @param item_name string --- @return string function M.log_item(item_name) - return string.format("vim.inspect(%s)", item_name) + return string.format("vim.inspect(%s)", item_name) end return M diff --git a/lua/99/language/typescript.lua b/lua/99/language/typescript.lua index 28bc0fa..09831b9 100644 --- a/lua/99/language/typescript.lua +++ b/lua/99/language/typescript.lua @@ -3,11 +3,11 @@ local M = {} --- @param item_name string --- @return string function M.log_item(item_name) - return item_name + return item_name end M.names = { - body = "body", + body = "body", } return M diff --git a/lua/99/logger/level.lua b/lua/99/logger/level.lua index 1b33caa..9309ac2 100644 --- a/lua/99/logger/level.lua +++ b/lua/99/logger/level.lua @@ -7,26 +7,26 @@ local FATAL = 15 --- @param level number --- @return string local function levelToString(level) - if level == DEBUG then - return "DEBUG" - elseif level == INFO then - return "INFO" - elseif level == WARN then - return "WARN" - elseif level == ERROR then - return "ERROR" - elseif level == FATAL then - return "FATAL" - end - assert(false, "unknown level", level) - return "" + if level == DEBUG then + return "DEBUG" + elseif level == INFO then + return "INFO" + elseif level == WARN then + return "WARN" + elseif level == ERROR then + return "ERROR" + elseif level == FATAL then + return "FATAL" + end + assert(false, "unknown level", level) + return "" end return { - DEBUG = DEBUG, - INFO = INFO, - WARN = WARN, - ERROR = ERROR, - FATAL = FATAL, - levelToString = levelToString, + DEBUG = DEBUG, + INFO = INFO, + WARN = WARN, + ERROR = ERROR, + FATAL = FATAL, + levelToString = levelToString, } diff --git a/lua/99/logger/logger.lua b/lua/99/logger/logger.lua index fdc010a..b2dc37b 100644 --- a/lua/99/logger/logger.lua +++ b/lua/99/logger/logger.lua @@ -17,29 +17,29 @@ local max_requests_in_logger_cache = MAX_REQUEST_DEFAULT --- @param ... any --- @return table<string, any> local function to_args(...) - local count = select("#", ...) - local out = {} - assert( - count % 2 == 0, - "you cannot call logging with an odd number of args. e.g: msg, [k, v]..." - ) - for i = 1, count, 2 do - local key = select(i, ...) - local value = select(i + 1, ...) - assert(type(key) == "string", "keys in logging must be strings") - assert(out[key] == nil, "key collision in logs: " .. key) - out[key] = value - end - return out + local count = select("#", ...) + local out = {} + assert( + count % 2 == 0, + "you cannot call logging with an odd number of args. e.g: msg, [k, v]..." + ) + for i = 1, count, 2 do + local key = select(i, ...) + local value = select(i + 1, ...) + assert(type(key) == "string", "keys in logging must be strings") + assert(out[key] == nil, "key collision in logs: " .. key) + out[key] = value + end + return out end --- @param log_statement table<string, any> --- @param args table<string, any> local function put_args(log_statement, args) - for k, v in pairs(args) do - assert(log_statement[k] == nil, "key collision in logs: " .. k) - log_statement[k] = v - end + for k, v in pairs(args) do + assert(log_statement[k] == nil, "key collision in logs: " .. k) + log_statement[k] = v + end end --- @class LoggerSink @@ -50,12 +50,12 @@ local VoidSink = {} VoidSink.__index = VoidSink function VoidSink.new() - return setmetatable({}, VoidSink) + return setmetatable({}, VoidSink) end --- @param _ string function VoidSink:write_line(_) - _ = self + _ = self end --- @class FileSink : LoggerSink @@ -66,23 +66,23 @@ FileSink.__index = FileSink --- @param path string --- @return LoggerSink function FileSink:new(path) - local fd, err = vim.uv.fs_open(path, "w", 493) - if not fd then - error("unable to file sink", err) - end + local fd, err = vim.uv.fs_open(path, "w", 493) + if not fd then + error("unable to file sink", err) + end - return setmetatable({ - fd = fd, - }, self) + return setmetatable({ + fd = fd, + }, self) end --- @param str string function FileSink:write_line(str) - local success, err = vim.uv.fs_write(self.fd, str .. "\n") - if not success then - error("unable to write to file sink", err) - end - vim.uv.fs_fsync(self.fd) + local success, err = vim.uv.fs_write(self.fd, str .. "\n") + if not success then + error("unable to write to file sink", err) + end + vim.uv.fs_fsync(self.fd) end --- @class PrintSink : LoggerSink @@ -91,13 +91,13 @@ PrintSink.__index = PrintSink --- @return LoggerSink function PrintSink:new() - return setmetatable({}, self) + return setmetatable({}, self) end --- @param str string function PrintSink:write_line(str) - local _ = self - print(str) + local _ = self + print(str) end --- @class _99.Logger.RequestLogs @@ -115,259 +115,258 @@ Logger.__index = Logger --- @param level number? --- @return _99.Logger function Logger:new(level) - level = level or levels.FATAL - return setmetatable({ - sink = VoidSink:new(), - level = level, - print_on_error = false, - extra_params = {}, - }, self) + level = level or levels.FATAL + return setmetatable({ + sink = VoidSink:new(), + level = level, + print_on_error = false, + extra_params = {}, + }, self) end --- @return _99.Logger function Logger:clone() - local params = {} - for k, v in pairs(self.extra_params) do - params[k] = v - end - return setmetatable({ - sink = self.sink, - level = self.level, - print_on_error = self.print_on_error, - extra_params = params, - }, Logger) + local params = {} + for k, v in pairs(self.extra_params) do + params[k] = v + end + return setmetatable({ + sink = self.sink, + level = self.level, + print_on_error = self.print_on_error, + extra_params = params, + }, Logger) end --- @param path string --- @return _99.Logger function Logger:file_sink(path) - self.sink = FileSink:new(path) - return self + self.sink = FileSink:new(path) + return self end --- @return _99.Logger function Logger:void_sink() - self.sink = VoidSink:new() - return self + self.sink = VoidSink:new() + return self end --- @return _99.Logger function Logger:print_sink() - self.sink = PrintSink:new() - return self + self.sink = PrintSink:new() + return self end --- @param area string --- @return _99.Logger function Logger:set_area(area) - local new_logger = self:clone() - new_logger.extra_params["Area"] = area - return new_logger + local new_logger = self:clone() + new_logger.extra_params["Area"] = area + return new_logger end --- @param xid number --- @return _99.Logger function Logger:set_id(xid) - local new_logger = self:clone() - new_logger.extra_params["id"] = xid - return new_logger + local new_logger = self:clone() + new_logger.extra_params["id"] = xid + return new_logger end --- @param level number --- @return _99.Logger function Logger:set_level(level) - self.level = level - return self + self.level = level + return self end --- @return _99.Logger function Logger:on_error_print_message() - self.print_on_error = true - return self + self.print_on_error = true + return self end --- @param opts _99.Logger.Options? function Logger:configure(opts) - if not opts then - return - end + if not opts then + return + end - if opts.level then - self:set_level(opts.level) - end + if opts.level then + self:set_level(opts.level) + end - if opts.type == "print" then - self:print_sink() - elseif opts.type == "file" then - assert( - opts.path, - "if you choose file for logger, you must have a path specified" - ) - self:file_sink(opts.path) - else - self:void_sink() - end + if opts.type == "print" then + self:print_sink() + elseif opts.type == "file" then + assert( + opts.path, + "if you choose file for logger, you must have a path specified" + ) + self:file_sink(opts.path) + else + self:void_sink() + end - if opts.print_on_error then - self:on_error_print_message() - end + if opts.print_on_error then + self:on_error_print_message() + end - max_requests_in_logger_cache = opts.max_requests_cached - or MAX_REQUEST_DEFAULT + max_requests_in_logger_cache = opts.max_requests_cached or MAX_REQUEST_DEFAULT end --- @param line string function Logger:_cache_log(line) - local id = self.extra_params.id - if not id then - return - end + local id = self.extra_params.id + if not id then + return + end - local cache = logger_cache[id] - local new_cache = false - if not cache then - cache = { - last_access = time.now(), - logs = {}, - } - logger_cache[id] = cache - table.insert(logger_list, id) - new_cache = true - end - cache.last_access = time.now() - table.insert(cache.logs, line) - table.sort(logger_list, function(a, b) - assert( - logger_cache[a] and logger_cache[b], - "logger list is out of sync with logger cache: " - .. tostring(a) - .. " and " - .. tostring(b) - ) - local a_time = logger_cache[a].last_access - local b_time = logger_cache[b].last_access - return a_time > b_time - end) + local cache = logger_cache[id] + local new_cache = false + if not cache then + cache = { + last_access = time.now(), + logs = {}, + } + logger_cache[id] = cache + table.insert(logger_list, id) + new_cache = true + end + cache.last_access = time.now() + table.insert(cache.logs, line) + table.sort(logger_list, function(a, b) + assert( + logger_cache[a] and logger_cache[b], + "logger list is out of sync with logger cache: " + .. tostring(a) + .. " and " + .. tostring(b) + ) + local a_time = logger_cache[a].last_access + local b_time = logger_cache[b].last_access + return a_time > b_time + end) - if not new_cache then - return - end + if not new_cache then + return + end - Logger._trim_cache() + Logger._trim_cache() end --- This is a _TEST ONLY_ function. you should not call this function outside --- of unit tests function Logger.reset() - logger_cache = {} - max_requests_in_logger_cache = MAX_REQUEST_DEFAULT + logger_cache = {} + max_requests_in_logger_cache = MAX_REQUEST_DEFAULT end --- @return string[][] function Logger.logs() - local out = {} - for _, id in ipairs(logger_list) do - local request_logs = logger_cache[id] - table.insert(out, request_logs.logs) - end - return out + local out = {} + for _, id in ipairs(logger_list) do + local request_logs = logger_cache[id] + table.insert(out, request_logs.logs) + end + return out end --- @param level number ---@param msg string ---@param ... any function Logger:_log(level, msg, ...) - if self.level > level then - return - end + if self.level > level then + return + end - local log_statement = { - level = levels.levelToString(level), - msg = msg, - } + local log_statement = { + level = levels.levelToString(level), + msg = msg, + } - put_args(log_statement, to_args(...)) - put_args(log_statement, self.extra_params) + put_args(log_statement, to_args(...)) + put_args(log_statement, self.extra_params) - assert(log_statement["id"], "every log must have an id associated with it") + assert(log_statement["id"], "every log must have an id associated with it") - local json_string = vim.json.encode(log_statement) - if self.print_on_error and level == levels.ERROR then - print(json_string) - end + local json_string = vim.json.encode(log_statement) + if self.print_on_error and level == levels.ERROR then + print(json_string) + end - self:_cache_log(json_string) - self.sink:write_line(json_string) + self:_cache_log(json_string) + self.sink:write_line(json_string) end --- @param msg string --- @param ... any function Logger:info(msg, ...) - self:_log(levels.INFO, msg, ...) + self:_log(levels.INFO, msg, ...) end --- @param msg string --- @param ... any function Logger:warn(msg, ...) - self:_log(levels.WARN, msg, ...) + self:_log(levels.WARN, msg, ...) end --- @param msg string --- @param ... any function Logger:debug(msg, ...) - self:_log(levels.DEBUG, msg, ...) + self:_log(levels.DEBUG, msg, ...) end --- @param msg string --- @param ... any function Logger:error(msg, ...) - self:_log(levels.ERROR, msg, ...) + self:_log(levels.ERROR, msg, ...) end --- @param msg string --- @param ... any function Logger:fatal(msg, ...) - self:_log(levels.FATAL, msg, ...) - assert(false, "fatal msg recieved: " .. msg, ...) + self:_log(levels.FATAL, msg, ...) + assert(false, "fatal msg recieved: " .. msg, ...) end --- @param test any ---@param msg string ---@param ... any function Logger:assert(test, msg, ...) - if not test then - self:fatal(msg, ...) - end + if not test then + self:fatal(msg, ...) + end end function Logger._trim_cache() - local count = 0 - local oldest = nil - local oldest_key = nil - for k, log in pairs(logger_cache) do - if oldest == nil or log.last_access < oldest.last_access then - oldest = log - oldest_key = k - end - count = count + 1 + local count = 0 + local oldest = nil + local oldest_key = nil + for k, log in pairs(logger_cache) do + if oldest == nil or log.last_access < oldest.last_access then + oldest = log + oldest_key = k end + count = count + 1 + end - if count > max_requests_in_logger_cache then - assert(oldest_key, "oldest key must exist") - logger_cache[oldest_key] = nil + if count > max_requests_in_logger_cache then + assert(oldest_key, "oldest key must exist") + logger_cache[oldest_key] = nil - for i, id in ipairs(logger_list) do - if id == oldest_key then - table.remove(logger_list, i) - break - end - end + for i, id in ipairs(logger_list) do + if id == oldest_key then + table.remove(logger_list, i) + break + end end + end end function Logger.set_max_cached_requests(count) - max_requests_in_logger_cache = count - Logger._trim_cache() + max_requests_in_logger_cache = count + Logger._trim_cache() end local module_logger = Logger:new(levels.DEBUG) diff --git a/lua/99/ops/clean-up.lua b/lua/99/ops/clean-up.lua index a9545b1..d56b6de 100644 --- a/lua/99/ops/clean-up.lua +++ b/lua/99/ops/clean-up.lua @@ -2,18 +2,18 @@ ---@param clean_up_fn fun(): nil ---@return fun(): nil return function(context, clean_up_fn) - local called = false - local request_id = -1 - local function clean_up() - if called then - return - end - - called = true - clean_up_fn() - context._99:remove_active_request(request_id) + local called = false + local request_id = -1 + local function clean_up() + if called then + return end - request_id = context._99:add_active_request(clean_up) - return clean_up + called = true + clean_up_fn() + context._99:remove_active_request(request_id) + end + request_id = context._99:add_active_request(clean_up) + + return clean_up end diff --git a/lua/99/ops/fill-in-function.lua b/lua/99/ops/fill-in-function.lua index b29c3da..ce5e105 100644 --- a/lua/99/ops/fill-in-function.lua +++ b/lua/99/ops/fill-in-function.lua @@ -6,112 +6,124 @@ local editor = require("99.editor") local RequestStatus = require("99.ops.request_status") local Window = require("99.window") local make_clean_up = require("99.ops.clean-up") +local Agents = require("99.extensions.agents") --- @param context _99.RequestContext --- @param res string local function update_file_with_changes(context, res) - local buffer = context.buffer - local mark = context.marks.function_location - local logger = - context.logger:set_area("fill_in_function#update_file_with_changes") + local buffer = context.buffer + local mark = context.marks.function_location + local logger = + context.logger:set_area("fill_in_function#update_file_with_changes") - logger:assert( - mark and buffer, - "mark and buffer have to be set on the location object" - ) - logger:assert(mark:is_valid(), "mark is no longer valid") + logger:assert( + mark and buffer, + "mark and buffer have to be set on the location object" + ) + logger:assert(mark:is_valid(), "mark is no longer valid") - local func_start = Point.from_mark(mark) - local ts = editor.treesitter - local func = ts.containing_function(context, func_start) + local func_start = Point.from_mark(mark) + local ts = editor.treesitter + local func = ts.containing_function(context, func_start) - logger:assert( - func, - "update_file_with_changes: unable to find function at mark location" - ) + logger:assert( + func, + "update_file_with_changes: unable to find function at mark location" + ) - local lines = vim.split(res, "\n") + local lines = vim.split(res, "\n") - -- lua docs ignore next error, func being tested already in assert - -- TODO: fix this? - func:replace_text(lines) + -- lua docs ignore next error, func being tested already in assert + -- TODO: fix this? + func:replace_text(lines) end --- @param context _99.RequestContext ---- @param additional_prompt string? -local function fill_in_function(context, additional_prompt) - local logger = context.logger:set_area("fill_in_function") - local ts = editor.treesitter - local buffer = vim.api.nvim_get_current_buf() - local cursor = Point:from_cursor() - local func = ts.containing_function(context, cursor) +--- @param opts? _99.ops.Opts +local function fill_in_function(context, opts) + opts = opts or {} + local logger = context.logger:set_area("fill_in_function") + local ts = editor.treesitter + local buffer = vim.api.nvim_get_current_buf() + local cursor = Point:from_cursor() + local func = ts.containing_function(context, cursor) - if not func then - logger:fatal("fill_in_function: unable to find any containing function") - return - end + if not func then + logger:fatal("fill_in_function: unable to find any containing function") + return + end - context.range = func.function_range + context.range = func.function_range - local virt_line_count = context._99.ai_stdout_rows - if virt_line_count >= 0 then - context.marks.function_location = Mark.mark_func_body(buffer, func) - end + logger:debug("fill_in_function", "opts", opts) + local virt_line_count = context._99.ai_stdout_rows + if virt_line_count >= 0 then + context.marks.function_location = Mark.mark_func_body(buffer, func) + end - local request = Request.new(context) - local full_prompt = context._99.prompts.prompts.fill_in_function() - if additional_prompt then - full_prompt = - context._99.prompts.prompts.prompt(additional_prompt, full_prompt) - end - request:add_prompt_content(full_prompt) + local request = Request.new(context) + local full_prompt = context._99.prompts.prompts.fill_in_function() + local additional_prompt = opts.additional_prompt + if additional_prompt then + full_prompt = + context._99.prompts.prompts.prompt(additional_prompt, full_prompt) - local request_status = RequestStatus.new( - 250, - context._99.ai_stdout_rows, - "Loading", - context.marks.function_location - ) - request_status:start() + local rules = Agents.find_rules(context._99.rules, additional_prompt) + logger:debug("found rules", "rules", rules) + context:add_agent_rules(rules) + end - local clean_up = make_clean_up(context, function() - context:clear_marks() - request:cancel() - request_status:stop() - end) + local additional_rules = opts.additional_rules + if additional_rules then + logger:debug("additional_rules", "additional_rules", additional_rules) + context:add_agent_rules(additional_rules) + end - request:start({ - on_stdout = function(line) - request_status:push(line) - end, - on_complete = function(status, response) - logger:info("on_complete", "status", status, "response", response) - vim.schedule(clean_up) + request:add_prompt_content(full_prompt) - if status == "failed" then - if context._99.display_errors then - Window.display_error( - "Error encountered while processing fill_in_function\n" - .. ( - response - or "No Error text provided. Check logs" - ) - ) - end - logger:error( - "unable to fill in function, enable and check logger for more details" - ) - elseif status == "cancelled" then - logger:debug("fill_in_function was cancelled") - -- TODO: small status window here - elseif status == "success" then - update_file_with_changes(context, response) - end - end, - on_stderr = function(line) - logger:debug("fill_in_function#on_stderr", "line", line) - end, - }) + local request_status = RequestStatus.new( + 250, + context._99.ai_stdout_rows, + "Loading", + context.marks.function_location + ) + request_status:start() + + local clean_up = make_clean_up(context, function() + context:clear_marks() + request:cancel() + request_status:stop() + end) + + request:start({ + on_stdout = function(line) + request_status:push(line) + end, + on_complete = function(status, response) + logger:info("on_complete", "status", status, "response", response) + vim.schedule(clean_up) + + if status == "failed" then + if context._99.display_errors then + Window.display_error( + "Error encountered while processing fill_in_function\n" + .. (response or "No Error text provided. Check logs") + ) + end + logger:error( + "unable to fill in function, enable and check logger for more details" + ) + elseif status == "cancelled" then + logger:debug("fill_in_function was cancelled") + -- TODO: small status window here + elseif status == "success" then + update_file_with_changes(context, response) + end + end, + on_stderr = function(line) + logger:debug("fill_in_function#on_stderr", "line", line) + end, + }) end return fill_in_function diff --git a/lua/99/ops/implement-fn.lua b/lua/99/ops/implement-fn.lua index 85cac0f..efec8a3 100644 --- a/lua/99/ops/implement-fn.lua +++ b/lua/99/ops/implement-fn.lua @@ -10,83 +10,83 @@ local make_clean_up = require("99.ops.clean-up") --- @param context _99.RequestContext --- @param response string local function update_code(context, response) - local code_mark = context.marks.code_placement - local logger = context.logger:set_area("implement_fn#update_code") - local point = Point.from_mark(code_mark) + local code_mark = context.marks.code_placement + local logger = context.logger:set_area("implement_fn#update_code") + local point = Point.from_mark(code_mark) - logger:debug("setting text at mark", "Point", point) - code_mark:set_text_at_mark("\n" .. response) + logger:debug("setting text at mark", "Point", point) + code_mark:set_text_at_mark("\n" .. response) end --- @param context _99.RequestContext local function implement_fn(context) - local ts = editor.treesitter - local cursor = Point:from_cursor() - local buffer = vim.api.nvim_get_current_buf() - local fn_call = ts.fn_call(buffer, cursor) - local logger = context.logger:set_area("implement_fn") + local ts = editor.treesitter + local cursor = Point:from_cursor() + local buffer = vim.api.nvim_get_current_buf() + local fn_call = ts.fn_call(buffer, cursor) + local logger = context.logger:set_area("implement_fn") - if not fn_call then - logger:fatal( - "cannot implement function, cursor was not on an identifier that is a function call" - ) - return - end + if not fn_call then + logger:fatal( + "cannot implement function, cursor was not on an identifier that is a function call" + ) + return + end - local range = Range:from_ts_node(fn_call, buffer) - local request = Request.new(context) + local range = Range:from_ts_node(fn_call, buffer) + local request = Request.new(context) - context.marks.end_of_fn_call = Mark.mark_end_of_range(buffer, range) - local func = ts.containing_function(buffer, cursor) - if func then - context.marks.code_placement = Mark.mark_above_func(buffer, func) - else - context.marks.code_placement = Mark.mark_above_range(range) - end + context.marks.end_of_fn_call = Mark.mark_end_of_range(buffer, range) + local func = ts.containing_function(buffer, cursor) + if func then + context.marks.code_placement = Mark.mark_above_func(buffer, func) + else + context.marks.code_placement = Mark.mark_above_range(range) + end - local code_placement = RequestStatus.new( - 250, - context._99.ai_stdout_rows, - "Loading", - context.marks.code_placement - ) - local at_call_site = RequestStatus.new( - 250, - 1, - "Implementing Function", - context.marks.end_of_fn_call - ) + local code_placement = RequestStatus.new( + 250, + context._99.ai_stdout_rows, + "Loading", + context.marks.code_placement + ) + local at_call_site = RequestStatus.new( + 250, + 1, + "Implementing Function", + context.marks.end_of_fn_call + ) - code_placement:start() - at_call_site:start() + code_placement:start() + at_call_site:start() - local clean_up = make_clean_up(context, function() - context:clear_marks() - request:cancel() - code_placement:stop() - at_call_site:stop() - end) + local clean_up = make_clean_up(context, function() + context:clear_marks() + request:cancel() + code_placement:stop() + at_call_site:stop() + end) - request:add_prompt_content(context._99.prompts.prompts.implement_function) - request:start({ - on_stdout = function(line) - code_placement:push(line) - end, - on_complete = function(status, response) - vim.schedule(clean_up) - if status ~= "success" then - logger:fatal( - "unable to implement function, enable and check logger for more details" - ) - end - pcall(update_code, context, response) - end, - on_stderr = function(line) - logger:error("stderr", "line", line) - end, - }) + request:add_prompt_content(context._99.prompts.prompts.implement_function) + request:start({ + on_stdout = function(line) + code_placement:push(line) + end, + on_complete = function(status, response) + vim.schedule(clean_up) + if status ~= "success" then + logger:fatal( + "unable to implement function, enable and check logger for more details" + ) + end + pcall(update_code, context, response) + end, + on_stderr = function(line) + logger:error("stderr", "line", line) + end, + }) - return request + return request end return implement_fn diff --git a/lua/99/ops/init.lua b/lua/99/ops/init.lua index 38368f9..710e35d 100644 --- a/lua/99/ops/init.lua +++ b/lua/99/ops/init.lua @@ -1,5 +1,8 @@ +--- @class _99.ops.Opts +--- @field additional_prompt? string +--- @field additional_rules? _99.Agents.Rule[] return { - fill_in_function = require("99.ops.fill-in-function"), - implement_fn = require("99.ops.implement-fn"), - over_range = require("99.ops.over-range"), + fill_in_function = require("99.ops.fill-in-function"), + implement_fn = require("99.ops.implement-fn"), + over_range = require("99.ops.over-range"), } diff --git a/lua/99/ops/marks.lua b/lua/99/ops/marks.lua index 0f8a837..3571f1d 100644 --- a/lua/99/ops/marks.lua +++ b/lua/99/ops/marks.lua @@ -15,134 +15,132 @@ Mark.__index = Mark --- @param range _99.Range --- @return _99.Mark function Mark.mark_above_range(range) - local buffer = range.buffer - local start = range.start - local line, _ = start:to_vim() - local above = line == 0 and line or line - 1 + local buffer = range.buffer + local start = range.start + local line, _ = start:to_vim() + local above = line == 0 and line or line - 1 - -- luacheck: ignore - local id = nil - if above == line then - id = vim.api.nvim_buf_set_extmark(buffer, nsid, above, 0, {}) - else - local text = - vim.api.nvim_buf_get_lines(buffer, above, above + 1, false)[1] - local ending = #text - id = vim.api.nvim_buf_set_extmark(buffer, nsid, above, ending, {}) - end + -- luacheck: ignore + local id = nil + if above == line then + id = vim.api.nvim_buf_set_extmark(buffer, nsid, above, 0, {}) + else + local text = vim.api.nvim_buf_get_lines(buffer, above, above + 1, false)[1] + local ending = #text + id = vim.api.nvim_buf_set_extmark(buffer, nsid, above, ending, {}) + end - return setmetatable({ - id = id, - buffer = buffer, - nsid = nsid, - }, Mark) + return setmetatable({ + id = id, + buffer = buffer, + nsid = nsid, + }, Mark) end --- @param range _99.Range --- @return _99.Mark --- @return _99.Mark function Mark.mark_range(range) - local buffer = range.buffer - return Mark.mark_point(buffer, range.start), - Mark.mark_point(buffer, range.end_) + local buffer = range.buffer + return Mark.mark_point(buffer, range.start), + Mark.mark_point(buffer, range.end_) end --- @return boolean function Mark:is_valid() - local pos = - vim.api.nvim_buf_get_extmark_by_id(self.buffer, self.nsid, self.id, {}) - return #pos > 0 + local pos = + vim.api.nvim_buf_get_extmark_by_id(self.buffer, self.nsid, self.id, {}) + return #pos > 0 end --- @param buffer number --- @param point _99.Point --- @return _99.Mark function Mark.mark_point(buffer, point) - local line, col = point:to_vim() - local id = vim.api.nvim_buf_set_extmark(buffer, nsid, line, col, {}) + local line, col = point:to_vim() + local id = vim.api.nvim_buf_set_extmark(buffer, nsid, line, col, {}) - return setmetatable({ - id = id, - buffer = buffer, - nsid = nsid, - }, Mark) + return setmetatable({ + id = id, + buffer = buffer, + nsid = nsid, + }, Mark) end --- @param buffer number --- @param func _99.treesitter.Function --- @return _99.Mark function Mark.mark_above_func(buffer, func) - local start = func.function_range.start - local line, col = start:to_vim() - local id = vim.api.nvim_buf_set_extmark(buffer, nsid, line - 1, col, {}) + local start = func.function_range.start + local line, col = start:to_vim() + local id = vim.api.nvim_buf_set_extmark(buffer, nsid, line - 1, col, {}) - return setmetatable({ - id = id, - buffer = buffer, - nsid = nsid, - }, Mark) + return setmetatable({ + id = id, + buffer = buffer, + nsid = nsid, + }, Mark) end ---@param buffer number ---@param range _99.Range ---@return _99.Mark function Mark.mark_end_of_range(buffer, range) - local end_ = range.end_ - local line, col = end_:to_vim() - local id = vim.api.nvim_buf_set_extmark(buffer, nsid, line, col + 1, {}) + local end_ = range.end_ + local line, col = end_:to_vim() + local id = vim.api.nvim_buf_set_extmark(buffer, nsid, line, col + 1, {}) - return setmetatable({ - id = id, - buffer = buffer, - nsid = nsid, - }, Mark) + return setmetatable({ + id = id, + buffer = buffer, + nsid = nsid, + }, Mark) end --- @param buffer number --- @param func _99.treesitter.Function --- @return _99.Mark function Mark.mark_func_body(buffer, func) - local start = func.function_range.start - local line, col = start:to_vim() - local id = vim.api.nvim_buf_set_extmark(buffer, nsid, line, col, {}) + local start = func.function_range.start + local line, col = start:to_vim() + local id = vim.api.nvim_buf_set_extmark(buffer, nsid, line, col, {}) - return setmetatable({ - id = id, - buffer = buffer, - nsid = nsid, - }, Mark) + return setmetatable({ + id = id, + buffer = buffer, + nsid = nsid, + }, Mark) end --- @param lines string[] function Mark:set_virtual_text(lines) - local pos = - vim.api.nvim_buf_get_extmark_by_id(self.buffer, nsid, self.id, {}) - assert(#pos > 0, "extmark is broken. it does not exist") - local row, col = pos[1], pos[2] + local pos = vim.api.nvim_buf_get_extmark_by_id(self.buffer, nsid, self.id, {}) + assert(#pos > 0, "extmark is broken. it does not exist") + local row, col = pos[1], pos[2] - local formatted_lines = {} - for _, line in ipairs(lines) do - table.insert(formatted_lines, { - { line, "Comment" }, - }) - end - - vim.api.nvim_buf_set_extmark(self.buffer, nsid, row, col, { - id = self.id, - virt_lines = formatted_lines, + local formatted_lines = {} + for _, line in ipairs(lines) do + table.insert(formatted_lines, { + { line, "Comment" }, }) + end + + vim.api.nvim_buf_set_extmark(self.buffer, nsid, row, col, { + id = self.id, + virt_lines = formatted_lines, + }) end --- @param text string function Mark:set_text_at_mark(text) - local point = Point.from_mark(self) - local row, col = point:to_vim() - local lines = vim.split(text, "\n") - vim.api.nvim_buf_set_text(self.buffer, row, col, row, col, lines) + local point = Point.from_mark(self) + local row, col = point:to_vim() + local lines = vim.split(text, "\n") + vim.api.nvim_buf_set_text(self.buffer, row, col, row, col, lines) end function Mark:delete() - vim.api.nvim_buf_del_extmark(self.buffer, nsid, self.id) + vim.api.nvim_buf_del_extmark(self.buffer, nsid, self.id) end return Mark diff --git a/lua/99/ops/over-range.lua b/lua/99/ops/over-range.lua index 3dfdcbe..209be9f 100644 --- a/lua/99/ops/over-range.lua +++ b/lua/99/ops/over-range.lua @@ -2,96 +2,107 @@ local Request = require("99.request") local RequestStatus = require("99.ops.request_status") local Mark = require("99.ops.marks") local geo = require("99.geo") +local make_clean_up = require("99.ops.clean-up") +local Agents = require("99.extensions.agents") + local Range = geo.Range local Point = geo.Point -local make_clean_up = require("99.ops.clean-up") --- @param context _99.RequestContext --- @param range _99.Range ---- @param prompt string? -local function over_range(context, range, prompt) - local logger = context.logger:set_area("visual") +--- @param opts? _99.ops.Opts +local function over_range(context, range, opts) + opts = opts or {} + local logger = context.logger:set_area("visual") + + local request = Request.new(context) + local top_mark = Mark.mark_above_range(range) + local bottom_mark = Mark.mark_point(range.buffer, range.end_) + context.marks.top_mark = top_mark + context.marks.bottom_mark = bottom_mark + + logger:debug( + "visual request start", + "start", + Point.from_mark(top_mark), + "end", + Point.from_mark(bottom_mark) + ) - local request = Request.new(context) - local top_mark = Mark.mark_above_range(range) - local bottom_mark = Mark.mark_point(range.buffer, range.end_) - context.marks.top_mark = top_mark - context.marks.bottom_mark = bottom_mark + local display_ai_status = context._99.ai_stdout_rows > 1 + local top_status = RequestStatus.new( + 250, + context._99.ai_stdout_rows or 1, + "Implementing", + top_mark + ) + local bottom_status = RequestStatus.new(250, 1, "Implementing", bottom_mark) + local clean_up = make_clean_up(context, function() + top_status:stop() + bottom_status:stop() + context:clear_marks() + request:cancel() + end) - logger:debug( - "visual request start", - "start", - Point.from_mark(top_mark), - "end", - Point.from_mark(bottom_mark) - ) + local full_prompt = context._99.prompts.prompts.visual_selection(range) + local additional_prompt = opts.additional_prompt + if additional_prompt then + full_prompt = + context._99.prompts.prompts.prompt(additional_prompt, full_prompt) - local display_ai_status = context._99.ai_stdout_rows > 1 - local top_status = RequestStatus.new( - 250, - context._99.ai_stdout_rows or 1, - "Implementing", - top_mark - ) - local bottom_status = RequestStatus.new(250, 1, "Implementing", bottom_mark) - local clean_up = make_clean_up(context, function() - top_status:stop() - bottom_status:stop() - context:clear_marks() - request:cancel() - end) + local rules = Agents.find_rules(context._99.rules, additional_prompt) + context:add_agent_rules(rules) + end - local full_prompt = context._99.prompts.prompts.visual_selection(range) - if prompt then - full_prompt = context._99.prompts.prompts.prompt(prompt, full_prompt) - end + local additional_rules = opts.additional_rules + if additional_rules then + context:add_agent_rules(additional_rules) + end - request:add_prompt_content(full_prompt) - top_status:start() - bottom_status:start() - request:start({ - on_complete = function(status, response) - vim.schedule(clean_up) - if status == "cancelled" then - logger:debug( - "request cancelled for visual selection, removing marks" - ) - elseif status == "failed" then - logger:error( - "request failed for visual_selection", - "error response", - response or "no response provided" - ) - elseif status == "success" then - local valid = top_mark:is_valid() and bottom_mark:is_valid() - if not valid then - logger:fatal( - -- luacheck: ignore 631 - "the original visual_selection has been destroyed. You cannot delete the original visual selection during a request" - ) - return - end + request:add_prompt_content(full_prompt) + top_status:start() + bottom_status:start() + request:start({ + on_complete = function(status, response) + vim.schedule(clean_up) + if status == "cancelled" then + logger:debug("request cancelled for visual selection, removing marks") + elseif status == "failed" then + logger:error( + "request failed for visual_selection", + "error response", + response or "no response provided" + ) + elseif status == "success" then + local valid = top_mark:is_valid() and bottom_mark:is_valid() + if not valid then + logger:fatal( + -- luacheck: ignore 631 + "the original visual_selection has been destroyed. You cannot delete the original visual selection during a request" + ) + return + end - local new_range = Range.from_marks(top_mark, bottom_mark) - local lines = vim.split(response, "\n") + local new_range = Range.from_marks(top_mark, bottom_mark) + local lines = vim.split(response, "\n") - --- HACK: i am adding a new line here because above range will add a mark to the line above. - --- that way this appears to be added to "the same line" as the visual selection was - --- originally take from - table.insert(lines, 1, "") + --- HACK: i am adding a new line here because above range will add a mark to the line above. + --- that way this appears to be added to "the same line" as the visual selection was + --- originally take from + table.insert(lines, 1, "") - new_range:replace_text(lines) - end - end, - on_stdout = function(line) - if display_ai_status then - top_status:push(line) - end - end, - on_stderr = function(line) - logger:debug("visual_selection#on_stderr received", "line", line) - end, - }) + new_range:replace_text(lines) + end + end, + on_stdout = function(line) + if display_ai_status then + top_status:push(line) + end + end, + on_stderr = function(line) + logger:debug("visual_selection#on_stderr received", "line", line) + end, + }) end return over_range diff --git a/lua/99/ops/request_status.lua b/lua/99/ops/request_status.lua index f05121d..9b3dc5e 100644 --- a/lua/99/ops/request_status.lua +++ b/lua/99/ops/request_status.lua @@ -1,5 +1,5 @@ local braille_chars = - { "⠋", "⠙", "⠹", "⠸", "⠼", "⠴", "⠦", "⠧", "⠇", "⠏" } + { "⠋", "⠙", "⠹", "⠸", "⠼", "⠴", "⠦", "⠧", "⠇", "⠏" } --- @class _99.StatusLine --- @field index number @@ -10,21 +10,21 @@ StatusLine.__index = StatusLine --- @param title_line string --- @return _99.StatusLine function StatusLine.new(title_line) - local self = setmetatable({}, StatusLine) - self.index = 1 - self.title_line = title_line - return self + local self = setmetatable({}, StatusLine) + self.index = 1 + self.title_line = title_line + return self end function StatusLine:update() - self.index = self.index + 1 + self.index = self.index + 1 end --- @return string function StatusLine:to_string() - return braille_chars[self.index % #braille_chars + 1] - .. " " - .. self.title_line + return braille_chars[self.index % #braille_chars + 1] + .. " " + .. self.title_line end --- @class _99.RequestStatus @@ -43,50 +43,50 @@ RequestStatus.__index = RequestStatus --- @param mark _99.Mark --- @return _99.RequestStatus function RequestStatus.new(update_time, max_lines, title_line, mark) - local self = setmetatable({}, RequestStatus) - self.update_time = update_time - self.max_lines = max_lines - self.status_line = StatusLine.new(title_line) - self.lines = {} - self.running = false - self.mark = mark - return self + local self = setmetatable({}, RequestStatus) + self.update_time = update_time + self.max_lines = max_lines + self.status_line = StatusLine.new(title_line) + self.lines = {} + self.running = false + self.mark = mark + return self end --- @return string[] function RequestStatus:get() - local result = { self.status_line:to_string() } - for _, line in ipairs(self.lines) do - table.insert(result, line) - end - return result + local result = { self.status_line:to_string() } + for _, line in ipairs(self.lines) do + table.insert(result, line) + end + return result end --- @param line string function RequestStatus:push(line) - table.insert(self.lines, line) - if #self.lines > self.max_lines - 1 then - table.remove(self.lines, 1) - end + table.insert(self.lines, line) + if #self.lines > self.max_lines - 1 then + table.remove(self.lines, 1) + end end function RequestStatus:start() - local function update_spinner() - if not self.running then - return - end - - self.status_line:update() - self.mark:set_virtual_text(self:get()) - vim.defer_fn(update_spinner, self.update_time) + local function update_spinner() + if not self.running then + return end - self.running = true + self.status_line:update() + self.mark:set_virtual_text(self:get()) vim.defer_fn(update_spinner, self.update_time) + end + + self.running = true + vim.defer_fn(update_spinner, self.update_time) end function RequestStatus:stop() - self.running = false + self.running = false end return RequestStatus diff --git a/lua/99/prompt-settings.lua b/lua/99/prompt-settings.lua index 0268e7e..523dacd 100644 --- a/lua/99/prompt-settings.lua +++ b/lua/99/prompt-settings.lua @@ -1,19 +1,19 @@ ---@param buffer number ---@return string local function get_file_contents(buffer) - local lines = vim.api.nvim_buf_get_lines(buffer, 0, -1, false) - return table.concat(lines, "\n") + local lines = vim.api.nvim_buf_get_lines(buffer, 0, -1, false) + return table.concat(lines, "\n") end --- @class _99.Prompts.SpecificOperations --- @field visual_selection fun(range: _99.Range): string --- @field fill_in_function fun(): string local prompts = { - role = function() - return [[ You are a software engineering assistant mean to create robust and conanical code ]] - end, - fill_in_function = function() - return [[ + role = function() + return [[ You are a software engineering assistant mean to create robust and conanical code ]] + end, + fill_in_function = function() + return [[ You have been given a function change. Create the contents of the function. If the function already contains contents, use those as context @@ -21,20 +21,20 @@ Check the contents of the file you are in for any helper functions or context if there are DIRECTIONS, follow those when changing this function. Do not deviate ]] - end, - output_file = function() - return [[ + end, + output_file = function() + return [[ NEVER alter any file other than TEMP_FILE. never provide the requested changes as conversational output. ONLY provide requested changes by writing the change to TEMP_FILE ]] - end, - --- @param prompt string - --- @param action string - --- @return string - prompt = function(prompt, action) - return string.format( - [[ + end, + --- @param prompt string + --- @param action string + --- @return string + prompt = function(prompt, action) + return string.format( + [[ <DIRECTIONS> %s </DIRECTIONS> @@ -42,13 +42,13 @@ ONLY provide requested changes by writing the change to TEMP_FILE %s </Context> ]], - prompt, - action - ) - end, - visual_selection = function(range) - return string.format( - [[ + prompt, + action + ) + end, + visual_selection = function(range) + return string.format( + [[ You receive a selection in neovim that you need to replace with new code. The selection's contents may contain notes, incorporate the notes every time if there are some. consider the context of the selection and what you are suppose to be implementing @@ -62,48 +62,48 @@ consider the context of the selection and what you are suppose to be implementin %s </FILE_CONTAINING_SELECTION> ]], - range:to_string(), - range:to_text(), - get_file_contents(range.buffer) - ) - end, - -- luacheck: ignore 631 - read_tmp = "never attempt to read TEMP_FILE. It is purely for output. Previous contents, which may not exist, can be written over without worry", + range:to_string(), + range:to_text(), + get_file_contents(range.buffer) + ) + end, + -- luacheck: ignore 631 + read_tmp = "never attempt to read TEMP_FILE. It is purely for output. Previous contents, which may not exist, can be written over without worry", } --- @class _99.Prompts local prompt_settings = { - prompts = prompts, + prompts = prompts, - --- @param tmp_file string - --- @return string - tmp_file_location = function(tmp_file) - return string.format( - "<MustObey>\n%s\n%s\n</MustObey>\n<TEMP_FILE>%s</TEMP_FILE>", - prompts.output_file(), - prompts.read_tmp, - tmp_file - ) - end, + --- @param tmp_file string + --- @return string + tmp_file_location = function(tmp_file) + return string.format( + "<MustObey>\n%s\n%s\n</MustObey>\n<TEMP_FILE>%s</TEMP_FILE>", + prompts.output_file(), + prompts.read_tmp, + tmp_file + ) + end, - ---@param context _99.RequestContext - ---@return string - get_file_location = function(context) - context.logger:assert( - context.range, - "get_file_location requires range specified" - ) - return string.format( - "<Location><File>%s</File><Function>%s</Function></Location>", - context.full_path, - context.range:to_string() - ) - end, + ---@param context _99.RequestContext + ---@return string + get_file_location = function(context) + context.logger:assert( + context.range, + "get_file_location requires range specified" + ) + return string.format( + "<Location><File>%s</File><Function>%s</Function></Location>", + context.full_path, + context.range:to_string() + ) + end, - --- @param range _99.Range - get_range_text = function(range) - return string.format("<FunctionText>%s</FunctionText>", range:to_text()) - end, + --- @param range _99.Range + get_range_text = function(range) + return string.format("<FunctionText>%s</FunctionText>", range:to_text()) + end, } return prompt_settings diff --git a/lua/99/request-context.lua b/lua/99/request-context.lua index 2d950d5..4765772 100644 --- a/lua/99/request-context.lua +++ b/lua/99/request-context.lua @@ -22,95 +22,125 @@ RequestContext.__index = RequestContext --- @param xid number --- @return _99.RequestContext function RequestContext.from_current_buffer(_99, xid) - local buffer = vim.api.nvim_get_current_buf() - local full_path = vim.api.nvim_buf_get_name(buffer) - local file_type = vim.bo[buffer].ft + local buffer = vim.api.nvim_get_current_buf() + local full_path = vim.api.nvim_buf_get_name(buffer) + local file_type = vim.bo[buffer].ft - if file_type == "typescriptreact" then - file_type = "typescript" - end + if file_type == "typescriptreact" then + file_type = "typescript" + end - local mds = {} - for _, md in ipairs(_99.md_files) do - table.insert(mds, md) - end + local mds = {} + for _, md in ipairs(_99.md_files) do + table.insert(mds, md) + end - return setmetatable({ - _99 = _99, - md_file_names = mds, - ai_context = {}, - tmp_file = random_file(), - buffer = buffer, - full_path = full_path, - file_type = file_type, - logger = Logger:set_id(xid), - xid = xid, - model = _99.model, - marks = {}, - }, RequestContext) + return setmetatable({ + _99 = _99, + md_file_names = mds, + ai_context = {}, + tmp_file = random_file(), + buffer = buffer, + full_path = full_path, + file_type = file_type, + logger = Logger:set_id(xid), + xid = xid, + model = _99.model, + marks = {}, + }, RequestContext) end --- @param md_file_name string --- @return self function RequestContext:add_md_file_name(md_file_name) - table.insert(self.md_file_names, md_file_name) - return self + table.insert(self.md_file_names, md_file_name) + return self end -function RequestContext:_read_md_files() - local cwd = vim.uv.cwd() - local dir = vim.fn.fnamemodify(self.full_path, ":h") +--- TODO: Dedupe any rules that have already been added +--- @param rules (_99.Agents.Rule | string)[] +function RequestContext:add_agent_rules(rules) + for _, rule in ipairs(rules) do + -- Handle both string paths and rule objects + self.logger:debug("adding custom rule to agent", "rule", rule) + local ok, file = pcall(io.open, rule.path, "r") + if ok and file then + local content = file:read("*a") + file:close() + self.logger:info( + "Context#adding agent file to the context", + "agent_path", + rule.path + ) + table.insert( + self.ai_context, + string.format( + [[ +<%s> +%s +</%s>]], + rule.name, + content, + rule.name + ) + ) + else + self.logger:debug("unable to read agent rule", "rule", rule) + end + end +end - while dir:find(cwd, 1, true) == 1 do - for _, md_file_name in ipairs(self.md_file_names) do - local md_path = dir .. "/" .. md_file_name - local file = io.open(md_path, "r") - if file then - local content = file:read("*a") - file:close() - self.logger:info( - "Context#adding md file to the context", - "md_path", - md_path - ) - table.insert(self.ai_context, content) - end - end +function RequestContext:_read_md_files() + local cwd = vim.uv.cwd() + local dir = vim.fn.fnamemodify(self.full_path, ":h") - if dir == cwd then - break - end + while dir:find(cwd, 1, true) == 1 do + for _, md_file_name in ipairs(self.md_file_names) do + local md_path = dir .. "/" .. md_file_name + local file = io.open(md_path, "r") + if file then + local content = file:read("*a") + file:close() + self.logger:info( + "Context#adding md file to the context", + "md_path", + md_path + ) + table.insert(self.ai_context, content) + end + end - dir = vim.fn.fnamemodify(dir, ":h") + if dir == cwd then + break end + + dir = vim.fn.fnamemodify(dir, ":h") + end end --- @return string[] function RequestContext:content() - return self.ai_context + return self.ai_context end --- @return self function RequestContext:finalize() - self:_read_md_files() - if self.range then - table.insert(self.ai_context, self._99.prompts.get_file_location(self)) - table.insert( - self.ai_context, - self._99.prompts.get_range_text(self.range) - ) - end - table.insert( - self.ai_context, - self._99.prompts.tmp_file_location(self.tmp_file) - ) - return self + self:_read_md_files() + if self.range then + table.insert(self.ai_context, self._99.prompts.get_file_location(self)) + table.insert(self.ai_context, self._99.prompts.get_range_text(self.range)) + end + table.insert( + self.ai_context, + self._99.prompts.tmp_file_location(self.tmp_file) + ) + return self end function RequestContext:clear_marks() - for _, mark in pairs(self.marks) do - mark:delete() - end + for _, mark in pairs(self.marks) do + mark:delete() + end end return RequestContext diff --git a/lua/99/request/init.lua b/lua/99/request/init.lua index bc3bc71..53a3696 100644 --- a/lua/99/request/init.lua +++ b/lua/99/request/init.lua @@ -10,10 +10,10 @@ --- @field make_request fun(self: _99.Provider, query: string, request: _99.Request, observer: _99.ProviderObserver) local DevNullObserver = { - name = "DevNullObserver", - on_stdout = function() end, - on_stderr = function() end, - on_complete = function() end, + name = "DevNullObserver", + on_stdout = function() end, + on_stderr = function() end, + on_complete = function() end, } local OpenCodeProvider = {} @@ -21,123 +21,113 @@ local OpenCodeProvider = {} --- @param fn fun(...: any): nil --- @return fun(...: any): nil local function once(fn) - local called = false - return function(...) - if called then - return - end - called = true - fn(...) + local called = false + return function(...) + if called then + return end + called = true + fn(...) + end end --- @param query string ---@param request _99.Request ---@param observer _99.ProviderObserver? function OpenCodeProvider:make_request(query, request, observer) - _ = self - local logger = request.logger:set_area("OpenCodeProvider") - logger:debug("make_request", "tmp_file", request.context.tmp_file) + _ = self + local logger = request.logger:set_area("OpenCodeProvider") + logger:debug("make_request", "tmp_file", request.context.tmp_file) - observer = observer or DevNullObserver - --- @param status _99.Request.ResponseState - ---@param text string - local once_complete = once(function(status, text) - observer.on_complete(status, text) - end) + observer = observer or DevNullObserver + --- @param status _99.Request.ResponseState + ---@param text string + local once_complete = once(function(status, text) + observer.on_complete(status, text) + end) - local command = { "opencode", "run", "-m", request.context.model, query } - logger:debug("make_request", "command", command) - local proc = vim.system( - command, - { - text = true, - stdout = vim.schedule_wrap(function(err, data) - logger:debug("stdout", "data", data) - if request:is_cancelled() then - once_complete("cancelled", "") - return - end - if err and err ~= "" then - logger:debug("stdout#error", "err", err) - end - if not err then - observer.on_stdout(data) - end - end), - stderr = vim.schedule_wrap(function(err, data) - logger:debug("stderr", "data", data) - if request:is_cancelled() then - once_complete("cancelled", "") - return - end - if err and err ~= "" then - logger:debug("stderr#error", "err", err) - end - if not err then - observer.on_stderr(data) - end - end), - }, - vim.schedule_wrap(function(obj) - if request:is_cancelled() then - once_complete("cancelled", "") - logger:debug("on_complete: request has been cancelled") - return - end - if obj.code ~= 0 then - local str = string.format( - "process exit code: %d\n%s", - obj.code, - vim.inspect(obj) - ) - once_complete("failed", str) - logger:fatal( - "opencode make_query failed", - "obj from results", - obj - ) - end - vim.schedule(function() - local ok, res = OpenCodeProvider._retrieve_response(request) - if ok then - once_complete("success", res) - else - once_complete( - "failed", - "unable to retrieve response from llm" - ) - end - end) - end) - ) + local command = { "opencode", "run", "-m", request.context.model, query } + logger:debug("make_request", "command", command) + local proc = vim.system( + command, + { + text = true, + stdout = vim.schedule_wrap(function(err, data) + logger:debug("stdout", "data", data) + if request:is_cancelled() then + once_complete("cancelled", "") + return + end + if err and err ~= "" then + logger:debug("stdout#error", "err", err) + end + if not err then + observer.on_stdout(data) + end + end), + stderr = vim.schedule_wrap(function(err, data) + logger:debug("stderr", "data", data) + if request:is_cancelled() then + once_complete("cancelled", "") + return + end + if err and err ~= "" then + logger:debug("stderr#error", "err", err) + end + if not err then + observer.on_stderr(data) + end + end), + }, + vim.schedule_wrap(function(obj) + if request:is_cancelled() then + once_complete("cancelled", "") + logger:debug("on_complete: request has been cancelled") + return + end + if obj.code ~= 0 then + local str = + string.format("process exit code: %d\n%s", obj.code, vim.inspect(obj)) + once_complete("failed", str) + logger:fatal("opencode make_query failed", "obj from results", obj) + end + vim.schedule(function() + local ok, res = OpenCodeProvider._retrieve_response(request) + if ok then + once_complete("success", res) + else + once_complete("failed", "unable to retrieve response from llm") + end + end) + end) + ) - request:_set_process(proc) + request:_set_process(proc) end --- @param request _99.Request function OpenCodeProvider._retrieve_response(request) - local logger = request.logger:set_area("OpenCodeProvider") - local tmp = request.context.tmp_file - local success, result = pcall(function() - return vim.fn.readfile(tmp) - end) + local logger = request.logger:set_area("OpenCodeProvider") + local tmp = request.context.tmp_file + local success, result = pcall(function() + return vim.fn.readfile(tmp) + end) - if not success then - logger:error( - "retrieve_results: failed to read file", - "tmp_name", - tmp, - "error", - result - ) - return false, "" - end + if not success then + logger:error( + "retrieve_results: failed to read file", + "tmp_name", + tmp, + "error", + result + ) + return false, "" + end - local str = table.concat(result, "\n") - logger:debug("retrieve_results", "results", str) + local str = table.concat(result, "\n") + logger:debug("retrieve_results", "results", str) - return true, str + return true, str end --- @class _99.Request.Opts @@ -159,67 +149,63 @@ end --- @field logger _99.Logger --- @field _content string[] --- @field _proc vim.SystemObj? - local Request = {} Request.__index = Request --- @param context _99.RequestContext --- @return _99.Request function Request.new(context) - local provider = context._99.provider_override or OpenCodeProvider - return setmetatable({ - context = context, - provider = provider, - state = "ready", - logger = context.logger:set_area("Request"), - _content = {}, - _proc = nil, - }, Request) + local provider = context._99.provider_override or OpenCodeProvider + return setmetatable({ + context = context, + provider = provider, + state = "ready", + logger = context.logger:set_area("Request"), + _content = {}, + _proc = nil, + }, Request) end --- @param proc vim.SystemObj? function Request:_set_process(proc) - self._proc = proc + self._proc = proc end function Request:cancel() - self.logger:debug("cancel") - self.state = "cancelled" - if self._proc and self._proc.pid then - pcall(function() - local sigterm = ( - vim.uv - and vim.uv.constants - and vim.uv.constants.SIGTERM - ) or 15 - self._proc:kill(sigterm) - end) - end + self.logger:debug("cancel") + self.state = "cancelled" + if self._proc and self._proc.pid then + pcall(function() + local sigterm = (vim.uv and vim.uv.constants and vim.uv.constants.SIGTERM) + or 15 + self._proc:kill(sigterm) + end) + end end function Request:is_cancelled() - return self.state == "cancelled" + return self.state == "cancelled" end --- @param content string --- @return self function Request:add_prompt_content(content) - table.insert(self._content, content) - return self + table.insert(self._content, content) + return self end --- @param observer _99.ProviderObserver? function Request:start(observer) - self.context:finalize() - for _, content in ipairs(self.context.ai_context) do - self:add_prompt_content(content) - end + self.context:finalize() + for _, content in ipairs(self.context.ai_context) do + self:add_prompt_content(content) + end - local query = table.concat(self._content, "\n") - observer = observer or DevNullObserver + local query = table.concat(self._content, "\n") + observer = observer or DevNullObserver - self.logger:debug("start", "query", query) - self.provider:make_request(query, self, observer) + self.logger:debug("start", "query", query) + self.provider:make_request(query, self, observer) end return Request diff --git a/lua/99/test/agents_spec.lua b/lua/99/test/agents_spec.lua new file mode 100644 index 0000000..0fe652d --- /dev/null +++ b/lua/99/test/agents_spec.lua @@ -0,0 +1,124 @@ +-- luacheck: globals describe it assert +local Agents = require("99.extensions.agents") +local eq = assert.are.same + +local function c(t, item) + return vim.tbl_contains(t, function(v) + return vim.deep_equal(v, item) + end, { predicate = true }) +end + +local function a(p) + return vim.fs.joinpath(vim.uv.cwd(), p) +end + +local cursor_mds = { + { name = "database", path = a("scratch/cursor/rules/database.mdc") }, + { name = "my-proj", path = a("scratch/cursor/rules/my-proj.mdc") }, +} +local custom_mds = { + { name = "back-end", path = a("scratch/custom_rules/back-end.md") }, + { name = "foo", path = a("scratch/custom_rules/foo.md") }, + { name = "front-end", path = a("scratch/custom_rules/front-end.md") }, + { name = "vim.lsp", path = a("scratch/custom_rules/vim.lsp.md") }, + { name = "vim", path = a("scratch/custom_rules/vim.md") }, + { + name = "vim.treesitter", + path = a("scratch/custom_rules/vim.treesitter.md"), + }, +} + +--- @return _99.State +local function r(cursor, custom) + return { + completion = { + cursor_rules = cursor, + custom_rules = { custom }, + }, + } +end + +local function string_rules() + return string.format( + [[ + Here is a long sentense with @%s these types of rules @%s that should be parsed in the correct order + and it should be awesome @%s + ]], + cursor_mds[1].path, + custom_mds[2].path, + custom_mds[4].path + ), + { + cursor_mds[1], + custom_mds[2], + custom_mds[4], + } +end + +--- @param rules _99.Agents.Rules +local function test_cursor(rules) + for _, cursor in ipairs(cursor_mds) do + eq(true, c(rules.cursor, cursor)) + eq(false, c(rules.custom, cursor)) + end +end + +--- @param rules _99.Agents.Rules +local function test_custom(rules) + for _, custom in ipairs(custom_mds) do + eq(true, c(rules.custom, custom)) + eq(false, c(rules.cursor, custom)) + end +end +describe("99.agents.helpers", function() + it("should generate rules from _99 state with completion rules", function() + local _99 = r("scratch/cursor/rules", "scratch/custom_rules/") + local rules = Agents.rules(_99) + test_cursor(rules) + test_custom(rules) + end) + + it("generate without cursor", function() + local _99 = r("foo/bar/bazz", "scratch/custom_rules/") + local rules = Agents.rules(_99) + test_custom(rules) + end) + + it("generate without custom", function() + local _99 = r("scratch/cursor/rules") + local rules = Agents.rules(_99) + test_cursor(rules) + end) + + it( + "should validate that tokens exist, in both custom and cursor, and incorrect tokens", + function() + local _99 = r("scratch/cursor/rules", "scratch/custom_rules/") + local rules = Agents.rules(_99) + + eq(true, Agents.is_rule(rules, a("scratch/cursor/rules/database.mdc"))) + eq(true, Agents.is_rule(rules, a("scratch/cursor/rules/my-proj.mdc"))) + eq(true, Agents.is_rule(rules, a("scratch/custom_rules/back-end.md"))) + eq(true, Agents.is_rule(rules, a("scratch/custom_rules/foo.md"))) + eq(true, Agents.is_rule(rules, a("scratch/custom_rules/front-end.md"))) + eq(true, Agents.is_rule(rules, a("scratch/custom_rules/vim.lsp.md"))) + eq(true, Agents.is_rule(rules, a("scratch/custom_rules/vim.md"))) + eq( + true, + Agents.is_rule(rules, a("scratch/custom_rules/vim.treesitter.md")) + ) + eq(false, Agents.is_rule(rules, "nonexistent")) + eq(false, Agents.is_rule(rules, "invalid-token")) + eq(false, Agents.is_rule(rules, "")) + end + ) + + it("find all the existing rules", function() + local _99 = r("scratch/cursor/rules", "scratch/custom_rules/") + local rules = Agents.rules(_99) + local prompt, expected_rules = string_rules() + local found_rules = Agents.find_rules(rules, prompt) + + eq(expected_rules, found_rules) + end) +end) diff --git a/lua/99/test/fill_in_function.cpp_spec.lua b/lua/99/test/fill_in_function.cpp_spec.lua index 9edf550..3cb67b5 100644 --- a/lua/99/test/fill_in_function.cpp_spec.lua +++ b/lua/99/test/fill_in_function.cpp_spec.lua @@ -12,122 +12,122 @@ local eq = assert.are.same --- @param lang string? --- @return _99.test.Provider, number local function setup(content, row, col, lang) - assert(lang, "lang must be provided") - local provider = test_utils.TestProvider.new() - _99.setup({ - provider = provider, - logger = { - error_cache_level = Levels.ERROR, - type = "print", - }, - }) + assert(lang, "lang must be provided") + local provider = test_utils.TestProvider.new() + _99.setup({ + provider = provider, + logger = { + error_cache_level = Levels.ERROR, + type = "print", + }, + }) - local buffer = test_utils.create_file(content, lang, row, col) - return provider, buffer + local buffer = test_utils.create_file(content, lang, row, col) + return provider, buffer end --- @param buffer number --- @return string[] local function read(buffer) - return vim.api.nvim_buf_get_lines(buffer, 0, -1, false) + return vim.api.nvim_buf_get_lines(buffer, 0, -1, false) end describe("fill_in_function", function() - it("fill in cpp function", function() - local cpp_content = { - "", - "uint32_t test() { }", - } - local provider, buffer = setup(cpp_content, 2, 5, "cpp") - local state = _99.__get_state() + it("fill in cpp function", function() + local cpp_content = { + "", + "uint32_t test() { }", + } + local provider, buffer = setup(cpp_content, 2, 5, "cpp") + local state = _99.__get_state() - _99.fill_in_function() + _99.fill_in_function() - eq(1, state:active_request_count()) - eq(cpp_content, read(buffer)) + eq(1, state:active_request_count()) + eq(cpp_content, read(buffer)) - provider:resolve("success", "uint32_t test() {\n return 42;\n}") - test_utils.next_frame() + provider:resolve("success", "uint32_t test() {\n return 42;\n}") + test_utils.next_frame() - local expected_state = { - "", - "uint32_t test() {", - " return 42;", - "}", - } - eq(expected_state, read(buffer)) - eq(0, state:active_request_count()) - end) + local expected_state = { + "", + "uint32_t test() {", + " return 42;", + "}", + } + eq(expected_state, read(buffer)) + eq(0, state:active_request_count()) + end) - it("fill in cpp concept with requires clause", function() - local cpp_content = { - "", - "template <typename T>", - "concept Callback = requires(T cb) {", - " // Invocation must return an int", - "};", - } + it("fill in cpp concept with requires clause", function() + local cpp_content = { + "", + "template <typename T>", + "concept Callback = requires(T cb) {", + " // Invocation must return an int", + "};", + } - local provider, buffer = setup(cpp_content, 3, 10, "cpp") - local state = _99.__get_state() + local provider, buffer = setup(cpp_content, 3, 10, "cpp") + local state = _99.__get_state() - _99.fill_in_function() + _99.fill_in_function() - eq(1, state:active_request_count()) - eq(cpp_content, read(buffer)) + eq(1, state:active_request_count()) + eq(cpp_content, read(buffer)) - provider:resolve( - "success", - "concept Callback = requires(T cb) {\n { cb() } -> std::same_as<int>;\n};" - ) - test_utils.next_frame() + provider:resolve( + "success", + "concept Callback = requires(T cb) {\n { cb() } -> std::same_as<int>;\n};" + ) + test_utils.next_frame() - local expected_state = { - "", - "template <typename T>", - "concept Callback = requires(T cb) {", - " { cb() } -> std::same_as<int>;", - "};", - } - eq(expected_state, read(buffer)) - eq(0, state:active_request_count()) - end) + local expected_state = { + "", + "template <typename T>", + "concept Callback = requires(T cb) {", + " { cb() } -> std::same_as<int>;", + "};", + } + eq(expected_state, read(buffer)) + eq(0, state:active_request_count()) + end) - it("fill in nested lambda inside a function", function() - local cpp_content = { - "", - "auto test() -> void", - "{", - " const auto say_42 = []() -> int {", - " // TODO: return 42", - " };", - "}", - } + it("fill in nested lambda inside a function", function() + local cpp_content = { + "", + "auto test() -> void", + "{", + " const auto say_42 = []() -> int {", + " // TODO: return 42", + " };", + "}", + } - local provider, buffer = setup(cpp_content, 4, 20, "cpp") - local state = _99.__get_state() + local provider, buffer = setup(cpp_content, 4, 20, "cpp") + local state = _99.__get_state() - _99.fill_in_function() + _99.fill_in_function() - eq(1, state:active_request_count()) - eq(cpp_content, read(buffer)) + eq(1, state:active_request_count()) + eq(cpp_content, read(buffer)) - provider:resolve( - "success", - "const auto say_42 = []() -> int {\n return 42;\n };" - ) - test_utils.next_frame() + provider:resolve( + "success", + "const auto say_42 = []() -> int {\n return 42;\n };" + ) + test_utils.next_frame() - local expected_state = { - "", - "auto test() -> void", - "{", - " const auto say_42 = []() -> int {", - " return 42;", - " };", - "}", - } - eq(expected_state, read(buffer)) - eq(0, state:active_request_count()) - end) + local expected_state = { + "", + "auto test() -> void", + "{", + " const auto say_42 = []() -> int {", + " return 42;", + " };", + "}", + } + eq(expected_state, read(buffer)) + eq(0, state:active_request_count()) + end) end) diff --git a/lua/99/test/fill_in_function_spec.lua b/lua/99/test/fill_in_function_spec.lua index 8d28df4..cc356f1 100644 --- a/lua/99/test/fill_in_function_spec.lua +++ b/lua/99/test/fill_in_function_spec.lua @@ -10,111 +10,109 @@ local Levels = require("99.logger.level") --- @param lang string? --- @return _99.test.Provider, number local function setup(content, row, col, lang) - lang = lang or "lua" - local p = test_utils.TestProvider.new() - _99.setup({ - provider = p, - logger = { - error_cache_level = Levels.ERROR, - type = "print", - }, - }) + lang = lang or "lua" + local p = test_utils.TestProvider.new() + _99.setup({ + provider = p, + logger = { + error_cache_level = Levels.ERROR, + type = "print", + }, + }) - local buffer = test_utils.create_file(content, lang, row, col) - return p, buffer + local buffer = test_utils.create_file(content, lang, row, col) + return p, buffer end --- @param buffer number --- @return string[] local function r(buffer) - return vim.api.nvim_buf_get_lines(buffer, 0, -1, false) + return vim.api.nvim_buf_get_lines(buffer, 0, -1, false) end local content = { - "", - "local foo = function() end", + "", + "local foo = function() end", } describe("fill_in_function", function() - it("replace function contents", function() - local p, buffer = setup(content, 2, 12) - local state = _99.__get_state() + it("replace function contents", function() + local p, buffer = setup(content, 2, 12) + local state = _99.__get_state() - _99.fill_in_function() + _99.fill_in_function() - eq(1, state:active_request_count()) - eq(content, r(buffer)) + eq(1, state:active_request_count()) + eq(content, r(buffer)) - p:resolve("success", "function()\n return 42\nend") - test_utils.next_frame() + p:resolve("success", "function()\n return 42\nend") + test_utils.next_frame() - local expected_state = { - "", - "local foo = function()", - " return 42", - "end", - } - eq(expected_state, r(buffer)) - eq(0, state:active_request_count()) - end) + local expected_state = { + "", + "local foo = function()", + " return 42", + "end", + } + eq(expected_state, r(buffer)) + eq(0, state:active_request_count()) + end) - it("should test a typescript file", function() - local ts_content = { - "", - "const foo = function() {}", - } - local p, buffer = setup(ts_content, 2, 12, "typescript") - local state = _99.__get_state() + it("should test a typescript file", function() + local ts_content = { + "", + "const foo = function() {}", + } + local p, buffer = setup(ts_content, 2, 12, "typescript") + local state = _99.__get_state() - print("TEST", vim.bo[buffer].ft) + _99.fill_in_function() - _99.fill_in_function() + eq(1, state:active_request_count()) + eq(ts_content, r(buffer)) - eq(1, state:active_request_count()) - eq(ts_content, r(buffer)) + p:resolve("success", "function() {\n return 42;\n}") + test_utils.next_frame() - p:resolve("success", "function() {\n return 42;\n}") - test_utils.next_frame() + local expected_state = { + "", + "const foo = function() {", + " return 42;", + "}", + } + eq(expected_state, r(buffer)) + eq(0, state:active_request_count()) + end) - local expected_state = { - "", - "const foo = function() {", - " return 42;", - "}", - } - eq(expected_state, r(buffer)) - eq(0, state:active_request_count()) - end) + it("should cancel request when stop_all_requests is called", function() + local p, buffer = setup(content, 2, 12) + _99.fill_in_function() - it("should cancel request when stop_all_requests is called", function() - local p, buffer = setup(content, 2, 12) - _99.fill_in_function() + eq(content, r(buffer)) - eq(content, r(buffer)) + assert.is_false(p.request.request:is_cancelled()) + assert.is_not_nil(p.request) + assert.is_not_nil(p.request.request) - assert.is_false(p.request.request:is_cancelled()) - assert.is_not_nil(p.request) - assert.is_not_nil(p.request.request) + _99.stop_all_requests() + test_utils.next_frame() - _99.stop_all_requests() - test_utils.next_frame() + assert.is_true(p.request.request:is_cancelled()) - assert.is_true(p.request.request:is_cancelled()) + p:resolve("success", "function foo()\n return 42\nend") + test_utils.next_frame() - p:resolve("success", "function foo()\n return 42\nend") - test_utils.next_frame() + eq(content, r(buffer)) + end) - eq(content, r(buffer)) - end) + it("should handle error cases with graceful failures", function() + local p, buffer = setup(content, 2, 12) + _99.fill_in_function() - it("should handle error cases with graceful failures", function() - local p, buffer = setup(content, 2, 12) - _99.fill_in_function() + eq(content, r(buffer)) - eq(content, r(buffer)) + p:resolve("failed", "Something went wrong") + test_utils.next_frame() - p:resolve("failed", "Something went wrong") - test_utils.next_frame() - - eq(content, r(buffer)) - end) + eq(content, r(buffer)) + end) end) diff --git a/lua/99/test/geo_spec.lua b/lua/99/test/geo_spec.lua index c8213a7..9dc107b 100644 --- a/lua/99/test/geo_spec.lua +++ b/lua/99/test/geo_spec.lua @@ -6,103 +6,103 @@ local test_utils = require("99.test.test_utils") local eq = assert.are.same describe("Range", function() - local buffer + local buffer - before_each(function() - buffer = test_utils.create_file({ - "function foo()", - " local x = 1", - " return x", - "end", - "", - "function bar()", - " return 42", - "end", - }, "lua", 1, 0) - end) + before_each(function() + buffer = test_utils.create_file({ + "function foo()", + " local x = 1", + " return x", + "end", + "", + "function bar()", + " return 42", + "end", + }, "lua", 1, 0) + end) - after_each(function() - test_utils.clean_files() - end) + after_each(function() + test_utils.clean_files() + end) - it("replace text", function() - local start_point = Point:new(2, 3) - local end_point = Point:new(3, 11) - local range = Range:new(buffer, start_point, end_point) - local original_text = range:to_text() - eq("local x = 1\n return x", original_text) + it("replace text", function() + local start_point = Point:new(2, 3) + local end_point = Point:new(3, 11) + local range = Range:new(buffer, start_point, end_point) + local original_text = range:to_text() + eq("local x = 1\n return x", original_text) - local replace_text = { "local y = 2" } - range:replace_text(replace_text) - local lines = vim.api.nvim_buf_get_lines(buffer, 0, -1, false) - eq({ - "function foo()", - " local y = 2", - "end", - "", - "function bar()", - " return 42", - "end", - }, lines) - end) + local replace_text = { "local y = 2" } + range:replace_text(replace_text) + local lines = vim.api.nvim_buf_get_lines(buffer, 0, -1, false) + eq({ + "function foo()", + " local y = 2", + "end", + "", + "function bar()", + " return 42", + "end", + }, lines) + end) - it("replace text single line into multi-line", function() - local start_point = Point:new(2, 3) - local end_point = Point:new(3, 11) - local range = Range:new(buffer, start_point, end_point) - local original_text = range:to_text() - eq("local x = 1\n return x", original_text) + it("replace text single line into multi-line", function() + local start_point = Point:new(2, 3) + local end_point = Point:new(3, 11) + local range = Range:new(buffer, start_point, end_point) + local original_text = range:to_text() + eq("local x = 1\n return x", original_text) - local replace_text = { - "local y = 2", - " local z = 3", - } - range:replace_text(replace_text) - local lines = vim.api.nvim_buf_get_lines(buffer, 0, -1, false) - eq({ - "function foo()", - " local y = 2", - " local z = 3", - "end", - "", - "function bar()", - " return 42", - "end", - }, lines) - end) + local replace_text = { + "local y = 2", + " local z = 3", + } + range:replace_text(replace_text) + local lines = vim.api.nvim_buf_get_lines(buffer, 0, -1, false) + eq({ + "function foo()", + " local y = 2", + " local z = 3", + "end", + "", + "function bar()", + " return 42", + "end", + }, lines) + end) - it( - "should be able to visual line select an empty line and return out an empty line of text", - function() - vim.api.nvim_win_set_cursor(0, { 5, 0 }) - vim.api.nvim_feedkeys("V", "x", false) + it( + "should be able to visual line select an empty line and return out an empty line of text", + function() + vim.api.nvim_win_set_cursor(0, { 5, 0 }) + vim.api.nvim_feedkeys("V", "x", false) - test_utils.next_frame() - vim.api.nvim_feedkeys( - vim.api.nvim_replace_termcodes("<Esc>", true, false, true), - "x", - false - ) + test_utils.next_frame() + vim.api.nvim_feedkeys( + vim.api.nvim_replace_termcodes("<Esc>", true, false, true), + "x", + false + ) - local range = Range.from_visual_selection() - local text = range:to_text() - eq("", text) - end - ) + local range = Range.from_visual_selection() + local text = range:to_text() + eq("", text) + end + ) - it("should create range from simple visual line selection", function() - vim.api.nvim_win_set_cursor(0, { 2, 0 }) - vim.api.nvim_feedkeys("V", "x", false) + it("should create range from simple visual line selection", function() + vim.api.nvim_win_set_cursor(0, { 2, 0 }) + vim.api.nvim_feedkeys("V", "x", false) - test_utils.next_frame() - vim.api.nvim_feedkeys( - vim.api.nvim_replace_termcodes("<Esc>", true, false, true), - "x", - false - ) + test_utils.next_frame() + vim.api.nvim_feedkeys( + vim.api.nvim_replace_termcodes("<Esc>", true, false, true), + "x", + false + ) - local range = Range.from_visual_selection() - local text = range:to_text() - eq(" local x = 1", text) - end) + local range = Range.from_visual_selection() + local text = range:to_text() + eq(" local x = 1", text) + end) end) diff --git a/lua/99/test/implement-fn_spec-DO-NOT-USE-YET.lua b/lua/99/test/implement-fn_spec-DO-NOT-USE-YET.lua deleted file mode 100644 index 317d270..0000000 --- a/lua/99/test/implement-fn_spec-DO-NOT-USE-YET.lua +++ /dev/null @@ -1,55 +0,0 @@ --- luacheck: globals describe it assert -local _99 = require("99") -local test_utils = require("99.test.test_utils") -local eq = assert.are.same -local test_content = require("99.test.test_content") - ---- @param content string[] ---- @return _99.test.Provider, number -local function setup(content) - local p = test_utils.TestProvider.new() - _99.setup({ - provider = p, - }) - - local buffer = test_utils.create_file(content, "lua", 3, 3) - return p, buffer -end - ---- @param buffer number ---- @return string[] -local function r(buffer) - return vim.api.nvim_buf_get_lines(buffer, 0, -1, false) -end - -local content = { - "function some_other_function() end", - "function foo()", - " bar()", - "end", - "", -} -describe("implement_function", function() - it("basic call", function() - local p, buffer = setup(content) - _99.implement_fn() - eq(content, r(buffer)) - - p:resolve("success", "function bar()\n return 42\nend") - test_utils.next_frame() - - local expected_state = { - "function some_other_function() end", - "function bar()", - " return 42", - "end", - "function foo()", - " bar()", - "end", - "", - } - eq(expected_state, r(buffer)) - end) - - it("should cancel request when stop_all_requests is called", function() end) -end) diff --git a/lua/99/test/logger_spec.lua b/lua/99/test/logger_spec.lua index a602dd5..c7bda84 100644 --- a/lua/99/test/logger_spec.lua +++ b/lua/99/test/logger_spec.lua @@ -5,7 +5,7 @@ local eq = assert.are.same local now = 0 time.now = function() - return now + return now end --- @class _99.Test.Logger.RequestLogs @@ -15,82 +15,82 @@ end --- @param all_logs string[][] --- @return _99.Test.Logger.RequestLogs local function l(all_logs) - local out = {} - for _, logs in ipairs(all_logs) do - local lines = {} - table.insert(out, lines) - for _, log_line in ipairs(logs) do - table.insert(lines, vim.json.decode(log_line)) - end + local out = {} + for _, logs in ipairs(all_logs) do + local lines = {} + table.insert(out, lines) + for _, log_line in ipairs(logs) do + table.insert(lines, vim.json.decode(log_line)) end - return out + end + return out end describe("Logger", function() - after_each(function() - Logger.reset() - now = 0 - Logger.set_max_cached_requests(2) - end) + after_each(function() + Logger.reset() + now = 0 + Logger.set_max_cached_requests(2) + end) - it("no caching of non ID'd logs. Global logs", function() - eq({}, Logger.logs()) + it("no caching of non ID'd logs. Global logs", function() + eq({}, Logger.logs()) - local ok = pcall(Logger.debug, Logger, "test log") - eq({}, Logger.logs()) - eq(ok, false) - end) + local ok = pcall(Logger.debug, Logger, "test log") + eq({}, Logger.logs()) + eq(ok, false) + end) - it("cache logs, keep max count", function() - eq({}, Logger.logs()) - local logger = Logger:set_id(69) + it("cache logs, keep max count", function() + eq({}, Logger.logs()) + local logger = Logger:set_id(69) - logger:debug("test log") + logger:debug("test log") - eq({ - { - { level = "DEBUG", id = 69, msg = "test log" }, - }, - }, l(Logger.logs())) + eq({ + { + { level = "DEBUG", id = 69, msg = "test log" }, + }, + }, l(Logger.logs())) - local logger2 = logger:set_id(420) - now = 1000 - logger2:error("error log") + local logger2 = logger:set_id(420) + now = 1000 + logger2:error("error log") - eq({ - { - { level = "ERROR", id = 420, msg = "error log" }, - }, - { - { level = "DEBUG", id = 69, msg = "test log" }, - }, - }, l(Logger.logs())) + eq({ + { + { level = "ERROR", id = 420, msg = "error log" }, + }, + { + { level = "DEBUG", id = 69, msg = "test log" }, + }, + }, l(Logger.logs())) - now = 1001 - logger:warn("warn log") + now = 1001 + logger:warn("warn log") - eq({ - { - { level = "DEBUG", id = 69, msg = "test log" }, - { level = "WARN", id = 69, msg = "warn log" }, - }, - { - { level = "ERROR", id = 420, msg = "error log" }, - }, - }, l(Logger.logs())) + eq({ + { + { level = "DEBUG", id = 69, msg = "test log" }, + { level = "WARN", id = 69, msg = "warn log" }, + }, + { + { level = "ERROR", id = 420, msg = "error log" }, + }, + }, l(Logger.logs())) - local logger3 = logger:set_id(1337) - now = 1002 - logger3:info("info log") + local logger3 = logger:set_id(1337) + now = 1002 + logger3:info("info log") - eq({ - { - { level = "INFO", id = 1337, msg = "info log" }, - }, - { - { level = "DEBUG", id = 69, msg = "test log" }, - { level = "WARN", id = 69, msg = "warn log" }, - }, - }, l(Logger.logs())) - end) + eq({ + { + { level = "INFO", id = 1337, msg = "info log" }, + }, + { + { level = "DEBUG", id = 69, msg = "test log" }, + { level = "WARN", id = 69, msg = "warn log" }, + }, + }, l(Logger.logs())) + end) end) diff --git a/lua/99/test/marks_spec.lua b/lua/99/test/marks_spec.lua index 3c28248..0b810f0 100644 --- a/lua/99/test/marks_spec.lua +++ b/lua/99/test/marks_spec.lua @@ -7,147 +7,143 @@ local test_utils = require("99.test.test_utils") local eq = assert.are.same describe("Mark", function() - local buffer + local buffer - before_each(function() - buffer = test_utils.create_file({ - "function foo()", - " local x = 1", - " return x", - "end", - "", - "function bar()", - " return 42", - "end", - }, "lua", 1, 0) - end) + before_each(function() + buffer = test_utils.create_file({ + "function foo()", + " local x = 1", + " return x", + "end", + "", + "function bar()", + " return 42", + "end", + }, "lua", 1, 0) + end) - after_each(function() - test_utils.clean_files() - end) + after_each(function() + test_utils.clean_files() + end) - it("should create a mark at a specific point", function() - local point = Point:new(2, 3) - local mark = Mark.mark_point(buffer, point) - local mark_point = Point.from_mark(mark) + it("should create a mark at a specific point", function() + local point = Point:new(2, 3) + local mark = Mark.mark_point(buffer, point) + local mark_point = Point.from_mark(mark) - eq(point, mark_point) + eq(point, mark_point) - mark:delete() - end) + mark:delete() + end) - it("marks range", function() - local start_point = Point:new(2, 3) - local end_point = Point:new(3, 10) - local range = Range:new(buffer, start_point, end_point) - local mark_start, mark_end = Mark.mark_range(range) - local actual_start = Point.from_mark(mark_start) - local actual_end = Point.from_mark(mark_end) + it("marks range", function() + local start_point = Point:new(2, 3) + local end_point = Point:new(3, 10) + local range = Range:new(buffer, start_point, end_point) + local mark_start, mark_end = Mark.mark_range(range) + local actual_start = Point.from_mark(mark_start) + local actual_end = Point.from_mark(mark_end) - eq(start_point, actual_start) - eq(end_point, actual_end) + eq(start_point, actual_start) + eq(end_point, actual_end) - mark_start:delete() - mark_end:delete() - end) + mark_start:delete() + mark_end:delete() + end) - it("should handle single-line ranges", function() - local start_point = Point:new(2, 3) - local end_point = Point:new(2, 10) - local range = Range:new(buffer, start_point, end_point) - local mark_start, mark_end = Mark.mark_range(range) - local actual_start = Point.from_mark(mark_start) - local actual_end = Point.from_mark(mark_end) + it("should handle single-line ranges", function() + local start_point = Point:new(2, 3) + local end_point = Point:new(2, 10) + local range = Range:new(buffer, start_point, end_point) + local mark_start, mark_end = Mark.mark_range(range) + local actual_start = Point.from_mark(mark_start) + local actual_end = Point.from_mark(mark_end) - eq(start_point, actual_start) - eq(end_point, actual_end) + eq(start_point, actual_start) + eq(end_point, actual_end) - mark_start:delete() - mark_end:delete() - end) + mark_start:delete() + mark_end:delete() + end) - it("should create mark one line above the range start", function() - local above_point = Point:new(2, 14) - local start_point = Point:new(3, 5) - local end_point = Point:new(4, 3) - local range = Range:new(buffer, start_point, end_point) - local mark = Mark.mark_above_range(range) - local mark_point = Point.from_mark(mark) + it("should create mark one line above the range start", function() + local above_point = Point:new(2, 14) + local start_point = Point:new(3, 5) + local end_point = Point:new(4, 3) + local range = Range:new(buffer, start_point, end_point) + local mark = Mark.mark_above_range(range) + local mark_point = Point.from_mark(mark) - eq(above_point, mark_point) + eq(above_point, mark_point) - mark:delete() - end) + mark:delete() + end) - it("should create mark at beginning when range starts at line 1", function() - local start_point = Point:new(1, 5) - local end_point = Point:new(2, 3) - local range = Range:new(buffer, start_point, end_point) - local mark = Mark.mark_above_range(range) - local mark_point = Point.from_mark(mark) + it("should create mark at beginning when range starts at line 1", function() + local start_point = Point:new(1, 5) + local end_point = Point:new(2, 3) + local range = Range:new(buffer, start_point, end_point) + local mark = Mark.mark_above_range(range) + local mark_point = Point.from_mark(mark) - local beginning_point = Point:new(1, 1) - eq(beginning_point, mark_point) - mark:delete() - end) + local beginning_point = Point:new(1, 1) + eq(beginning_point, mark_point) + mark:delete() + end) - it("should create mark at the end of the range", function() - local start_point = Point:new(2, 3) - local end_point = Point:new(3, 8) - local range = Range:new(buffer, start_point, end_point) - local mark = Mark.mark_end_of_range(buffer, range) - local mark_point = Point.from_mark(mark) - local expected_end_point = end_point:add(Point:new(0, 1)) + it("should create mark at the end of the range", function() + local start_point = Point:new(2, 3) + local end_point = Point:new(3, 8) + local range = Range:new(buffer, start_point, end_point) + local mark = Mark.mark_end_of_range(buffer, range) + local mark_point = Point.from_mark(mark) + local expected_end_point = end_point:add(Point:new(0, 1)) - eq(expected_end_point, mark_point) + eq(expected_end_point, mark_point) - mark:delete() - end) + mark:delete() + end) - it("should create mark above a function", function() - local func_start = Point:new(6, 1) - local func_end = Point:new(8, 4) - local func_range = Range:new(buffer, func_start, func_end) - local mock_func = { - function_range = func_range, - } + it("should create mark above a function", function() + local func_start = Point:new(6, 1) + local func_end = Point:new(8, 4) + local func_range = Range:new(buffer, func_start, func_end) + local mock_func = { + function_range = func_range, + } - local mark = Mark.mark_above_func(buffer, mock_func) - local mark_point = Point.from_mark(mark) - local expected_mark_point = func_start:sub(Point:new(1, 0)) + local mark = Mark.mark_above_func(buffer, mock_func) + local mark_point = Point.from_mark(mark) + local expected_mark_point = func_start:sub(Point:new(1, 0)) - eq(expected_mark_point, mark_point) - mark:delete() - end) + eq(expected_mark_point, mark_point) + mark:delete() + end) - it("should create mark at function body start", function() - local func_start = Point:new(6, 1) - local func_end = Point:new(8, 4) - local func_range = Range:new(buffer, func_start, func_end) - local mock_func = { - function_range = func_range, - } - local mark = Mark.mark_func_body(buffer, mock_func) - local mark_point = Point.from_mark(mark) + it("should create mark at function body start", function() + local func_start = Point:new(6, 1) + local func_end = Point:new(8, 4) + local func_range = Range:new(buffer, func_start, func_end) + local mock_func = { + function_range = func_range, + } + local mark = Mark.mark_func_body(buffer, mock_func) + local mark_point = Point.from_mark(mark) - eq(func_start, mark_point) + eq(func_start, mark_point) - mark:delete() - end) + mark:delete() + end) - it("should delete the extmark", function() - local point = Point:new(2, 3) - local mark = Mark.mark_point(buffer, point) - local mark_pos = Point.from_mark(mark) - eq(point, mark_pos) - mark:delete() + it("should delete the extmark", function() + local point = Point:new(2, 3) + local mark = Mark.mark_point(buffer, point) + local mark_pos = Point.from_mark(mark) + eq(point, mark_pos) + mark:delete() - local deleted_pos = vim.api.nvim_buf_get_extmark_by_id( - mark.buffer, - mark.nsid, - mark.id, - {} - ) - eq(0, #deleted_pos) - end) + local deleted_pos = + vim.api.nvim_buf_get_extmark_by_id(mark.buffer, mark.nsid, mark.id, {}) + eq(0, #deleted_pos) + end) end) diff --git a/lua/99/test/request_status_spec.lua b/lua/99/test/request_status_spec.lua index fadae04..90bb145 100644 --- a/lua/99/test/request_status_spec.lua +++ b/lua/99/test/request_status_spec.lua @@ -3,17 +3,17 @@ local eq = assert.are.same local RequestStatus = require("99.ops.request_status") describe("request_status", function() - it("setting lines and status line", function() - local status = RequestStatus.new(2000000, 3, "TITLE") - eq({ "⠙ TITLE" }, status:get()) + it("setting lines and status line", function() + local status = RequestStatus.new(2000000, 3, "TITLE") + eq({ "⠙ TITLE" }, status:get()) - status:push("foo") - status:push("bar") + status:push("foo") + status:push("bar") - eq({ "⠙ TITLE", "foo", "bar" }, status:get()) + eq({ "⠙ TITLE", "foo", "bar" }, status:get()) - status:push("baz") + status:push("baz") - eq({ "⠙ TITLE", "bar", "baz" }, status:get()) - end) + eq({ "⠙ TITLE", "bar", "baz" }, status:get()) + end) end) diff --git a/lua/99/test/test_utils.lua b/lua/99/test/test_utils.lua index 29b7344..2f60b86 100644 --- a/lua/99/test/test_utils.lua +++ b/lua/99/test/test_utils.lua @@ -1,14 +1,14 @@ local M = {} function M.next_frame() - local next = false - vim.schedule(function() - next = true - end) + local next = false + vim.schedule(function() + next = true + end) - vim.wait(1000, function() - return next - end) + vim.wait(1000, function() + return next + end) end M.created_files = {} @@ -25,64 +25,64 @@ local TestProvider = {} TestProvider.__index = TestProvider function TestProvider.new() - return setmetatable({}, TestProvider) + return setmetatable({}, TestProvider) end --- @param query string ---@param request _99.Request ---@param observer _99.ProviderObserver? function TestProvider:make_request(query, request, observer) - local logger = request.context.logger:set_area("TestProvider") - logger:debug("make_request", "tmp_file", request.context.tmp_file) - self.request = { - query = query, - request = request, - observer = observer, - logger = logger, - } + local logger = request.context.logger:set_area("TestProvider") + logger:debug("make_request", "tmp_file", request.context.tmp_file) + self.request = { + query = query, + request = request, + observer = observer, + logger = logger, + } end --- @param status _99.Request.ResponseState --- @param result string function TestProvider:resolve(status, result) - assert(self.request, "you cannot call resolve until make_request is called") - local obs = self.request.observer - if obs then - --- to match the behavior expected from the OpenCodeProvider - if self.request.request:is_cancelled() then - obs.on_complete("cancelled", result) - else - obs.on_complete(status, result) - end + assert(self.request, "you cannot call resolve until make_request is called") + local obs = self.request.observer + if obs then + --- to match the behavior expected from the OpenCodeProvider + if self.request.request:is_cancelled() then + obs.on_complete("cancelled", result) + else + obs.on_complete(status, result) end - self.request = nil + end + self.request = nil end --- @param line string function TestProvider:stdout(line) - assert(self.request, "you cannot call stdout until make_request is called") - local obs = self.request.observer - if obs then - obs.on_stdout(line) - end + assert(self.request, "you cannot call stdout until make_request is called") + local obs = self.request.observer + if obs then + obs.on_stdout(line) + end end --- @param line string function TestProvider:stderr(line) - assert(self.request, "you cannot call stderr until make_request is called") - local obs = self.request.observer - if obs then - obs.on_stderr(line) - end + assert(self.request, "you cannot call stderr until make_request is called") + local obs = self.request.observer + if obs then + obs.on_stderr(line) + end end M.TestProvider = TestProvider function M.clean_files() - for _, bufnr in ipairs(M.created_files) do - vim.api.nvim_buf_delete(bufnr, { force = true }) - end - M.created_files = {} + for _, bufnr in ipairs(M.created_files) do + vim.api.nvim_buf_delete(bufnr, { force = true }) + end + M.created_files = {} end ---@param contents string[] @@ -90,17 +90,17 @@ end ---@param row number? ---@param col number? function M.create_file(contents, file_type, row, col) - assert(type(contents) == "table", "contents must be a table of strings") - file_type = file_type or "lua" - local bufnr = vim.api.nvim_create_buf(false, false) + assert(type(contents) == "table", "contents must be a table of strings") + file_type = file_type or "lua" + local bufnr = vim.api.nvim_create_buf(false, false) - vim.api.nvim_set_current_buf(bufnr) - vim.bo[bufnr].ft = file_type - vim.api.nvim_buf_set_lines(bufnr, 0, -1, false, contents) - vim.api.nvim_win_set_cursor(0, { row or 1, col or 0 }) + vim.api.nvim_set_current_buf(bufnr) + vim.bo[bufnr].ft = file_type + vim.api.nvim_buf_set_lines(bufnr, 0, -1, false, contents) + vim.api.nvim_win_set_cursor(0, { row or 1, col or 0 }) - table.insert(M.created_files, bufnr) - return bufnr + table.insert(M.created_files, bufnr) + return bufnr end return M diff --git a/lua/99/test/visual_spec.lua b/lua/99/test/visual_spec.lua index bb02567..c4b1394 100644 --- a/lua/99/test/visual_spec.lua +++ b/lua/99/test/visual_spec.lua @@ -13,159 +13,156 @@ local Point = require("99.geo").Point --- @param end_col number --- @return _99.test.Provider, number, _99.Range local function setup(content, start_row, start_col, end_row, end_col) - local p = test_utils.TestProvider.new() - _99.setup({ - provider = p, - logger = { - error_cache_level = Levels.ERROR, - }, - }) + local p = test_utils.TestProvider.new() + _99.setup({ + provider = p, + logger = { + error_cache_level = Levels.ERROR, + }, + }) - local buffer = test_utils.create_file(content, "lua", start_row, start_col) + local buffer = test_utils.create_file(content, "lua", start_row, start_col) - -- Create a range for the visual selection - local start_point = Point:new(start_row, start_col) - local end_point = Point:new(end_row, end_col) - local range = Range:new(buffer, start_point, end_point) + -- Create a range for the visual selection + local start_point = Point:new(start_row, start_col) + local end_point = Point:new(end_row, end_col) + local range = Range:new(buffer, start_point, end_point) - return p, buffer, range + return p, buffer, range end --- @param buffer number --- @return string[] local function r(buffer) - return vim.api.nvim_buf_get_lines(buffer, 0, -1, false) + return vim.api.nvim_buf_get_lines(buffer, 0, -1, false) end local content = { - "local function foo()", - " -- TODO: implement", - "end", + "local function foo()", + " -- TODO: implement", + "end", } describe("visual", function() - it("should replace visual selection with AI response", function() - local p, buffer, range = setup(content, 2, 1, 2, 23) - local state = _99.__get_state() - local visual_fn = require("99.ops.over-range") + it("should replace visual selection with AI response", function() + local p, buffer, range = setup(content, 2, 1, 2, 23) + local state = _99.__get_state() + local visual_fn = require("99.ops.over-range") - local context = - require("99.request-context").from_current_buffer(state, 100) - visual_fn(context, range) + local context = + require("99.request-context").from_current_buffer(state, 100) + visual_fn(context, range) - eq(1, state:active_request_count()) - eq(content, r(buffer)) + eq(1, state:active_request_count()) + eq(content, r(buffer)) - p:resolve("success", " return 'implemented!'") - test_utils.next_frame() + p:resolve("success", " return 'implemented!'") + test_utils.next_frame() - local expected_state = { - "local function foo()", - " return 'implemented!'", - "end", - } - eq(expected_state, r(buffer)) - -- Note: Not checking active_request_count() == 0 due to logger bug with "id" key collision - end) + local expected_state = { + "local function foo()", + " return 'implemented!'", + "end", + } + eq(expected_state, r(buffer)) + -- Note: Not checking active_request_count() == 0 due to logger bug with "id" key collision + end) - it("should handle multi-line replacement", function() - local multi_line_content = { - "local function bar()", - " -- TODO: implement", - " -- more comments", - " -- even more", - "end", - } - local p, buffer, range = setup(multi_line_content, 2, 1, 4, 17) - local state = _99.__get_state() - local visual_fn = require("99.ops.over-range") + it("should handle multi-line replacement", function() + local multi_line_content = { + "local function bar()", + " -- TODO: implement", + " -- more comments", + " -- even more", + "end", + } + local p, buffer, range = setup(multi_line_content, 2, 1, 4, 17) + local state = _99.__get_state() + local visual_fn = require("99.ops.over-range") - local context = - require("99.request-context").from_current_buffer(state, 200) - visual_fn(context, range) + local context = + require("99.request-context").from_current_buffer(state, 200) + visual_fn(context, range) - eq(1, state:active_request_count()) - eq(multi_line_content, r(buffer)) + eq(1, state:active_request_count()) + eq(multi_line_content, r(buffer)) - p:resolve( - "success", - " local x = 1\n local y = 2\n return x + y" - ) - test_utils.next_frame() + p:resolve("success", " local x = 1\n local y = 2\n return x + y") + test_utils.next_frame() - local expected_state = { - "local function bar()", - " local x = 1", - " local y = 2", - " return x + y", - "end", - } - eq(expected_state, r(buffer)) - -- Note: Not checking active_request_count() == 0 due to logger bug with "id" key collision - end) + local expected_state = { + "local function bar()", + " local x = 1", + " local y = 2", + " return x + y", + "end", + } + eq(expected_state, r(buffer)) + -- Note: Not checking active_request_count() == 0 due to logger bug with "id" key collision + end) - it("should cancel request when stop_all_requests is called", function() - local p, buffer, range = setup(content, 2, 1, 2, 23) - local visual_fn = require("99.ops.over-range") - local state = _99.__get_state() - local context = - require("99.request-context").from_current_buffer(state, 300) + it("should cancel request when stop_all_requests is called", function() + local p, buffer, range = setup(content, 2, 1, 2, 23) + local visual_fn = require("99.ops.over-range") + local state = _99.__get_state() + local context = + require("99.request-context").from_current_buffer(state, 300) - visual_fn(context, range) + visual_fn(context, range) - eq(content, r(buffer)) + eq(content, r(buffer)) - assert.is_false(p.request.request:is_cancelled()) - assert.is_not_nil(p.request) - assert.is_not_nil(p.request.request) + assert.is_false(p.request.request:is_cancelled()) + assert.is_not_nil(p.request) + assert.is_not_nil(p.request.request) - _99.stop_all_requests() - test_utils.next_frame() + _99.stop_all_requests() + test_utils.next_frame() - assert.is_true(p.request.request:is_cancelled()) + assert.is_true(p.request.request:is_cancelled()) - p:resolve("success", " return 'should not appear'") - test_utils.next_frame() + p:resolve("success", " return 'should not appear'") + test_utils.next_frame() - -- Buffer should remain unchanged after cancellation - eq(content, r(buffer)) - end) + -- Buffer should remain unchanged after cancellation + eq(content, r(buffer)) + end) - it("should handle error cases with graceful failures", function() - local p, buffer, range = setup(content, 2, 1, 2, 23) - local visual_fn = require("99.ops.over-range") - local state = _99.__get_state() - local context = - require("99.request-context").from_current_buffer(state, 400) + it("should handle error cases with graceful failures", function() + local p, buffer, range = setup(content, 2, 1, 2, 23) + local visual_fn = require("99.ops.over-range") + local state = _99.__get_state() + local context = + require("99.request-context").from_current_buffer(state, 400) - visual_fn(context, range) + visual_fn(context, range) - eq(content, r(buffer)) + eq(content, r(buffer)) - p:resolve("failed", "Something went wrong") - test_utils.next_frame() + p:resolve("failed", "Something went wrong") + test_utils.next_frame() - -- Buffer should remain unchanged on failure - eq(content, r(buffer)) - end) + -- Buffer should remain unchanged on failure + eq(content, r(buffer)) + end) - it("should handle cancelled status gracefully", function() - local p, buffer, range = setup(content, 2, 1, 2, 23) - local visual_fn = require("99.ops.over-range") - local state = _99.__get_state() - local context = - require("99.request-context").from_current_buffer(state, 500) + it("should handle cancelled status gracefully", function() + local p, buffer, range = setup(content, 2, 1, 2, 23) + local visual_fn = require("99.ops.over-range") + local state = _99.__get_state() + local context = + require("99.request-context").from_current_buffer(state, 500) - visual_fn(context, range) + visual_fn(context, range) - eq(content, r(buffer)) + eq(content, r(buffer)) - -- Manually cancel and resolve as cancelled - p.request.request:cancel() - p:resolve("cancelled", "Request was cancelled") - test_utils.next_frame() + -- Manually cancel and resolve as cancelled + p.request.request:cancel() + p:resolve("cancelled", "Request was cancelled") + test_utils.next_frame() - -- Buffer should remain unchanged on cancellation - eq(content, r(buffer)) - end) + -- Buffer should remain unchanged on cancellation + eq(content, r(buffer)) + end) end) diff --git a/lua/99/time.lua b/lua/99/time.lua index b668ea1..66e2fe2 100644 --- a/lua/99/time.lua +++ b/lua/99/time.lua @@ -1,7 +1,7 @@ local M = {} function M.now() - return vim.uv.now() + return vim.uv.now() end return M diff --git a/lua/99/utils.lua b/lua/99/utils.lua index 7e012a2..c308624 100644 --- a/lua/99/utils.lua +++ b/lua/99/utils.lua @@ -4,11 +4,11 @@ local M = {} --- to make the _99_state have the project directory. --- @return string function M.random_file() - return string.format( - "%s/tmp/99-%d", - vim.uv.cwd(), - math.floor(math.random() * 10000) - ) + return string.format( + "%s/tmp/99-%d", + vim.uv.cwd(), + math.floor(math.random() * 10000) + ) end return M diff --git a/lua/99/window/init.lua b/lua/99/window/init.lua index 85f9da5..3475d5a 100644 --- a/lua/99/window/init.lua +++ b/lua/99/window/init.lua @@ -1,9 +1,10 @@ --- @class _99.window.Module --- @field active_windows _99.window.Window[] local M = { - active_windows = {}, + active_windows = {}, } local nsid = vim.api.nvim_create_namespace("99.window.error") +local nvim_win_is_valid = vim.api.nvim_win_is_valid --- @class _99.window.Config --- @field width number @@ -20,285 +21,329 @@ local nsid = vim.api.nvim_create_namespace("99.window.error") --- @param lines string[] --- @return string[] local function ensure_no_new_lines(lines) - local display_lines = {} - for _, line in ipairs(lines) do - local split_lines = vim.split(line, "\n") - for _, clean_line in ipairs(split_lines) do - table.insert(display_lines, clean_line) - end + local display_lines = {} + for _, line in ipairs(lines) do + local split_lines = vim.split(line, "\n") + for _, clean_line in ipairs(split_lines) do + table.insert(display_lines, clean_line) end - return display_lines + end + return display_lines end --- @return number --- @return number local function get_ui_dimensions() - local ui = vim.api.nvim_list_uis()[1] - return ui.width, ui.height + local ui = vim.api.nvim_list_uis()[1] + return ui.width, ui.height end --- @return _99.window.Config local function create_window_top_config() - local width, _ = get_ui_dimensions() - return { - width = width - 2, - height = 3, - anchor = "NE", - } + local width, _ = get_ui_dimensions() + return { + width = width - 2, + height = 3, + anchor = "NE", + } end --- @return _99.window.Config local function create_window_top_left_config() - local width, _ = get_ui_dimensions() - return { - width = math.floor(width / 3), - height = 3, - anchor = "NE", - } + local width, _ = get_ui_dimensions() + return { + width = math.floor(width / 3), + height = 3, + anchor = "NE", + } end --- @return _99.window.Config local function create_window_full_screen() - local width, height = get_ui_dimensions() - return { - width = width - 2, - height = height - 2, - anchor = "NE", - } + local width, height = get_ui_dimensions() + return { + width = width - 2, + height = height - 2, + anchor = "NE", + } +end + +--- @param config _99.window.Config +---@param offset_bottom number | nil +--- @return _99.window.Config +local function create_window_inside(config, offset_bottom) + offset_bottom = offset_bottom or 0 + return { + width = config.width - 2, + height = 1, + row = config.row + config.height - offset_bottom, + col = config.col + 1, + anchor = config.anchor, + } end --- @return _99.window.Config local function create_centered_window() - local width, height = get_ui_dimensions() - local win_width = math.floor(width * 2 / 3) - local win_height = math.floor(height / 3) - return { - width = win_width, - height = win_height, - row = math.floor((height - win_height) / 2), - col = math.floor((width - win_width) / 2), - } + local width, height = get_ui_dimensions() + local win_width = math.floor(width * 2 / 3) + local win_height = math.floor(height / 3) + return { + width = win_width, + height = win_height, + row = math.floor((height - win_height) / 2), + col = math.floor((width - win_width) / 2), + } end --- @param config _99.window.Config --- @param win_config vim.api.keyset.win_config --- @return _99.window.Window local function create_floating_window(config, win_config) - local buf_id = vim.api.nvim_create_buf(false, true) - local win_id = vim.api.nvim_open_win(buf_id, true, { - relative = "editor", - width = config.width, - height = config.height, - row = config.row or 0, - col = config.col or 0, - anchor = config.anchor, - style = "minimal", - border = win_config.border, - title = win_config.title, - title_pos = "center", - }) - local window = { - config = config, - win_id = win_id, - buf_id = buf_id, - } - vim.wo[win_id].wrap = true + local buf_id = vim.api.nvim_create_buf(false, true) + local win_id = vim.api.nvim_open_win(buf_id, true, { + relative = "editor", + width = config.width, + height = config.height, + row = config.row or 0, + col = config.col or 0, + anchor = config.anchor, + style = "minimal", + border = win_config.border, + title = win_config.title, + title_pos = "center", + zindex = win_config.zindex, + }) + local window = { + config = config, + win_id = win_id, + buf_id = buf_id, + } + vim.wo[win_id].wrap = true - table.insert(M.active_windows, window) - return window + table.insert(M.active_windows, window) + return window end --- @param window _99.window.Window local function highlight_error(window) - local line_count = vim.api.nvim_buf_line_count(window.buf_id) + local line_count = vim.api.nvim_buf_line_count(window.buf_id) - if line_count > 0 then - vim.api.nvim_buf_set_extmark(window.buf_id, nsid, 0, 0, { - end_row = 1, - hl_group = "Normal", - hl_eol = true, - }) - end + if line_count > 0 then + vim.api.nvim_buf_set_extmark(window.buf_id, nsid, 0, 0, { + end_row = 1, + hl_group = "Normal", + hl_eol = true, + }) + end - if line_count > 1 then - vim.api.nvim_buf_set_extmark(window.buf_id, nsid, 1, 0, { - end_row = line_count, - hl_group = "ErrorMsg", - hl_eol = true, - }) - end + if line_count > 1 then + vim.api.nvim_buf_set_extmark(window.buf_id, nsid, 1, 0, { + end_row = line_count, + hl_group = "ErrorMsg", + hl_eol = true, + }) + end end --- @param error_text string --- @return _99.window.Window function M.display_error(error_text) - local window = create_floating_window(create_window_top_config(), { - title = " 99 Error ", - border = "rounded", - }) - local lines = vim.split(error_text, "\n") + local window = create_floating_window(create_window_top_config(), { + title = " 99 Error ", + border = "rounded", + }) + local lines = vim.split(error_text, "\n") - table.insert(lines, 1, "") - table.insert( - lines, - 1, - "99: Fatal operational error encountered (error logs may have more in-depth information)" - ) + table.insert(lines, 1, "") + table.insert( + lines, + 1, + "99: Fatal operational error encountered (error logs may have more in-depth information)" + ) - vim.api.nvim_buf_set_lines(window.buf_id, 0, -1, false, lines) - highlight_error(window) - return window + vim.api.nvim_buf_set_lines(window.buf_id, 0, -1, false, lines) + highlight_error(window) + return window end --- @param window _99.window.Window local function window_close(window) - if vim.api.nvim_win_is_valid(window.win_id) then - vim.api.nvim_win_close(window.win_id, true) - end - if vim.api.nvim_buf_is_valid(window.buf_id) then - vim.api.nvim_buf_delete(window.buf_id, { force = true }) - end - - local found = false - for i, w in ipairs(M.active_windows) do - if w.buf_id == window.buf_id and w.win_id == window.win_id then - found = true - table.remove(M.active_windows, i) - break - end - end - - assert( - found, - "somehow we have closed a window that did not belong to the windows library" - ) + if nvim_win_is_valid(window.win_id) then + vim.api.nvim_win_close(window.win_id, true) + end + if vim.api.nvim_buf_is_valid(window.buf_id) then + vim.api.nvim_buf_delete(window.buf_id, { force = true }) + end end --- @param text string function M.display_cancellation_message(text) - local config = create_window_top_left_config() - local window = create_floating_window(config, { - title = " 99 Cancelled ", - border = "rounded", - }) - local lines = vim.split(text, "\n") + local config = create_window_top_left_config() + local window = create_floating_window(config, { + title = " 99 Cancelled ", + border = "rounded", + }) + local lines = vim.split(text, "\n") - vim.api.nvim_buf_set_lines(window.buf_id, 0, -1, false, lines) + vim.api.nvim_buf_set_lines(window.buf_id, 0, -1, false, lines) - vim.api.nvim_buf_set_extmark(window.buf_id, nsid, 0, 0, { - end_row = vim.api.nvim_buf_line_count(window.buf_id), - hl_group = "WarningMsg", - hl_eol = true, - }) + vim.api.nvim_buf_set_extmark(window.buf_id, nsid, 0, 0, { + end_row = vim.api.nvim_buf_line_count(window.buf_id), + hl_group = "WarningMsg", + hl_eol = true, + }) - vim.defer_fn(function() - window_close(window) - end, 5000) + vim.defer_fn(function() + if nvim_win_is_valid(window.win_id) then + M.clear_active_popups() + end + end, 5000) - return window + return window end --- TODO: i dont like how the other interfaces have text being passed in --- but this one is lines. probably need to revisit this --- @param lines string[] function M.display_full_screen_message(lines) - --- TODO: i really dislike that i am closing and opening windows - --- i think it would be better to perserve the one that is already open - --- but i just want this to work and then later... ohh much later, ill fix - --- this basic nonsense - M.clear_active_popups() - local window = create_floating_window(create_window_full_screen(), { - title = " 99 ", - border = "rounded", - }) - local display_lines = ensure_no_new_lines(lines) - vim.api.nvim_buf_set_lines(window.buf_id, 0, -1, false, display_lines) + --- TODO: i really dislike that i am closing and opening windows + --- i think it would be better to perserve the one that is already open + --- but i just want this to work and then later... ohh much later, ill fix + --- this basic nonsense + M.clear_active_popups() + local window = create_floating_window(create_window_full_screen(), { + title = " 99 ", + border = "rounded", + }) + local display_lines = ensure_no_new_lines(lines) + vim.api.nvim_buf_set_lines(window.buf_id, 0, -1, false, display_lines) end --- @return _99.window.Window --- @return _99.window.Config function M.create_centered_window() - M.clear_active_popups() - local config = create_centered_window() - local window = create_floating_window(config, { - title = " 99 ", - border = "rounded", - }) - return window, config + M.clear_active_popups() + local config = create_centered_window() + local window = create_floating_window(config, { + title = " 99 ", + border = "rounded", + }) + return window, config end --- @param message string[] function M.display_centered_message(message) - M.clear_active_popups() - local config = create_centered_window() - local window = create_floating_window(config, { - title = " 99 ", - border = "rounded", - }) - local display_lines = ensure_no_new_lines(message) + M.clear_active_popups() + local config = create_centered_window() + local window = create_floating_window(config, { + title = " 99 ", + border = "rounded", + }) + local display_lines = ensure_no_new_lines(message) - vim.api.nvim_buf_set_lines(window.buf_id, 0, -1, false, display_lines) + vim.api.nvim_buf_set_lines(window.buf_id, 0, -1, false, display_lines) - return window + return window end ---- @param cb fun(success: boolean, result: string): nil ---- @param opts {} -function M.capture_input(cb, opts) - _ = opts +--- @param win _99.window.Window +--- @param name string +local function set_defaul_win_options(win, name) + vim.api.nvim_buf_set_name(win.buf_id, name) + vim.wo[win.win_id].number = true + vim.bo[win.buf_id].filetype = "99" + vim.bo[win.buf_id].buftype = "acwrite" + vim.bo[win.buf_id].bufhidden = "wipe" + vim.bo[win.buf_id].swapfile = false +end - --- TODO: styling should be extendable - M.clear_active_popups() - local config = create_centered_window() - local win = create_floating_window(config, { - title = " 99 Prompt ", - border = "rounded", - }) +--- @class _99.window.CaptureInputOpts +--- @field cb fun(success: boolean, result: string): nil +--- @field on_load? fun(): nil - vim.api.nvim_buf_set_name(win.buf_id, "99-prompt") - vim.wo[win.win_id].number = true - vim.bo[win.buf_id].filetype = "99" - vim.bo[win.buf_id].buftype = "acwrite" - vim.bo[win.buf_id].bufhidden = "wipe" - vim.bo[win.buf_id].swapfile = false +--- @param opts _99.window.CaptureInputOpts +function M.capture_input(opts) + _ = opts + M.clear_active_popups() - local group = vim.api.nvim_create_augroup( - "99_present_prompt_" .. win.buf_id, - { clear = true } - ) + local config = create_centered_window() + local win = create_floating_window(config, { + title = " 99 Prompt ", + border = "rounded", + }) - vim.api.nvim_create_autocmd("BufWriteCmd", { - group = group, - buffer = win.buf_id, - callback = function() - local lines = vim.api.nvim_buf_get_lines(win.buf_id, 0, -1, false) - local result = table.concat(lines, "\n") - M.clear_active_popups() - cb(true, result) - end, - }) + set_defaul_win_options(win, "99-prompt") + vim.api.nvim_set_current_win(win.win_id) - vim.api.nvim_create_autocmd("BufUnload", { - group = group, - buffer = win.buf_id, - callback = function() - vim.api.nvim_del_augroup_by_id(group) - end, - }) + local group = vim.api.nvim_create_augroup( + "99_present_prompt_" .. win.buf_id, + { clear = true } + ) - vim.keymap.set("n", "q", function() + vim.api.nvim_create_autocmd("BufLeave", { + group = group, + buffer = win.buf_id, + callback = function() + if nvim_win_is_valid(win.win_id) then + vim.api.nvim_set_current_win(win.win_id) + else M.clear_active_popups() - cb(false, "") - end, { buffer = win.buf_id, nowait = true }) + end + end, + }) + + vim.api.nvim_create_autocmd("BufWriteCmd", { + group = group, + buffer = win.buf_id, + callback = function() + if not nvim_win_is_valid(win.win_id) then + return + end + local lines = vim.api.nvim_buf_get_lines(win.buf_id, 0, -1, false) + local result = table.concat(lines, "\n") + M.clear_active_popups() + opts.cb(true, result) + end, + }) + + vim.api.nvim_create_autocmd("BufUnload", { + group = group, + buffer = win.buf_id, + callback = function() + if not nvim_win_is_valid(win.win_id) then + return + end + vim.api.nvim_del_augroup_by_id(group) + end, + }) + + vim.api.nvim_create_autocmd("WinClosed", { + group = group, + pattern = tostring(win.win_id), + callback = function() + if not nvim_win_is_valid(win.win_id) then + return + end + M.clear_active_popups() + opts.cb(false, "") + end, + }) + + vim.keymap.set("n", "q", function() + M.clear_active_popups() + opts.cb(false, "") + end, { buffer = win.buf_id, nowait = true }) + + if opts.on_load then + vim.schedule(opts.on_load) + end end ---- not worried about perf, we will likely only ever have 1 maybe 2 windows ---- ever open at the same time function M.clear_active_popups() - while #M.active_windows > 0 do - local window = M.active_windows[1] - window_close(window) - end + for _, window in ipairs(M.active_windows) do + window_close(window) + end + M.active_windows = {} end return M diff --git a/scratch/cursor/rules/database.mdc b/scratch/cursor/rules/database.mdc new file mode 100644 index 0000000..842a6fd --- /dev/null +++ b/scratch/cursor/rules/database.mdc @@ -0,0 +1,2 @@ +another file about my db + diff --git a/scratch/cursor/rules/my-proj.mdc b/scratch/cursor/rules/my-proj.mdc new file mode 100644 index 0000000..51fa0bd --- /dev/null +++ b/scratch/cursor/rules/my-proj.mdc @@ -0,0 +1,4 @@ +--- +--- +Rules about my project + diff --git a/scratch/custom_rules/back-end.md b/scratch/custom_rules/back-end.md new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/scratch/custom_rules/back-end.md diff --git a/scratch/custom_rules/foo.md b/scratch/custom_rules/foo.md new file mode 100644 index 0000000..139597f --- /dev/null +++ b/scratch/custom_rules/foo.md @@ -0,0 +1,2 @@ + + diff --git a/scratch/custom_rules/front-end.md b/scratch/custom_rules/front-end.md new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/scratch/custom_rules/front-end.md diff --git a/scratch/custom_rules/vim.lsp.md b/scratch/custom_rules/vim.lsp.md new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/scratch/custom_rules/vim.lsp.md diff --git a/scratch/custom_rules/vim.md b/scratch/custom_rules/vim.md new file mode 100644 index 0000000..f3a6435 --- /dev/null +++ b/scratch/custom_rules/vim.md @@ -0,0 +1,2 @@ +whenever you write any lua code that is meant to be executed in vim +add 1 print statement per line of code describing the next code diff --git a/scratch/custom_rules/vim.treesitter.md b/scratch/custom_rules/vim.treesitter.md new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/scratch/custom_rules/vim.treesitter.md diff --git a/scratch/refresh.lua b/scratch/refresh.lua index 2d7b27b..aba679b 100644 --- a/scratch/refresh.lua +++ b/scratch/refresh.lua @@ -1,15 +1,43 @@ R("99") -local foo_bar = {fizz = 3} +local _99 = require("99") +_99.setup({ + completion = { + custom_rules = { + "~/.behaviors/", + }, + source = "cmp", + }, +}) +local Ext = require("99.extensions") +local Agents = require("99.extensions.agents") +local Helpers = require("99.extensions.agents.helpers") -function fizz_buzz(count) - local result = {} - for i = 1, count do - end - return result -end +print(vim.inspect(Agents.rules(_99.__get_state()))) +print(vim.inspect(Helpers.ls("/home/theprimeagen/.behaviors"))) + +--- @class Config +--- @field width number +--- @field height number +--- @field offset_row number +--- @field offset_col number +--- @field border string +function create_window(config) + -- Create a new buffer + local buf = vim.api.nvim_create_buf(false, true) + + -- Configure the floating window + local win_config = { + relative = 'editor', + width = config.width, + height = config.height, + row = config.offset_row, + col = config.offset_col, + style = 'minimal', + border = 'rounded' + } + + -- Open the floating window + local win = vim.api.nvim_open_win(buf, true, win_config) ---- @param numbers number[] -function sort(numbers) - table.sort(numbers) - return numbers + return { buf = buf, win = win } end diff --git a/scripts/ci/install_treesitter_parsers.lua b/scripts/ci/install_treesitter_parsers.lua new file mode 100644 index 0000000..6393fd8 --- /dev/null +++ b/scripts/ci/install_treesitter_parsers.lua @@ -0,0 +1,72 @@ +local install_dir = vim.fn.stdpath("data") .. "/site" +local ok_setup, setup_err = pcall(function() + require("nvim-treesitter").setup({ + install_dir = install_dir, + }) +end) + +if not ok_setup then + vim.api.nvim_echo( + { { "Error: " .. tostring(setup_err), "ErrorMsg" } }, + true, + {} + ) + vim.cmd("cq") +end + +local ok_install, install_err = pcall(function() + require("nvim-treesitter").install({ + "c", + "cpp", + "go", + "lua", + "php", + "python", + "typescript", + "javascript", + "java", + "ruby", + "tsx", + "c_sharp", + "vue", + }):wait(300000) +end) + +if not ok_install then + vim.api.nvim_echo({ + { "Error: " .. tostring(install_err), "ErrorMsg" }, + }, true, {}) + vim.cmd("cq") +end + +local required_parsers = { + c = "c.so", + cpp = "cpp.so", + go = "go.so", + lua = "lua.so", + php = "php.so", + python = "python.so", + typescript = "typescript.so", + javascript = "javascript.so", + java = "java.so", + ruby = "ruby.so", + tsx = "tsx.so", + c_sharp = "c_sharp.so", + vue = "vue.so", +} + +for lang, filename in pairs(required_parsers) do + local parser_path = install_dir .. "/parser/" .. filename + if not vim.uv.fs_stat(parser_path) then + vim.api.nvim_echo({ + { + "Error: " + .. lang + .. " parser missing after install: " + .. parser_path, + "ErrorMsg", + }, + }, true, {}) + vim.cmd("cq") + end +end diff --git a/scripts/tests/minimal.vim b/scripts/tests/minimal.vim index c267d9e..30093ba 100644 --- a/scripts/tests/minimal.vim +++ b/scripts/tests/minimal.vim @@ -21,18 +21,67 @@ runtime! plugin/plenary.vim runtime! plugin/nvim-treesitter.lua lua <<EOF -local required_parsers = { - 'c', 'cpp', 'go', 'lua', 'php', 'python', 'typescript', 'javascript', 'java', 'ruby', 'tsx', 'c_sharp', 'vue', 'elixir' +vim.opt.rtp:append(vim.fn.stdpath('data') .. '/site') + +-- parsers to attempt to install (for user convenience) +local all_parsers = { + 'c', 'cpp', 'go', 'lua', 'php', 'python', 'typescript', + 'javascript', 'java', 'ruby', 'tsx', 'c_sharp', 'vue', 'elixir' } -local installed_parsers = require'nvim-treesitter.info'.installed_parsers() -local to_install = vim.tbl_filter(function(parser) - return not vim.tbl_contains(installed_parsers, parser) -end, required_parsers) -if #to_install > 0 then + +-- parsers actually required for tests to run +local required_parsers = { 'lua', 'typescript' } + +local function missing_parsers(parsers) + local missing = {} + local buf = vim.api.nvim_create_buf(false, true) + for _, lang in ipairs(parsers) do + local ok = pcall(vim.treesitter.get_parser, buf, lang) + if not ok then + table.insert(missing, lang) + end + end + vim.api.nvim_buf_delete(buf, { force = true }) + return missing +end + +local function install_with_main_branch_api(parsers) + local install_dir = vim.fn.stdpath('data') .. '/site' + require('nvim-treesitter').setup({ install_dir = install_dir }) + require('nvim-treesitter').install(parsers):wait(300000) +end + +-- master branch is deprecated but still widely used +local function install_with_master_branch_api(parsers) -- fixes 'pos_delta >= 0' error - https://github.com/nvim-lua/plenary.nvim/issues/52 vim.cmd('set display=lastline') -- make "TSInstall*" available - vim.cmd 'runtime! plugin/nvim-treesitter.vim' - vim.cmd('TSInstallSync ' .. table.concat(to_install, ' ')) + vim.cmd('runtime! plugin/nvim-treesitter.vim') + vim.cmd('TSInstallSync ' .. table.concat(parsers, ' ')) +end + +local to_install = missing_parsers(all_parsers) +if #to_install > 0 then + -- Detect which nvim-treesitter API is available (main vs master branch) + local has_main_api, ts = pcall(require, 'nvim-treesitter') + has_main_api = has_main_api and type(ts.install) == 'function' + + if has_main_api then + local ok, err = pcall(install_with_main_branch_api, to_install) + if not ok then + print('Tree-sitter install error (main API): ' .. tostring(err)) + end + else + local ok, err = pcall(install_with_master_branch_api, to_install) + if not ok then + print('Tree-sitter install error (master API): ' .. tostring(err)) + end + end +end + +-- only error if required parsers are still missing +local still_missing = missing_parsers(required_parsers) +if #still_missing > 0 then + error('Missing required Tree-sitter parsers: ' .. table.concat(still_missing, ', ')) end EOF |
