squirreling 0.9.1 → 0.9.3

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -10,8 +10,8 @@ import { argValueError } from '../validationErrors.js'
10
10
  * @param {Object} options
11
11
  * @param {StringFunc} options.funcName - Uppercase function name
12
12
  * @param {SqlPrimitive[]} options.args - Function arguments
13
- * @param {number} [options.positionStart] - Start position for error reporting
14
- * @param {number} [options.positionEnd] - End position for error reporting
13
+ * @param {number} options.positionStart - Start position for error reporting
14
+ * @param {number} options.positionEnd - End position for error reporting
15
15
  * @param {number} [options.rowIndex] - Row index for error reporting
16
16
  * @returns {SqlPrimitive}
17
17
  */
@@ -194,6 +194,8 @@ export function parsePrimary(state) {
194
194
  throw missingClauseError({
195
195
  missing: 'at least one WHEN clause',
196
196
  context: 'CASE expression',
197
+ positionStart,
198
+ positionEnd: state.lastPos,
197
199
  })
198
200
  }
199
201
 
@@ -38,8 +38,7 @@ export function parseFunctionCall(state, funcName, positionStart) {
38
38
  const starTok = current(state)
39
39
  consume(state)
40
40
  args.push({
41
- type: 'identifier',
42
- name: '*',
41
+ type: 'star',
43
42
  positionStart: starTok.positionStart,
44
43
  positionEnd: state.lastPos,
45
44
  })
@@ -74,7 +73,7 @@ export function parseFunctionCall(state, funcName, positionStart) {
74
73
 
75
74
  // Validate star argument at parse time (only COUNT supports *)
76
75
  const funcNameUpper = funcName.toUpperCase()
77
- const hasStar = args.length === 1 && args[0].type === 'identifier' && args[0].name === '*'
76
+ const hasStar = args.length === 1 && args[0].type === 'star'
78
77
  if (hasStar && isAggregateFunc(funcNameUpper) && funcNameUpper !== 'COUNT') {
79
78
  throw new ParseError({
80
79
  message: `${funcName} cannot be applied to "*"`,
@@ -139,8 +139,8 @@ export function argCountParseError({ funcName, expected, received, positionStart
139
139
  * @param {Object} options
140
140
  * @param {string} options.missing - What is missing (e.g., 'WHEN clause', 'FROM clause', 'ON condition')
141
141
  * @param {string} options.context - Where it's missing from (e.g., 'CASE expression', 'SELECT statement', 'JOIN')
142
- * @param {number} [options.positionStart] - Start position in query
143
- * @param {number} [options.positionEnd] - End position in query
142
+ * @param {number} options.positionStart - Start position in query
143
+ * @param {number} options.positionEnd - End position in query
144
144
  * @returns {ParseError}
145
145
  */
146
146
  export function missingClauseError({ missing, context, positionStart, positionEnd }) {
@@ -10,11 +10,13 @@
10
10
  * @returns {Map<string, string[] | undefined>}
11
11
  */
12
12
  export function extractColumns(select) {
13
+ /** @type {Map<string, string[] | undefined>} */
14
+ const result = new Map()
15
+
13
16
  // Build alias list from FROM + JOINs
14
17
  const fromAlias = select.from.kind === 'table'
15
18
  ? select.from.alias ?? select.from.table
16
19
  : select.from.alias
17
- /** @type {string[]} */
18
20
  const aliases = [fromAlias]
19
21
  for (const join of select.joins) {
20
22
  aliases.push(join.alias ?? join.table)
@@ -30,20 +32,18 @@ export function extractColumns(select) {
30
32
  return result
31
33
  }
32
34
 
33
- // Track which tables need all columns (SELECT table.*)
34
- /** @type {Set<string>} */
35
- const allColumnsNeeded = new Set()
36
- for (const col of select.columns) {
37
- if (col.kind === 'star' && col.table) {
38
- allColumnsNeeded.add(col.table)
39
- }
40
- }
35
+ // Track per-table columns needed; undefined means all columns (table.*)
36
+ /** @type {Map<string, Set<string> | undefined>} */
37
+ const perTable = new Map(aliases.map(alias => [alias, new Set()]))
41
38
 
42
39
  // Collect all identifiers from all clauses
43
40
  /** @type {Set<string>} */
44
41
  const identifiers = new Set()
45
42
  for (const col of select.columns) {
46
- if (col.kind === 'derived') {
43
+ if (col.kind === 'star' && col.table) {
44
+ // SELECT table.* means all columns needed
45
+ perTable.set(col.table, undefined)
46
+ } else if (col.kind === 'derived') {
47
47
  collectColumnsFromExpr(col.expr, identifiers)
48
48
  }
49
49
  }
@@ -59,15 +59,6 @@ export function extractColumns(select) {
59
59
  collectColumnsFromExpr(join.on, identifiers)
60
60
  }
61
61
 
62
- // Initialize per-table sets (skip tables needing all columns)
63
- /** @type {Map<string, Set<string>>} */
64
- const perTable = new Map()
65
- for (const alias of aliases) {
66
- if (!allColumnsNeeded.has(alias)) {
67
- perTable.set(alias, new Set())
68
- }
69
- }
70
-
71
62
  // Partition identifiers by table prefix
72
63
  for (const name of identifiers) {
73
64
  const dotIndex = name.indexOf('.')
@@ -76,27 +67,19 @@ export function extractColumns(select) {
76
67
  const tablePrefix = name.substring(0, dotIndex)
77
68
  const columnName = name.substring(dotIndex + 1)
78
69
  const set = perTable.get(tablePrefix)
79
- if (set) {
80
- set.add(columnName)
81
- }
70
+ if (set) set.add(columnName)
82
71
  } else {
83
72
  // Unqualified: add to all tables (ambiguous)
84
73
  for (const [, set] of perTable) {
85
- set.add(name)
74
+ if (set) set.add(name)
86
75
  }
87
76
  }
88
77
  }
89
78
 
90
79
  // Build result map: convert Sets to arrays, undefined for all-columns tables
91
- /** @type {Map<string, string[] | undefined>} */
92
- const result = new Map()
93
80
  for (const alias of aliases) {
94
- if (allColumnsNeeded.has(alias)) {
95
- result.set(alias, undefined)
96
- } else {
97
- const set = perTable.get(alias)
98
- result.set(alias, set ? [...set] : undefined)
99
- }
81
+ const set = perTable.get(alias)
82
+ result.set(alias, set ? [...set] : undefined)
100
83
  }
101
84
  return result
102
85
  }
@@ -104,15 +87,13 @@ export function extractColumns(select) {
104
87
  /**
105
88
  * Recursively collects column names (identifiers) from an expression
106
89
  *
107
- * @param {ExprNode | undefined} expr
90
+ * @param {ExprNode} expr
108
91
  * @param {Set<string>} columns
109
92
  */
110
93
  function collectColumnsFromExpr(expr, columns) {
111
94
  if (!expr) return
112
- if (expr.type === 'identifier' && expr.name !== '*') {
95
+ if (expr.type === 'identifier') {
113
96
  columns.add(expr.name)
114
- } else if (expr.type === 'literal') {
115
- // No columns
116
97
  } else if (expr.type === 'binary') {
117
98
  collectColumnsFromExpr(expr.left, columns)
118
99
  collectColumnsFromExpr(expr.right, columns)
@@ -122,6 +103,7 @@ function collectColumnsFromExpr(expr, columns) {
122
103
  for (const arg of expr.args) {
123
104
  collectColumnsFromExpr(arg, columns)
124
105
  }
106
+ collectColumnsFromExpr(expr.filter, columns)
125
107
  } else if (expr.type === 'cast') {
126
108
  collectColumnsFromExpr(expr.expr, columns)
127
109
  } else if (expr.type === 'in valuelist') {
@@ -131,9 +113,6 @@ function collectColumnsFromExpr(expr, columns) {
131
113
  }
132
114
  } else if (expr.type === 'in') {
133
115
  collectColumnsFromExpr(expr.expr, columns)
134
- // Subquery columns are from a different scope, don't collect
135
- } else if (expr.type === 'exists' || expr.type === 'not exists') {
136
- // Subquery columns are from a different scope, don't collect
137
116
  } else if (expr.type === 'case') {
138
117
  if (expr.caseExpr) {
139
118
  collectColumnsFromExpr(expr.caseExpr, columns)
@@ -146,4 +125,5 @@ function collectColumnsFromExpr(expr, columns) {
146
125
  collectColumnsFromExpr(expr.elseResult, columns)
147
126
  }
148
127
  }
128
+ // No columns: count(*), literal, interval, exists, not exists, subquery
149
129
  }
package/src/plan/plan.js CHANGED
@@ -3,7 +3,7 @@ import { findAggregate } from '../validation.js'
3
3
  import { extractColumns } from './columns.js'
4
4
 
5
5
  /**
6
- * @import { ExprNode, JoinClause, PlanSqlOptions, ScanOptions, SelectStatement } from '../types.js'
6
+ * @import { ExprNode, DerivedColumn, JoinClause, PlanSqlOptions, ScanOptions, SelectColumn, SelectStatement } from '../types.js'
7
7
  * @import { QueryPlan } from './types.d.ts'
8
8
  */
9
9
 
@@ -51,11 +51,19 @@ function planSelect({ select, ctePlans }) {
51
51
  ? select.from.alias ?? select.from.table
52
52
  : select.from.alias
53
53
 
54
- // Determine per-table column hints for pushdown
54
+ // Determine scan hints for direct table scans (WHERE and LIMIT/OFFSET are
55
+ // included so they are only applied to fresh scans, not CTE/subquery plans)
55
56
  /** @type {ScanOptions} */
56
57
  const hints = {}
57
58
  const perTableColumns = extractColumns(select)
58
59
  hints.columns = perTableColumns.get(sourceAlias)
60
+ if (!select.joins.length) {
61
+ hints.where = select.where
62
+ if (!needsBuffering && !select.distinct) {
63
+ hints.limit = select.limit
64
+ hints.offset = select.offset
65
+ }
66
+ }
59
67
 
60
68
  // Start with the data source (FROM clause)
61
69
  /** @type {QueryPlan} */
@@ -66,27 +74,24 @@ function planSelect({ select, ctePlans }) {
66
74
  plan = planJoin({ left: plan, joins: select.joins, leftTable: sourceAlias, ctePlans, perTableColumns })
67
75
  }
68
76
 
69
- // Delegate WHERE and LIMIT/OFFSET to scan when plan is a direct table scan
70
- if (plan.type === 'Scan') {
71
- plan.hints.where = select.where
72
- if (!needsBuffering && !select.distinct) {
73
- plan.hints.limit = select.limit
74
- plan.hints.offset = select.offset
75
- }
76
- }
77
+ // Whether FROM resolved to our own direct table scan
78
+ const isOwnScan = plan.type === 'Scan' && plan.hints === hints
77
79
 
78
- // Add WHERE filter when scan can't handle it (JOINs, subqueries, CTEs)
79
- const isScan = plan.type === 'Scan'
80
- if (select.where && !isScan) {
80
+ // Add WHERE filter when the scan didn't receive it
81
+ if (select.where && !isOwnScan) {
81
82
  plan = { type: 'Filter', condition: select.where, child: plan }
82
83
  }
83
84
 
84
85
  if (useGrouping) {
85
86
  // Aggregation path: GROUP BY or scalar aggregate
86
87
  // HAVING is integrated into aggregate nodes for access to group context
87
- plan = select.groupBy.length
88
- ? { type: 'HashAggregate', groupBy: select.groupBy, columns: select.columns, having: select.having, child: plan }
89
- : { type: 'ScalarAggregate', columns: select.columns, having: select.having, child: plan }
88
+ if (select.groupBy.length) {
89
+ plan = { type: 'HashAggregate', groupBy: select.groupBy, columns: select.columns, having: select.having, child: plan }
90
+ } else if (!select.having && !select.where && plan.type === 'Scan' && isOwnScan && isAllCountStar(select.columns)) {
91
+ plan = { type: 'Count', table: plan.table, columns: select.columns }
92
+ } else {
93
+ plan = { type: 'ScalarAggregate', columns: select.columns, having: select.having, child: plan }
94
+ }
90
95
 
91
96
  // ORDER BY (after aggregation)
92
97
  if (select.orderBy.length) {
@@ -99,7 +104,7 @@ function planSelect({ select, ctePlans }) {
99
104
  }
100
105
 
101
106
  // LIMIT/OFFSET
102
- if (select.limit !== undefined || select.offset !== undefined) {
107
+ if (select.limit !== undefined || select.offset) {
103
108
  plan = { type: 'Limit', limit: select.limit, offset: select.offset, child: plan }
104
109
  }
105
110
  } else {
@@ -124,13 +129,18 @@ function planSelect({ select, ctePlans }) {
124
129
  // DISTINCT needs to come after projection but before LIMIT
125
130
  // However, for streaming distinct we need to project first
126
131
  // So the order is: Sort -> Project -> Distinct -> Limit
127
- plan = { type: 'Project', columns: select.columns, child: plan }
132
+
133
+ // Fast path for SELECT *
134
+ const isPassthrough = select.columns.length === 1 && select.columns[0].kind === 'star'
135
+ if (!isPassthrough) {
136
+ plan = { type: 'Project', columns: select.columns, child: plan }
137
+ }
128
138
 
129
139
  if (select.distinct) {
130
140
  plan = { type: 'Distinct', child: plan }
131
141
  }
132
142
 
133
- if (!(isScan && !needsBuffering && !select.distinct) && (select.limit !== undefined || select.offset !== undefined)) {
143
+ if (!(isOwnScan && !needsBuffering && !select.distinct) && (select.limit !== undefined || select.offset)) {
134
144
  plan = { type: 'Limit', limit: select.limit, offset: select.offset, child: plan }
135
145
  }
136
146
  }
@@ -302,3 +312,22 @@ function extractSimpleJoinKeys({ condition, leftTable, rightTable }) {
302
312
 
303
313
  return { leftKey: left, rightKey: right }
304
314
  }
315
+
316
+ /**
317
+ * Checks if every SELECT column is a plain COUNT(*).
318
+ *
319
+ * @param {SelectColumn[]} columns
320
+ * @returns {columns is DerivedColumn[]}
321
+ */
322
+ function isAllCountStar(columns) {
323
+ if (columns.length === 0) return false
324
+ return columns.every(col =>
325
+ col.kind === 'derived' &&
326
+ col.expr.type === 'function' &&
327
+ col.expr.name.toUpperCase() === 'COUNT' &&
328
+ col.expr.args.length === 1 &&
329
+ col.expr.args[0].type === 'star' &&
330
+ !col.expr.distinct &&
331
+ !col.expr.filter
332
+ )
333
+ }
@@ -1,7 +1,8 @@
1
- import { ExprNode, JoinType, OrderByItem, ScanOptions, SelectColumn } from '../types.js'
1
+ import { DerivedColumn, ExprNode, JoinType, OrderByItem, ScanOptions, SelectColumn } from '../types.js'
2
2
 
3
3
  export type QueryPlan =
4
4
  | ScanNode
5
+ | CountNode
5
6
  | FilterNode
6
7
  | ProjectNode
7
8
  | SortNode
@@ -20,6 +21,13 @@ export interface ScanNode {
20
21
  hints: ScanOptions
21
22
  }
22
23
 
24
+ // Count node for COUNT(*) optimization
25
+ export interface CountNode {
26
+ type: 'Count'
27
+ table: string
28
+ columns: DerivedColumn[]
29
+ }
30
+
23
31
  // Single-child nodes
24
32
  export interface FilterNode {
25
33
  type: 'Filter'
package/src/types.d.ts CHANGED
@@ -42,6 +42,7 @@ export type Row = Record<string, SqlPrimitive>[]
42
42
  * Async data source for streaming SQL execution.
43
43
  */
44
44
  export interface AsyncDataSource {
45
+ numRows?: number
45
46
  scan(options: ScanOptions): ScanResults
46
47
  }
47
48
 
@@ -214,6 +215,10 @@ export interface IntervalNode extends ExprNodeBase {
214
215
  unit: IntervalUnit
215
216
  }
216
217
 
218
+ export interface StarNode extends ExprNodeBase {
219
+ type: 'star'
220
+ }
221
+
217
222
  export type ExprNode =
218
223
  | LiteralNode
219
224
  | IdentifierNode
@@ -227,6 +232,7 @@ export type ExprNode =
227
232
  | CaseNode
228
233
  | SubqueryNode
229
234
  | IntervalNode
235
+ | StarNode
230
236
 
231
237
  export type AggregateFunc = 'COUNT' | 'SUM' | 'AVG' | 'MIN' | 'MAX' | 'JSON_ARRAYAGG' | 'STDDEV_SAMP' | 'STDDEV_POP'
232
238
 
@@ -270,6 +276,22 @@ export type StringFunc =
270
276
  | 'RIGHT'
271
277
  | 'INSTR'
272
278
 
279
+ export type SpatialFunc =
280
+ | 'ST_INTERSECTS'
281
+ | 'ST_CONTAINS'
282
+ | 'ST_CONTAINSPROPERLY'
283
+ | 'ST_WITHIN'
284
+ | 'ST_OVERLAPS'
285
+ | 'ST_TOUCHES'
286
+ | 'ST_EQUALS'
287
+ | 'ST_CROSSES'
288
+ | 'ST_COVERS'
289
+ | 'ST_COVEREDBY'
290
+ | 'ST_DWITHIN'
291
+ | 'ST_GEOMFROMTEXT'
292
+ | 'ST_MAKEENVELOPE'
293
+ | 'ST_ASTEXT'
294
+
273
295
  export interface StarColumn {
274
296
  kind: 'star'
275
297
  table?: string
package/src/validation.js CHANGED
@@ -1,7 +1,7 @@
1
1
  import { ParseError } from './parseErrors.js'
2
2
 
3
3
  /**
4
- * @import { AggregateFunc, BinaryOp, ExprNode, FunctionNode, IntervalUnit, MathFunc, StringFunc, UserDefinedFunction } from './types.js'
4
+ * @import { AggregateFunc, BinaryOp, ExprNode, FunctionNode, IntervalUnit, MathFunc, SpatialFunc, StringFunc, UserDefinedFunction } from './types.js'
5
5
  * @param {string} name
6
6
  * @returns {name is AggregateFunc}
7
7
  */
@@ -79,6 +79,19 @@ export function isRegexpFunc(name) {
79
79
  return ['REGEXP_SUBSTR', 'REGEXP_REPLACE'].includes(name)
80
80
  }
81
81
 
82
+ /**
83
+ * @param {string} name
84
+ * @returns {name is SpatialFunc}
85
+ */
86
+ export function isSpatialFunc(name) {
87
+ return [
88
+ 'ST_INTERSECTS', 'ST_CONTAINS', 'ST_CONTAINSPROPERLY', 'ST_WITHIN',
89
+ 'ST_OVERLAPS', 'ST_TOUCHES', 'ST_EQUALS', 'ST_CROSSES',
90
+ 'ST_COVERS', 'ST_COVEREDBY', 'ST_DWITHIN',
91
+ 'ST_GEOMFROMTEXT', 'ST_MAKEENVELOPE', 'ST_ASTEXT',
92
+ ].includes(name)
93
+ }
94
+
82
95
  /**
83
96
  * @param {string} name
84
97
  * @returns {name is MathFunc}
@@ -178,6 +191,12 @@ export const FUNCTION_ARG_COUNTS = {
178
191
  JSON_OBJECT: { min: 0 },
179
192
  JSON_ARRAYAGG: { min: 1, max: 1 },
180
193
 
194
+ // Array functions
195
+ ARRAY_LENGTH: { min: 1, max: 1 },
196
+ ARRAY_POSITION: { min: 2, max: 2 },
197
+ ARRAY_SORT: { min: 1, max: 1 },
198
+ CARDINALITY: { min: 1, max: 1 },
199
+
181
200
  // Conditional functions
182
201
  COALESCE: { min: 1 },
183
202
  NULLIF: { min: 2, max: 2 },
@@ -190,6 +209,22 @@ export const FUNCTION_ARG_COUNTS = {
190
209
  MAX: { min: 1, max: 1 },
191
210
  STDDEV_SAMP: { min: 1, max: 1 },
192
211
  STDDEV_POP: { min: 1, max: 1 },
212
+
213
+ // Spatial predicate functions
214
+ ST_INTERSECTS: { min: 2, max: 2 },
215
+ ST_CONTAINS: { min: 2, max: 2 },
216
+ ST_CONTAINSPROPERLY: { min: 2, max: 2 },
217
+ ST_WITHIN: { min: 2, max: 2 },
218
+ ST_OVERLAPS: { min: 2, max: 2 },
219
+ ST_TOUCHES: { min: 2, max: 2 },
220
+ ST_EQUALS: { min: 2, max: 2 },
221
+ ST_CROSSES: { min: 2, max: 2 },
222
+ ST_COVERS: { min: 2, max: 2 },
223
+ ST_COVEREDBY: { min: 2, max: 2 },
224
+ ST_DWITHIN: { min: 3, max: 3 },
225
+ ST_GEOMFROMTEXT: { min: 1, max: 1 },
226
+ ST_MAKEENVELOPE: { min: 4, max: 4 },
227
+ ST_ASTEXT: { min: 1, max: 1 },
193
228
  }
194
229
 
195
230
  /**
@@ -249,7 +284,8 @@ export function isKnownFunction(funcName, functions) {
249
284
  isAggregateFunc(funcName) ||
250
285
  isMathFunc(funcName) ||
251
286
  isStringFunc(funcName) ||
252
- isRegexpFunc(funcName)
287
+ isRegexpFunc(funcName) ||
288
+ isSpatialFunc(funcName)
253
289
  ) {
254
290
  return true
255
291
  }
@@ -258,6 +294,7 @@ export function isKnownFunction(funcName, functions) {
258
294
  if ([
259
295
  'CURRENT_DATE', 'CURRENT_TIME', 'CURRENT_TIMESTAMP',
260
296
  'JSON_VALUE', 'JSON_QUERY', 'JSON_OBJECT',
297
+ 'ARRAY_LENGTH', 'ARRAY_POSITION', 'ARRAY_SORT', 'CARDINALITY',
261
298
  'COALESCE', 'NULLIF', 'CAST',
262
299
  ].includes(funcName)) {
263
300
  return true
@@ -68,6 +68,22 @@ export const FUNCTION_SIGNATURES = {
68
68
  MAX: 'expression',
69
69
  STDDEV_SAMP: 'expression',
70
70
  STDDEV_POP: 'expression',
71
+
72
+ // Spatial predicate functions
73
+ ST_INTERSECTS: 'geometry, geometry',
74
+ ST_CONTAINS: 'geometry, geometry',
75
+ ST_CONTAINSPROPERLY: 'geometry, geometry',
76
+ ST_WITHIN: 'geometry, geometry',
77
+ ST_OVERLAPS: 'geometry, geometry',
78
+ ST_TOUCHES: 'geometry, geometry',
79
+ ST_EQUALS: 'geometry, geometry',
80
+ ST_CROSSES: 'geometry, geometry',
81
+ ST_COVERS: 'geometry, geometry',
82
+ ST_COVEREDBY: 'geometry, geometry',
83
+ ST_DWITHIN: 'geometry, geometry, distance',
84
+ ST_GEOMFROMTEXT: 'wkt',
85
+ ST_MAKEENVELOPE: 'xmin, ymin, xmax, ymax',
86
+ ST_ASTEXT: 'geometry',
71
87
  }
72
88
 
73
89
  /**
@@ -99,7 +115,7 @@ export function argValueError({ funcName, message, positionStart, positionEnd, h
99
115
  */
100
116
  export function aggregateError({ funcName, positionStart, positionEnd }) {
101
117
  return new ExecutionError({
102
- message: `Aggregate function ${funcName} must exist in a GROUP BY clause or be part of an aggregate SELECT list`,
118
+ message: `Aggregate function ${funcName} is not available in this context`,
103
119
  positionStart,
104
120
  positionEnd,
105
121
  })