squirreling 0.6.0 → 0.6.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/package.json CHANGED
@@ -1,6 +1,6 @@
1
1
  {
2
2
  "name": "squirreling",
3
- "version": "0.6.0",
3
+ "version": "0.6.1",
4
4
  "description": "Squirreling SQL Engine",
5
5
  "author": "Hyperparam",
6
6
  "homepage": "https://hyperparam.app",
@@ -37,11 +37,11 @@
37
37
  "test": "vitest run"
38
38
  },
39
39
  "devDependencies": {
40
- "@types/node": "24.10.2",
41
- "@vitest/coverage-v8": "4.0.15",
42
- "eslint": "9.39.1",
40
+ "@types/node": "24.10.4",
41
+ "@vitest/coverage-v8": "4.0.16",
42
+ "eslint": "9.39.2",
43
43
  "eslint-plugin-jsdoc": "61.5.0",
44
44
  "typescript": "5.9.3",
45
- "vitest": "4.0.15"
45
+ "vitest": "4.0.16"
46
46
  }
47
47
  }
@@ -1,8 +1,7 @@
1
1
  /**
2
- * @import { AsyncCell, AsyncCells, AsyncDataSource, AsyncRow, SqlPrimitive } from '../types.js'
2
+ * @import { AsyncCell, AsyncCells, AsyncDataSource, AsyncRow, ScanOptions, SqlPrimitive } from '../types.js'
3
3
  */
4
4
 
5
-
6
5
  /**
7
6
  * Wraps an async generator of plain objects into an AsyncDataSource
8
7
  *
@@ -11,8 +10,11 @@
11
10
  */
12
11
  export function generatorSource(gen) {
13
12
  return {
14
- async *scan() {
15
- yield* gen
13
+ async *scan({ signal }) {
14
+ for await (const row of gen) {
15
+ if (signal?.aborted) break
16
+ yield row
17
+ }
16
18
  },
17
19
  }
18
20
  }
@@ -40,8 +42,9 @@ function asyncRow(obj) {
40
42
  */
41
43
  export function memorySource(data) {
42
44
  return {
43
- async *scan() {
45
+ async *scan({ signal }) {
44
46
  for (const item of data) {
47
+ if (signal?.aborted) break
45
48
  yield asyncRow(item)
46
49
  }
47
50
  },
@@ -58,11 +61,14 @@ export function cachedDataSource(source) {
58
61
  const cache = new Map()
59
62
  return {
60
63
  /**
64
+ * @param {ScanOptions} options
61
65
  * @yields {AsyncRow}
62
66
  */
63
- async *scan() {
67
+ async *scan(options) {
68
+ const { signal } = options
64
69
  let index = 0
65
- for await (const row of source.scan()) {
70
+ for await (const row of source.scan(options)) {
71
+ if (signal?.aborted) break
66
72
  const rowIndex = index
67
73
  /** @type {AsyncCells} */
68
74
  const cells = {}
@@ -1,7 +1,46 @@
1
+ import { isAggregateFunc } from '../validation.js'
2
+
1
3
  /**
2
4
  * @import { ExprNode, SelectStatement, SelectColumn } from '../types.js'
3
5
  */
4
6
 
7
+ /**
8
+ * Checks if an expression contains any aggregate function calls
9
+ *
10
+ * @param {ExprNode | undefined} expr
11
+ * @returns {boolean}
12
+ */
13
+ export function containsAggregate(expr) {
14
+ if (!expr) return false
15
+ if (expr.type === 'function' && isAggregateFunc(expr.name.toUpperCase())) {
16
+ return true
17
+ }
18
+ if (expr.type === 'binary') {
19
+ return containsAggregate(expr.left) || containsAggregate(expr.right)
20
+ }
21
+ if (expr.type === 'unary') {
22
+ return containsAggregate(expr.argument)
23
+ }
24
+ if (expr.type === 'cast') {
25
+ return containsAggregate(expr.expr)
26
+ }
27
+ if (expr.type === 'case') {
28
+ if (expr.caseExpr && containsAggregate(expr.caseExpr)) return true
29
+ for (const when of expr.whenClauses) {
30
+ if (containsAggregate(when.condition) || containsAggregate(when.result)) return true
31
+ }
32
+ if (containsAggregate(expr.elseResult)) return true
33
+ }
34
+ if (expr.type === 'in valuelist') {
35
+ if (containsAggregate(expr.expr)) return true
36
+ for (const val of expr.values) {
37
+ if (containsAggregate(val)) return true
38
+ }
39
+ }
40
+ // Note: Don't recurse into subqueries - they have their own aggregate scope
41
+ return false
42
+ }
43
+
5
44
  /**
6
45
  * Extracts column names needed from a SELECT statement.
7
46
  *
@@ -50,11 +89,6 @@ export function extractColumns(select) {
50
89
  function collectColumnsFromSelectColumn(col, columns) {
51
90
  if (col.kind === 'derived') {
52
91
  collectColumnsFromExpr(col.expr, columns)
53
- } else if (col.kind === 'aggregate') {
54
- if (col.arg.kind === 'expression') {
55
- collectColumnsFromExpr(col.arg.expr, columns)
56
- }
57
- // 'star' aggregate (COUNT(*)) doesn't reference specific columns
58
92
  }
59
93
  // 'star' columns handled separately (returns undefined for all columns)
60
94
  }
@@ -2,8 +2,7 @@ import { missingClauseError } from '../parseErrors.js'
2
2
  import { tableNotFoundError, unsupportedOperationError } from '../executionErrors.js'
3
3
  import { generatorSource, memorySource } from '../backend/dataSource.js'
4
4
  import { parseSql } from '../parse/parse.js'
5
- import { defaultAggregateAlias, evaluateAggregate } from './aggregates.js'
6
- import { extractColumns } from './columns.js'
5
+ import { containsAggregate, extractColumns } from './columns.js'
7
6
  import { evaluateExpr } from './expression.js'
8
7
  import { evaluateHavingExpr } from './having.js'
9
8
  import { executeJoins } from './join.js'
@@ -19,7 +18,7 @@ import { compareForTerm, defaultDerivedAlias, stringify } from './utils.js'
19
18
  * @param {ExecuteSqlOptions} options - the execution options
20
19
  * @yields {AsyncRow} async generator yielding result rows
21
20
  */
22
- export async function* executeSql({ tables, query }) {
21
+ export async function* executeSql({ tables, query, signal }) {
23
22
  const select = typeof query === 'string' ? parseSql(query) : query
24
23
 
25
24
  // Check for unsupported operations
@@ -41,17 +40,23 @@ export async function* executeSql({ tables, query }) {
41
40
  }
42
41
  }
43
42
 
44
- yield* executeSelect(select, normalizedTables)
43
+ yield* executeSelect({ select, tables: normalizedTables, signal })
45
44
  }
46
45
 
46
+ /**
47
+ * @typedef {Object} ExecuteSelectOptions
48
+ * @property {SelectStatement} select
49
+ * @property {Record<string, AsyncDataSource>} tables
50
+ * @property {AbortSignal} [signal]
51
+ */
52
+
47
53
  /**
48
54
  * Executes a SELECT query against the provided tables
49
55
  *
50
- * @param {SelectStatement} select
51
- * @param {Record<string, AsyncDataSource>} tables
56
+ * @param {ExecuteSelectOptions} options
52
57
  * @yields {AsyncRow}
53
58
  */
54
- export async function* executeSelect(select, tables) {
59
+ export async function* executeSelect({ select, tables, signal }) {
55
60
  /** @type {AsyncDataSource} */
56
61
  let dataSource
57
62
  /** @type {string} */
@@ -67,7 +72,7 @@ export async function* executeSelect(select, tables) {
67
72
  } else {
68
73
  // Nested subquery - recursively resolve
69
74
  fromTableName = select.from.alias
70
- dataSource = generatorSource(executeSelect(select.from.query, tables))
75
+ dataSource = generatorSource(executeSelect({ select: select.from.query, tables, signal }))
71
76
  }
72
77
 
73
78
  // Execute JOINs if present
@@ -75,7 +80,7 @@ export async function* executeSelect(select, tables) {
75
80
  dataSource = await executeJoins(dataSource, select.joins, fromTableName, tables)
76
81
  }
77
82
 
78
- yield* evaluateSelectAst(select, dataSource, tables)
83
+ yield* evaluateSelectAst({ select, dataSource, tables, signal })
79
84
  }
80
85
 
81
86
  /**
@@ -201,40 +206,52 @@ async function sortRows(rows, orderBy, tables) {
201
206
  return groups.flat().map(i => rows[i])
202
207
  }
203
208
 
209
+ /**
210
+ * @typedef {Object} EvaluateSelectAstOptions
211
+ * @property {SelectStatement} select
212
+ * @property {AsyncDataSource} dataSource
213
+ * @property {Record<string, AsyncDataSource>} tables
214
+ * @property {AbortSignal} [signal]
215
+ */
216
+
204
217
  /**
205
218
  * Evaluates a select with a resolved FROM data source
206
219
  *
207
- * @param {SelectStatement} select
208
- * @param {AsyncDataSource} dataSource
209
- * @param {Record<string, AsyncDataSource>} tables
220
+ * @param {EvaluateSelectAstOptions} options
210
221
  * @yields {AsyncRow}
211
222
  */
212
- async function* evaluateSelectAst(select, dataSource, tables) {
223
+ async function* evaluateSelectAst({ select, dataSource, tables, signal }) {
213
224
  // SQL priority: from, where, group by, having, select, order by, offset, limit
214
225
 
215
- const hasAggregate = select.columns.some(col => col.kind === 'aggregate')
226
+ const hasAggregate = select.columns.some(col => col.kind === 'derived' && containsAggregate(col.expr))
216
227
  const useGrouping = hasAggregate || select.groupBy.length > 0
217
228
  const needsBuffering = useGrouping || select.orderBy.length > 0
218
229
 
219
230
  if (needsBuffering) {
220
231
  // BUFFERING PATH: Collect all rows, process, then yield
221
- yield* evaluateBuffered(select, dataSource, tables, hasAggregate, useGrouping)
232
+ yield* evaluateBuffered({ select, dataSource, tables, hasAggregate, useGrouping, signal })
222
233
  } else {
223
234
  // STREAMING PATH: Yield rows one by one
224
- yield* evaluateStreaming(select, dataSource, tables)
235
+ yield* evaluateStreaming({ select, dataSource, tables, signal })
225
236
  }
226
237
  }
227
238
 
239
+ /**
240
+ * @typedef {Object} EvaluateStreamingOptions
241
+ * @property {SelectStatement} select
242
+ * @property {AsyncDataSource} dataSource
243
+ * @property {Record<string, AsyncDataSource>} tables
244
+ * @property {AbortSignal} [signal]
245
+ */
246
+
228
247
  /**
229
248
  * Streaming evaluation for simple queries (no ORDER BY or GROUP BY)
230
249
  * Supports DISTINCT by tracking seen row keys without buffering full rows
231
250
  *
232
- * @param {SelectStatement} select
233
- * @param {AsyncDataSource} dataSource
234
- * @param {Record<string, AsyncDataSource>} tables
251
+ * @param {EvaluateStreamingOptions} options
235
252
  * @yields {AsyncRow}
236
253
  */
237
- async function* evaluateStreaming(select, dataSource, tables) {
254
+ async function* evaluateStreaming({ select, dataSource, tables, signal }) {
238
255
  let rowsYielded = 0
239
256
  let rowsSkipped = 0
240
257
  let rowIndex = 0
@@ -255,7 +272,7 @@ async function* evaluateStreaming(select, dataSource, tables) {
255
272
  offset: select.offset,
256
273
  }
257
274
 
258
- for await (const row of dataSource.scan(hints)) {
275
+ for await (const row of dataSource.scan({ hints, signal })) {
259
276
  rowIndex++
260
277
  // WHERE filter
261
278
  if (select.where) {
@@ -285,10 +302,6 @@ async function* evaluateStreaming(select, dataSource, tables) {
285
302
  const alias = col.alias ?? defaultDerivedAlias(col.expr)
286
303
  columns.push(alias)
287
304
  cells[alias] = () => evaluateExpr({ node: col.expr, row, tables, rowIndex: currentRowIndex })
288
- } else if (col.kind === 'aggregate') {
289
- throw new Error(
290
- 'Aggregate functions require GROUP BY or will act on the whole dataset; add GROUP BY or remove aggregates'
291
- )
292
305
  }
293
306
  }
294
307
 
@@ -312,17 +325,23 @@ async function* evaluateStreaming(select, dataSource, tables) {
312
325
  }
313
326
  }
314
327
 
328
+ /**
329
+ * @typedef {Object} EvaluateBufferedOptions
330
+ * @property {SelectStatement} select
331
+ * @property {AsyncDataSource} dataSource
332
+ * @property {Record<string, AsyncDataSource>} tables
333
+ * @property {boolean} hasAggregate
334
+ * @property {boolean} useGrouping
335
+ * @property {AbortSignal} [signal]
336
+ */
337
+
315
338
  /**
316
339
  * Buffered evaluation for complex queries (with ORDER BY or GROUP BY)
317
340
  *
318
- * @param {SelectStatement} select
319
- * @param {AsyncDataSource} dataSource
320
- * @param {Record<string, AsyncDataSource>} tables
321
- * @param {boolean} hasAggregate
322
- * @param {boolean} useGrouping
341
+ * @param {EvaluateBufferedOptions} options
323
342
  * @yields {AsyncRow}
324
343
  */
325
- async function* evaluateBuffered(select, dataSource, tables, hasAggregate, useGrouping) {
344
+ async function* evaluateBuffered({ select, dataSource, tables, hasAggregate, useGrouping, signal }) {
326
345
  // Build hints for data source optimization
327
346
  // Note: limit/offset not passed here since buffering needs all rows for sorting/grouping
328
347
  /** @type {QueryHints} */
@@ -334,7 +353,7 @@ async function* evaluateBuffered(select, dataSource, tables, hasAggregate, useGr
334
353
  // Step 1: Collect all rows from data source
335
354
  /** @type {AsyncRow[]} */
336
355
  const working = []
337
- for await (const row of dataSource.scan(hints)) {
356
+ for await (const row of dataSource.scan({ hints, signal })) {
338
357
  working.push(row)
339
358
  }
340
359
 
@@ -414,18 +433,9 @@ async function* evaluateBuffered(select, dataSource, tables, hasAggregate, useGr
414
433
  if (col.kind === 'derived') {
415
434
  const alias = col.alias ?? defaultDerivedAlias(col.expr)
416
435
  columns.push(alias)
417
- if (group.length > 0) {
418
- cells[alias] = () => evaluateExpr({ node: col.expr, row: group[0], tables })
419
- } else {
420
- delete cells[alias]
421
- }
422
- continue
423
- }
424
-
425
- if (col.kind === 'aggregate') {
426
- const alias = col.alias ?? defaultAggregateAlias(col)
427
- columns.push(alias)
428
- cells[alias] = () => evaluateAggregate({ col, rows: group, tables })
436
+ // Pass group to evaluateExpr so it can handle aggregate functions within expressions
437
+ // For empty groups, still provide an empty row context for aggregates to return appropriate values
438
+ cells[alias] = () => evaluateExpr({ node: col.expr, row: group[0] ?? { columns: [], cells: {} }, tables, rows: group })
429
439
  continue
430
440
  }
431
441
  }
@@ -1,11 +1,12 @@
1
1
  import { unknownFunctionError } from '../parseErrors.js'
2
2
  import { invalidContextError } from '../executionErrors.js'
3
3
  import {
4
+ aggregateError,
4
5
  argCountError,
5
6
  argValueError,
6
7
  castError,
7
8
  } from '../validationErrors.js'
8
- import { isMathFunc } from '../validation.js'
9
+ import { isAggregateFunc, isMathFunc } from '../validation.js'
9
10
  import { applyIntervalToDate } from './date.js'
10
11
  import { executeSelect } from './execute.js'
11
12
  import { evaluateMathFunc } from './math.js'
@@ -23,9 +24,10 @@ import { applyBinaryOp, stringify } from './utils.js'
23
24
  * @param {AsyncRow} params.row - The data row to evaluate against
24
25
  * @param {Record<string, AsyncDataSource>} params.tables
25
26
  * @param {number} [params.rowIndex] - 1-based row index for error reporting
27
+ * @param {AsyncRow[]} [params.rows] - Group of rows for aggregate functions
26
28
  * @returns {Promise<SqlPrimitive>} The result of the evaluation
27
29
  */
28
- export async function evaluateExpr({ node, row, tables, rowIndex }) {
30
+ export async function evaluateExpr({ node, row, tables, rowIndex, rows }) {
29
31
  if (node.type === 'literal') {
30
32
  return node.value
31
33
  }
@@ -47,7 +49,7 @@ export async function evaluateExpr({ node, row, tables, rowIndex }) {
47
49
 
48
50
  // Scalar subquery - returns a single value
49
51
  if (node.type === 'subquery') {
50
- const gen = executeSelect(node.subquery, tables)
52
+ const gen = executeSelect({ select: node.subquery, tables })
51
53
  const { value } = await gen.next() // Start the generator
52
54
  gen.return(undefined) // Stop further execution
53
55
  if (!value) return null
@@ -57,16 +59,16 @@ export async function evaluateExpr({ node, row, tables, rowIndex }) {
57
59
  // Unary operators
58
60
  if (node.type === 'unary') {
59
61
  if (node.op === 'NOT') {
60
- return !await evaluateExpr({ node: node.argument, row, tables, rowIndex })
62
+ return !await evaluateExpr({ node: node.argument, row, tables, rowIndex, rows })
61
63
  }
62
64
  if (node.op === 'IS NULL') {
63
- return await evaluateExpr({ node: node.argument, row, tables, rowIndex }) == null
65
+ return await evaluateExpr({ node: node.argument, row, tables, rowIndex, rows }) == null
64
66
  }
65
67
  if (node.op === 'IS NOT NULL') {
66
- return await evaluateExpr({ node: node.argument, row, tables, rowIndex }) != null
68
+ return await evaluateExpr({ node: node.argument, row, tables, rowIndex, rows }) != null
67
69
  }
68
70
  if (node.op === '-') {
69
- const val = await evaluateExpr({ node: node.argument, row, tables, rowIndex })
71
+ const val = await evaluateExpr({ node: node.argument, row, tables, rowIndex, rows })
70
72
  if (val == null) return null
71
73
  return -val
72
74
  }
@@ -76,15 +78,15 @@ export async function evaluateExpr({ node, row, tables, rowIndex }) {
76
78
  if (node.type === 'binary') {
77
79
  // Handle date +/- interval at AST level
78
80
  if ((node.op === '+' || node.op === '-') && node.right.type === 'interval') {
79
- const dateVal = await evaluateExpr({ node: node.left, row, tables, rowIndex })
81
+ const dateVal = await evaluateExpr({ node: node.left, row, tables, rowIndex, rows })
80
82
  return applyIntervalToDate(dateVal, node.right.value, node.right.unit, node.op)
81
83
  }
82
84
  if (node.op === '+' && node.left.type === 'interval') {
83
- const dateVal = await evaluateExpr({ node: node.right, row, tables, rowIndex })
85
+ const dateVal = await evaluateExpr({ node: node.right, row, tables, rowIndex, rows })
84
86
  return applyIntervalToDate(dateVal, node.left.value, node.left.unit, '+')
85
87
  }
86
88
 
87
- const left = await evaluateExpr({ node: node.left, row, tables, rowIndex })
89
+ const left = await evaluateExpr({ node: node.left, row, tables, rowIndex, rows })
88
90
 
89
91
  // Short-circuit evaluation for AND and OR
90
92
  if (node.op === 'AND') {
@@ -94,15 +96,120 @@ export async function evaluateExpr({ node, row, tables, rowIndex }) {
94
96
  if (left) return true
95
97
  }
96
98
 
97
- const right = await evaluateExpr({ node: node.right, row, tables, rowIndex })
99
+ const right = await evaluateExpr({ node: node.right, row, tables, rowIndex, rows })
98
100
  return applyBinaryOp(node.op, left, right)
99
101
  }
100
102
 
101
103
  // Function calls
102
104
  if (node.type === 'function') {
103
105
  const funcName = node.name.toUpperCase()
106
+
107
+ // Handle aggregate functions
108
+ if (isAggregateFunc(funcName)) {
109
+ if (!rows) {
110
+ throw aggregateError({
111
+ funcName,
112
+ issue: 'requires GROUP BY or will act on the whole dataset',
113
+ })
114
+ }
115
+
116
+ // Check for star argument (COUNT(*))
117
+ if (node.args.length === 1 && node.args[0].type === 'identifier' && node.args[0].name === '*') {
118
+ if (funcName === 'COUNT') {
119
+ return rows.length
120
+ }
121
+ throw aggregateError({
122
+ funcName,
123
+ issue: '(*) is not supported, use a column name',
124
+ })
125
+ }
126
+
127
+ if (node.args.length !== 1) {
128
+ throw argCountError({
129
+ funcName,
130
+ expected: 1,
131
+ received: node.args.length,
132
+ positionStart: node.positionStart,
133
+ positionEnd: node.positionEnd,
134
+ rowNumber: rowIndex,
135
+ })
136
+ }
137
+
138
+ const argNode = node.args[0]
139
+
140
+ if (funcName === 'COUNT') {
141
+ if (node.distinct) {
142
+ const seen = new Set()
143
+ for (const r of rows) {
144
+ const v = await evaluateExpr({ node: argNode, row: r, tables })
145
+ if (v != null) seen.add(v)
146
+ }
147
+ return seen.size
148
+ }
149
+ let count = 0
150
+ for (const r of rows) {
151
+ const v = await evaluateExpr({ node: argNode, row: r, tables })
152
+ if (v != null) count++
153
+ }
154
+ return count
155
+ }
156
+
157
+ if (funcName === 'SUM' || funcName === 'AVG' || funcName === 'MIN' || funcName === 'MAX') {
158
+ let sum = 0
159
+ let count = 0
160
+ /** @type {number | null} */
161
+ let min = null
162
+ /** @type {number | null} */
163
+ let max = null
164
+
165
+ for (const r of rows) {
166
+ const raw = await evaluateExpr({ node: argNode, row: r, tables })
167
+ if (raw == null) continue
168
+ const num = Number(raw)
169
+ if (!Number.isFinite(num)) continue
170
+
171
+ if (count === 0) {
172
+ min = num
173
+ max = num
174
+ } else {
175
+ if (min == null || num < min) min = num
176
+ if (max == null || num > max) max = num
177
+ }
178
+ sum += num
179
+ count++
180
+ }
181
+
182
+ if (funcName === 'SUM') return sum
183
+ if (funcName === 'AVG') return count === 0 ? null : sum / count
184
+ if (funcName === 'MIN') return min
185
+ if (funcName === 'MAX') return max
186
+ }
187
+
188
+ if (funcName === 'JSON_ARRAYAGG') {
189
+ /** @type {SqlPrimitive[]} */
190
+ const values = []
191
+ if (node.distinct) {
192
+ const seen = new Set()
193
+ for (const r of rows) {
194
+ const v = await evaluateExpr({ node: argNode, row: r, tables })
195
+ const key = stringify(v)
196
+ if (!seen.has(key)) {
197
+ seen.add(key)
198
+ values.push(v)
199
+ }
200
+ }
201
+ } else {
202
+ for (const r of rows) {
203
+ const v = await evaluateExpr({ node: argNode, row: r, tables })
204
+ values.push(v)
205
+ }
206
+ }
207
+ return values
208
+ }
209
+ }
210
+
104
211
  /** @type {SqlPrimitive[]} */
105
- const args = await Promise.all(node.args.map(arg => evaluateExpr({ node: arg, row, tables, rowIndex })))
212
+ const args = await Promise.all(node.args.map(arg => evaluateExpr({ node: arg, row, tables, rowIndex, rows })))
106
213
 
107
214
  if (funcName === 'UPPER') {
108
215
  if (args.length !== 1) {
@@ -427,7 +534,7 @@ export async function evaluateExpr({ node, row, tables, rowIndex }) {
427
534
  }
428
535
 
429
536
  if (node.type === 'cast') {
430
- const val = await evaluateExpr({ node: node.expr, row, tables, rowIndex })
537
+ const val = await evaluateExpr({ node: node.expr, row, tables, rowIndex, rows })
431
538
  if (val == null) return null
432
539
  const toType = node.toType.toUpperCase()
433
540
  if (toType === 'TEXT' || toType === 'STRING' || toType === 'VARCHAR') {
@@ -470,17 +577,17 @@ export async function evaluateExpr({ node, row, tables, rowIndex }) {
470
577
 
471
578
  // IN and NOT IN with value lists
472
579
  if (node.type === 'in valuelist') {
473
- const exprVal = await evaluateExpr({ node: node.expr, row, tables, rowIndex })
580
+ const exprVal = await evaluateExpr({ node: node.expr, row, tables, rowIndex, rows })
474
581
  for (const valueNode of node.values) {
475
- const val = await evaluateExpr({ node: valueNode, row, tables, rowIndex })
582
+ const val = await evaluateExpr({ node: valueNode, row, tables, rowIndex, rows })
476
583
  if (exprVal === val) return true
477
584
  }
478
585
  return false
479
586
  }
480
587
  // IN with subqueries
481
588
  if (node.type === 'in') {
482
- const exprVal = await evaluateExpr({ node: node.expr, row, tables, rowIndex })
483
- const results = executeSelect(node.subquery, tables)
589
+ const exprVal = await evaluateExpr({ node: node.expr, row, tables, rowIndex, rows })
590
+ const results = executeSelect({ select: node.subquery, tables })
484
591
  for await (const resRow of results) {
485
592
  const value = await resRow.cells[resRow.columns[0]]()
486
593
  if (exprVal === value) return true
@@ -490,39 +597,39 @@ export async function evaluateExpr({ node, row, tables, rowIndex }) {
490
597
 
491
598
  // EXISTS and NOT EXISTS with subqueries
492
599
  if (node.type === 'exists') {
493
- const results = await executeSelect(node.subquery, tables).next()
600
+ const results = await executeSelect({ select: node.subquery, tables }).next()
494
601
  return results.done === false
495
602
  }
496
603
  if (node.type === 'not exists') {
497
- const results = await executeSelect(node.subquery, tables).next()
604
+ const results = await executeSelect({ select: node.subquery, tables }).next()
498
605
  return results.done === true
499
606
  }
500
607
 
501
608
  // CASE expressions
502
609
  if (node.type === 'case') {
503
610
  // For simple CASE: evaluate the case expression once
504
- const caseValue = node.caseExpr && await evaluateExpr({ node: node.caseExpr, row, tables, rowIndex })
611
+ const caseValue = node.caseExpr && await evaluateExpr({ node: node.caseExpr, row, tables, rowIndex, rows })
505
612
 
506
613
  // Iterate through WHEN clauses
507
614
  for (const whenClause of node.whenClauses) {
508
615
  let conditionResult
509
616
  if (caseValue !== undefined) {
510
617
  // Simple CASE: compare caseValue with condition
511
- const whenValue = await evaluateExpr({ node: whenClause.condition, row, tables, rowIndex })
618
+ const whenValue = await evaluateExpr({ node: whenClause.condition, row, tables, rowIndex, rows })
512
619
  conditionResult = caseValue === whenValue
513
620
  } else {
514
621
  // Searched CASE: evaluate condition as boolean
515
- conditionResult = await evaluateExpr({ node: whenClause.condition, row, tables, rowIndex })
622
+ conditionResult = await evaluateExpr({ node: whenClause.condition, row, tables, rowIndex, rows })
516
623
  }
517
624
 
518
625
  if (conditionResult) {
519
- return evaluateExpr({ node: whenClause.result, row, tables, rowIndex })
626
+ return evaluateExpr({ node: whenClause.result, row, tables, rowIndex, rows })
520
627
  }
521
628
  }
522
629
 
523
630
  // No WHEN clause matched, return ELSE result or NULL
524
631
  if (node.elseResult) {
525
- return evaluateExpr({ node: node.elseResult, row, tables, rowIndex })
632
+ return evaluateExpr({ node: node.elseResult, row, tables, rowIndex, rows })
526
633
  }
527
634
  return null
528
635
  }
@@ -30,7 +30,7 @@ export async function executeJoins(leftSource, joins, leftTableName, tables) {
30
30
  // Buffer right rows for hash index (required for hash join)
31
31
  /** @type {AsyncRow[]} */
32
32
  const rightRows = []
33
- for await (const row of rightSource.scan()) {
33
+ for await (const row of rightSource.scan({})) {
34
34
  rightRows.push(row)
35
35
  }
36
36
 
@@ -39,14 +39,16 @@ export async function executeJoins(leftSource, joins, leftTableName, tables) {
39
39
 
40
40
  // Return streaming data source - left rows stream through without buffering
41
41
  return {
42
- async *scan() {
42
+ async *scan(options) {
43
+ const { signal } = options
43
44
  yield* hashJoin({
44
- leftRows: leftSource.scan(), // Stream directly, not buffered
45
+ leftRows: leftSource.scan(options), // Stream directly, not buffered
45
46
  rightRows,
46
47
  join,
47
48
  leftTable: currentLeftTable,
48
49
  rightTable: rightTableName,
49
50
  tables,
51
+ signal,
50
52
  })
51
53
  },
52
54
  }
@@ -55,7 +57,7 @@ export async function executeJoins(leftSource, joins, leftTableName, tables) {
55
57
  // Multiple joins: buffer intermediate results, stream final join
56
58
  /** @type {AsyncRow[]} */
57
59
  let leftRows = []
58
- for await (const row of leftSource.scan()) {
60
+ for await (const row of leftSource.scan({})) {
59
61
  leftRows.push(row)
60
62
  }
61
63
 
@@ -69,7 +71,7 @@ export async function executeJoins(leftSource, joins, leftTableName, tables) {
69
71
 
70
72
  /** @type {AsyncRow[]} */
71
73
  const rightRows = []
72
- for await (const row of rightSource.scan()) {
74
+ for await (const row of rightSource.scan({})) {
73
75
  rightRows.push(row)
74
76
  }
75
77
 
@@ -105,7 +107,7 @@ export async function executeJoins(leftSource, joins, leftTableName, tables) {
105
107
 
106
108
  /** @type {AsyncRow[]} */
107
109
  const rightRows = []
108
- for await (const row of rightSource.scan()) {
110
+ for await (const row of rightSource.scan({})) {
109
111
  rightRows.push(row)
110
112
  }
111
113
 
@@ -113,7 +115,8 @@ export async function executeJoins(leftSource, joins, leftTableName, tables) {
113
115
  const lastRightTableName = lastJoin.alias ?? lastJoin.table
114
116
 
115
117
  return {
116
- async *scan() {
118
+ async *scan(options) {
119
+ const { signal } = options
117
120
  yield* hashJoin({
118
121
  leftRows,
119
122
  rightRows,
@@ -121,6 +124,7 @@ export async function executeJoins(leftSource, joins, leftTableName, tables) {
121
124
  leftTable: currentLeftTable,
122
125
  rightTable: lastRightTableName,
123
126
  tables,
127
+ signal,
124
128
  })
125
129
  },
126
130
  }
@@ -232,9 +236,10 @@ function mergeRows(leftRow, rightRow, leftTable, rightTable) {
232
236
  * @param {string} params.leftTable - name of left table (for column prefixing)
233
237
  * @param {string} params.rightTable - name of right table (for column prefixing, may be alias)
234
238
  * @param {Record<string, AsyncDataSource>} params.tables - all tables for expression evaluation
239
+ * @param {AbortSignal} [params.signal] - abort signal for cancellation
235
240
  * @yields {AsyncRow} joined rows
236
241
  */
237
- async function* hashJoin({ leftRows, rightRows, join, leftTable, rightTable, tables }) {
242
+ async function* hashJoin({ leftRows, rightRows, join, leftTable, rightTable, tables, signal }) {
238
243
  const { joinType, on: onCondition } = join
239
244
 
240
245
  if (!onCondition) {
@@ -281,6 +286,7 @@ async function* hashJoin({ leftRows, rightRows, join, leftTable, rightTable, tab
281
286
 
282
287
  // PROBE PHASE: Stream through left rows, yield matches immediately
283
288
  for await (const leftRow of leftRows) {
289
+ if (signal?.aborted) break
284
290
  // Capture left column info from first row (for NULL row generation)
285
291
  if (!leftPrefixedCols) {
286
292
  leftPrefixedCols = leftRow.columns.flatMap(col =>
@@ -322,6 +328,7 @@ async function* hashJoin({ leftRows, rightRows, join, leftTable, rightTable, tab
322
328
  const matchedRightRows = joinType === 'RIGHT' || joinType === 'FULL' ? new Set() : null
323
329
 
324
330
  for await (const leftRow of leftRows) {
331
+ if (signal?.aborted) break
325
332
  // Capture left column info from first row (for NULL row generation)
326
333
  if (!leftPrefixedCols) {
327
334
  leftPrefixedCols = leftRow.columns.flatMap(col =>
@@ -132,6 +132,10 @@ export function defaultDerivedAlias(expr) {
132
132
  return defaultDerivedAlias(expr.left) + '_' + expr.op + '_' + defaultDerivedAlias(expr.right)
133
133
  }
134
134
  if (expr.type === 'function') {
135
+ // Handle aggregate functions with star (COUNT(*) -> count_all)
136
+ if (expr.args.length === 1 && expr.args[0].type === 'identifier' && expr.args[0].name === '*') {
137
+ return expr.name.toLowerCase() + '_all'
138
+ }
135
139
  return expr.name.toLowerCase() + '_' + expr.args.map(defaultDerivedAlias).join('_')
136
140
  }
137
141
  if (expr.type === 'interval') {
@@ -133,6 +133,15 @@ export function parsePrimary(state) {
133
133
 
134
134
  /** @type {ExprNode[]} */
135
135
  const args = []
136
+ let distinct = false
137
+
138
+ // Check for DISTINCT or ALL keyword (for aggregate functions like COUNT(DISTINCT x))
139
+ if (current(state).type === 'keyword' && current(state).value === 'DISTINCT') {
140
+ consume(state) // consume DISTINCT
141
+ distinct = true
142
+ } else if (current(state).type === 'keyword' && current(state).value === 'ALL') {
143
+ consume(state) // consume ALL (default behavior, just consume it)
144
+ }
136
145
 
137
146
  if (current(state).type !== 'paren' || current(state).value !== ')') {
138
147
  while (true) {
@@ -156,10 +165,21 @@ export function parsePrimary(state) {
156
165
 
157
166
  expect(state, 'paren', ')')
158
167
 
168
+ // Aggregate functions require at least one argument
169
+ if (isAggregateFunc(funcName) && args.length === 0) {
170
+ throw syntaxError({
171
+ expected: 'expression',
172
+ received: '")"',
173
+ positionStart: positionStart + funcName.length + 1, // position after opening paren
174
+ positionEnd: lastPosition(state),
175
+ })
176
+ }
177
+
159
178
  return {
160
179
  type: 'function',
161
180
  name: funcName,
162
181
  args,
182
+ distinct: distinct || undefined,
163
183
  positionStart,
164
184
  positionEnd: lastPosition(state),
165
185
  }
@@ -1,11 +1,11 @@
1
1
  import { tokenize } from './tokenize.js'
2
2
  import { parseExpression } from './expression.js'
3
- import { RESERVED_AFTER_COLUMN, RESERVED_AFTER_TABLE, isAggregateFunc } from '../validation.js'
3
+ import { RESERVED_AFTER_COLUMN, RESERVED_AFTER_TABLE } from '../validation.js'
4
4
  import { consume, current, expect, expectIdentifier, match, parseError, peekToken } from './state.js'
5
5
  import { parseJoins } from './joins.js'
6
6
 
7
7
  /**
8
- * @import { AggregateColumn, AggregateArg, AggregateFunc, ExprNode, FromSubquery, FromTable, OrderByItem, ParserState, SelectStatement, SelectColumn } from '../types.js'
8
+ * @import { ExprNode, FromSubquery, FromTable, OrderByItem, ParserState, SelectStatement, SelectColumn } from '../types.js'
9
9
  */
10
10
 
11
11
  /**
@@ -79,61 +79,12 @@ function parseSelectItem(state) {
79
79
  throw parseError(state, 'column name or expression')
80
80
  }
81
81
 
82
- const next = peekToken(state, 1)
83
- if (next.type === 'paren' && next.value === '(') {
84
- const upper = tok.value.toUpperCase()
85
- if (isAggregateFunc(upper)) {
86
- expectIdentifier(state) // consume function name
87
- return parseAggregateItem(state, upper)
88
- }
89
- }
90
-
91
- // Delegate to expression parser
82
+ // Delegate to expression parser (handles all expressions including aggregates)
92
83
  const expr = parseExpression(state)
93
84
  const alias = parseAs(state)
94
85
  return { kind: 'derived', expr, alias }
95
86
  }
96
87
 
97
- /**
98
- * @param {ParserState} state
99
- * @param {AggregateFunc} func
100
- * @returns {AggregateColumn}
101
- */
102
- function parseAggregateItem(state, func) {
103
- expect(state, 'paren', '(')
104
-
105
- /** @type {AggregateArg} */
106
- let arg
107
-
108
- const cur = current(state)
109
- if (cur.type === 'operator' && cur.value === '*') {
110
- consume(state)
111
- arg = { kind: 'star' }
112
- } else {
113
- /** @type {'all' | 'distinct'} */
114
- let quantifier = 'all'
115
- if (cur.type === 'keyword' && cur.value === 'ALL') {
116
- consume(state) // consume ALL
117
- } else if (cur.type === 'keyword' && cur.value === 'DISTINCT') {
118
- consume(state)
119
- quantifier = 'distinct'
120
- }
121
-
122
- const expr = parseExpression(state)
123
- arg = {
124
- kind: 'expression',
125
- expr,
126
- quantifier,
127
- }
128
- }
129
-
130
- expect(state, 'paren', ')')
131
-
132
- const alias = parseAs(state)
133
-
134
- return { kind: 'aggregate', func, arg, alias }
135
- }
136
-
137
88
  /**
138
89
  * Parses an optional table alias (e.g., "FROM users u" or "FROM users AS u")
139
90
  * @param {ParserState} state
package/src/types.d.ts CHANGED
@@ -1,4 +1,31 @@
1
+ // executeSql(options)
2
+ export interface ExecuteSqlOptions {
3
+ tables: Record<string, Row | AsyncDataSource>
4
+ query: string | SelectStatement
5
+ signal?: AbortSignal
6
+ }
7
+
8
+ // AsyncRow represents a row with async cell values
9
+ export interface AsyncRow {
10
+ columns: string[]
11
+ cells: AsyncCells
12
+ }
13
+ export type AsyncCells = Record<string, AsyncCell>
14
+ export type AsyncCell = () => Promise<SqlPrimitive>
15
+
16
+ export type Row = Record<string, SqlPrimitive>[]
1
17
 
18
+ /**
19
+ * Async data source for streaming SQL execution.
20
+ * Provides an async iterator over rows.
21
+ */
22
+ export interface AsyncDataSource {
23
+ scan(options: ScanOptions): AsyncIterable<AsyncRow>
24
+ }
25
+ export interface ScanOptions {
26
+ hints?: QueryHints
27
+ signal?: AbortSignal
28
+ }
2
29
  /**
3
30
  * Hints passed to data sources for query optimization.
4
31
  * All hints are optional and "best effort" - sources may ignore them.
@@ -15,27 +42,6 @@ export interface QueryHints {
15
42
  offset?: number
16
43
  }
17
44
 
18
- /**
19
- * Async data source for streaming SQL execution.
20
- * Provides an async iterator over rows.
21
- */
22
- export interface AsyncDataSource {
23
- scan(hints?: QueryHints): AsyncIterable<AsyncRow>
24
- }
25
- export interface AsyncRow {
26
- columns: string[]
27
- cells: AsyncCells
28
- }
29
- export type AsyncCells = Record<string, AsyncCell>
30
- export type AsyncCell = () => Promise<SqlPrimitive>
31
-
32
- export type Row = Record<string, SqlPrimitive>[]
33
-
34
- export interface ExecuteSqlOptions {
35
- tables: Record<string, Row | AsyncDataSource>
36
- query: string | SelectStatement
37
- }
38
-
39
45
  export type SqlPrimitive =
40
46
  | string
41
47
  | number
@@ -109,6 +115,7 @@ export interface FunctionNode extends ExprNodeBase {
109
115
  type: 'function'
110
116
  name: string
111
117
  args: ExprNode[]
118
+ distinct?: boolean
112
119
  }
113
120
 
114
121
  export interface CastNode extends ExprNodeBase {
@@ -220,32 +227,13 @@ export type StringFunc =
220
227
  | 'CURRENT_TIME'
221
228
  | 'CURRENT_TIMESTAMP'
222
229
 
223
- export interface AggregateArgStar {
224
- kind: 'star'
225
- }
226
-
227
- export interface AggregateArgExpression {
228
- kind: 'expression'
229
- expr: ExprNode
230
- quantifier: 'all' | 'distinct'
231
- }
232
-
233
- export type AggregateArg = AggregateArgStar | AggregateArgExpression
234
-
235
- export interface AggregateColumn {
236
- kind: 'aggregate'
237
- func: AggregateFunc
238
- arg: AggregateArg
239
- alias?: string
240
- }
241
-
242
230
  export interface DerivedColumn {
243
231
  kind: 'derived'
244
232
  expr: ExprNode
245
233
  alias?: string
246
234
  }
247
235
 
248
- export type SelectColumn = StarColumn | AggregateColumn | DerivedColumn
236
+ export type SelectColumn = StarColumn | DerivedColumn
249
237
 
250
238
  export interface OrderByItem {
251
239
  expr: ExprNode
@@ -1,119 +0,0 @@
1
- import { unknownFunctionError } from '../parseErrors.js'
2
- import { aggregateError } from '../validationErrors.js'
3
- import { evaluateExpr } from './expression.js'
4
- import { defaultDerivedAlias, stringify } from './utils.js'
5
-
6
- /**
7
- * Evaluates an aggregate function over a set of rows
8
- *
9
- * @import { AggregateColumn, AsyncDataSource, AsyncRow, SqlPrimitive } from '../types.js'
10
- * @param {Object} options
11
- * @param {AggregateColumn} options.col - aggregate column definition
12
- * @param {AsyncRow[]} options.rows - rows to aggregate
13
- * @param {Record<string, AsyncDataSource>} options.tables
14
- * @returns {Promise<SqlPrimitive>} aggregated result
15
- */
16
- export async function evaluateAggregate({ col, rows, tables }) {
17
- const { arg, func } = col
18
-
19
- if (func === 'COUNT') {
20
- if (arg.kind === 'star') return rows.length
21
- if (arg.quantifier === 'distinct') {
22
- const seen = new Set()
23
- for (const row of rows) {
24
- const v = await evaluateExpr({ node: arg.expr, row, tables })
25
- if (v != null) {
26
- seen.add(v)
27
- }
28
- }
29
- return seen.size
30
- }
31
- let count = 0
32
- for (const row of rows) {
33
- const v = await evaluateExpr({ node: arg.expr, row, tables })
34
- if (v != null) {
35
- count += 1
36
- }
37
- }
38
- return count
39
- }
40
-
41
- if (func === 'SUM' || func === 'AVG' || func === 'MIN' || func === 'MAX') {
42
- if (arg.kind === 'star') {
43
- throw aggregateError({ funcName: func, issue: '(*) is not supported, use a column name' })
44
- }
45
- let sum = 0
46
- let count = 0
47
- /** @type {number | null} */
48
- let min = null
49
- /** @type {number | null} */
50
- let max = null
51
-
52
- for (const row of rows) {
53
- const raw = await evaluateExpr({ node: arg.expr, row, tables })
54
- if (raw == null) continue
55
- const num = Number(raw)
56
- if (!Number.isFinite(num)) continue
57
-
58
- if (count === 0) {
59
- min = num
60
- max = num
61
- } else {
62
- if (min == null || num < min) min = num
63
- if (max == null || num > max) max = num
64
- }
65
- sum += num
66
- count += 1
67
- }
68
-
69
- if (func === 'SUM') return sum
70
- if (func === 'AVG') return count === 0 ? null : sum / count
71
- if (func === 'MIN') return min
72
- if (func === 'MAX') return max
73
- }
74
-
75
- if (func === 'JSON_ARRAYAGG') {
76
- if (arg.kind === 'star') {
77
- throw aggregateError({ funcName: 'JSON_ARRAYAGG', issue: '(*) is not supported, use a column name or expression' })
78
- }
79
- /** @type {SqlPrimitive[]} */
80
- const values = []
81
- if (arg.quantifier === 'distinct') {
82
- const seen = new Set()
83
- for (const row of rows) {
84
- const v = await evaluateExpr({ node: arg.expr, row, tables })
85
- const key = stringify(v)
86
- if (!seen.has(key)) {
87
- seen.add(key)
88
- values.push(v)
89
- }
90
- }
91
- } else {
92
- for (const row of rows) {
93
- const v = await evaluateExpr({ node: arg.expr, row, tables })
94
- values.push(v)
95
- }
96
- }
97
- return values
98
- }
99
-
100
- throw unknownFunctionError({
101
- funcName: func,
102
- positionStart: 0,
103
- positionEnd: 0,
104
- validFunctions: 'COUNT, SUM, AVG, MIN, MAX, JSON_ARRAYAGG',
105
- })
106
- }
107
-
108
- /**
109
- * Generates a default alias name for an aggregate function
110
- * (e.g., "count_all", "sum_amount")
111
- *
112
- * @param {AggregateColumn} col
113
- * @returns {string}
114
- */
115
- export function defaultAggregateAlias(col) {
116
- const base = col.func.toLowerCase()
117
- if (col.arg.kind === 'star') return base + '_all'
118
- return base + '_' + defaultDerivedAlias(col.arg.expr)
119
- }