squirreling 0.6.0 → 0.7.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -4,20 +4,22 @@ import { evaluateExpr } from './expression.js'
4
4
  import { stringify } from './utils.js'
5
5
 
6
6
  /**
7
- * @import { AsyncRow, AsyncDataSource, JoinClause, ExprNode, AsyncCells } from '../types.js'
7
+ * @import { AsyncRow, AsyncDataSource, JoinClause, ExprNode, AsyncCells, UserDefinedFunction } from '../types.js'
8
8
  */
9
9
 
10
10
  /**
11
11
  * Executes JOIN operations against a base data source
12
12
  *
13
- * @param {AsyncDataSource} leftSource - the left side of the join (FROM table)
14
- * @param {JoinClause[]} joins - array of join clauses to execute
15
- * @param {string} leftTableName - name of the left table (for column prefixing)
16
- * @param {Record<string, AsyncDataSource>} tables - all available tables
13
+ * @param {Object} options
14
+ * @param {AsyncDataSource} options.leftSource - the left side of the join (FROM table)
15
+ * @param {JoinClause[]} options.joins - array of join clauses to execute
16
+ * @param {string} options.leftTable - name of the left table (for column prefixing)
17
+ * @param {Record<string, AsyncDataSource>} options.tables - all available tables
18
+ * @param {Record<string, UserDefinedFunction>} [options.functions]
17
19
  * @returns {Promise<AsyncDataSource>} data source yielding joined rows
18
20
  */
19
- export async function executeJoins(leftSource, joins, leftTableName, tables) {
20
- let currentLeftTable = leftTableName
21
+ export async function executeJoins({ leftSource, joins, leftTable, tables, functions }) {
22
+ let currentLeftTable = leftTable
21
23
 
22
24
  // Single join optimization: stream left rows without buffering
23
25
  if (joins.length === 1) {
@@ -30,23 +32,26 @@ export async function executeJoins(leftSource, joins, leftTableName, tables) {
30
32
  // Buffer right rows for hash index (required for hash join)
31
33
  /** @type {AsyncRow[]} */
32
34
  const rightRows = []
33
- for await (const row of rightSource.scan()) {
35
+ for await (const row of rightSource.scan({})) {
34
36
  rightRows.push(row)
35
37
  }
36
38
 
37
39
  // Use alias for column prefixing if present
38
- const rightTableName = join.alias ?? join.table
40
+ const rightTable = join.alias ?? join.table
39
41
 
40
42
  // Return streaming data source - left rows stream through without buffering
41
43
  return {
42
- async *scan() {
44
+ async *scan(options) {
45
+ const { signal } = options
43
46
  yield* hashJoin({
44
- leftRows: leftSource.scan(), // Stream directly, not buffered
47
+ leftRows: leftSource.scan(options), // Stream directly, not buffered
45
48
  rightRows,
46
49
  join,
47
50
  leftTable: currentLeftTable,
48
- rightTable: rightTableName,
51
+ rightTable,
49
52
  tables,
53
+ functions,
54
+ signal,
50
55
  })
51
56
  },
52
57
  }
@@ -55,7 +60,7 @@ export async function executeJoins(leftSource, joins, leftTableName, tables) {
55
60
  // Multiple joins: buffer intermediate results, stream final join
56
61
  /** @type {AsyncRow[]} */
57
62
  let leftRows = []
58
- for await (const row of leftSource.scan()) {
63
+ for await (const row of leftSource.scan({})) {
59
64
  leftRows.push(row)
60
65
  }
61
66
 
@@ -69,12 +74,12 @@ export async function executeJoins(leftSource, joins, leftTableName, tables) {
69
74
 
70
75
  /** @type {AsyncRow[]} */
71
76
  const rightRows = []
72
- for await (const row of rightSource.scan()) {
77
+ for await (const row of rightSource.scan({})) {
73
78
  rightRows.push(row)
74
79
  }
75
80
 
76
81
  // Use alias for column prefixing if present
77
- const rightTableName = join.alias ?? join.table
82
+ const rightTable = join.alias ?? join.table
78
83
 
79
84
  // Collect intermediate results into array for next join
80
85
  /** @type {AsyncRow[]} */
@@ -84,8 +89,9 @@ export async function executeJoins(leftSource, joins, leftTableName, tables) {
84
89
  rightRows,
85
90
  join,
86
91
  leftTable: currentLeftTable,
87
- rightTable: rightTableName,
92
+ rightTable,
88
93
  tables,
94
+ functions,
89
95
  })
90
96
  for await (const row of joined) {
91
97
  newLeftRows.push(row)
@@ -93,34 +99,37 @@ export async function executeJoins(leftSource, joins, leftTableName, tables) {
93
99
  leftRows = newLeftRows
94
100
 
95
101
  // After join, the "left" table for the next join includes all joined tables
96
- currentLeftTable = `${currentLeftTable}_${rightTableName}`
102
+ currentLeftTable = `${currentLeftTable}_${rightTable}`
97
103
  }
98
104
 
99
105
  // Final join: stream the results
100
- const lastJoin = joins[joins.length - 1]
101
- const rightSource = tables[lastJoin.table]
106
+ const join = joins[joins.length - 1]
107
+ const rightSource = tables[join.table]
102
108
  if (rightSource === undefined) {
103
- throw tableNotFoundError({ tableName: lastJoin.table })
109
+ throw tableNotFoundError({ tableName: join.table })
104
110
  }
105
111
 
106
112
  /** @type {AsyncRow[]} */
107
113
  const rightRows = []
108
- for await (const row of rightSource.scan()) {
114
+ for await (const row of rightSource.scan({})) {
109
115
  rightRows.push(row)
110
116
  }
111
117
 
112
118
  // Use alias for column prefixing if present
113
- const lastRightTableName = lastJoin.alias ?? lastJoin.table
119
+ const rightTable = join.alias ?? join.table
114
120
 
115
121
  return {
116
- async *scan() {
122
+ async *scan(options) {
123
+ const { signal } = options
117
124
  yield* hashJoin({
118
125
  leftRows,
119
126
  rightRows,
120
- join: lastJoin,
127
+ join,
121
128
  leftTable: currentLeftTable,
122
- rightTable: lastRightTableName,
129
+ rightTable,
123
130
  tables,
131
+ functions,
132
+ signal,
124
133
  })
125
134
  },
126
135
  }
@@ -232,9 +241,11 @@ function mergeRows(leftRow, rightRow, leftTable, rightTable) {
232
241
  * @param {string} params.leftTable - name of left table (for column prefixing)
233
242
  * @param {string} params.rightTable - name of right table (for column prefixing, may be alias)
234
243
  * @param {Record<string, AsyncDataSource>} params.tables - all tables for expression evaluation
244
+ * @param {Record<string, UserDefinedFunction>} [params.functions]
245
+ * @param {AbortSignal} [params.signal] - abort signal for cancellation
235
246
  * @yields {AsyncRow} joined rows
236
247
  */
237
- async function* hashJoin({ leftRows, rightRows, join, leftTable, rightTable, tables }) {
248
+ async function* hashJoin({ leftRows, rightRows, join, leftTable, rightTable, tables, functions, signal }) {
238
249
  const { joinType, on: onCondition } = join
239
250
 
240
251
  if (!onCondition) {
@@ -264,7 +275,7 @@ async function* hashJoin({ leftRows, rightRows, join, leftTable, rightTable, tab
264
275
  // BUILD PHASE: Index right rows by join key
265
276
  // Skip null keys - SQL semantics: NULL != NULL
266
277
  for (const rightRow of rightRows) {
267
- const keyValue = await evaluateExpr({ node: keys.rightKey, row: rightRow, tables })
278
+ const keyValue = await evaluateExpr({ node: keys.rightKey, row: rightRow, tables, functions })
268
279
  if (keyValue == null) continue // NULL keys never match
269
280
  const keyStr = stringify(keyValue)
270
281
  let bucket = hashMap.get(keyStr)
@@ -281,6 +292,7 @@ async function* hashJoin({ leftRows, rightRows, join, leftTable, rightTable, tab
281
292
 
282
293
  // PROBE PHASE: Stream through left rows, yield matches immediately
283
294
  for await (const leftRow of leftRows) {
295
+ if (signal?.aborted) break
284
296
  // Capture left column info from first row (for NULL row generation)
285
297
  if (!leftPrefixedCols) {
286
298
  leftPrefixedCols = leftRow.columns.flatMap(col =>
@@ -288,7 +300,7 @@ async function* hashJoin({ leftRows, rightRows, join, leftTable, rightTable, tab
288
300
  )
289
301
  }
290
302
 
291
- const keyValue = await evaluateExpr({ node: keys.leftKey, row: leftRow, tables })
303
+ const keyValue = await evaluateExpr({ node: keys.leftKey, row: leftRow, tables, functions })
292
304
  const keyStr = stringify(keyValue)
293
305
 
294
306
  const matchingRightRows = hashMap.get(keyStr)
@@ -322,6 +334,7 @@ async function* hashJoin({ leftRows, rightRows, join, leftTable, rightTable, tab
322
334
  const matchedRightRows = joinType === 'RIGHT' || joinType === 'FULL' ? new Set() : null
323
335
 
324
336
  for await (const leftRow of leftRows) {
337
+ if (signal?.aborted) break
325
338
  // Capture left column info from first row (for NULL row generation)
326
339
  if (!leftPrefixedCols) {
327
340
  leftPrefixedCols = leftRow.columns.flatMap(col =>
@@ -333,7 +346,7 @@ async function* hashJoin({ leftRows, rightRows, join, leftTable, rightTable, tab
333
346
 
334
347
  for (const rightRow of rightRows) {
335
348
  const tempMerged = mergeRows(leftRow, rightRow, leftTable, rightTable)
336
- const matches = await evaluateExpr({ node: onCondition, row: tempMerged, tables })
349
+ const matches = await evaluateExpr({ node: onCondition, row: tempMerged, tables, functions })
337
350
 
338
351
  if (matches) {
339
352
  hasMatch = true
@@ -1,7 +1,6 @@
1
1
  /**
2
2
  * @import { MathFunc, SqlPrimitive } from '../types.js'
3
3
  */
4
- import { argCountError } from '../validationErrors.js'
5
4
 
6
5
  /**
7
6
  * Evaluate a math function
@@ -9,71 +8,28 @@ import { argCountError } from '../validationErrors.js'
9
8
  * @param {Object} options
10
9
  * @param {MathFunc} options.funcName - Uppercase function name
11
10
  * @param {SqlPrimitive[]} options.args - Function arguments
12
- * @param {number} options.positionStart - Start position in query
13
- * @param {number} options.positionEnd - End position in query
14
- * @param {number} [options.rowNumber] - 1-based row number for error reporting
15
11
  * @returns {SqlPrimitive} Result
16
12
  */
17
- export function evaluateMathFunc({ funcName, args, positionStart, positionEnd, rowNumber }) {
13
+ export function evaluateMathFunc({ funcName, args }) {
18
14
  if (funcName === 'FLOOR') {
19
- if (args.length !== 1) {
20
- throw argCountError({
21
- funcName: 'FLOOR',
22
- expected: 1,
23
- received: args.length,
24
- positionStart,
25
- positionEnd,
26
- rowNumber,
27
- })
28
- }
29
15
  const val = args[0]
30
16
  if (val == null) return null
31
17
  return Math.floor(Number(val))
32
18
  }
33
19
 
34
20
  if (funcName === 'CEIL' || funcName === 'CEILING') {
35
- if (args.length !== 1) {
36
- throw argCountError({
37
- funcName,
38
- expected: 1,
39
- received: args.length,
40
- positionStart,
41
- positionEnd,
42
- rowNumber,
43
- })
44
- }
45
21
  const val = args[0]
46
22
  if (val == null) return null
47
23
  return Math.ceil(Number(val))
48
24
  }
49
25
 
50
26
  if (funcName === 'ABS') {
51
- if (args.length !== 1) {
52
- throw argCountError({
53
- funcName: 'ABS',
54
- expected: 1,
55
- received: args.length,
56
- positionStart,
57
- positionEnd,
58
- rowNumber,
59
- })
60
- }
61
27
  const val = args[0]
62
28
  if (val == null) return null
63
29
  return Math.abs(Number(val))
64
30
  }
65
31
 
66
32
  if (funcName === 'MOD') {
67
- if (args.length !== 2) {
68
- throw argCountError({
69
- funcName: 'MOD',
70
- expected: 2,
71
- received: args.length,
72
- positionStart,
73
- positionEnd,
74
- rowNumber,
75
- })
76
- }
77
33
  const dividend = args[0]
78
34
  const divisor = args[1]
79
35
  if (dividend == null || divisor == null) return null
@@ -81,64 +37,24 @@ export function evaluateMathFunc({ funcName, args, positionStart, positionEnd, r
81
37
  }
82
38
 
83
39
  if (funcName === 'EXP') {
84
- if (args.length !== 1) {
85
- throw argCountError({
86
- funcName: 'EXP',
87
- expected: 1,
88
- received: args.length,
89
- positionStart,
90
- positionEnd,
91
- rowNumber,
92
- })
93
- }
94
40
  const val = args[0]
95
41
  if (val == null) return null
96
42
  return Math.exp(Number(val))
97
43
  }
98
44
 
99
45
  if (funcName === 'LN') {
100
- if (args.length !== 1) {
101
- throw argCountError({
102
- funcName: 'LN',
103
- expected: 1,
104
- received: args.length,
105
- positionStart,
106
- positionEnd,
107
- rowNumber,
108
- })
109
- }
110
46
  const val = args[0]
111
47
  if (val == null) return null
112
48
  return Math.log(Number(val))
113
49
  }
114
50
 
115
51
  if (funcName === 'LOG10') {
116
- if (args.length !== 1) {
117
- throw argCountError({
118
- funcName: 'LOG10',
119
- expected: 1,
120
- received: args.length,
121
- positionStart,
122
- positionEnd,
123
- rowNumber,
124
- })
125
- }
126
52
  const val = args[0]
127
53
  if (val == null) return null
128
54
  return Math.log10(Number(val))
129
55
  }
130
56
 
131
57
  if (funcName === 'POWER') {
132
- if (args.length !== 2) {
133
- throw argCountError({
134
- funcName: 'POWER',
135
- expected: 2,
136
- received: args.length,
137
- positionStart,
138
- positionEnd,
139
- rowNumber,
140
- })
141
- }
142
58
  const base = args[0]
143
59
  const exponent = args[1]
144
60
  if (base == null || exponent == null) return null
@@ -146,144 +62,61 @@ export function evaluateMathFunc({ funcName, args, positionStart, positionEnd, r
146
62
  }
147
63
 
148
64
  if (funcName === 'SQRT') {
149
- if (args.length !== 1) {
150
- throw argCountError({
151
- funcName: 'SQRT',
152
- expected: 1,
153
- received: args.length,
154
- positionStart,
155
- positionEnd,
156
- rowNumber,
157
- })
158
- }
159
65
  const val = args[0]
160
66
  if (val == null) return null
161
67
  return Math.sqrt(Number(val))
162
68
  }
163
69
 
164
70
  if (funcName === 'SIN') {
165
- if (args.length !== 1) {
166
- throw argCountError({
167
- funcName: 'SIN',
168
- expected: 1,
169
- received: args.length,
170
- positionStart,
171
- positionEnd,
172
- rowNumber,
173
- })
174
- }
175
71
  const val = args[0]
176
72
  if (val == null) return null
177
73
  return Math.sin(Number(val))
178
74
  }
179
75
 
180
76
  if (funcName === 'COS') {
181
- if (args.length !== 1) {
182
- throw argCountError({
183
- funcName: 'COS',
184
- expected: 1,
185
- received: args.length,
186
- positionStart,
187
- positionEnd,
188
- rowNumber,
189
- })
190
- }
191
77
  const val = args[0]
192
78
  if (val == null) return null
193
79
  return Math.cos(Number(val))
194
80
  }
195
81
 
196
82
  if (funcName === 'TAN') {
197
- if (args.length !== 1) {
198
- throw argCountError({
199
- funcName: 'TAN',
200
- expected: 1,
201
- received: args.length,
202
- positionStart,
203
- positionEnd,
204
- rowNumber,
205
- })
206
- }
207
83
  const val = args[0]
208
84
  if (val == null) return null
209
85
  return Math.tan(Number(val))
210
86
  }
211
87
 
212
88
  if (funcName === 'COT') {
213
- if (args.length !== 1) {
214
- throw argCountError({
215
- funcName: 'COT',
216
- expected: 1,
217
- received: args.length,
218
- positionStart,
219
- positionEnd,
220
- rowNumber,
221
- })
222
- }
223
89
  const val = args[0]
224
90
  if (val == null) return null
225
91
  return 1 / Math.tan(Number(val))
226
92
  }
227
93
 
228
94
  if (funcName === 'ASIN') {
229
- if (args.length !== 1) {
230
- throw argCountError({
231
- funcName: 'ASIN',
232
- expected: 1,
233
- received: args.length,
234
- positionStart,
235
- positionEnd,
236
- rowNumber,
237
- })
238
- }
239
95
  const val = args[0]
240
96
  if (val == null) return null
241
97
  return Math.asin(Number(val))
242
98
  }
243
99
 
244
100
  if (funcName === 'ACOS') {
245
- if (args.length !== 1) {
246
- throw argCountError({
247
- funcName: 'ACOS',
248
- expected: 1,
249
- received: args.length,
250
- positionStart,
251
- positionEnd,
252
- rowNumber,
253
- })
254
- }
255
101
  const val = args[0]
256
102
  if (val == null) return null
257
103
  return Math.acos(Number(val))
258
104
  }
259
105
 
260
106
  if (funcName === 'ATAN') {
261
- if (args.length !== 1) {
262
- throw argCountError({
263
- funcName: 'ATAN',
264
- expected: 1,
265
- received: args.length,
266
- positionStart,
267
- positionEnd,
268
- rowNumber,
269
- })
107
+ if (args.length === 1) {
108
+ const val = args[0]
109
+ if (val == null) return null
110
+ return Math.atan(Number(val))
111
+ } else {
112
+ const y = args[0]
113
+ const x = args[1]
114
+ if (y == null || x == null) return null
115
+ return Math.atan2(Number(y), Number(x))
270
116
  }
271
- const val = args[0]
272
- if (val == null) return null
273
- return Math.atan(Number(val))
274
117
  }
275
118
 
276
119
  if (funcName === 'ATAN2') {
277
- if (args.length !== 2) {
278
- throw argCountError({
279
- funcName: 'ATAN2',
280
- expected: 2,
281
- received: args.length,
282
- positionStart,
283
- positionEnd,
284
- rowNumber,
285
- })
286
- }
287
120
  const y = args[0]
288
121
  const x = args[1]
289
122
  if (y == null || x == null) return null
@@ -291,50 +124,18 @@ export function evaluateMathFunc({ funcName, args, positionStart, positionEnd, r
291
124
  }
292
125
 
293
126
  if (funcName === 'DEGREES') {
294
- if (args.length !== 1) {
295
- throw argCountError({
296
- funcName: 'DEGREES',
297
- expected: 1,
298
- received: args.length,
299
- positionStart,
300
- positionEnd,
301
- rowNumber,
302
- })
303
- }
304
127
  const val = args[0]
305
128
  if (val == null) return null
306
129
  return Number(val) * 180 / Math.PI
307
130
  }
308
131
 
309
132
  if (funcName === 'RADIANS') {
310
- if (args.length !== 1) {
311
- throw argCountError({
312
- funcName: 'RADIANS',
313
- expected: 1,
314
- received: args.length,
315
- positionStart,
316
- positionEnd,
317
- rowNumber,
318
- })
319
- }
320
133
  const val = args[0]
321
134
  if (val == null) return null
322
135
  return Number(val) * Math.PI / 180
323
136
  }
324
137
 
325
138
  if (funcName === 'PI') {
326
- if (args.length !== 0) {
327
- throw argCountError({
328
- funcName: 'PI',
329
- expected: 0,
330
- received: args.length,
331
- positionStart,
332
- positionEnd,
333
- rowNumber,
334
- })
335
- }
336
139
  return Math.PI
337
140
  }
338
-
339
- return null
340
141
  }
@@ -132,6 +132,10 @@ export function defaultDerivedAlias(expr) {
132
132
  return defaultDerivedAlias(expr.left) + '_' + expr.op + '_' + defaultDerivedAlias(expr.right)
133
133
  }
134
134
  if (expr.type === 'function') {
135
+ // Handle aggregate functions with star (COUNT(*) -> count_all)
136
+ if (expr.args.length === 1 && expr.args[0].type === 'identifier' && expr.args[0].name === '*') {
137
+ return expr.name.toLowerCase() + '_all'
138
+ }
135
139
  return expr.name.toLowerCase() + '_' + expr.args.map(defaultDerivedAlias).join('_')
136
140
  }
137
141
  if (expr.type === 'interval') {
package/src/index.d.ts CHANGED
@@ -14,10 +14,11 @@ export function executeSql(options: ExecuteSqlOptions): AsyncGenerator<AsyncRow>
14
14
  /**
15
15
  * Parses a SQL query string into an abstract syntax tree
16
16
  *
17
- * @param query - SQL query string to parse
17
+ * @param options
18
+ * @param options.query - SQL query string to parse
18
19
  * @returns parsed SQL select statement
19
20
  */
20
- export function parseSql(query: string): SelectStatement
21
+ export function parseSql(options: { query: string }): SelectStatement
21
22
 
22
23
  /**
23
24
  * Collects all results from an async generator into an array
@@ -1,10 +1,11 @@
1
1
  import {
2
+ argCountParseError,
2
3
  invalidLiteralError,
3
4
  missingClauseError,
4
5
  syntaxError,
5
6
  unknownFunctionError,
6
7
  } from '../parseErrors.js'
7
- import { isAggregateFunc, isIntervalUnit, isMathFunc, isStringFunc } from '../validation.js'
8
+ import { isIntervalUnit, isKnownFunction, validateFunctionArgCount } from '../validation.js'
8
9
  import { parseComparison } from './comparison.js'
9
10
  import { parseSelectInternal } from './parse.js'
10
11
  import { consume, current, expect, expectIdentifier, lastPosition, match, peekToken } from './state.js'
@@ -122,10 +123,15 @@ export function parsePrimary(state) {
122
123
  // function call
123
124
  if (next.type === 'paren' && next.value === '(') {
124
125
  const funcName = tok.value
125
-
126
- // validate function names
127
- if (!isStringFunc(funcName) && !isAggregateFunc(funcName) && !isMathFunc(funcName)) {
128
- throw unknownFunctionError({ funcName, positionStart: tok.positionStart, positionEnd: tok.positionEnd })
126
+ const funcNameUpper = funcName.toUpperCase()
127
+
128
+ // Validate function existence early for better error messages
129
+ if (!isKnownFunction(funcNameUpper, state.functions)) {
130
+ throw unknownFunctionError({
131
+ funcName,
132
+ positionStart,
133
+ positionEnd: tok.positionEnd,
134
+ })
129
135
  }
130
136
 
131
137
  consume(state) // function name
@@ -133,6 +139,15 @@ export function parsePrimary(state) {
133
139
 
134
140
  /** @type {ExprNode[]} */
135
141
  const args = []
142
+ let distinct = false
143
+
144
+ // Check for DISTINCT or ALL keyword (for aggregate functions like COUNT(DISTINCT x))
145
+ if (current(state).type === 'keyword' && current(state).value === 'DISTINCT') {
146
+ consume(state) // consume DISTINCT
147
+ distinct = true
148
+ } else if (current(state).type === 'keyword' && current(state).value === 'ALL') {
149
+ consume(state) // consume ALL (default behavior, just consume it)
150
+ }
136
151
 
137
152
  if (current(state).type !== 'paren' || current(state).value !== ')') {
138
153
  while (true) {
@@ -156,10 +171,23 @@ export function parsePrimary(state) {
156
171
 
157
172
  expect(state, 'paren', ')')
158
173
 
174
+ // Validate argument count at parse time
175
+ const validation = validateFunctionArgCount(funcNameUpper, args.length)
176
+ if (!validation.valid) {
177
+ throw argCountParseError({
178
+ funcName,
179
+ expected: validation.expected,
180
+ received: args.length,
181
+ positionStart,
182
+ positionEnd: lastPosition(state),
183
+ })
184
+ }
185
+
159
186
  return {
160
187
  type: 'function',
161
188
  name: funcName,
162
189
  args,
190
+ distinct: distinct || undefined,
163
191
  positionStart,
164
192
  positionEnd: lastPosition(state),
165
193
  }