squirreling 0.6.1 → 0.7.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -4,20 +4,22 @@ import { evaluateExpr } from './expression.js'
4
4
  import { stringify } from './utils.js'
5
5
 
6
6
  /**
7
- * @import { AsyncRow, AsyncDataSource, JoinClause, ExprNode, AsyncCells } from '../types.js'
7
+ * @import { AsyncRow, AsyncDataSource, JoinClause, ExprNode, AsyncCells, UserDefinedFunction } from '../types.js'
8
8
  */
9
9
 
10
10
  /**
11
11
  * Executes JOIN operations against a base data source
12
12
  *
13
- * @param {AsyncDataSource} leftSource - the left side of the join (FROM table)
14
- * @param {JoinClause[]} joins - array of join clauses to execute
15
- * @param {string} leftTableName - name of the left table (for column prefixing)
16
- * @param {Record<string, AsyncDataSource>} tables - all available tables
13
+ * @param {Object} options
14
+ * @param {AsyncDataSource} options.leftSource - the left side of the join (FROM table)
15
+ * @param {JoinClause[]} options.joins - array of join clauses to execute
16
+ * @param {string} options.leftTable - name of the left table (for column prefixing)
17
+ * @param {Record<string, AsyncDataSource>} options.tables - all available tables
18
+ * @param {Record<string, UserDefinedFunction>} [options.functions]
17
19
  * @returns {Promise<AsyncDataSource>} data source yielding joined rows
18
20
  */
19
- export async function executeJoins(leftSource, joins, leftTableName, tables) {
20
- let currentLeftTable = leftTableName
21
+ export async function executeJoins({ leftSource, joins, leftTable, tables, functions }) {
22
+ let currentLeftTable = leftTable
21
23
 
22
24
  // Single join optimization: stream left rows without buffering
23
25
  if (joins.length === 1) {
@@ -35,7 +37,7 @@ export async function executeJoins(leftSource, joins, leftTableName, tables) {
35
37
  }
36
38
 
37
39
  // Use alias for column prefixing if present
38
- const rightTableName = join.alias ?? join.table
40
+ const rightTable = join.alias ?? join.table
39
41
 
40
42
  // Return streaming data source - left rows stream through without buffering
41
43
  return {
@@ -46,8 +48,9 @@ export async function executeJoins(leftSource, joins, leftTableName, tables) {
46
48
  rightRows,
47
49
  join,
48
50
  leftTable: currentLeftTable,
49
- rightTable: rightTableName,
51
+ rightTable,
50
52
  tables,
53
+ functions,
51
54
  signal,
52
55
  })
53
56
  },
@@ -76,7 +79,7 @@ export async function executeJoins(leftSource, joins, leftTableName, tables) {
76
79
  }
77
80
 
78
81
  // Use alias for column prefixing if present
79
- const rightTableName = join.alias ?? join.table
82
+ const rightTable = join.alias ?? join.table
80
83
 
81
84
  // Collect intermediate results into array for next join
82
85
  /** @type {AsyncRow[]} */
@@ -86,8 +89,9 @@ export async function executeJoins(leftSource, joins, leftTableName, tables) {
86
89
  rightRows,
87
90
  join,
88
91
  leftTable: currentLeftTable,
89
- rightTable: rightTableName,
92
+ rightTable,
90
93
  tables,
94
+ functions,
91
95
  })
92
96
  for await (const row of joined) {
93
97
  newLeftRows.push(row)
@@ -95,14 +99,14 @@ export async function executeJoins(leftSource, joins, leftTableName, tables) {
95
99
  leftRows = newLeftRows
96
100
 
97
101
  // After join, the "left" table for the next join includes all joined tables
98
- currentLeftTable = `${currentLeftTable}_${rightTableName}`
102
+ currentLeftTable = `${currentLeftTable}_${rightTable}`
99
103
  }
100
104
 
101
105
  // Final join: stream the results
102
- const lastJoin = joins[joins.length - 1]
103
- const rightSource = tables[lastJoin.table]
106
+ const join = joins[joins.length - 1]
107
+ const rightSource = tables[join.table]
104
108
  if (rightSource === undefined) {
105
- throw tableNotFoundError({ tableName: lastJoin.table })
109
+ throw tableNotFoundError({ tableName: join.table })
106
110
  }
107
111
 
108
112
  /** @type {AsyncRow[]} */
@@ -112,7 +116,7 @@ export async function executeJoins(leftSource, joins, leftTableName, tables) {
112
116
  }
113
117
 
114
118
  // Use alias for column prefixing if present
115
- const lastRightTableName = lastJoin.alias ?? lastJoin.table
119
+ const rightTable = join.alias ?? join.table
116
120
 
117
121
  return {
118
122
  async *scan(options) {
@@ -120,10 +124,11 @@ export async function executeJoins(leftSource, joins, leftTableName, tables) {
120
124
  yield* hashJoin({
121
125
  leftRows,
122
126
  rightRows,
123
- join: lastJoin,
127
+ join,
124
128
  leftTable: currentLeftTable,
125
- rightTable: lastRightTableName,
129
+ rightTable,
126
130
  tables,
131
+ functions,
127
132
  signal,
128
133
  })
129
134
  },
@@ -236,10 +241,11 @@ function mergeRows(leftRow, rightRow, leftTable, rightTable) {
236
241
  * @param {string} params.leftTable - name of left table (for column prefixing)
237
242
  * @param {string} params.rightTable - name of right table (for column prefixing, may be alias)
238
243
  * @param {Record<string, AsyncDataSource>} params.tables - all tables for expression evaluation
244
+ * @param {Record<string, UserDefinedFunction>} [params.functions]
239
245
  * @param {AbortSignal} [params.signal] - abort signal for cancellation
240
246
  * @yields {AsyncRow} joined rows
241
247
  */
242
- async function* hashJoin({ leftRows, rightRows, join, leftTable, rightTable, tables, signal }) {
248
+ async function* hashJoin({ leftRows, rightRows, join, leftTable, rightTable, tables, functions, signal }) {
243
249
  const { joinType, on: onCondition } = join
244
250
 
245
251
  if (!onCondition) {
@@ -269,7 +275,7 @@ async function* hashJoin({ leftRows, rightRows, join, leftTable, rightTable, tab
269
275
  // BUILD PHASE: Index right rows by join key
270
276
  // Skip null keys - SQL semantics: NULL != NULL
271
277
  for (const rightRow of rightRows) {
272
- const keyValue = await evaluateExpr({ node: keys.rightKey, row: rightRow, tables })
278
+ const keyValue = await evaluateExpr({ node: keys.rightKey, row: rightRow, tables, functions })
273
279
  if (keyValue == null) continue // NULL keys never match
274
280
  const keyStr = stringify(keyValue)
275
281
  let bucket = hashMap.get(keyStr)
@@ -294,7 +300,7 @@ async function* hashJoin({ leftRows, rightRows, join, leftTable, rightTable, tab
294
300
  )
295
301
  }
296
302
 
297
- const keyValue = await evaluateExpr({ node: keys.leftKey, row: leftRow, tables })
303
+ const keyValue = await evaluateExpr({ node: keys.leftKey, row: leftRow, tables, functions })
298
304
  const keyStr = stringify(keyValue)
299
305
 
300
306
  const matchingRightRows = hashMap.get(keyStr)
@@ -340,7 +346,7 @@ async function* hashJoin({ leftRows, rightRows, join, leftTable, rightTable, tab
340
346
 
341
347
  for (const rightRow of rightRows) {
342
348
  const tempMerged = mergeRows(leftRow, rightRow, leftTable, rightTable)
343
- const matches = await evaluateExpr({ node: onCondition, row: tempMerged, tables })
349
+ const matches = await evaluateExpr({ node: onCondition, row: tempMerged, tables, functions })
344
350
 
345
351
  if (matches) {
346
352
  hasMatch = true
@@ -1,7 +1,6 @@
1
1
  /**
2
2
  * @import { MathFunc, SqlPrimitive } from '../types.js'
3
3
  */
4
- import { argCountError } from '../validationErrors.js'
5
4
 
6
5
  /**
7
6
  * Evaluate a math function
@@ -9,71 +8,28 @@ import { argCountError } from '../validationErrors.js'
9
8
  * @param {Object} options
10
9
  * @param {MathFunc} options.funcName - Uppercase function name
11
10
  * @param {SqlPrimitive[]} options.args - Function arguments
12
- * @param {number} options.positionStart - Start position in query
13
- * @param {number} options.positionEnd - End position in query
14
- * @param {number} [options.rowNumber] - 1-based row number for error reporting
15
11
  * @returns {SqlPrimitive} Result
16
12
  */
17
- export function evaluateMathFunc({ funcName, args, positionStart, positionEnd, rowNumber }) {
13
+ export function evaluateMathFunc({ funcName, args }) {
18
14
  if (funcName === 'FLOOR') {
19
- if (args.length !== 1) {
20
- throw argCountError({
21
- funcName: 'FLOOR',
22
- expected: 1,
23
- received: args.length,
24
- positionStart,
25
- positionEnd,
26
- rowNumber,
27
- })
28
- }
29
15
  const val = args[0]
30
16
  if (val == null) return null
31
17
  return Math.floor(Number(val))
32
18
  }
33
19
 
34
20
  if (funcName === 'CEIL' || funcName === 'CEILING') {
35
- if (args.length !== 1) {
36
- throw argCountError({
37
- funcName,
38
- expected: 1,
39
- received: args.length,
40
- positionStart,
41
- positionEnd,
42
- rowNumber,
43
- })
44
- }
45
21
  const val = args[0]
46
22
  if (val == null) return null
47
23
  return Math.ceil(Number(val))
48
24
  }
49
25
 
50
26
  if (funcName === 'ABS') {
51
- if (args.length !== 1) {
52
- throw argCountError({
53
- funcName: 'ABS',
54
- expected: 1,
55
- received: args.length,
56
- positionStart,
57
- positionEnd,
58
- rowNumber,
59
- })
60
- }
61
27
  const val = args[0]
62
28
  if (val == null) return null
63
29
  return Math.abs(Number(val))
64
30
  }
65
31
 
66
32
  if (funcName === 'MOD') {
67
- if (args.length !== 2) {
68
- throw argCountError({
69
- funcName: 'MOD',
70
- expected: 2,
71
- received: args.length,
72
- positionStart,
73
- positionEnd,
74
- rowNumber,
75
- })
76
- }
77
33
  const dividend = args[0]
78
34
  const divisor = args[1]
79
35
  if (dividend == null || divisor == null) return null
@@ -81,64 +37,24 @@ export function evaluateMathFunc({ funcName, args, positionStart, positionEnd, r
81
37
  }
82
38
 
83
39
  if (funcName === 'EXP') {
84
- if (args.length !== 1) {
85
- throw argCountError({
86
- funcName: 'EXP',
87
- expected: 1,
88
- received: args.length,
89
- positionStart,
90
- positionEnd,
91
- rowNumber,
92
- })
93
- }
94
40
  const val = args[0]
95
41
  if (val == null) return null
96
42
  return Math.exp(Number(val))
97
43
  }
98
44
 
99
45
  if (funcName === 'LN') {
100
- if (args.length !== 1) {
101
- throw argCountError({
102
- funcName: 'LN',
103
- expected: 1,
104
- received: args.length,
105
- positionStart,
106
- positionEnd,
107
- rowNumber,
108
- })
109
- }
110
46
  const val = args[0]
111
47
  if (val == null) return null
112
48
  return Math.log(Number(val))
113
49
  }
114
50
 
115
51
  if (funcName === 'LOG10') {
116
- if (args.length !== 1) {
117
- throw argCountError({
118
- funcName: 'LOG10',
119
- expected: 1,
120
- received: args.length,
121
- positionStart,
122
- positionEnd,
123
- rowNumber,
124
- })
125
- }
126
52
  const val = args[0]
127
53
  if (val == null) return null
128
54
  return Math.log10(Number(val))
129
55
  }
130
56
 
131
57
  if (funcName === 'POWER') {
132
- if (args.length !== 2) {
133
- throw argCountError({
134
- funcName: 'POWER',
135
- expected: 2,
136
- received: args.length,
137
- positionStart,
138
- positionEnd,
139
- rowNumber,
140
- })
141
- }
142
58
  const base = args[0]
143
59
  const exponent = args[1]
144
60
  if (base == null || exponent == null) return null
@@ -146,144 +62,61 @@ export function evaluateMathFunc({ funcName, args, positionStart, positionEnd, r
146
62
  }
147
63
 
148
64
  if (funcName === 'SQRT') {
149
- if (args.length !== 1) {
150
- throw argCountError({
151
- funcName: 'SQRT',
152
- expected: 1,
153
- received: args.length,
154
- positionStart,
155
- positionEnd,
156
- rowNumber,
157
- })
158
- }
159
65
  const val = args[0]
160
66
  if (val == null) return null
161
67
  return Math.sqrt(Number(val))
162
68
  }
163
69
 
164
70
  if (funcName === 'SIN') {
165
- if (args.length !== 1) {
166
- throw argCountError({
167
- funcName: 'SIN',
168
- expected: 1,
169
- received: args.length,
170
- positionStart,
171
- positionEnd,
172
- rowNumber,
173
- })
174
- }
175
71
  const val = args[0]
176
72
  if (val == null) return null
177
73
  return Math.sin(Number(val))
178
74
  }
179
75
 
180
76
  if (funcName === 'COS') {
181
- if (args.length !== 1) {
182
- throw argCountError({
183
- funcName: 'COS',
184
- expected: 1,
185
- received: args.length,
186
- positionStart,
187
- positionEnd,
188
- rowNumber,
189
- })
190
- }
191
77
  const val = args[0]
192
78
  if (val == null) return null
193
79
  return Math.cos(Number(val))
194
80
  }
195
81
 
196
82
  if (funcName === 'TAN') {
197
- if (args.length !== 1) {
198
- throw argCountError({
199
- funcName: 'TAN',
200
- expected: 1,
201
- received: args.length,
202
- positionStart,
203
- positionEnd,
204
- rowNumber,
205
- })
206
- }
207
83
  const val = args[0]
208
84
  if (val == null) return null
209
85
  return Math.tan(Number(val))
210
86
  }
211
87
 
212
88
  if (funcName === 'COT') {
213
- if (args.length !== 1) {
214
- throw argCountError({
215
- funcName: 'COT',
216
- expected: 1,
217
- received: args.length,
218
- positionStart,
219
- positionEnd,
220
- rowNumber,
221
- })
222
- }
223
89
  const val = args[0]
224
90
  if (val == null) return null
225
91
  return 1 / Math.tan(Number(val))
226
92
  }
227
93
 
228
94
  if (funcName === 'ASIN') {
229
- if (args.length !== 1) {
230
- throw argCountError({
231
- funcName: 'ASIN',
232
- expected: 1,
233
- received: args.length,
234
- positionStart,
235
- positionEnd,
236
- rowNumber,
237
- })
238
- }
239
95
  const val = args[0]
240
96
  if (val == null) return null
241
97
  return Math.asin(Number(val))
242
98
  }
243
99
 
244
100
  if (funcName === 'ACOS') {
245
- if (args.length !== 1) {
246
- throw argCountError({
247
- funcName: 'ACOS',
248
- expected: 1,
249
- received: args.length,
250
- positionStart,
251
- positionEnd,
252
- rowNumber,
253
- })
254
- }
255
101
  const val = args[0]
256
102
  if (val == null) return null
257
103
  return Math.acos(Number(val))
258
104
  }
259
105
 
260
106
  if (funcName === 'ATAN') {
261
- if (args.length !== 1) {
262
- throw argCountError({
263
- funcName: 'ATAN',
264
- expected: 1,
265
- received: args.length,
266
- positionStart,
267
- positionEnd,
268
- rowNumber,
269
- })
107
+ if (args.length === 1) {
108
+ const val = args[0]
109
+ if (val == null) return null
110
+ return Math.atan(Number(val))
111
+ } else {
112
+ const y = args[0]
113
+ const x = args[1]
114
+ if (y == null || x == null) return null
115
+ return Math.atan2(Number(y), Number(x))
270
116
  }
271
- const val = args[0]
272
- if (val == null) return null
273
- return Math.atan(Number(val))
274
117
  }
275
118
 
276
119
  if (funcName === 'ATAN2') {
277
- if (args.length !== 2) {
278
- throw argCountError({
279
- funcName: 'ATAN2',
280
- expected: 2,
281
- received: args.length,
282
- positionStart,
283
- positionEnd,
284
- rowNumber,
285
- })
286
- }
287
120
  const y = args[0]
288
121
  const x = args[1]
289
122
  if (y == null || x == null) return null
@@ -291,50 +124,18 @@ export function evaluateMathFunc({ funcName, args, positionStart, positionEnd, r
291
124
  }
292
125
 
293
126
  if (funcName === 'DEGREES') {
294
- if (args.length !== 1) {
295
- throw argCountError({
296
- funcName: 'DEGREES',
297
- expected: 1,
298
- received: args.length,
299
- positionStart,
300
- positionEnd,
301
- rowNumber,
302
- })
303
- }
304
127
  const val = args[0]
305
128
  if (val == null) return null
306
129
  return Number(val) * 180 / Math.PI
307
130
  }
308
131
 
309
132
  if (funcName === 'RADIANS') {
310
- if (args.length !== 1) {
311
- throw argCountError({
312
- funcName: 'RADIANS',
313
- expected: 1,
314
- received: args.length,
315
- positionStart,
316
- positionEnd,
317
- rowNumber,
318
- })
319
- }
320
133
  const val = args[0]
321
134
  if (val == null) return null
322
135
  return Number(val) * Math.PI / 180
323
136
  }
324
137
 
325
138
  if (funcName === 'PI') {
326
- if (args.length !== 0) {
327
- throw argCountError({
328
- funcName: 'PI',
329
- expected: 0,
330
- received: args.length,
331
- positionStart,
332
- positionEnd,
333
- rowNumber,
334
- })
335
- }
336
139
  return Math.PI
337
140
  }
338
-
339
- return null
340
141
  }
package/src/index.d.ts CHANGED
@@ -14,10 +14,11 @@ export function executeSql(options: ExecuteSqlOptions): AsyncGenerator<AsyncRow>
14
14
  /**
15
15
  * Parses a SQL query string into an abstract syntax tree
16
16
  *
17
- * @param query - SQL query string to parse
17
+ * @param options
18
+ * @param options.query - SQL query string to parse
18
19
  * @returns parsed SQL select statement
19
20
  */
20
- export function parseSql(query: string): SelectStatement
21
+ export function parseSql(options: { query: string }): SelectStatement
21
22
 
22
23
  /**
23
24
  * Collects all results from an async generator into an array
@@ -1,10 +1,11 @@
1
1
  import {
2
+ argCountParseError,
2
3
  invalidLiteralError,
3
4
  missingClauseError,
4
5
  syntaxError,
5
6
  unknownFunctionError,
6
7
  } from '../parseErrors.js'
7
- import { isAggregateFunc, isIntervalUnit, isMathFunc, isStringFunc } from '../validation.js'
8
+ import { isIntervalUnit, isKnownFunction, validateFunctionArgCount } from '../validation.js'
8
9
  import { parseComparison } from './comparison.js'
9
10
  import { parseSelectInternal } from './parse.js'
10
11
  import { consume, current, expect, expectIdentifier, lastPosition, match, peekToken } from './state.js'
@@ -122,10 +123,15 @@ export function parsePrimary(state) {
122
123
  // function call
123
124
  if (next.type === 'paren' && next.value === '(') {
124
125
  const funcName = tok.value
125
-
126
- // validate function names
127
- if (!isStringFunc(funcName) && !isAggregateFunc(funcName) && !isMathFunc(funcName)) {
128
- throw unknownFunctionError({ funcName, positionStart: tok.positionStart, positionEnd: tok.positionEnd })
126
+ const funcNameUpper = funcName.toUpperCase()
127
+
128
+ // Validate function existence early for better error messages
129
+ if (!isKnownFunction(funcNameUpper, state.functions)) {
130
+ throw unknownFunctionError({
131
+ funcName,
132
+ positionStart,
133
+ positionEnd: tok.positionEnd,
134
+ })
129
135
  }
130
136
 
131
137
  consume(state) // function name
@@ -165,12 +171,14 @@ export function parsePrimary(state) {
165
171
 
166
172
  expect(state, 'paren', ')')
167
173
 
168
- // Aggregate functions require at least one argument
169
- if (isAggregateFunc(funcName) && args.length === 0) {
170
- throw syntaxError({
171
- expected: 'expression',
172
- received: '")"',
173
- positionStart: positionStart + funcName.length + 1, // position after opening paren
174
+ // Validate argument count at parse time
175
+ const validation = validateFunctionArgCount(funcNameUpper, args.length)
176
+ if (!validation.valid) {
177
+ throw argCountParseError({
178
+ funcName,
179
+ expected: validation.expected,
180
+ received: args.length,
181
+ positionStart,
174
182
  positionEnd: lastPosition(state),
175
183
  })
176
184
  }
@@ -5,17 +5,17 @@ import { consume, current, expect, expectIdentifier, match, parseError, peekToke
5
5
  import { parseJoins } from './joins.js'
6
6
 
7
7
  /**
8
- * @import { ExprNode, FromSubquery, FromTable, OrderByItem, ParserState, SelectStatement, SelectColumn } from '../types.js'
8
+ * @import { ExprNode, FromSubquery, FromTable, OrderByItem, ParserState, SelectStatement, SelectColumn, UserDefinedFunction } from '../types.js'
9
9
  */
10
10
 
11
11
  /**
12
- * @param {string} query
12
+ * @param {{ query: string, functions?: Record<string, UserDefinedFunction> }} options
13
13
  * @returns {SelectStatement}
14
14
  */
15
- export function parseSql(query) {
15
+ export function parseSql({ query, functions }) {
16
16
  const tokens = tokenize(query)
17
17
  /** @type {ParserState} */
18
- const state = { tokens, pos: 0 }
18
+ const state = { tokens, pos: 0, functions }
19
19
  const select = parseSelectInternal(state)
20
20
 
21
21
  const tok = current(state)
@@ -2,6 +2,8 @@
2
2
  // PARSE ERRORS - Issues during SQL tokenization and parsing
3
3
  // ============================================================================
4
4
 
5
+ import { FUNCTION_SIGNATURES } from './validationErrors.js'
6
+
5
7
  /**
6
8
  * Structured parse error with position range.
7
9
  */
@@ -103,6 +105,33 @@ export function unknownFunctionError({ funcName, positionStart, positionEnd, val
103
105
  })
104
106
  }
105
107
 
108
+ /**
109
+ * Error for wrong number of function arguments at parse time.
110
+ *
111
+ * @param {Object} options
112
+ * @param {string} options.funcName - The function name
113
+ * @param {number | string} options.expected - Expected count (number or range like "2 to 3")
114
+ * @param {number} options.received - Actual argument count
115
+ * @param {number} options.positionStart - Start position in query
116
+ * @param {number} options.positionEnd - End position in query
117
+ * @returns {ParseError}
118
+ */
119
+ export function argCountParseError({ funcName, expected, received, positionStart, positionEnd }) {
120
+ const signature = FUNCTION_SIGNATURES[funcName] ?? ''
121
+ let expectedStr = `${expected} arguments`
122
+ if (expected === 0) expectedStr = 'no arguments'
123
+ if (expected === 1) expectedStr = '1 argument'
124
+ if (typeof expected === 'string' && expected.endsWith(' 1')) {
125
+ expectedStr = `${expected} argument`
126
+ }
127
+
128
+ return new ParseError({
129
+ message: `${funcName}(${signature}) function requires ${expectedStr}, got ${received}`,
130
+ positionStart,
131
+ positionEnd,
132
+ })
133
+ }
134
+
106
135
  /**
107
136
  * Error for missing required clause or structure.
108
137
  *
package/src/types.d.ts CHANGED
@@ -1,7 +1,11 @@
1
+ // User-defined function type
2
+ export type UserDefinedFunction = (...args: SqlPrimitive[]) => SqlPrimitive | Promise<SqlPrimitive>
3
+
1
4
  // executeSql(options)
2
5
  export interface ExecuteSqlOptions {
3
6
  tables: Record<string, Row | AsyncDataSource>
4
7
  query: string | SelectStatement
8
+ functions?: Record<string, UserDefinedFunction>
5
9
  signal?: AbortSignal
6
10
  }
7
11
 
@@ -254,6 +258,7 @@ export interface ParserState {
254
258
  tokens: Token[]
255
259
  pos: number
256
260
  lastPos?: number
261
+ functions?: Record<string, UserDefinedFunction>
257
262
  }
258
263
 
259
264
  // Tokenizer types