squirreling 0.9.0 → 0.9.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/README.md CHANGED
@@ -47,8 +47,8 @@ const asyncRows: AsyncIterable<AsyncRow> = executeSql({
47
47
  })
48
48
 
49
49
  // Process rows as they arrive (streaming)
50
- for await (const { id, name } of asyncRows) {
51
- console.log(`User id=${await id()}, name=${await name()}`)
50
+ for await (const { cells } of asyncRows) {
51
+ console.log(`User id=${await cells.id()}, name=${await cells.name()}`)
52
52
  }
53
53
  ```
54
54
 
@@ -95,7 +95,7 @@ interface AsyncDataSource {
95
95
  }
96
96
 
97
97
  interface ScanOptions {
98
- columns?: string[]
98
+ columns?: string[] // columns to scan (undefined means all)
99
99
  where?: ExprNode
100
100
  limit?: number
101
101
  offset?: number
@@ -128,11 +128,12 @@ const customSource: AsyncDataSource = {
128
128
 
129
129
  Squirreling mostly follows the SQL standard. The following features are supported:
130
130
 
131
- - `SELECT` statements with `WHERE`, `ORDER BY`, `LIMIT`, `OFFSET`
131
+ - `SELECT` statements with `DISTINCT`, `WHERE`, `ORDER BY`, `LIMIT`, `OFFSET`
132
132
  - `WITH` clause for Common Table Expressions (CTEs)
133
133
  - Subqueries in `SELECT`, `FROM`, and `WHERE` clauses
134
- - `JOIN` operations: `INNER JOIN`, `LEFT JOIN`, `RIGHT JOIN`, `FULL JOIN`, `POSITIONAL JOIN`
134
+ - `JOIN` operations: `INNER JOIN`, `LEFT JOIN`, `RIGHT JOIN`, `FULL JOIN`, `CROSS JOIN`, `POSITIONAL JOIN`
135
135
  - `GROUP BY` and `HAVING` clauses
136
+ - Expressions: `CASE`, `CAST`, `BETWEEN`, `IN`, `LIKE`, `IS NULL`, `IS NOT NULL`
136
137
 
137
138
  ### Quoting
138
139
 
@@ -142,7 +143,7 @@ Squirreling mostly follows the SQL standard. The following features are supporte
142
143
 
143
144
  ### Functions
144
145
 
145
- - Aggregate: `COUNT`, `SUM`, `AVG`, `MIN`, `MAX`, `JSON_ARRAYAGG`
146
+ - Aggregate: `COUNT`, `SUM`, `AVG`, `MIN`, `MAX`, `STDDEV_POP`, `STDDEV_SAMP`, `JSON_ARRAYAGG`
146
147
  - String: `CONCAT`, `SUBSTRING`, `REPLACE`, `LENGTH`, `UPPER`, `LOWER`, `TRIM`, `LEFT`, `RIGHT`, `INSTR`
147
148
  - Math: `ABS`, `SIGN`, `CEIL`, `FLOOR`, `ROUND`, `MOD`, `RAND`, `RANDOM`, `LN`, `LOG10`, `EXP`, `POWER`, `SQRT`
148
149
  - Trig: `SIN`, `COS`, `TAN`, `COT`, `ASIN`, `ACOS`, `ATAN`, `ATAN2`, `DEGREES`, `RADIANS`, `PI`
package/package.json CHANGED
@@ -1,6 +1,6 @@
1
1
  {
2
2
  "name": "squirreling",
3
- "version": "0.9.0",
3
+ "version": "0.9.2",
4
4
  "description": "Squirreling Async SQL Engine",
5
5
  "author": "Hyperparam",
6
6
  "homepage": "https://hyperparam.app",
@@ -37,10 +37,10 @@
37
37
  "test": "vitest run"
38
38
  },
39
39
  "devDependencies": {
40
- "@types/node": "25.2.3",
40
+ "@types/node": "25.3.0",
41
41
  "@vitest/coverage-v8": "4.0.18",
42
42
  "eslint": "9.39.2",
43
- "eslint-plugin-jsdoc": "62.5.5",
43
+ "eslint-plugin-jsdoc": "62.6.1",
44
44
  "typescript": "5.9.3",
45
45
  "vitest": "4.0.18"
46
46
  }
@@ -25,6 +25,7 @@ function asyncRow(obj) {
25
25
  */
26
26
  export function memorySource(data) {
27
27
  return {
28
+ numRows: data.length,
28
29
  scan({ where, limit, offset, signal }) {
29
30
  // Only apply offset and limit if no where clause
30
31
  const start = !where ? offset ?? 0 : 0
@@ -1,6 +1,7 @@
1
+ import { derivedAlias } from '../expression/alias.js'
1
2
  import { evaluateExpr } from '../expression/evaluate.js'
2
- import { defaultDerivedAlias, stringify } from './utils.js'
3
3
  import { executePlan } from './execute.js'
4
+ import { stringify } from './utils.js'
4
5
 
5
6
  /**
6
7
  * @import { AsyncCells, AsyncRow, ExecuteContext, SelectColumn } from '../types.js'
@@ -31,7 +32,7 @@ function projectAggregateColumns(selectColumns, group, context) {
31
32
  }
32
33
  }
33
34
  } else if (col.kind === 'derived') {
34
- const alias = col.alias ?? defaultDerivedAlias(col.expr)
35
+ const alias = col.alias ?? derivedAlias(col.expr)
35
36
  columns.push(alias)
36
37
  cells[alias] = () => evaluateExpr({
37
38
  node: col.expr,
@@ -1,16 +1,17 @@
1
1
  import { memorySource } from '../backend/dataSource.js'
2
2
  import { tableNotFoundError } from '../executionErrors.js'
3
+ import { derivedAlias } from '../expression/alias.js'
3
4
  import { evaluateExpr } from '../expression/evaluate.js'
4
5
  import { parseSql } from '../parse/parse.js'
5
6
  import { planSql } from '../plan/plan.js'
6
7
  import { executeHashAggregate, executeScalarAggregate } from './aggregates.js'
7
8
  import { executeHashJoin, executeNestedLoopJoin, executePositionalJoin } from './join.js'
8
9
  import { executeSort } from './sort.js'
9
- import { defaultDerivedAlias, stableRowKey } from './utils.js'
10
+ import { stableRowKey } from './utils.js'
10
11
 
11
12
  /**
12
13
  * @import { AsyncCells, AsyncDataSource, AsyncRow, ExecuteContext, ExecuteSqlOptions, ExprNode, SelectStatement } from '../types.js'
13
- * @import { DistinctNode, FilterNode, LimitNode, ProjectNode, QueryPlan, ScanNode } from '../plan/types.js'
14
+ * @import { CountNode, DistinctNode, FilterNode, LimitNode, ProjectNode, QueryPlan, ScanNode } from '../plan/types.js'
14
15
  */
15
16
 
16
17
  /**
@@ -60,6 +61,8 @@ export async function* executeSelect({ select, context }) {
60
61
  export async function* executePlan({ plan, context }) {
61
62
  if (plan.type === 'Scan') {
62
63
  yield* executeScan(plan, context)
64
+ } else if (plan.type === 'Count') {
65
+ yield* executeCount(plan, context)
63
66
  } else if (plan.type === 'Filter') {
64
67
  yield* executeFilter(plan, context)
65
68
  } else if (plan.type === 'Project') {
@@ -104,15 +107,15 @@ async function* executeScan(plan, context) {
104
107
  const { rows, appliedWhere, appliedLimitOffset } = scanResult
105
108
 
106
109
  // Applied limit/offset without applied where is invalid
107
- const hasLimitOffset = plan.hints?.limit !== undefined || plan.hints?.offset // 0 offset is noop
108
- if (!appliedWhere && appliedLimitOffset && plan.hints?.where && hasLimitOffset) {
110
+ const hasLimitOffset = plan.hints.limit !== undefined || plan.hints.offset // 0 offset is noop
111
+ if (!appliedWhere && appliedLimitOffset && plan.hints.where && hasLimitOffset) {
109
112
  throw new Error(`Data source "${plan.table}" applied limit/offset without applying where`)
110
113
  }
111
114
 
112
115
  let result = rows
113
116
 
114
117
  // Apply WHERE if data source did not
115
- if (!appliedWhere && plan.hints?.where) {
118
+ if (!appliedWhere && plan.hints.where) {
116
119
  result = filterRows(result, plan.hints.where, context)
117
120
  }
118
121
 
@@ -124,6 +127,44 @@ async function* executeScan(plan, context) {
124
127
  yield* result
125
128
  }
126
129
 
130
+ /**
131
+ * Executes a Count node using numRows when available, falling back to scan
132
+ *
133
+ * @param {CountNode} plan
134
+ * @param {ExecuteContext} context
135
+ * @yields {AsyncRow}
136
+ */
137
+ async function* executeCount(plan, { tables, signal }) {
138
+ const dataSource = tables[plan.table]
139
+ if (dataSource === undefined) {
140
+ throw tableNotFoundError({ tableName: plan.table })
141
+ }
142
+
143
+ // Use source numRows if available
144
+ let count = dataSource.numRows
145
+ if (dataSource.numRows === undefined) {
146
+ // Fall back to counting rows via scan
147
+ count = 0
148
+ const { rows } = dataSource.scan({ signal })
149
+ // eslint-disable-next-line no-unused-vars
150
+ for await (const _ of rows) {
151
+ if (signal?.aborted) return
152
+ count++
153
+ }
154
+ }
155
+
156
+ /** @type {string[]} */
157
+ const columns = []
158
+ /** @type {AsyncCells} */
159
+ const cells = {}
160
+ for (const col of plan.columns) {
161
+ const alias = col.alias ?? derivedAlias(col.expr)
162
+ columns.push(alias)
163
+ cells[alias] = () => Promise.resolve(count)
164
+ }
165
+ yield { columns, cells }
166
+ }
167
+
127
168
  /**
128
169
  * Filters rows by a condition
129
170
  *
@@ -207,7 +248,7 @@ async function* executeProject(plan, context) {
207
248
  cells[key] = row.cells[key]
208
249
  }
209
250
  } else if (col.kind === 'derived') {
210
- const alias = col.alias ?? defaultDerivedAlias(col.expr)
251
+ const alias = col.alias ?? derivedAlias(col.expr)
211
252
  columns.push(alias)
212
253
  cells[alias] = () => evaluateExpr({
213
254
  node: col.expr,
@@ -1,5 +1,5 @@
1
1
  /**
2
- * @import { AsyncCells, AsyncRow, ExprNode, OrderByItem, SqlPrimitive } from '../types.js'
2
+ * @import { AsyncCells, AsyncRow, OrderByItem, SqlPrimitive } from '../types.js'
3
3
  */
4
4
 
5
5
  /**
@@ -57,45 +57,6 @@ export async function collect(asyncRows) {
57
57
  return results
58
58
  }
59
59
 
60
- /**
61
- * Generates a default alias for a derived column expression
62
- *
63
- * @param {ExprNode} expr - the expression node
64
- * @returns {string} the generated alias
65
- */
66
- export function defaultDerivedAlias(expr) {
67
- if (expr.type === 'identifier') {
68
- // For qualified names like 'users.name', use just the column part as alias
69
- if (expr.name.includes('.')) {
70
- return expr.name.split('.').pop()
71
- }
72
- return expr.name
73
- }
74
- if (expr.type === 'literal') {
75
- return String(expr.value)
76
- }
77
- if (expr.type === 'cast') {
78
- return defaultDerivedAlias(expr.expr) + '_as_' + expr.toType
79
- }
80
- if (expr.type === 'unary') {
81
- return expr.op + '_' + defaultDerivedAlias(expr.argument)
82
- }
83
- if (expr.type === 'binary') {
84
- return defaultDerivedAlias(expr.left) + '_' + expr.op + '_' + defaultDerivedAlias(expr.right)
85
- }
86
- if (expr.type === 'function') {
87
- // Handle aggregate functions with star (COUNT(*) -> count_all)
88
- if (expr.args.length === 1 && expr.args[0].type === 'identifier' && expr.args[0].name === '*') {
89
- return expr.name.toLowerCase() + '_all'
90
- }
91
- return expr.name.toLowerCase() + '_' + expr.args.map(defaultDerivedAlias).join('_')
92
- }
93
- if (expr.type === 'interval') {
94
- return `interval_${expr.value}_${expr.unit.toLowerCase()}`
95
- }
96
- return 'expr'
97
- }
98
-
99
60
  /**
100
61
  * @param {SqlPrimitive} value
101
62
  * @returns {string}
@@ -0,0 +1,42 @@
1
+ /**
2
+ * @import { ExprNode } from '../types.js'
3
+ */
4
+
5
+ /**
6
+ * Generates a default alias for a derived column expression
7
+ *
8
+ * @param {ExprNode} expr - the expression node
9
+ * @returns {string} the generated alias
10
+ */
11
+ export function derivedAlias(expr) {
12
+ if (expr.type === 'identifier') {
13
+ // For qualified names like 'users.name', use just the column part as alias
14
+ if (expr.name.includes('.')) {
15
+ return expr.name.split('.').pop()
16
+ }
17
+ return expr.name
18
+ }
19
+ if (expr.type === 'literal') {
20
+ return String(expr.value)
21
+ }
22
+ if (expr.type === 'cast') {
23
+ return derivedAlias(expr.expr) + '_as_' + expr.toType
24
+ }
25
+ if (expr.type === 'unary') {
26
+ return expr.op + '_' + derivedAlias(expr.argument)
27
+ }
28
+ if (expr.type === 'binary') {
29
+ return derivedAlias(expr.left) + '_' + expr.op + '_' + derivedAlias(expr.right)
30
+ }
31
+ if (expr.type === 'function') {
32
+ // Handle aggregate functions with star (COUNT(*) -> count_all)
33
+ if (expr.args.length === 1 && expr.args[0].type === 'star') {
34
+ return expr.name.toLowerCase() + '_all'
35
+ }
36
+ return expr.name.toLowerCase() + '_' + expr.args.map(derivedAlias).join('_')
37
+ }
38
+ if (expr.type === 'interval') {
39
+ return `interval_${expr.value}_${expr.unit.toLowerCase()}`
40
+ }
41
+ return 'expr'
42
+ }
@@ -1,9 +1,10 @@
1
1
  import { executeSelect } from '../execute/execute.js'
2
- import { defaultDerivedAlias, stringify } from '../execute/utils.js'
2
+ import { stringify } from '../execute/utils.js'
3
3
  import { columnNotFoundError, invalidContextError } from '../executionErrors.js'
4
4
  import { unknownFunctionError } from '../parseErrors.js'
5
5
  import { isAggregateFunc, isMathFunc, isRegexpFunc, isStringFunc } from '../validation.js'
6
6
  import { aggregateError, argValueError, castError } from '../validationErrors.js'
7
+ import { derivedAlias } from './alias.js'
7
8
  import { applyBinaryOp } from './binary.js'
8
9
  import { applyIntervalToDate } from './date.js'
9
10
  import { evaluateMathFunc } from './math.js'
@@ -104,9 +105,9 @@ export async function evaluateExpr({ node, row, rowIndex, rows, context }) {
104
105
  if (!rows) {
105
106
  // Aggregate function used outside of aggregate context
106
107
  // This is only allowed if same aggregate was in the SELECT list
107
- const derivedAlias = defaultDerivedAlias(node)
108
- if (row.columns.includes(derivedAlias)) {
109
- return row.cells[derivedAlias]()
108
+ const alias = derivedAlias(node)
109
+ if (row.columns.includes(alias)) {
110
+ return row.cells[alias]()
110
111
  } else {
111
112
  throw aggregateError({
112
113
  funcName,
@@ -126,14 +127,13 @@ export async function evaluateExpr({ node, row, rowIndex, rows, context }) {
126
127
  }
127
128
  }
128
129
 
129
- // Handle COUNT(*) special case
130
- if (node.args.length === 1 && node.args[0].type === 'identifier' && funcName === 'COUNT' && node.args[0].name === '*') {
131
- return filteredRows.length
132
- }
133
-
134
130
  const argNode = node.args[0]
135
-
136
131
  if (funcName === 'COUNT') {
132
+ // COUNT(*) special case
133
+ if (argNode.type === 'star') {
134
+ return filteredRows.length
135
+ }
136
+
137
137
  if (node.distinct) {
138
138
  const seen = new Set()
139
139
  for (const row of filteredRows) {
@@ -10,8 +10,8 @@ import { argValueError } from '../validationErrors.js'
10
10
  * @param {Object} options
11
11
  * @param {string} options.funcName - Uppercase function name
12
12
  * @param {SqlPrimitive[]} options.args - Function arguments
13
- * @param {number} [options.positionStart] - Start position in SQL string for error reporting
14
- * @param {number} [options.positionEnd] - End position in SQL string for error reporting
13
+ * @param {number} options.positionStart - Start position in SQL string for error reporting
14
+ * @param {number} options.positionEnd - End position in SQL string for error reporting
15
15
  * @param {number} [options.rowIndex] - Row number for error reporting
16
16
  * @returns {SqlPrimitive}
17
17
  */
@@ -10,8 +10,8 @@ import { argValueError } from '../validationErrors.js'
10
10
  * @param {Object} options
11
11
  * @param {StringFunc} options.funcName - Uppercase function name
12
12
  * @param {SqlPrimitive[]} options.args - Function arguments
13
- * @param {number} [options.positionStart] - Start position for error reporting
14
- * @param {number} [options.positionEnd] - End position for error reporting
13
+ * @param {number} options.positionStart - Start position for error reporting
14
+ * @param {number} options.positionEnd - End position for error reporting
15
15
  * @param {number} [options.rowIndex] - Row index for error reporting
16
16
  * @returns {SqlPrimitive}
17
17
  */
package/src/index.d.ts CHANGED
@@ -1,4 +1,4 @@
1
- import type { AsyncDataSource, AsyncRow, ExecuteContext, ExecuteSqlOptions, ParseSqlOptions, PlanSqlOptions, QueryPlan, SelectStatement, SqlPrimitive, Token } from './types.js'
1
+ import type { AsyncDataSource, AsyncRow, ExecuteContext, ExecuteSqlOptions, ExprNode, ParseSqlOptions, PlanSqlOptions, QueryPlan, SelectStatement, SqlPrimitive, Token } from './types.js'
2
2
  export type {
3
3
  AsyncCells,
4
4
  AsyncDataSource,
@@ -76,3 +76,12 @@ export function tokenizeSql(sql: string): Token[]
76
76
  export function collect<T>(asyncGen: AsyncGenerator<AsyncRow>): Promise<Record<string, SqlPrimitive>[]>
77
77
 
78
78
  export function cachedDataSource(source: AsyncDataSource): AsyncDataSource
79
+
80
+ /**
81
+ * Generates a default alias for a derived column expression.
82
+ * Useful for generating column names pre-execution.
83
+ *
84
+ * @param expr - the expression node
85
+ * @returns the generated alias
86
+ */
87
+ export function derivedAlias(expr: ExprNode): string
package/src/index.js CHANGED
@@ -4,3 +4,4 @@ export { planSql } from './plan/plan.js'
4
4
  export { tokenizeSql } from './parse/tokenize.js'
5
5
  export { collect } from './execute/utils.js'
6
6
  export { cachedDataSource } from './backend/dataSource.js'
7
+ export { derivedAlias } from './expression/alias.js'
@@ -194,6 +194,8 @@ export function parsePrimary(state) {
194
194
  throw missingClauseError({
195
195
  missing: 'at least one WHEN clause',
196
196
  context: 'CASE expression',
197
+ positionStart,
198
+ positionEnd: state.lastPos,
197
199
  })
198
200
  }
199
201
 
@@ -38,8 +38,7 @@ export function parseFunctionCall(state, funcName, positionStart) {
38
38
  const starTok = current(state)
39
39
  consume(state)
40
40
  args.push({
41
- type: 'identifier',
42
- name: '*',
41
+ type: 'star',
43
42
  positionStart: starTok.positionStart,
44
43
  positionEnd: state.lastPos,
45
44
  })
@@ -74,7 +73,7 @@ export function parseFunctionCall(state, funcName, positionStart) {
74
73
 
75
74
  // Validate star argument at parse time (only COUNT supports *)
76
75
  const funcNameUpper = funcName.toUpperCase()
77
- const hasStar = args.length === 1 && args[0].type === 'identifier' && args[0].name === '*'
76
+ const hasStar = args.length === 1 && args[0].type === 'star'
78
77
  if (hasStar && isAggregateFunc(funcNameUpper) && funcNameUpper !== 'COUNT') {
79
78
  throw new ParseError({
80
79
  message: `${funcName} cannot be applied to "*"`,
@@ -139,8 +139,8 @@ export function argCountParseError({ funcName, expected, received, positionStart
139
139
  * @param {Object} options
140
140
  * @param {string} options.missing - What is missing (e.g., 'WHEN clause', 'FROM clause', 'ON condition')
141
141
  * @param {string} options.context - Where it's missing from (e.g., 'CASE expression', 'SELECT statement', 'JOIN')
142
- * @param {number} [options.positionStart] - Start position in query
143
- * @param {number} [options.positionEnd] - End position in query
142
+ * @param {number} options.positionStart - Start position in query
143
+ * @param {number} options.positionEnd - End position in query
144
144
  * @returns {ParseError}
145
145
  */
146
146
  export function missingClauseError({ missing, context, positionStart, positionEnd }) {
@@ -10,11 +10,13 @@
10
10
  * @returns {Map<string, string[] | undefined>}
11
11
  */
12
12
  export function extractColumns(select) {
13
+ /** @type {Map<string, string[] | undefined>} */
14
+ const result = new Map()
15
+
13
16
  // Build alias list from FROM + JOINs
14
17
  const fromAlias = select.from.kind === 'table'
15
18
  ? select.from.alias ?? select.from.table
16
19
  : select.from.alias
17
- /** @type {string[]} */
18
20
  const aliases = [fromAlias]
19
21
  for (const join of select.joins) {
20
22
  aliases.push(join.alias ?? join.table)
@@ -30,20 +32,18 @@ export function extractColumns(select) {
30
32
  return result
31
33
  }
32
34
 
33
- // Track which tables need all columns (SELECT table.*)
34
- /** @type {Set<string>} */
35
- const allColumnsNeeded = new Set()
36
- for (const col of select.columns) {
37
- if (col.kind === 'star' && col.table) {
38
- allColumnsNeeded.add(col.table)
39
- }
40
- }
35
+ // Track per-table columns needed; undefined means all columns (table.*)
36
+ /** @type {Map<string, Set<string> | undefined>} */
37
+ const perTable = new Map(aliases.map(alias => [alias, new Set()]))
41
38
 
42
39
  // Collect all identifiers from all clauses
43
40
  /** @type {Set<string>} */
44
41
  const identifiers = new Set()
45
42
  for (const col of select.columns) {
46
- if (col.kind === 'derived') {
43
+ if (col.kind === 'star' && col.table) {
44
+ // SELECT table.* means all columns needed
45
+ perTable.set(col.table, undefined)
46
+ } else if (col.kind === 'derived') {
47
47
  collectColumnsFromExpr(col.expr, identifiers)
48
48
  }
49
49
  }
@@ -59,15 +59,6 @@ export function extractColumns(select) {
59
59
  collectColumnsFromExpr(join.on, identifiers)
60
60
  }
61
61
 
62
- // Initialize per-table sets (skip tables needing all columns)
63
- /** @type {Map<string, Set<string>>} */
64
- const perTable = new Map()
65
- for (const alias of aliases) {
66
- if (!allColumnsNeeded.has(alias)) {
67
- perTable.set(alias, new Set())
68
- }
69
- }
70
-
71
62
  // Partition identifiers by table prefix
72
63
  for (const name of identifiers) {
73
64
  const dotIndex = name.indexOf('.')
@@ -76,27 +67,19 @@ export function extractColumns(select) {
76
67
  const tablePrefix = name.substring(0, dotIndex)
77
68
  const columnName = name.substring(dotIndex + 1)
78
69
  const set = perTable.get(tablePrefix)
79
- if (set) {
80
- set.add(columnName)
81
- }
70
+ if (set) set.add(columnName)
82
71
  } else {
83
72
  // Unqualified: add to all tables (ambiguous)
84
73
  for (const [, set] of perTable) {
85
- set.add(name)
74
+ if (set) set.add(name)
86
75
  }
87
76
  }
88
77
  }
89
78
 
90
79
  // Build result map: convert Sets to arrays, undefined for all-columns tables
91
- /** @type {Map<string, string[] | undefined>} */
92
- const result = new Map()
93
80
  for (const alias of aliases) {
94
- if (allColumnsNeeded.has(alias)) {
95
- result.set(alias, undefined)
96
- } else {
97
- const set = perTable.get(alias)
98
- result.set(alias, set ? [...set] : undefined)
99
- }
81
+ const set = perTable.get(alias)
82
+ result.set(alias, set ? [...set] : undefined)
100
83
  }
101
84
  return result
102
85
  }
@@ -104,15 +87,13 @@ export function extractColumns(select) {
104
87
  /**
105
88
  * Recursively collects column names (identifiers) from an expression
106
89
  *
107
- * @param {ExprNode | undefined} expr
90
+ * @param {ExprNode} expr
108
91
  * @param {Set<string>} columns
109
92
  */
110
93
  function collectColumnsFromExpr(expr, columns) {
111
94
  if (!expr) return
112
- if (expr.type === 'identifier' && expr.name !== '*') {
95
+ if (expr.type === 'identifier') {
113
96
  columns.add(expr.name)
114
- } else if (expr.type === 'literal') {
115
- // No columns
116
97
  } else if (expr.type === 'binary') {
117
98
  collectColumnsFromExpr(expr.left, columns)
118
99
  collectColumnsFromExpr(expr.right, columns)
@@ -122,6 +103,7 @@ function collectColumnsFromExpr(expr, columns) {
122
103
  for (const arg of expr.args) {
123
104
  collectColumnsFromExpr(arg, columns)
124
105
  }
106
+ collectColumnsFromExpr(expr.filter, columns)
125
107
  } else if (expr.type === 'cast') {
126
108
  collectColumnsFromExpr(expr.expr, columns)
127
109
  } else if (expr.type === 'in valuelist') {
@@ -131,9 +113,6 @@ function collectColumnsFromExpr(expr, columns) {
131
113
  }
132
114
  } else if (expr.type === 'in') {
133
115
  collectColumnsFromExpr(expr.expr, columns)
134
- // Subquery columns are from a different scope, don't collect
135
- } else if (expr.type === 'exists' || expr.type === 'not exists') {
136
- // Subquery columns are from a different scope, don't collect
137
116
  } else if (expr.type === 'case') {
138
117
  if (expr.caseExpr) {
139
118
  collectColumnsFromExpr(expr.caseExpr, columns)
@@ -146,4 +125,5 @@ function collectColumnsFromExpr(expr, columns) {
146
125
  collectColumnsFromExpr(expr.elseResult, columns)
147
126
  }
148
127
  }
128
+ // No columns: count(*), literal, interval, exists, not exists, subquery
149
129
  }
package/src/plan/plan.js CHANGED
@@ -3,7 +3,7 @@ import { findAggregate } from '../validation.js'
3
3
  import { extractColumns } from './columns.js'
4
4
 
5
5
  /**
6
- * @import { ExprNode, JoinClause, PlanSqlOptions, ScanOptions, SelectStatement } from '../types.js'
6
+ * @import { ExprNode, DerivedColumn, JoinClause, PlanSqlOptions, ScanOptions, SelectColumn, SelectStatement } from '../types.js'
7
7
  * @import { QueryPlan } from './types.d.ts'
8
8
  */
9
9
 
@@ -51,11 +51,19 @@ function planSelect({ select, ctePlans }) {
51
51
  ? select.from.alias ?? select.from.table
52
52
  : select.from.alias
53
53
 
54
- // Determine per-table column hints for pushdown
54
+ // Determine scan hints for direct table scans (WHERE and LIMIT/OFFSET are
55
+ // included so they are only applied to fresh scans, not CTE/subquery plans)
55
56
  /** @type {ScanOptions} */
56
57
  const hints = {}
57
58
  const perTableColumns = extractColumns(select)
58
59
  hints.columns = perTableColumns.get(sourceAlias)
60
+ if (!select.joins.length) {
61
+ hints.where = select.where
62
+ if (!needsBuffering && !select.distinct) {
63
+ hints.limit = select.limit
64
+ hints.offset = select.offset
65
+ }
66
+ }
59
67
 
60
68
  // Start with the data source (FROM clause)
61
69
  /** @type {QueryPlan} */
@@ -66,27 +74,24 @@ function planSelect({ select, ctePlans }) {
66
74
  plan = planJoin({ left: plan, joins: select.joins, leftTable: sourceAlias, ctePlans, perTableColumns })
67
75
  }
68
76
 
69
- // Delegate WHERE and LIMIT/OFFSET to scan when plan is a direct table scan
70
- if (plan.type === 'Scan') {
71
- plan.hints.where = select.where
72
- if (!needsBuffering && !select.distinct) {
73
- plan.hints.limit = select.limit
74
- plan.hints.offset = select.offset
75
- }
76
- }
77
+ // Whether FROM resolved to our own direct table scan
78
+ const isOwnScan = plan.type === 'Scan' && plan.hints === hints
77
79
 
78
- // Add WHERE filter when scan can't handle it (JOINs, subqueries, CTEs)
79
- const isScan = plan.type === 'Scan'
80
- if (select.where && !isScan) {
80
+ // Add WHERE filter when the scan didn't receive it
81
+ if (select.where && !isOwnScan) {
81
82
  plan = { type: 'Filter', condition: select.where, child: plan }
82
83
  }
83
84
 
84
85
  if (useGrouping) {
85
86
  // Aggregation path: GROUP BY or scalar aggregate
86
87
  // HAVING is integrated into aggregate nodes for access to group context
87
- plan = select.groupBy.length
88
- ? { type: 'HashAggregate', groupBy: select.groupBy, columns: select.columns, having: select.having, child: plan }
89
- : { type: 'ScalarAggregate', columns: select.columns, having: select.having, child: plan }
88
+ if (select.groupBy.length) {
89
+ plan = { type: 'HashAggregate', groupBy: select.groupBy, columns: select.columns, having: select.having, child: plan }
90
+ } else if (!select.having && !select.where && plan.type === 'Scan' && isOwnScan && isAllCountStar(select.columns)) {
91
+ plan = { type: 'Count', table: plan.table, columns: select.columns }
92
+ } else {
93
+ plan = { type: 'ScalarAggregate', columns: select.columns, having: select.having, child: plan }
94
+ }
90
95
 
91
96
  // ORDER BY (after aggregation)
92
97
  if (select.orderBy.length) {
@@ -99,7 +104,7 @@ function planSelect({ select, ctePlans }) {
99
104
  }
100
105
 
101
106
  // LIMIT/OFFSET
102
- if (select.limit !== undefined || select.offset !== undefined) {
107
+ if (select.limit !== undefined || select.offset) {
103
108
  plan = { type: 'Limit', limit: select.limit, offset: select.offset, child: plan }
104
109
  }
105
110
  } else {
@@ -124,13 +129,18 @@ function planSelect({ select, ctePlans }) {
124
129
  // DISTINCT needs to come after projection but before LIMIT
125
130
  // However, for streaming distinct we need to project first
126
131
  // So the order is: Sort -> Project -> Distinct -> Limit
127
- plan = { type: 'Project', columns: select.columns, child: plan }
132
+
133
+ // Fast path for SELECT *
134
+ const isPassthrough = select.columns.length === 1 && select.columns[0].kind === 'star'
135
+ if (!isPassthrough) {
136
+ plan = { type: 'Project', columns: select.columns, child: plan }
137
+ }
128
138
 
129
139
  if (select.distinct) {
130
140
  plan = { type: 'Distinct', child: plan }
131
141
  }
132
142
 
133
- if (!(isScan && !needsBuffering && !select.distinct) && (select.limit !== undefined || select.offset !== undefined)) {
143
+ if (!(isOwnScan && !needsBuffering && !select.distinct) && (select.limit !== undefined || select.offset)) {
134
144
  plan = { type: 'Limit', limit: select.limit, offset: select.offset, child: plan }
135
145
  }
136
146
  }
@@ -302,3 +312,22 @@ function extractSimpleJoinKeys({ condition, leftTable, rightTable }) {
302
312
 
303
313
  return { leftKey: left, rightKey: right }
304
314
  }
315
+
316
+ /**
317
+ * Checks if every SELECT column is a plain COUNT(*).
318
+ *
319
+ * @param {SelectColumn[]} columns
320
+ * @returns {columns is DerivedColumn[]}
321
+ */
322
+ function isAllCountStar(columns) {
323
+ if (columns.length === 0) return false
324
+ return columns.every(col =>
325
+ col.kind === 'derived' &&
326
+ col.expr.type === 'function' &&
327
+ col.expr.name.toUpperCase() === 'COUNT' &&
328
+ col.expr.args.length === 1 &&
329
+ col.expr.args[0].type === 'star' &&
330
+ !col.expr.distinct &&
331
+ !col.expr.filter
332
+ )
333
+ }
@@ -1,7 +1,8 @@
1
- import { ExprNode, JoinType, OrderByItem, ScanOptions, SelectColumn } from '../types.js'
1
+ import { DerivedColumn, ExprNode, JoinType, OrderByItem, ScanOptions, SelectColumn } from '../types.js'
2
2
 
3
3
  export type QueryPlan =
4
4
  | ScanNode
5
+ | CountNode
5
6
  | FilterNode
6
7
  | ProjectNode
7
8
  | SortNode
@@ -20,6 +21,13 @@ export interface ScanNode {
20
21
  hints: ScanOptions
21
22
  }
22
23
 
24
+ // Count node for COUNT(*) optimization
25
+ export interface CountNode {
26
+ type: 'Count'
27
+ table: string
28
+ columns: DerivedColumn[]
29
+ }
30
+
23
31
  // Single-child nodes
24
32
  export interface FilterNode {
25
33
  type: 'Filter'
package/src/types.d.ts CHANGED
@@ -42,6 +42,7 @@ export type Row = Record<string, SqlPrimitive>[]
42
42
  * Async data source for streaming SQL execution.
43
43
  */
44
44
  export interface AsyncDataSource {
45
+ numRows?: number
45
46
  scan(options: ScanOptions): ScanResults
46
47
  }
47
48
 
@@ -214,6 +215,10 @@ export interface IntervalNode extends ExprNodeBase {
214
215
  unit: IntervalUnit
215
216
  }
216
217
 
218
+ export interface StarNode extends ExprNodeBase {
219
+ type: 'star'
220
+ }
221
+
217
222
  export type ExprNode =
218
223
  | LiteralNode
219
224
  | IdentifierNode
@@ -227,12 +232,7 @@ export type ExprNode =
227
232
  | CaseNode
228
233
  | SubqueryNode
229
234
  | IntervalNode
230
-
231
- export interface StarColumn {
232
- kind: 'star'
233
- table?: string
234
- alias?: string
235
- }
235
+ | StarNode
236
236
 
237
237
  export type AggregateFunc = 'COUNT' | 'SUM' | 'AVG' | 'MIN' | 'MAX' | 'JSON_ARRAYAGG' | 'STDDEV_SAMP' | 'STDDEV_POP'
238
238
 
@@ -276,6 +276,11 @@ export type StringFunc =
276
276
  | 'RIGHT'
277
277
  | 'INSTR'
278
278
 
279
+ export interface StarColumn {
280
+ kind: 'star'
281
+ table?: string
282
+ }
283
+
279
284
  export interface DerivedColumn {
280
285
  kind: 'derived'
281
286
  expr: ExprNode
@@ -99,7 +99,7 @@ export function argValueError({ funcName, message, positionStart, positionEnd, h
99
99
  */
100
100
  export function aggregateError({ funcName, positionStart, positionEnd }) {
101
101
  return new ExecutionError({
102
- message: `Aggregate function ${funcName} must exist in a GROUP BY clause or be part of an aggregate SELECT list`,
102
+ message: `Aggregate function ${funcName} is not available in this context`,
103
103
  positionStart,
104
104
  positionEnd,
105
105
  })