summaryrefslogtreecommitdiffstatshomepage
path: root/runtime/lua/vim/treesitter/_range.lua
blob: 8ec938acc53e6428727ca95daf61f916803e3c30 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
local api = vim.api

local M = {}

---@class Range2
---@inlinedoc
---@field [1] integer start row
---@field [2] integer end row

---@class Range4
---@inlinedoc
---@field [1] integer start row
---@field [2] integer start column
---@field [3] integer end row
---@field [4] integer end column

---@class Range6
---@inlinedoc
---@field [1] integer start row
---@field [2] integer start column
---@field [3] integer start bytes
---@field [4] integer end row
---@field [5] integer end column
---@field [6] integer end bytes

---@alias Range Range2|Range4|Range6

---@param a_row integer
---@param a_col integer
---@param b_row integer
---@param b_col integer
---@return integer
--- 1: a > b
--- 0: a == b
--- -1: a < b
local function cmp_pos(a_row, a_col, b_row, b_col)
  if a_row == b_row then
    if a_col > b_col then
      return 1
    elseif a_col < b_col then
      return -1
    else
      return 0
    end
  elseif a_row > b_row then
    return 1
  end

  return -1
end

M.cmp_pos = {
  lt = function(...)
    return cmp_pos(...) == -1
  end,
  le = function(...)
    return cmp_pos(...) ~= 1
  end,
  gt = function(...)
    return cmp_pos(...) == 1
  end,
  ge = function(...)
    return cmp_pos(...) ~= -1
  end,
  eq = function(...)
    return cmp_pos(...) == 0
  end,
  ne = function(...)
    return cmp_pos(...) ~= 0
  end,
}

setmetatable(M.cmp_pos, { __call = cmp_pos })

---Check if a variable is a valid range object
---@param r any
---@return boolean
function M.validate(r)
  if type(r) ~= 'table' or #r ~= 6 and #r ~= 4 then
    return false
  end

  for _, e in
    ipairs(r --[[@as any[] ]])
  do
    if type(e) ~= 'number' then
      return false
    end
  end

  return true
end

---@param r1 Range
---@param r2 Range
---@return boolean
function M.intercepts(r1, r2)
  local srow_1, scol_1, erow_1, ecol_1 = M.unpack4(r1)
  local srow_2, scol_2, erow_2, ecol_2 = M.unpack4(r2)

  -- r1 is above r2
  if M.cmp_pos.le(erow_1, ecol_1, srow_2, scol_2) then
    return false
  end

  -- r1 is below r2
  if M.cmp_pos.ge(srow_1, scol_1, erow_2, ecol_2) then
    return false
  end

  return true
end

---@param r1 Range6
---@param r2 Range6
---@return Range6?
---@overload fun(r1:Range4,r2:Range4):Range4?
function M.intersection(r1, r2)
  if not M.intercepts(r1, r2) then
    return nil
  end

  if #r1 == 4 or #r2 == 4 then
    local rs = M.cmp_pos.le(r1[1], r1[2], r2[1], r2[2]) and r2 or r1
    local re = M.cmp_pos.ge(r1[3], r1[4], r2[3], r2[4]) and r2 or r1
    return { rs[1], rs[2], re[3], re[4] }
  end

  local rs = M.cmp_pos.le(r1[1], r1[2], r2[1], r2[2]) and r2 or r1
  local re = M.cmp_pos.ge(r1[4], r1[5], r2[4], r2[5]) and r2 or r1
  return { rs[1], rs[2], rs[3], re[4], re[5], re[6] }
end

---@param r Range
---@return integer, integer, integer, integer
function M.unpack4(r)
  if #r == 2 then
    return r[1], 0, r[2], 0
  end
  local off_1 = #r == 6 and 1 or 0
  return r[1], r[2], r[3 + off_1], r[4 + off_1]
end

---@param r Range6
---@return integer, integer, integer, integer, integer, integer
function M.unpack6(r)
  return r[1], r[2], r[3], r[4], r[5], r[6]
end

---@param r1 Range
---@param r2 Range
---@return boolean whether r1 contains r2
function M.contains(r1, r2)
  local srow_1, scol_1, erow_1, ecol_1 = M.unpack4(r1)
  local srow_2, scol_2, erow_2, ecol_2 = M.unpack4(r2)

  -- start doesn't fit
  if M.cmp_pos.gt(srow_1, scol_1, srow_2, scol_2) then
    return false
  end

  -- end doesn't fit
  if M.cmp_pos.lt(erow_1, ecol_1, erow_2, ecol_2) then
    return false
  end

  return true
end

--- @param r1 Range4
--- @param r2 Range4
--- @return boolean
function M.equal(r1, r2)
  local srow_1, scol_1, erow_1, ecol_1 = M.unpack4(r1)
  local srow_2, scol_2, erow_2, ecol_2 = M.unpack4(r2)
  return srow_1 == srow_2 and scol_1 == scol_2 and erow_1 == erow_2 and ecol_1 == ecol_2
end

--- @param source integer|string
--- @param index integer
--- @return integer
local function get_offset(source, index)
  if index == 0 then
    return 0
  end

  if type(source) == 'number' then
    return api.nvim_buf_get_offset(source, index)
  end

  local byte = 0
  local next_offset = source:gmatch('()\n')
  local line = 1
  while line <= index do
    byte = next_offset() --[[@as integer]]
    line = line + 1
  end

  return byte
end

---@param source integer|string
---@param range Range
---@return Range6
function M.add_bytes(source, range)
  if type(range) == 'table' and #range == 6 then
    return range --[[@as Range6]]
  end

  local start_row, start_col, end_row, end_col = M.unpack4(range)
  -- TODO(vigoux): proper byte computation here, and account for EOL ?
  local start_byte = get_offset(source, start_row) + start_col
  local end_byte = get_offset(source, end_row) + end_col

  return { start_row, start_col, start_byte, end_row, end_col, end_byte }
end

return M