diff options
| author | ThePrimeAgain <theprimeagain@theprimeagain.com> | 2025-12-06 11:12:20 -0700 |
|---|---|---|
| committer | ThePrimeAgain <theprimeagain@theprimeagain.com> | 2025-12-06 11:12:20 -0700 |
| commit | 045ac5e104765415337d691ec923afa1cbd032c8 (patch) | |
| tree | 3697940fcb62c920daecad66b7bc3ce30b6e444c /lua | |
| parent | 3ae40439f36baeb96979a77cb5faf0dd6890bf31 (diff) | |
| download | a4-045ac5e104765415337d691ec923afa1cbd032c8.tar.xz a4-045ac5e104765415337d691ec923afa1cbd032c8.zip | |
virtual text is running! no testings yet
Diffstat (limited to 'lua')
| -rw-r--r-- | lua/99/editor/location.lua | 18 | ||||
| -rw-r--r-- | lua/99/editor/lsp.lua | 6 | ||||
| -rw-r--r-- | lua/99/editor/treesitter.lua | 152 | ||||
| -rw-r--r-- | lua/99/geo.lua | 46 | ||||
| -rw-r--r-- | lua/99/language/init.lua | 7 | ||||
| -rw-r--r-- | lua/99/ops/fill-in-function.lua | 48 | ||||
| -rw-r--r-- | lua/99/ops/marks.lua | 78 | ||||
| -rw-r--r-- | lua/99/prompt_settings.lua | 2 | ||||
| -rw-r--r-- | lua/99/test/fill_in_function_spec.lua | 52 | ||||
| -rw-r--r-- | lua/99/test/integration_spec.lua | 53 | ||||
| -rw-r--r-- | lua/99/test/test_utils.lua | 2 |
11 files changed, 309 insertions, 155 deletions
diff --git a/lua/99/editor/location.lua b/lua/99/editor/location.lua index 8432443..b33c5e0 100644 --- a/lua/99/editor/location.lua +++ b/lua/99/editor/location.lua @@ -1,18 +1,16 @@ -local Range = require("99.geo").Range - --- @class _99.Location --- @field full_path string ---- @field range Range ---- @field node TSNode +--- @field range _99.Range +--- @field node _99.treesitter.Node --- @field buffer number --- @field file_type string ---- @field marks table<string, string> +--- @field marks table<string, _99.Mark> --- @field ns_id string local Location = {} Location.__index = Location ---- @param node TSNode ---- @param range Range +--- @param node _99.treesitter.Node +--- @param range _99.Range function Location.from_ts_node(node, range) local full_path = vim.api.nvim_buf_get_name(range.buffer) local file_type = vim.bo[range.buffer].ft @@ -30,4 +28,10 @@ function Location.from_ts_node(node, range) }, Location) end +function Location:clear_marks() + for _, mark in pairs(self.marks) do + mark:delete() + end +end + return Location diff --git a/lua/99/editor/lsp.lua b/lua/99/editor/lsp.lua index 6c7923a..9d7bceb 100644 --- a/lua/99/editor/lsp.lua +++ b/lua/99/editor/lsp.lua @@ -16,7 +16,7 @@ local logger = require("99.logger.logger") --- @field range LspRange --- @field uri string ---- @param node TSNode +--- @param node _99.treesitter.Node local function ts_node_to_lsp_position(node) local start_row, start_col, _, _ = node:range() -- Treesitter node range return { line = start_row, character = start_col } @@ -52,7 +52,7 @@ function Lsp:new(config) end --- @param buffer number ---- @param node TSNode[] +--- @param node _99.treesitter.Node[] --- @param cb fun(res: LspDefinitionResult | nil): nil function Lsp:get_ts_node_definition(buffer, node, cb) local range = ts_node_to_lsp_position(node) @@ -94,7 +94,7 @@ function Lsp:_filter_flatten(resultsList, buffer) end --- @param buffer number ---- @param nodes TSNode[] +--- @param nodes _99.treesitter.Node[] --- @param cb fun(res: LspDefinitionResult[]): nil function Lsp:batch_get_ts_node_definitions(buffer, nodes, cb) if #nodes == 0 then diff --git a/lua/99/editor/treesitter.lua b/lua/99/editor/treesitter.lua index 42162e5..ce6b7bc 100644 --- a/lua/99/editor/treesitter.lua +++ b/lua/99/editor/treesitter.lua @@ -1,13 +1,18 @@ local geo = require("99.geo") local Logger = require("99.logger.logger") local Range = geo.Range +local Mark = require("99.ops.marks") ---- @class TSNode ---- @field start fun(self: TSNode): number, number, number ---- @field end_ fun(self: TSNode): number, number, number ---- @field named fun(self: TSNode): boolean ---- @field type fun(self: TSNode): string ---- @field range fun(self: TSNode): number, number, number, number +--- @class _99.treesitter.TSNode +--- @field start fun(): number +--- @field end_ fun(): number + +--- @class _99.treesitter.Node +--- @field start fun(self: _99.treesitter.Node): number, number, number +--- @field end_ fun(self: _99.treesitter.Node): number, number, number +--- @field named fun(self: _99.treesitter.Node): boolean +--- @field type fun(self: _99.treesitter.Node): string +--- @field range fun(self: _99.treesitter.Node): number, number, number, number local M = {} @@ -29,8 +34,8 @@ local function tree_root(buffer, lang) end --- @param buffer number ---- @param cursor Point ---- @return TSNode | nil +--- @param cursor _99.Point +--- @return _99.treesitter.Node | nil function M.identifier(buffer, cursor) local lang = vim.bo[buffer].ft local root = tree_root(buffer, lang) @@ -81,17 +86,59 @@ function M.identifier(buffer, cursor) return found end ---- @class Scope ---- @field scope TSNode[] ---- @field range Range[] +--- @class _99.treesitter.Function +--- @field function_range _99.Range +--- @field function_node _99.treesitter.Node +--- @field body_range _99.Range +--- @field body_node _99.treesitter.Node +local Function = {} +Function.__index = Function + +--- @param ts_node _99.treesitter.TSNode +---@param lang string +---@param buffer number +---@param cursor _99.Point +---@return _99.treesitter.Function +function Function.from_ts_node(ts_node, lang, buffer, cursor) + local ok, query = pcall(vim.treesitter.query.get, lang, function_query) + if not ok or query == nil then + Logger:fatal("INVARIANT: not query or not ok") + error("failed") + end + + local func = { } + for id, node, _ in query:iter_captures(ts_node, buffer, 0, -1, { all = true }) do + local range = Range:from_ts_node(node, buffer) + local name = query.captures[id] + if range:contains(cursor) then + if name == "context.function" then + func.function_node = node + func.function_range = range + elseif name == "context.body" then + func.body_node = node + func.body_range = range + end + end + end + + -- Not all functions have bodies, example: function foo() end + assert(func.function_node ~= nil, "INVARIANT: function_node not found") + assert(func.function_range ~= nil, "INVARIANT: function_range not found") + + return setmetatable(func, Function) +end + +--- @class _99.Scope +--- @field scope _99.treesitter.Node[] +--- @field range _99.Range[] --- @field buffer number ---- @field cursor Point +--- @field cursor _99.Point local Scope = {} Scope.__index = Scope ---- @param cursor Point +--- @param cursor _99.Point --- @param buffer number ---- @return Scope +--- @return _99.Scope function Scope:new(cursor, buffer) return setmetatable({ scope = {}, @@ -106,17 +153,17 @@ function Scope:has_scope() return #self.range > 0 end ---- @return TSNode | nil +--- @return _99.treesitter.Node | nil function Scope:get_inner_scope() return self.scope[#self.scope] end ---- @return Range | nil +--- @return _99.Range | nil function Scope:get_inner_range() return self.range[#self.range] end ---- @param node TSNode +--- @param node _99.treesitter.Node function Scope:push(node) local range = Range:from_ts_node(node, self.buffer) if not range:contains(self.cursor) then @@ -134,10 +181,72 @@ function Scope:finalize() end) end ---- @param cursor Point +--- @param buffer number +--- @param cursor _99.Point +--- @return _99.treesitter.Function? +function M.containing_function(buffer, cursor) + local lang = vim.bo[buffer].ft + local root = tree_root(buffer, lang) + if not root then + Logger:debug("LSP: could not find tree root") + return nil + end + + local ok, query = pcall(vim.treesitter.query.get, lang, function_query) + if not ok or query == nil then + Logger:debug( + "LSP: not ok or query", + "query", + vim.inspect(query), + "lang", + lang, + "ok", + vim.inspect(ok) + ) + return nil + end + + --- @type _99.Range + local found_range = nil + local found_node = nil + for id, node, _ in query:iter_captures(root, buffer, 0, -1, { all = true }) do + local range = Range:from_ts_node(node, buffer) + local name = query.captures[id] + Logger:debug("containing_function#capture", "range", range:to_string(), "name", name) + if name == "context.function" and range:contains(cursor) then + Logger:debug(" containing_function#capture#found", "cursor", cursor:to_string(), "range", range:to_string()) + if not found_range then + found_range = range + found_node = node + elseif found_range:area() > range:area() then + found_range = range + found_node = node + end + end + Logger:debug("containing_function#capture finished loop", "found_range", found_range and found_range:to_string() or "found_range is nil") + end + + if not found_range then + return nil + end + assert(found_node, "INVARIANT: found_range is not nil but found node is") + + ok, query = pcall(vim.treesitter.query.get, lang, function_query) + if not ok or query == nil then + Logger:fatal("INVARIANT: found_range ", "range", found_range:to_text()) + return + end + + --- TODO: learn the diagnostics + --- @type _99.treesitter.Function + return Function.from_ts_node(found_node, lang, buffer, cursor) +end + +--- @param cursor _99.Point --- @param buffer number? ---- @return Scope -function M.function_scopes(cursor, buffer) +--- @return _99.Scope +function M.scopes(cursor, buffer) + Logger:fatal("M.scopes not implemented") buffer = buffer or vim.api.nvim_get_current_buf() local scope = Scope:new(cursor, buffer) @@ -164,7 +273,6 @@ function M.function_scopes(cursor, buffer) for id, node, _ in query:iter_captures(root, buffer, 0, -1, { all = true }) do local name = query.captures[id] - print("cursor query captures", "id", id, "name", name) if name == "context.scope" then scope:push(node) elseif name == "context.body" then @@ -177,7 +285,7 @@ function M.function_scopes(cursor, buffer) return scope end ---- @return TSNode[] +--- @return _99.treesitter.Node[] function M.imports() assert(false, "not implemented") local root = tree_root() diff --git a/lua/99/geo.lua b/lua/99/geo.lua index d5a804c..88b2545 100644 --- a/lua/99/geo.lua +++ b/lua/99/geo.lua @@ -1,6 +1,6 @@ local project_row = 100000000 ---- @param point_or_row Point | number +--- @param point_or_row _99.Point | number --- @param col number | nil --- @return number local function project(point_or_row, col) @@ -11,7 +11,7 @@ local function project(point_or_row, col) end --- stores all values as 1 based ---- @class Point +--- @class _99.Point --- @field row number --- @field col number local Point = {} @@ -49,7 +49,7 @@ end --- 1 based point --- @param row number --- @param col number ---- @return Point +--- @return _99.Point function Point:new(row, col) assert(type(row) == "number", "expected row to be a number") assert(type(col) == "number", "expected col to be a number") @@ -65,6 +65,7 @@ function Point:from_cursor() col = 0, }, self) + --- NOTE: win_get_cursor 1, 0 based return local cursor = vim.api.nvim_win_get_cursor(0) local cursor_row, cursor_col = cursor[1], cursor[2] point.row = cursor_row @@ -75,13 +76,18 @@ end --- @param ns_id string ---@param buffer number ---@param mark_id string -function Point:from_extmark(ns_id, buffer, mark_id) +function Point.from_extmark(ns_id, buffer, mark_id) + --- NOTE: get_extmark_by_id returns 0 based local row, col = vim.api.nvim_buf_get_extmark_by_id(buffer, ns_id, mark_id) + return setmetatable({ + row = row + 1, + col = col + 1, + }, Point) end --- @param row number ---@param col number ---- @return Point +--- @return _99.Point function Point:from_ts_point(row, col) return setmetatable({ row = row + 1, @@ -90,7 +96,7 @@ function Point:from_ts_point(row, col) end --- stores all 2 points ---- @param range Range +--- @param range _99.Range --- @return boolean function Point:in_ts_range(range) return range:contains(self) @@ -123,46 +129,46 @@ function Point:to_ts() return self.row - 1, self.col - 1 end ---- @param point Point +--- @param point _99.Point --- @return boolean function Point:gt(point) return project(self) > project(point) end ---- @param point Point +--- @param point _99.Point --- @return boolean function Point:lt(point) return project(self) < project(point) end ---- @param point Point +--- @param point _99.Point --- @return boolean function Point:lte(point) return project(self) <= project(point) end ---- @param point Point +--- @param point _99.Point --- @return boolean function Point:gte(point) return project(self) >= project(point) end ---- @param point Point +--- @param point _99.Point --- @return boolean function Point:eq(point) return project(self) == project(point) end ---- @class Range ---- @field start Point ---- @field end_ Point +--- @class _99.Range +--- @field start _99.Point +--- @field end_ _99.Point --- @field buffer number local Range = {} Range.__index = Range ---@param buffer number ---- @param start Point ----@param end_ Point +--- @param start _99.Point +---@param end_ _99.Point function Range:new(buffer, start, end_) return setmetatable({ start = start, @@ -171,9 +177,9 @@ function Range:new(buffer, start, end_) }, self) end ----@param node TSNode +---@param node _99.treesitter.TSNode ---@param buffer number ----@return Range +---@return _99.Range function Range:from_ts_node(node, buffer) -- ts is zero based local start_row, start_col, _ = node:start() @@ -187,7 +193,7 @@ function Range:from_ts_node(node, buffer) return setmetatable(range, self) end ---- @param point Point +--- @param point _99.Point --- @return boolean function Range:contains(point) local start = project(self.start) @@ -212,7 +218,7 @@ function Range:to_text() return table.concat(text, "\n") end ---- @param range Range +--- @param range _99.Range --- @return boolean function Range:contains_range(range) return self.start:lte(range.start) and self.end_:gte(range.end_) diff --git a/lua/99/language/init.lua b/lua/99/language/init.lua index 50710f4..80edb31 100644 --- a/lua/99/language/init.lua +++ b/lua/99/language/init.lua @@ -27,13 +27,6 @@ end --- @return number function M.add_function_spacing(_99, location) local lang = M.languages[location.file_type] - print( - "language", - "file_type", - vim.inspect(location.file_type), - "lang", - vim.inspect(lang) - ) if not lang then Logger:fatal("langauge currently not supported", "lang", lang) end diff --git a/lua/99/ops/fill-in-function.lua b/lua/99/ops/fill-in-function.lua index 066ebd8..2b7c064 100644 --- a/lua/99/ops/fill-in-function.lua +++ b/lua/99/ops/fill-in-function.lua @@ -1,12 +1,10 @@ local geo = require("99.geo") -local Range = geo.Range local Point = geo.Point local Logger = require("99.logger.logger") local Request = require("99.request") -local marks = require("99.ops.marks") +local Mark = require("99.ops.marks") local Context = require("99.ops.context") local editor = require("99.editor") -local Languages = require("99.language") --- @param res string --- @param location _99.Location @@ -62,33 +60,23 @@ end --- @param _99 _99.State local function fill_in_function(_99) local ts = editor.treesitter + local buffer = vim.api.nvim_get_current_buf() local cursor = Point:from_cursor() - local scopes = ts.function_scopes(cursor) - local scope = scopes:get_inner_scope() - local range = scopes:get_inner_range() + local func = ts.containing_function(buffer, cursor) - if not range or not scope then - Logger:error("fill_in_function: unable to find any containing function") - error("you cannot call fill_in_function not in a function") + if not func then + Logger:fatal("fill_in_function: unable to find any containing function") + return end - local location = editor.Location.from_ts_node(scope, range) - local ai_input_row = Languages.add_function_spacing(_99, location) - if ai_input_row == -1 then - Logger:warn("fill_in_function: add_function_spacing returned -1") - else - local buffer = location.buffer - local mark = marks( - buffer, - Range:new( - buffer, - Point:new(ai_input_row, 1), - Point:new(ai_input_row, 1) - ) - ) + local location = + editor.Location.from_ts_node(func.function_node, func.function_range) + local virt_line_count = _99.ai_stdout_rows + if virt_line_count >= 0 then + location.marks.function_location = Mark.mark_func_body(buffer, func) + :set_max_virt_lines(virt_line_count) end - -- location.marks.virtual_text_start = marks(buffer, local context = Context.new(_99):finalize(_99, location) local request = Request.new({ provider = _99.provider_override, @@ -97,12 +85,17 @@ local function fill_in_function(_99) }) context:add_to_request(request) - location.marks.function_location = marks(location.buffer, range) request:add_prompt_content(_99.prompts.prompts.fill_in_function) request:start({ - on_stdout = function(line) end, + on_stdout = function(line) + local mark = location.marks.function_location + if mark then + mark:set_virtual_text({ line }) + end + end, on_complete = function(ok, response) + location:clear_marks() if not ok then Logger:fatal( "unable to fill in function, enable and check logger for more details" @@ -110,8 +103,7 @@ local function fill_in_function(_99) end update_file_with_changes(response, location) end, - on_stderr = function(line) - end, + on_stderr = function(line) end, }) end diff --git a/lua/99/ops/marks.lua b/lua/99/ops/marks.lua index 51d9b4a..99b7eb2 100644 --- a/lua/99/ops/marks.lua +++ b/lua/99/ops/marks.lua @@ -1,18 +1,72 @@ -local marks_to_use = "yuiophjklnm" -local mark_index = 0 +local Logger = require("99.logger.logger") + +local nsid = vim.api.nvim_create_namespace("99.marks") + +--- @class _99.Mark.Text +--- @field text string +--- @field hlgroup string + +--- @class _99.Mark +--- @field id any -- whatever extmark returns +--- @field buffer number +--- @field max_lines number +--- @field lines string[] +local Mark = {} +Mark.__index = Mark --- @param buffer number ----@param range Range ----@return string -local function mark_function(buffer, range) - local start_row, start_col = range.start:to_vim() - local idx = (mark_index + 1) % #marks_to_use - local mark = marks_to_use:sub(idx, idx) +--- @param func _99.treesitter.Function +function Mark.mark_func_body(buffer, func) + local start = func.function_range.start + local line, col = start:to_vim() + local id = vim.api.nvim_buf_set_extmark(buffer, nsid, line, col, {}) + + Logger:debug("mark_func_body", "function_range", func.function_range:to_string(), "function_range#start", func.function_range.start:to_string()) + return setmetatable({ + id = id, + buffer = buffer, + max_lines = 1, + lines = {}, + }, Mark) +end + +--- @param count number +--- @return _99.Mark +function Mark:set_max_virt_lines(count) + self.max_lines = count + return self +end + +--- @param lines string[] +function Mark:set_virtual_text(lines) + local pos = + vim.api.nvim_buf_get_extmark_by_id(self.buffer, nsid, self.id, {}) + assert(#pos > 0, "extmark is broken. it does not exist") + local row, col = pos[1], pos[2] + + for _, line in ipairs(lines) do + table.insert(self.lines, line) + if #self.lines > self.max_lines then + table.remove(self.lines, 1) + end + end - vim.api.nvim_buf_set_mark(buffer, mark, start_row + 1, start_col, {}) + local formatted_lines = {} + for _, line in ipairs(self.lines) do + table.insert(formatted_lines, { + { line, "Comment" }, + }) + end + + + vim.api.nvim_buf_set_extmark(self.buffer, nsid, row, col, { + id = self.id, + virt_lines = formatted_lines, + }) +end - mark_index = idx - return mark +function Mark:delete() + vim.api.nvim_buf_del_extmark(self.buffer, nsid, self.id) end -return mark_function +return Mark diff --git a/lua/99/prompt_settings.lua b/lua/99/prompt_settings.lua index d97a7fb..d4ba667 100644 --- a/lua/99/prompt_settings.lua +++ b/lua/99/prompt_settings.lua @@ -31,7 +31,7 @@ local prompt_settings = { ) end, - --- @param range Range + --- @param range _99.Range get_range_text = function(range) return string.format("<FunctionText>%s</FunctionText>", range:to_text()) end, diff --git a/lua/99/test/fill_in_function_spec.lua b/lua/99/test/fill_in_function_spec.lua new file mode 100644 index 0000000..4aad1a7 --- /dev/null +++ b/lua/99/test/fill_in_function_spec.lua @@ -0,0 +1,52 @@ +-- luacheck: globals describe it assert +local _99 = require("99") +local test_utils = require("99.test.test_utils") +local eq = assert.are.same +local test_content = require("99.test.test_content") + +--- @param content string[] +--- @return _99.test.Provider, number +local function setup(content) + local p = test_utils.TestProvider.new() + _99.setup({ + provider = p, + }) + + local buffer = test_utils.create_file(content, "lua", 2) + return p, buffer +end + +--- @param buffer number +--- @return string[] +local function r(buffer) + return vim.api.nvim_buf_get_lines(buffer, 0, -1, false) +end + +local cases = { + { "single line", test_content.empty_function_single_line }, + { "multiline", test_content.empty_function_2_lines }, +} + +describe("fill_in_function", function() + for _, case in ipairs(cases) do + it(case[1], function() + local p, buffer = setup(case[2]) + _99.fill_in_function() + eq(case[2], r(buffer)) + + p:resolve(true, "function foo()\n return 42\nend") + test_utils.next_frame() + + local expected_state = { + "", + "function foo()", + " return 42", + "end", + "", + } + eq(expected_state, r(buffer)) + end) + end + it("stdout into function virtual text", function() + end) +end) diff --git a/lua/99/test/integration_spec.lua b/lua/99/test/integration_spec.lua deleted file mode 100644 index c318f8e..0000000 --- a/lua/99/test/integration_spec.lua +++ /dev/null @@ -1,53 +0,0 @@ --- luacheck: globals describe it assert -local _99 = require("99") -local test_utils = require("99.test.test_utils") -local eq = assert.are.same -local test_content = require("99.test.test_content") - ---- @param content string[] ---- @return _99.test.Provider, number -local function setup(content) - local p = test_utils.TestProvider.new() - _99.setup({ - provider = p, - }) - - local buffer = test_utils.create_file(content, "lua", 2) - return p, buffer -end - ---- @param buffer number ---- @return string[] -local function r(buffer) - return vim.api.nvim_buf_get_lines(buffer, 0, -1, false) -end - -describe("fill_in_function", function() - it("should fill in function that multiple lines", function() - local p, buffer = setup(test_content.empty_function_2_lines) - _99.fill_in_function() - - local expected_state = { - "", - "function foo()", - "", - "", - "", - "end", - "", - } - eq(expected_state, r(buffer)) - - p:resolve(true, "function foo()\n return 42\nend") - test_utils.next_frame() - - expected_state = { - "", - "function foo()", - " return 42", - "end", - "", - } - eq(expected_state, r(buffer)) - end) -end) diff --git a/lua/99/test/test_utils.lua b/lua/99/test/test_utils.lua index 410375d..372b565 100644 --- a/lua/99/test/test_utils.lua +++ b/lua/99/test/test_utils.lua @@ -43,7 +43,6 @@ function TestProvider:resolve(success, result) assert(self.request, "you cannot call resolve until make_request is called") local obs = self.request.observer if obs then - print("complete callback", vim.inspect(obs)) obs.on_complete(success, result) end self.request = nil @@ -88,7 +87,6 @@ function M.create_file(contents, file_type, row, col) vim.api.nvim_set_current_buf(bufnr) vim.bo[bufnr].ft = file_type vim.api.nvim_buf_set_lines(bufnr, 0, -1, false, contents) - print("row", row or 1, "col", col or 0) vim.api.nvim_win_set_cursor(0, { row or 1, col or 0 }) table.insert(M.created_files, bufnr) |
