diff options
| author | ThePrimeAgain <theprimeagain@theprimeagain.com> | 2025-12-09 14:15:06 -0700 |
|---|---|---|
| committer | ThePrimeAgain <theprimeagain@theprimeagain.com> | 2025-12-09 14:15:06 -0700 |
| commit | 87a762d45e26ce49f3a6834c0b810cf1e2cba885 (patch) | |
| tree | 4c054d68e8da69495269b5391208258f28af014e | |
| parent | 2c40101e1369061e95690ea5509199c7e42413a6 (diff) | |
| download | a4-87a762d45e26ce49f3a6834c0b810cf1e2cba885.tar.xz a4-87a762d45e26ce49f3a6834c0b810cf1e2cba885.zip | |
a working request... its looking good
| -rw-r--r-- | Makefile | 2 | ||||
| -rw-r--r-- | lua/99/editor/treesitter.lua | 40 | ||||
| -rw-r--r-- | lua/99/geo.lua | 13 | ||||
| -rw-r--r-- | lua/99/init.lua | 8 | ||||
| -rw-r--r-- | lua/99/ops/fill-in-function.lua | 50 | ||||
| -rw-r--r-- | lua/99/ops/marks.lua | 25 | ||||
| -rw-r--r-- | lua/99/ops/request_status.lua | 68 | ||||
| -rw-r--r-- | lua/99/request/init.lua | 4 | ||||
| -rw-r--r-- | lua/99/test/fill_in_function_spec.lua | 2 | ||||
| -rw-r--r-- | lua/99/test/request_status_spec.lua | 19 | ||||
| -rw-r--r-- | scratch/refresh.lua | 3 |
11 files changed, 167 insertions, 67 deletions
@@ -9,7 +9,7 @@ lua_lint: lua_test: echo "===> Testing" nvim --headless --noplugin -u scripts/tests/minimal.vim \ - -c "PlenaryBustedDirectory lua/vim-with-me {minimal_init = 'scripts/tests/minimal.vim'}" + -c "PlenaryBustedDirectory lua/99 {minimal_init = 'scripts/tests/minimal.vim'}" lua_clean: echo "===> Cleaning" diff --git a/lua/99/editor/treesitter.lua b/lua/99/editor/treesitter.lua index ce6b7bc..9c8ef55 100644 --- a/lua/99/editor/treesitter.lua +++ b/lua/99/editor/treesitter.lua @@ -102,12 +102,14 @@ Function.__index = Function function Function.from_ts_node(ts_node, lang, buffer, cursor) local ok, query = pcall(vim.treesitter.query.get, lang, function_query) if not ok or query == nil then - Logger:fatal("INVARIANT: not query or not ok") + Logger:fatal("not query or not ok") error("failed") end - local func = { } - for id, node, _ in query:iter_captures(ts_node, buffer, 0, -1, { all = true }) do + local func = {} + for id, node, _ in + query:iter_captures(ts_node, buffer, 0, -1, { all = true }) + do local range = Range:from_ts_node(node, buffer) local name = query.captures[id] if range:contains(cursor) then @@ -122,14 +124,14 @@ function Function.from_ts_node(ts_node, lang, buffer, cursor) end -- Not all functions have bodies, example: function foo() end - assert(func.function_node ~= nil, "INVARIANT: function_node not found") - assert(func.function_range ~= nil, "INVARIANT: function_range not found") + assert(func.function_node ~= nil, "function_node not found") + assert(func.function_range ~= nil, "function_range not found") return setmetatable(func, Function) end --- @class _99.Scope ---- @field scope _99.treesitter.Node[] +--- @field scope _99.treesitter.TSNode[] --- @field range _99.Range[] --- @field buffer number --- @field cursor _99.Point @@ -153,7 +155,7 @@ function Scope:has_scope() return #self.range > 0 end ---- @return _99.treesitter.Node | nil +--- @return _99.treesitter.TSNode | nil function Scope:get_inner_scope() return self.scope[#self.scope] end @@ -163,7 +165,7 @@ function Scope:get_inner_range() return self.range[#self.range] end ---- @param node _99.treesitter.Node +--- @param node _99.treesitter.TSNode function Scope:push(node) local range = Range:from_ts_node(node, self.buffer) if not range:contains(self.cursor) then @@ -212,9 +214,21 @@ function M.containing_function(buffer, cursor) for id, node, _ in query:iter_captures(root, buffer, 0, -1, { all = true }) do local range = Range:from_ts_node(node, buffer) local name = query.captures[id] - Logger:debug("containing_function#capture", "range", range:to_string(), "name", name) + Logger:debug( + "containing_function#capture", + "range", + range:to_string(), + "name", + name + ) if name == "context.function" and range:contains(cursor) then - Logger:debug(" containing_function#capture#found", "cursor", cursor:to_string(), "range", range:to_string()) + Logger:debug( + " containing_function#capture#found", + "cursor", + cursor:to_string(), + "range", + range:to_string() + ) if not found_range then found_range = range found_node = node @@ -223,7 +237,11 @@ function M.containing_function(buffer, cursor) found_node = node end end - Logger:debug("containing_function#capture finished loop", "found_range", found_range and found_range:to_string() or "found_range is nil") + Logger:debug( + "containing_function#capture finished loop", + "found_range", + found_range and found_range:to_string() or "found_range is nil" + ) end if not found_range then diff --git a/lua/99/geo.lua b/lua/99/geo.lua index 88b2545..43962f6 100644 --- a/lua/99/geo.lua +++ b/lua/99/geo.lua @@ -159,6 +159,19 @@ function Point:eq(point) return project(self) == project(point) end +--- @param mark _99.Mark +--- @return _99.Point +function Point.from_mark(mark) + --- buf extmark by id is a 0 based api + local pos = + vim.api.nvim_buf_get_extmark_by_id(mark.buffer, mark.nsid, mark.id, {}) + + return setmetatable({ + row = pos[1] + 1, + col = pos[2] + 1, + }, Point) +end + --- @class _99.Range --- @field start _99.Point --- @field end_ _99.Point diff --git a/lua/99/init.lua b/lua/99/init.lua index c89bd9b..f712132 100644 --- a/lua/99/init.lua +++ b/lua/99/init.lua @@ -50,6 +50,14 @@ function _99.fill_in_function() ops.fill_in_function(_99_state) end +--- As a warning do not use this function unless you intend to use it for +--- debugging purposes. Any other use will likely result in this library +--- not working properly +--- @return _99.State +function _99.__get_state() + return _99_state +end + --- @param opts _99.Options? function _99.setup(opts) opts = opts or {} diff --git a/lua/99/ops/fill-in-function.lua b/lua/99/ops/fill-in-function.lua index 2b7c064..8439fd3 100644 --- a/lua/99/ops/fill-in-function.lua +++ b/lua/99/ops/fill-in-function.lua @@ -5,24 +5,25 @@ local Request = require("99.request") local Mark = require("99.ops.marks") local Context = require("99.ops.context") local editor = require("99.editor") +local RequestStatus = require("99.ops.request_status") --- @param res string --- @param location _99.Location local function update_file_with_changes(res, location) - assert( - location.marks.function_location, - "function_location mark was not set, unrecoverable error" - ) - local mark = location.marks.function_location local buffer = location.buffer + local mark = location.marks.function_location - local mark_pos = vim.api.nvim_buf_get_mark(buffer, mark) - local mark_point = Point:new(mark_pos[1], mark_pos[2] + 1) + assert( + mark and buffer, + "mark and buffer have to be set on the location object" + ) + local func_start = Point.from_mark(mark) local ts = editor.treesitter - local scopes = ts.function_scopes(mark_point, buffer) + local func = ts.containing_function(buffer, func_start) - if not scopes or not scopes:has_scope() then + mark:delete() + if not func then Logger:error( "update_file_with_changes: unable to find function at mark location" ) @@ -32,8 +33,7 @@ local function update_file_with_changes(res, location) return end - local range = scopes.range[#scopes.range] - + local range = func.function_range local function_start_row, _ = range.start:to_vim() local function_end_row, _ = range.end_:to_vim() @@ -48,16 +48,6 @@ local function update_file_with_changes(res, location) end --- @param _99 _99.State ---- @param location _99.Location ----@param thoughts string[] -local function update_with_cot(_99, location, thoughts) - local lines = _99.ai_stdout_rows - --- use nvim_buf_set_extmark({buffer}, {ns_id}, {line}, {col}, {opts}) - --- only show the last few thoughts lines - --- i want to display virtual text of the latest thoughts -end - ---- @param _99 _99.State local function fill_in_function(_99) local ts = editor.treesitter local buffer = vim.api.nvim_get_current_buf() @@ -73,8 +63,8 @@ local function fill_in_function(_99) editor.Location.from_ts_node(func.function_node, func.function_range) local virt_line_count = _99.ai_stdout_rows if virt_line_count >= 0 then - location.marks.function_location = Mark.mark_func_body(buffer, func) - :set_max_virt_lines(virt_line_count) + location.marks.function_location = + Mark.mark_func_body(buffer, func) end local context = Context.new(_99):finalize(_99, location) @@ -87,15 +77,19 @@ local function fill_in_function(_99) context:add_to_request(request) request:add_prompt_content(_99.prompts.prompts.fill_in_function) + local request_status = RequestStatus.new( + 250, + _99.ai_stdout_rows, + location.marks.function_location + ) + request_status:start() + request:start({ on_stdout = function(line) - local mark = location.marks.function_location - if mark then - mark:set_virtual_text({ line }) - end + request_status:push(line) end, on_complete = function(ok, response) - location:clear_marks() + request_status:stop() if not ok then Logger:fatal( "unable to fill in function, enable and check logger for more details" diff --git a/lua/99/ops/marks.lua b/lua/99/ops/marks.lua index bea4607..544f155 100644 --- a/lua/99/ops/marks.lua +++ b/lua/99/ops/marks.lua @@ -9,8 +9,7 @@ local nsid = vim.api.nvim_create_namespace("99.marks") --- @class _99.Mark --- @field id any -- whatever extmark returns --- @field buffer number ---- @field max_lines number ---- @field lines string[] +--- @field nsid any local Mark = {} Mark.__index = Mark @@ -21,45 +20,27 @@ function Mark.mark_func_body(buffer, func) local line, col = start:to_vim() local id = vim.api.nvim_buf_set_extmark(buffer, nsid, line, col, {}) - Logger:debug("mark_func_body", "function_range", func.function_range:to_string(), "function_range#start", func.function_range.start:to_string()) return setmetatable({ id = id, buffer = buffer, - max_lines = 1, - lines = {}, + nsid = nsid, }, Mark) end ---- @param count number ---- @return _99.Mark -function Mark:set_max_virt_lines(count) - self.max_lines = count - return self -end - --- @param lines string[] function Mark:set_virtual_text(lines) local pos = vim.api.nvim_buf_get_extmark_by_id(self.buffer, nsid, self.id, {}) assert(#pos > 0, "extmark is broken. it does not exist") local row, col = pos[1], pos[2] - Logger:warn("set_virt", "pos", pos) - - for _, line in ipairs(lines) do - table.insert(self.lines, line) - if #self.lines > self.max_lines then - table.remove(self.lines, 1) - end - end local formatted_lines = {} - for _, line in ipairs(self.lines) do + for _, line in ipairs(lines) do table.insert(formatted_lines, { { line, "Comment" }, }) end - vim.api.nvim_buf_set_extmark(self.buffer, nsid, row, col, { id = self.id, virt_lines = formatted_lines, diff --git a/lua/99/ops/request_status.lua b/lua/99/ops/request_status.lua new file mode 100644 index 0000000..357f568 --- /dev/null +++ b/lua/99/ops/request_status.lua @@ -0,0 +1,68 @@ +--- @class _99.RequestStatus +--- @field update_time number the milliseconds per update to the virtual text +--- @field status_line string +--- @field lines string[] +--- @field max_lines number +--- @field running boolean +--- @field mark _99.Mark? +local RequestStatus = {} +RequestStatus.__index = RequestStatus + +--- @param update_time number +--- @param max_lines number +--- @param mark _99.Mark? +--- @return _99.RequestStatus +function RequestStatus.new(update_time, max_lines, mark) + local self = setmetatable({}, RequestStatus) + self.update_time = update_time + self.max_lines = max_lines + self.status_line = "⠋" + self.lines = {} + self.running = false + self.mark = mark + return self +end + +--- @return string[] +function RequestStatus:get() + local result = { self.status_line } + for _, line in ipairs(self.lines) do + table.insert(result, line) + end + return result +end + +--- @param line string +function RequestStatus:push(line) + table.insert(self.lines, line) + if #self.lines > self.max_lines - 1 then + table.remove(self.lines, 1) + end +end + +function RequestStatus:start() + local braille_chars = {"⠋", "⠙", "⠹", "⠸", "⠼", "⠴", "⠦", "⠧", "⠇", "⠏"} + local index = 0 + + local function update_spinner() + if not self.running then + return + end + + self.status_line = braille_chars[index % #braille_chars + 1] + if self.mark then + self.mark:set_virtual_text(self:get()) + end + index = index + 1 + vim.defer_fn(update_spinner, self.update_time) + end + + self.running = true + update_spinner() +end + +function RequestStatus:stop() + self.running = false +end + +return RequestStatus diff --git a/lua/99/request/init.lua b/lua/99/request/init.lua index fd9e3b0..885e4f6 100644 --- a/lua/99/request/init.lua +++ b/lua/99/request/init.lua @@ -124,10 +124,10 @@ Request.__index = Request --- @param opts _99.Request.Opts function Request.new(opts) - validate_opts(opts) - opts.provider = opts.provider or OpenCodeProvider + validate_opts(opts) + local config = opts --[[ @as _99.Request.Config ]] return setmetatable({ diff --git a/lua/99/test/fill_in_function_spec.lua b/lua/99/test/fill_in_function_spec.lua index 4aad1a7..c55c775 100644 --- a/lua/99/test/fill_in_function_spec.lua +++ b/lua/99/test/fill_in_function_spec.lua @@ -47,6 +47,4 @@ describe("fill_in_function", function() eq(expected_state, r(buffer)) end) end - it("stdout into function virtual text", function() - end) end) diff --git a/lua/99/test/request_status_spec.lua b/lua/99/test/request_status_spec.lua new file mode 100644 index 0000000..a6e7d42 --- /dev/null +++ b/lua/99/test/request_status_spec.lua @@ -0,0 +1,19 @@ +-- luacheck: globals describe it assert +local eq = assert.are.same +local RequestStatus = require("99.ops.request_status") + +describe("request_status", function() + it("setting lines and status line", function() + local status = RequestStatus.new(250, 3) + eq({"⠋"}, status:get()) + + status:push("foo") + status:push("bar") + + eq({"⠋", "foo", "bar"}, status:get()) + + status:push("baz") + + eq({"⠋", "bar", "baz"}, status:get()) + end) +end) diff --git a/scratch/refresh.lua b/scratch/refresh.lua index 5a6cee0..1bd4fb4 100644 --- a/scratch/refresh.lua +++ b/scratch/refresh.lua @@ -55,7 +55,8 @@ function create_mark() local fn = ts.containing_function(buffer, Point:from_cursor()) assert(fn, "could not find containing function") - local m = Mark.mark_func_body(buffer, fn):set_max_virt_lines(3) + local _99 = require("99") + local m = Mark.mark_func_body(_99.__get_state(), buffer, fn) write(m, { "hello, world", |
