summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--lua/99/editor/location.lua3
-rw-r--r--lua/99/editor/treesitter.lua15
-rw-r--r--lua/99/ops/context.lua23
-rw-r--r--lua/99/ops/fill-in-function.lua58
-rw-r--r--lua/99/ops/implement-fn.lua2
-rw-r--r--lua/99/ops/init.lua2
-rw-r--r--lua/99/ops/prompts.lua0
-rw-r--r--lua/99/ops/system-rules.lua38
-rw-r--r--lua/99/prompt_settings.lua7
-rw-r--r--scratch/refresh.lua4
10 files changed, 69 insertions, 83 deletions
diff --git a/lua/99/editor/location.lua b/lua/99/editor/location.lua
index 2da56b8..b4b3de8 100644
--- a/lua/99/editor/location.lua
+++ b/lua/99/editor/location.lua
@@ -5,16 +5,13 @@ local Point = require("99.geo")
--- @field range Range
--- @field buffer number
--- @field marks table<string, string>
---- @field cursor Point
local Location = {}
Location.__index = Location
function Location.from_range(range)
local full_path = vim.api.nvim_buf_get_name(range.buffer)
- local cursor = Point:from_cursor()
return setmetatable({
- cursor = cursor,
buffer = range.buffer,
full_path = full_path,
range = range,
diff --git a/lua/99/editor/treesitter.lua b/lua/99/editor/treesitter.lua
index dff7a3d..b76939d 100644
--- a/lua/99/editor/treesitter.lua
+++ b/lua/99/editor/treesitter.lua
@@ -100,6 +100,11 @@ function Scope:has_scope()
return #self.range > 0
end
+--- @return Range | nil
+function Scope:get_inner_scope()
+ return self.range[#self.range]
+end
+
--- @param node TSNode
function Scope:push(node)
local range = Range:from_ts_node(node, self.buffer)
@@ -120,25 +125,24 @@ end
--- @param cursor Point
--- @param buffer number?
---- @return Scope | nil
+--- @return Scope
function M.function_scopes(cursor, buffer)
buffer = buffer or vim.api.nvim_get_current_buf()
+ local scope = Scope:new(cursor, buffer)
local lang = vim.bo[buffer].ft
local root = tree_root(buffer, lang)
if not root then
Logger:debug("LSP: could not find tree root")
- return nil
+ return scope
end
local ok, query = pcall(vim.treesitter.query.get, lang, function_query)
if not ok or query == nil then
Logger:debug("LSP: not ok or query", "query", vim.inspect(query), "lang", lang, "ok", vim.inspect(ok))
- return nil
+ return scope
end
- print("function_scopes", buffer)
- local scope = Scope:new(cursor, buffer)
for _, match, _ in query:iter_matches(root, buffer, 0, -1, { all = true }) do
for _, nodes in pairs(match) do
for _, node in ipairs(nodes) do
@@ -154,6 +158,7 @@ end
--- @return TSNode[]
function M.imports()
+ assert(false, "not implemented")
local root = tree_root()
if not root then
return {}
diff --git a/lua/99/ops/context.lua b/lua/99/ops/context.lua
index d88a170..0428671 100644
--- a/lua/99/ops/context.lua
+++ b/lua/99/ops/context.lua
@@ -13,9 +13,15 @@ end
local Context = {}
Context.__index = Context
-function Context.new()
+--- @param _99 _99.State
+function Context.new(_99)
+ local mds = {}
+ for _, md in ipairs(_99.md_files) do
+ table.insert(mds, md)
+ end
+
return setmetatable({
- md_file_names = {},
+ md_file_names = mds,
ai_context = {},
tmp_file = random_file(),
}, Context)
@@ -28,13 +34,22 @@ function Context:add_md_file_name(md_file_name)
return self
end
---- @param location _99.Location
-function Context:finalize(location)
+function Context:_read_md_files()
--- @ai use location's buffer's full path and walk back until we are at cwd
--- @ai and read each of the md_file_names. if it exists then add it to
--- @ai ai_context.
end
+--- @param _99 _99.State
+--- @param location _99.Location
+--- @return self
+function Context:finalize(_99, location)
+ table.insert(self.ai_context, _99.prompts.get_file_location(location))
+ table.insert(self.ai_context, _99.prompts.get_range_text(location.range))
+ table.insert(self.ai_context, _99.prompts.tmp_file_location(self.tmp_file))
+ return self
+end
+
--- @param request _99.Request
function Context:add_to_request(request)
for _, context in ipairs(self.ai_context) do
diff --git a/lua/99/ops/fill-in-function.lua b/lua/99/ops/fill-in-function.lua
index aebb444..ed28447 100644
--- a/lua/99/ops/fill-in-function.lua
+++ b/lua/99/ops/fill-in-function.lua
@@ -1,28 +1,27 @@
local Point = require("99.geo").Point
local Logger = require("99.logger.logger")
local Request = require("99.request")
-local system_rules = require("99.request.system-rules")
local marks = require("99.ops.marks")
+local Context = require("99.ops.context")
local editor = require("99.editor")
-local ops = require("99.ops")
---- @param request _99.Request
---- @param ok boolean
--- @param res string
-local function update_file_with_changes(request, ok, res)
- if not ok then
- error("unable to fill in function, enable and check logger for more details")
- end
- local mark_pos = vim.api.nvim_buf_get_mark(request.buffer, request.mark)
+--- @param location _99.Location
+local function update_file_with_changes(res, location)
+ assert(location.marks.function_location, "function_location mark was not set, unrecoverable error")
+ local mark = location.marks.function_location
+ local buffer = location.buffer
+
+ local mark_pos = vim.api.nvim_buf_get_mark(buffer, mark)
local mark_point = Point:new(mark_pos[1], mark_pos[2] + 1)
local ts = editor.treesitter
- local scopes = ts.function_scopes(mark_point, request.buffer)
- print("update_file_with_changes buffer", request.buffer)
+ local scopes = ts.function_scopes(mark_point, buffer)
+ print("update_file_with_changes buffer", buffer)
if not scopes or not scopes:has_scope() then
Logger:error("update_file_with_changes: unable to find function at mark location")
- error("update_file_with_changes: funable to find function at mark location")
+ error("update_file_with_changes: funable to find function at mark location")
return
end
@@ -32,35 +31,40 @@ local function update_file_with_changes(request, ok, res)
local function_end_row, _ = range.end_:to_vim()
local lines = vim.split(res, "\n")
- vim.api.nvim_buf_set_lines(request.buffer, function_start_row, function_end_row + 1, false, lines)
+ vim.api.nvim_buf_set_lines(buffer, function_start_row, function_end_row + 1, false, lines)
end
-
--- @param _99 _99.State
--- @return _99.Request
local function fill_in_function(_99)
local ts = editor.treesitter
local cursor = Point:from_cursor()
local scopes = ts.function_scopes(cursor)
- local buffer = vim.api.nvim_get_current_buf()
+ local range = scopes:get_inner_scope()
- local location = editor.Location.from_range(range)
- local request = Request.new({
- model = _99.model,
- on_complete = update_file_with_changes,
- })
-
- if not request:has_scopes() then
- Logger:warn("fill_in_function: unable to find any containing function")
+ if not range then
+ Logger:error("fill_in_function: unable to find any containing function")
error("you cannot call fill_in_function not in a function")
end
- local range = request:get_inner_scope()
+ local location = editor.Location.from_range(range)
+ local context = Context.new(_99):finalize(_99, location)
+ local request = Request.new({
+ model = _99.model,
+ on_complete = function(_, ok, response)
+ if not ok then
+ Logger:fatal("unable to fill in function, enable and check logger for more details")
+ end
+ update_file_with_changes(response, location)
+ end,
+ context = context,
+ })
- request:set_system_prompt(system_rules(request))
- request.mark = marks(request.buffer, range)
+ context:add_to_request(request)
+ location.marks.function_location = marks(location.buffer, range)
+ request:add_prompt_content(_99.prompts.prompts.fill_in_function)
- return request
+ return request
end
return fill_in_function
diff --git a/lua/99/ops/implement-fn.lua b/lua/99/ops/implement-fn.lua
index a0e6e94..ef28c10 100644
--- a/lua/99/ops/implement-fn.lua
+++ b/lua/99/ops/implement-fn.lua
@@ -1,7 +1,5 @@
local Logger = require("99.logger.logger")
local Request = require("99.request")
-local system_rules = require("99.request.system-rules")
-local marks = require("99.ops.marks")
local editor = require("99.editor")
local Range = require("99.geo").Range
diff --git a/lua/99/ops/init.lua b/lua/99/ops/init.lua
index 8b569bc..2e6bced 100644
--- a/lua/99/ops/init.lua
+++ b/lua/99/ops/init.lua
@@ -1,6 +1,6 @@
return {
fill_in_function = require("99.ops.fill-in-function"),
implement_fn = require("99.ops.implement-fn"),
- context = require("99.ops.context"),
+ Context = require("99.ops.context"),
}
diff --git a/lua/99/ops/prompts.lua b/lua/99/ops/prompts.lua
deleted file mode 100644
index e69de29..0000000
--- a/lua/99/ops/prompts.lua
+++ /dev/null
diff --git a/lua/99/ops/system-rules.lua b/lua/99/ops/system-rules.lua
deleted file mode 100644
index c71828e..0000000
--- a/lua/99/ops/system-rules.lua
+++ /dev/null
@@ -1,38 +0,0 @@
-local _99_settings = {
- fill_in_function = "fill in the function. dont change the function signature. do not edit anything outside of this function. prioritize using internal functions for work that has already been done. any NOTE's left in the function should be removed but instructions followed",
-}
-
---- @param tmp_file string
---- @return string
-local function system_rules(tmp_file)
- return string.format(
- "<MustObey>\n%s\n%s\n</MustObey><TEMP_FILE>%s</TEMP_FILE>",
- _99_settings.output_file,
- tmp_file
- )
-end
-
---- @param buffer number
----@param range Range
----@return string
-local function get_file_location(buffer, range)
- local full_path = vim.fn.expand("%:p")
- return string.format("<Location><File>%s</File><Function>%s</Function></Location>", full_path, range:to_string())
-end
-
---- @param range Range
-local function get_range_text(range)
- return string.format("<FunctionText>%s</FunctionText>", range:to_text())
-end
-
---- @param request _99.Request
---- @return string
-return function(request)
- local range = request:get_inner_scope()
- local buffer = request.buffer
- return table.concat({
- system_rules(request.tmp_name),
- get_file_location(buffer, range),
- get_range_text(range),
- })
-end
diff --git a/lua/99/prompt_settings.lua b/lua/99/prompt_settings.lua
index ffccb56..e6c90ef 100644
--- a/lua/99/prompt_settings.lua
+++ b/lua/99/prompt_settings.lua
@@ -2,17 +2,18 @@
local prompts = {
fill_in_function = "fill in the function. dont change the function signature. do not edit anything outside of this function. prioritize using internal functions for work that has already been done. any NOTE's left in the function should be removed but instructions followed",
implement_function = "implement the function that the cursor is on. make sure you inspect the current file carefully and any imports that look related. being thorough is better than being fast. being correct is better than being speedy.",
+ output_file = "never alter any file other than TEMP_FILE.",
+ read_tmp = "never attempt to read TEMP_FILE. It is purely for output. Previous contents, which may not exist, can be written over without worry",
}
--- @class _99.Prompts
local prompt_settings = {
prompts = prompts,
- --- @param output_file string
--- @param tmp_file string
--- @return string
- tmp_file_location = function(output_file, tmp_file)
- return string.format("<MustObey>\n%s\n%s\n</MustObey><TEMP_FILE>%s</TEMP_FILE>", output_file, tmp_file)
+ tmp_file_location = function(tmp_file)
+ return string.format("<MustObey>\n%s\n%s\n</MustObey>\n<TEMP_FILE>%s</TEMP_FILE>", prompts.output_file, prompts.read_tmp, tmp_file)
end,
---@param location _99.Location
diff --git a/scratch/refresh.lua b/scratch/refresh.lua
index db78ea5..c90f76f 100644
--- a/scratch/refresh.lua
+++ b/scratch/refresh.lua
@@ -1,3 +1,7 @@
+local function pick_a_number_that_is_42()
+ return 42
+end
+
function return_42()
local number = pick_a_number_that_is_42()
return number