squirreling 0.6.1 → 0.7.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,19 +1,16 @@
1
1
  import { unknownFunctionError } from '../parseErrors.js'
2
2
  import { invalidContextError } from '../executionErrors.js'
3
- import {
4
- aggregateError,
5
- argCountError,
6
- argValueError,
7
- castError,
8
- } from '../validationErrors.js'
9
- import { isAggregateFunc, isMathFunc } from '../validation.js'
3
+ import { aggregateError, argValueError, castError } from '../validationErrors.js'
4
+ import { isAggregateFunc, isMathFunc, isRegexpFunc, isStringFunc } from '../validation.js'
10
5
  import { applyIntervalToDate } from './date.js'
11
6
  import { executeSelect } from './execute.js'
12
7
  import { evaluateMathFunc } from './math.js'
8
+ import { evaluateRegexpFunc } from './regexp.js'
9
+ import { evaluateStringFunc } from './strings.js'
13
10
  import { applyBinaryOp, stringify } from './utils.js'
14
11
 
15
12
  /**
16
- * @import { ExprNode, AsyncRow, SqlPrimitive, AsyncDataSource, IntervalUnit } from '../types.js'
13
+ * @import { ExprNode, AsyncRow, SqlPrimitive, AsyncDataSource, UserDefinedFunction } from '../types.js'
17
14
  */
18
15
 
19
16
  /**
@@ -23,11 +20,12 @@ import { applyBinaryOp, stringify } from './utils.js'
23
20
  * @param {ExprNode} params.node - The expression node to evaluate
24
21
  * @param {AsyncRow} params.row - The data row to evaluate against
25
22
  * @param {Record<string, AsyncDataSource>} params.tables
23
+ * @param {Record<string, UserDefinedFunction>} [params.functions] - User-defined functions
26
24
  * @param {number} [params.rowIndex] - 1-based row index for error reporting
27
25
  * @param {AsyncRow[]} [params.rows] - Group of rows for aggregate functions
28
26
  * @returns {Promise<SqlPrimitive>} The result of the evaluation
29
27
  */
30
- export async function evaluateExpr({ node, row, tables, rowIndex, rows }) {
28
+ export async function evaluateExpr({ node, row, tables, functions, rowIndex, rows }) {
31
29
  if (node.type === 'literal') {
32
30
  return node.value
33
31
  }
@@ -59,16 +57,16 @@ export async function evaluateExpr({ node, row, tables, rowIndex, rows }) {
59
57
  // Unary operators
60
58
  if (node.type === 'unary') {
61
59
  if (node.op === 'NOT') {
62
- return !await evaluateExpr({ node: node.argument, row, tables, rowIndex, rows })
60
+ return !await evaluateExpr({ node: node.argument, row, tables, functions, rowIndex, rows })
63
61
  }
64
62
  if (node.op === 'IS NULL') {
65
- return await evaluateExpr({ node: node.argument, row, tables, rowIndex, rows }) == null
63
+ return await evaluateExpr({ node: node.argument, row, tables, functions, rowIndex, rows }) == null
66
64
  }
67
65
  if (node.op === 'IS NOT NULL') {
68
- return await evaluateExpr({ node: node.argument, row, tables, rowIndex, rows }) != null
66
+ return await evaluateExpr({ node: node.argument, row, tables, functions, rowIndex, rows }) != null
69
67
  }
70
68
  if (node.op === '-') {
71
- const val = await evaluateExpr({ node: node.argument, row, tables, rowIndex, rows })
69
+ const val = await evaluateExpr({ node: node.argument, row, tables, functions, rowIndex, rows })
72
70
  if (val == null) return null
73
71
  return -val
74
72
  }
@@ -78,15 +76,15 @@ export async function evaluateExpr({ node, row, tables, rowIndex, rows }) {
78
76
  if (node.type === 'binary') {
79
77
  // Handle date +/- interval at AST level
80
78
  if ((node.op === '+' || node.op === '-') && node.right.type === 'interval') {
81
- const dateVal = await evaluateExpr({ node: node.left, row, tables, rowIndex, rows })
79
+ const dateVal = await evaluateExpr({ node: node.left, row, tables, functions, rowIndex, rows })
82
80
  return applyIntervalToDate(dateVal, node.right.value, node.right.unit, node.op)
83
81
  }
84
82
  if (node.op === '+' && node.left.type === 'interval') {
85
- const dateVal = await evaluateExpr({ node: node.right, row, tables, rowIndex, rows })
83
+ const dateVal = await evaluateExpr({ node: node.right, row, tables, functions, rowIndex, rows })
86
84
  return applyIntervalToDate(dateVal, node.left.value, node.left.unit, '+')
87
85
  }
88
86
 
89
- const left = await evaluateExpr({ node: node.left, row, tables, rowIndex, rows })
87
+ const left = await evaluateExpr({ node: node.left, row, tables, functions, rowIndex, rows })
90
88
 
91
89
  // Short-circuit evaluation for AND and OR
92
90
  if (node.op === 'AND') {
@@ -96,7 +94,7 @@ export async function evaluateExpr({ node, row, tables, rowIndex, rows }) {
96
94
  if (left) return true
97
95
  }
98
96
 
99
- const right = await evaluateExpr({ node: node.right, row, tables, rowIndex, rows })
97
+ const right = await evaluateExpr({ node: node.right, row, tables, functions, rowIndex, rows })
100
98
  return applyBinaryOp(node.op, left, right)
101
99
  }
102
100
 
@@ -124,31 +122,20 @@ export async function evaluateExpr({ node, row, tables, rowIndex, rows }) {
124
122
  })
125
123
  }
126
124
 
127
- if (node.args.length !== 1) {
128
- throw argCountError({
129
- funcName,
130
- expected: 1,
131
- received: node.args.length,
132
- positionStart: node.positionStart,
133
- positionEnd: node.positionEnd,
134
- rowNumber: rowIndex,
135
- })
136
- }
137
-
138
125
  const argNode = node.args[0]
139
126
 
140
127
  if (funcName === 'COUNT') {
141
128
  if (node.distinct) {
142
129
  const seen = new Set()
143
130
  for (const r of rows) {
144
- const v = await evaluateExpr({ node: argNode, row: r, tables })
131
+ const v = await evaluateExpr({ node: argNode, row: r, tables, functions })
145
132
  if (v != null) seen.add(v)
146
133
  }
147
134
  return seen.size
148
135
  }
149
136
  let count = 0
150
137
  for (const r of rows) {
151
- const v = await evaluateExpr({ node: argNode, row: r, tables })
138
+ const v = await evaluateExpr({ node: argNode, row: r, tables, functions })
152
139
  if (v != null) count++
153
140
  }
154
141
  return count
@@ -163,7 +150,7 @@ export async function evaluateExpr({ node, row, tables, rowIndex, rows }) {
163
150
  let max = null
164
151
 
165
152
  for (const r of rows) {
166
- const raw = await evaluateExpr({ node: argNode, row: r, tables })
153
+ const raw = await evaluateExpr({ node: argNode, row: r, tables, functions })
167
154
  if (raw == null) continue
168
155
  const num = Number(raw)
169
156
  if (!Number.isFinite(num)) continue
@@ -191,7 +178,7 @@ export async function evaluateExpr({ node, row, tables, rowIndex, rows }) {
191
178
  if (node.distinct) {
192
179
  const seen = new Set()
193
180
  for (const r of rows) {
194
- const v = await evaluateExpr({ node: argNode, row: r, tables })
181
+ const v = await evaluateExpr({ node: argNode, row: r, tables, functions })
195
182
  const key = stringify(v)
196
183
  if (!seen.has(key)) {
197
184
  seen.add(key)
@@ -200,7 +187,7 @@ export async function evaluateExpr({ node, row, tables, rowIndex, rows }) {
200
187
  }
201
188
  } else {
202
189
  for (const r of rows) {
203
- const v = await evaluateExpr({ node: argNode, row: r, tables })
190
+ const v = await evaluateExpr({ node: argNode, row: r, tables, functions })
204
191
  values.push(v)
205
192
  }
206
193
  }
@@ -209,222 +196,54 @@ export async function evaluateExpr({ node, row, tables, rowIndex, rows }) {
209
196
  }
210
197
 
211
198
  /** @type {SqlPrimitive[]} */
212
- const args = await Promise.all(node.args.map(arg => evaluateExpr({ node: arg, row, tables, rowIndex, rows })))
213
-
214
- if (funcName === 'UPPER') {
215
- if (args.length !== 1) {
216
- throw argCountError({
217
- funcName: 'UPPER',
218
- expected: 1,
219
- received: args.length,
220
- positionStart: node.positionStart,
221
- positionEnd: node.positionEnd,
222
- rowNumber: rowIndex,
223
- })
224
- }
225
- const val = args[0]
226
- if (val == null) return null
227
- return String(val).toUpperCase()
228
- }
229
-
230
- if (funcName === 'LOWER') {
231
- if (args.length !== 1) {
232
- throw argCountError({
233
- funcName: 'LOWER',
234
- expected: 1,
235
- received: args.length,
236
- positionStart: node.positionStart,
237
- positionEnd: node.positionEnd,
238
- rowNumber: rowIndex,
239
- })
240
- }
241
- const val = args[0]
242
- if (val == null) return null
243
- return String(val).toLowerCase()
244
- }
245
-
246
- if (funcName === 'CONCAT') {
247
- if (args.length < 1) {
248
- throw argCountError({
249
- funcName: 'CONCAT',
250
- expected: 'at least 1',
251
- received: args.length,
252
- positionStart: node.positionStart,
253
- positionEnd: node.positionEnd,
254
- rowNumber: rowIndex,
255
- })
256
- }
257
- // SQL CONCAT returns NULL if any argument is NULL
258
- if (args.some(a => a == null)) return null
259
- if (args.some(a => typeof a === 'object')) {
260
- throw argValueError({
261
- funcName: 'CONCAT',
262
- message: 'does not support object arguments',
263
- positionStart: node.positionStart,
264
- positionEnd: node.positionEnd,
265
- hint: 'Use CAST to convert objects to strings first.',
266
- rowNumber: rowIndex,
267
- })
268
- }
269
- return args.map(a => String(a)).join('')
270
- }
271
-
272
- if (funcName === 'LENGTH') {
273
- if (args.length !== 1) {
274
- throw argCountError({
275
- funcName: 'LENGTH',
276
- expected: 1,
277
- received: args.length,
278
- positionStart: node.positionStart,
279
- positionEnd: node.positionEnd,
280
- rowNumber: rowIndex,
281
- })
282
- }
283
- const val = args[0]
284
- if (val == null) return null
285
- return String(val).length
286
- }
199
+ const args = await Promise.all(node.args.map(arg => evaluateExpr({ node: arg, row, tables, functions, rowIndex, rows })))
287
200
 
288
- if (funcName === 'SUBSTRING' || funcName === 'SUBSTR') {
289
- if (args.length < 2 || args.length > 3) {
290
- throw argCountError({
291
- funcName,
292
- expected: '2 or 3',
293
- received: args.length,
294
- positionStart: node.positionStart,
295
- positionEnd: node.positionEnd,
296
- rowNumber: rowIndex,
297
- })
298
- }
299
- const str = args[0]
300
- if (str == null) return null
301
- const strVal = String(str)
302
- const start = Number(args[1])
303
- if (!Number.isInteger(start) || start < 1) {
304
- throw argValueError({
305
- funcName,
306
- message: `start position must be a positive integer, got ${args[1]}`,
307
- positionStart: node.positionStart,
308
- positionEnd: node.positionEnd,
309
- hint: 'SQL uses 1-based indexing.',
310
- rowNumber: rowIndex,
311
- })
312
- }
313
- // SQL uses 1-based indexing
314
- const startIdx = start - 1
315
- if (args.length === 3) {
316
- const len = Number(args[2])
317
- if (!Number.isInteger(len) || len < 0) {
318
- throw argValueError({
319
- funcName,
320
- message: `length must be a non-negative integer, got ${args[2]}`,
321
- positionStart: node.positionStart,
322
- positionEnd: node.positionEnd,
323
- rowNumber: rowIndex,
324
- })
325
- }
326
- return strVal.substring(startIdx, startIdx + len)
327
- }
328
- return strVal.substring(startIdx)
329
- }
330
-
331
- if (funcName === 'TRIM') {
332
- if (args.length !== 1) {
333
- throw argCountError({
334
- funcName: 'TRIM',
335
- expected: 1,
336
- received: args.length,
337
- positionStart: node.positionStart,
338
- positionEnd: node.positionEnd,
339
- rowNumber: rowIndex,
340
- })
341
- }
342
- const val = args[0]
343
- if (val == null) return null
344
- return String(val).trim()
201
+ if (isStringFunc(funcName)) {
202
+ return evaluateStringFunc({
203
+ funcName,
204
+ args,
205
+ positionStart: node.positionStart,
206
+ positionEnd: node.positionEnd,
207
+ rowIndex,
208
+ })
345
209
  }
346
210
 
347
- if (funcName === 'REPLACE') {
348
- if (args.length !== 3) {
349
- throw argCountError({
350
- funcName: 'REPLACE',
351
- expected: 3,
352
- received: args.length,
353
- positionStart: node.positionStart,
354
- positionEnd: node.positionEnd,
355
- rowNumber: rowIndex,
356
- })
357
- }
358
- const str = args[0]
359
- const searchStr = args[1]
360
- const replaceStr = args[2]
361
- // SQL REPLACE returns NULL if any argument is NULL
362
- if (str == null || searchStr == null || replaceStr == null) return null
363
- return String(str).replaceAll(String(searchStr), String(replaceStr))
211
+ if (isRegexpFunc(funcName)) {
212
+ return evaluateRegexpFunc({
213
+ funcName,
214
+ args,
215
+ positionStart: node.positionStart,
216
+ positionEnd: node.positionEnd,
217
+ rowIndex,
218
+ })
364
219
  }
365
220
 
366
- if (funcName === 'RANDOM' || funcName === 'RAND') {
367
- if (args.length !== 0) {
368
- throw argCountError({
369
- funcName,
370
- expected: 0,
371
- received: args.length,
372
- positionStart: node.positionStart,
373
- positionEnd: node.positionEnd,
374
- rowNumber: rowIndex,
375
- })
221
+ if (funcName === 'COALESCE') {
222
+ // Short-circuit: evaluate args one at a time, return first non-null
223
+ for (const arg of node.args) {
224
+ const val = await evaluateExpr({ node: arg, row, tables, functions, rowIndex, rows })
225
+ if (val != null) return val
376
226
  }
377
- return Math.random()
227
+ return null
378
228
  }
379
229
 
380
230
  if (funcName === 'CURRENT_DATE') {
381
- if (args.length !== 0) {
382
- throw argCountError({
383
- funcName: 'CURRENT_DATE',
384
- expected: 0,
385
- received: args.length,
386
- positionStart: node.positionStart,
387
- positionEnd: node.positionEnd,
388
- rowNumber: rowIndex,
389
- })
390
- }
391
231
  return new Date().toISOString().split('T')[0]
392
232
  }
393
233
 
394
234
  if (funcName === 'CURRENT_TIME') {
395
- if (args.length !== 0) {
396
- throw argCountError({
397
- funcName: 'CURRENT_TIME',
398
- expected: 0,
399
- received: args.length,
400
- positionStart: node.positionStart,
401
- positionEnd: node.positionEnd,
402
- rowNumber: rowIndex,
403
- })
404
- }
405
235
  return new Date().toISOString().split('T')[1].replace('Z', '')
406
236
  }
407
237
 
408
238
  if (funcName === 'CURRENT_TIMESTAMP') {
409
- if (args.length !== 0) {
410
- throw argCountError({
411
- funcName: 'CURRENT_TIMESTAMP',
412
- expected: 0,
413
- received: args.length,
414
- positionStart: node.positionStart,
415
- positionEnd: node.positionEnd,
416
- rowNumber: rowIndex,
417
- })
418
- }
419
239
  return new Date().toISOString()
420
240
  }
421
241
 
422
242
  if (funcName === 'JSON_OBJECT') {
423
243
  if (args.length % 2 !== 0) {
424
- throw argCountError({
244
+ throw argValueError({
425
245
  funcName: 'JSON_OBJECT',
426
- expected: 'even number',
427
- received: args.length,
246
+ message: 'requires an even number of arguments (key-value pairs)',
428
247
  positionStart: node.positionStart,
429
248
  positionEnd: node.positionEnd,
430
249
  rowNumber: rowIndex,
@@ -451,16 +270,6 @@ export async function evaluateExpr({ node, row, tables, rowIndex, rows }) {
451
270
  }
452
271
 
453
272
  if (funcName === 'JSON_VALUE' || funcName === 'JSON_QUERY') {
454
- if (args.length !== 2) {
455
- throw argCountError({
456
- funcName,
457
- expected: 2,
458
- received: args.length,
459
- positionStart: node.positionStart,
460
- positionEnd: node.positionEnd,
461
- rowNumber: rowIndex,
462
- })
463
- }
464
273
  let jsonArg = args[0]
465
274
  const pathArg = args[1]
466
275
  if (jsonArg == null || pathArg == null) return null
@@ -517,13 +326,15 @@ export async function evaluateExpr({ node, row, tables, rowIndex, rows }) {
517
326
  }
518
327
 
519
328
  if (isMathFunc(funcName)) {
520
- return evaluateMathFunc({
521
- funcName,
522
- args,
523
- positionStart: node.positionStart,
524
- positionEnd: node.positionEnd,
525
- rowNumber: rowIndex,
526
- })
329
+ return evaluateMathFunc({ funcName, args })
330
+ }
331
+
332
+ // Check user-defined functions (case-insensitive lookup)
333
+ if (functions) {
334
+ const udfName = Object.keys(functions).find(k => k.toUpperCase() === funcName)
335
+ if (udfName) {
336
+ return await functions[udfName].apply(...args)
337
+ }
527
338
  }
528
339
 
529
340
  throw unknownFunctionError({
@@ -534,7 +345,7 @@ export async function evaluateExpr({ node, row, tables, rowIndex, rows }) {
534
345
  }
535
346
 
536
347
  if (node.type === 'cast') {
537
- const val = await evaluateExpr({ node: node.expr, row, tables, rowIndex, rows })
348
+ const val = await evaluateExpr({ node: node.expr, row, tables, functions, rowIndex, rows })
538
349
  if (val == null) return null
539
350
  const toType = node.toType.toUpperCase()
540
351
  if (toType === 'TEXT' || toType === 'STRING' || toType === 'VARCHAR') {
@@ -577,16 +388,16 @@ export async function evaluateExpr({ node, row, tables, rowIndex, rows }) {
577
388
 
578
389
  // IN and NOT IN with value lists
579
390
  if (node.type === 'in valuelist') {
580
- const exprVal = await evaluateExpr({ node: node.expr, row, tables, rowIndex, rows })
391
+ const exprVal = await evaluateExpr({ node: node.expr, row, tables, functions, rowIndex, rows })
581
392
  for (const valueNode of node.values) {
582
- const val = await evaluateExpr({ node: valueNode, row, tables, rowIndex, rows })
393
+ const val = await evaluateExpr({ node: valueNode, row, tables, functions, rowIndex, rows })
583
394
  if (exprVal === val) return true
584
395
  }
585
396
  return false
586
397
  }
587
398
  // IN with subqueries
588
399
  if (node.type === 'in') {
589
- const exprVal = await evaluateExpr({ node: node.expr, row, tables, rowIndex, rows })
400
+ const exprVal = await evaluateExpr({ node: node.expr, row, tables, functions, rowIndex, rows })
590
401
  const results = executeSelect({ select: node.subquery, tables })
591
402
  for await (const resRow of results) {
592
403
  const value = await resRow.cells[resRow.columns[0]]()
@@ -608,28 +419,28 @@ export async function evaluateExpr({ node, row, tables, rowIndex, rows }) {
608
419
  // CASE expressions
609
420
  if (node.type === 'case') {
610
421
  // For simple CASE: evaluate the case expression once
611
- const caseValue = node.caseExpr && await evaluateExpr({ node: node.caseExpr, row, tables, rowIndex, rows })
422
+ const caseValue = node.caseExpr && await evaluateExpr({ node: node.caseExpr, row, tables, functions, rowIndex, rows })
612
423
 
613
424
  // Iterate through WHEN clauses
614
425
  for (const whenClause of node.whenClauses) {
615
426
  let conditionResult
616
427
  if (caseValue !== undefined) {
617
428
  // Simple CASE: compare caseValue with condition
618
- const whenValue = await evaluateExpr({ node: whenClause.condition, row, tables, rowIndex, rows })
429
+ const whenValue = await evaluateExpr({ node: whenClause.condition, row, tables, functions, rowIndex, rows })
619
430
  conditionResult = caseValue === whenValue
620
431
  } else {
621
432
  // Searched CASE: evaluate condition as boolean
622
- conditionResult = await evaluateExpr({ node: whenClause.condition, row, tables, rowIndex, rows })
433
+ conditionResult = await evaluateExpr({ node: whenClause.condition, row, tables, functions, rowIndex, rows })
623
434
  }
624
435
 
625
436
  if (conditionResult) {
626
- return evaluateExpr({ node: whenClause.result, row, tables, rowIndex, rows })
437
+ return evaluateExpr({ node: whenClause.result, row, tables, functions, rowIndex, rows })
627
438
  }
628
439
  }
629
440
 
630
441
  // No WHEN clause matched, return ELSE result or NULL
631
442
  if (node.elseResult) {
632
- return evaluateExpr({ node: node.elseResult, row, tables, rowIndex, rows })
443
+ return evaluateExpr({ node: node.elseResult, row, tables, functions, rowIndex, rows })
633
444
  }
634
445
  return null
635
446
  }
@@ -4,19 +4,21 @@ import { evaluateExpr } from './expression.js'
4
4
  import { applyBinaryOp } from './utils.js'
5
5
 
6
6
  /**
7
- * @import { AggregateFunc, AsyncDataSource, ExprNode, AsyncRow, SqlPrimitive } from '../types.js'
7
+ * @import { AggregateFunc, AsyncDataSource, ExprNode, AsyncRow, SqlPrimitive, UserDefinedFunction } from '../types.js'
8
8
  */
9
9
 
10
10
  /**
11
11
  * Evaluates a HAVING expression with support for aggregate functions
12
12
  *
13
- * @param {ExprNode} expr - the HAVING expression
14
- * @param {AsyncRow} row - the aggregated result row
15
- * @param {AsyncRow[]} group - the group of rows for re-evaluating aggregates
16
- * @param {Record<string, AsyncDataSource>} tables
13
+ * @param {Object} options
14
+ * @param {ExprNode} options.expr - the HAVING expression
15
+ * @param {AsyncRow} options.row - the aggregated result row
16
+ * @param {AsyncRow[]} options.group - the group of rows for re-evaluating aggregates
17
+ * @param {Record<string, AsyncDataSource>} options.tables
18
+ * @param {Record<string, UserDefinedFunction>} [options.functions]
17
19
  * @returns {Promise<boolean>} whether the HAVING condition is satisfied
18
20
  */
19
- export async function evaluateHavingExpr(expr, row, group, tables) {
21
+ export async function evaluateHavingExpr({ expr, row, group, tables, functions }) {
20
22
  // Having context
21
23
  const context = { ...group[0] ?? {}, ...row }
22
24
 
@@ -26,12 +28,12 @@ export async function evaluateHavingExpr(expr, row, group, tables) {
26
28
  const funcName = expr.name.toUpperCase()
27
29
  if (isAggregateFunc(funcName)) {
28
30
  // Evaluate aggregate function on the group
29
- return Boolean(await evaluateAggregateFunction(funcName, expr.args, group, tables))
31
+ return Boolean(await evaluateAggregateFunction({ funcName, args: expr.args, group, tables, functions }))
30
32
  }
31
33
  }
32
34
 
33
35
  if (expr.type === 'binary') {
34
- const left = await evaluateHavingValue(expr.left, context, group, tables)
36
+ const left = await evaluateHavingValue({ expr: expr.left, context, group, tables, functions })
35
37
 
36
38
  // Short-circuit evaluation for AND and OR
37
39
  if (expr.op === 'AND') {
@@ -41,61 +43,65 @@ export async function evaluateHavingExpr(expr, row, group, tables) {
41
43
  if (left) return true
42
44
  }
43
45
 
44
- const right = await evaluateHavingValue(expr.right, context, group, tables)
46
+ const right = await evaluateHavingValue({ expr: expr.right, context, group, tables, functions })
45
47
  return Boolean(applyBinaryOp(expr.op, left, right))
46
48
  }
47
49
 
48
50
  if (expr.type === 'unary') {
49
51
  if (expr.op === 'NOT') {
50
- return !await evaluateHavingExpr(expr.argument, context, group, tables)
52
+ return !await evaluateHavingExpr({ expr: expr.argument, row: context, group, tables, functions })
51
53
  }
52
54
  if (expr.op === 'IS NULL') {
53
- return await evaluateHavingValue(expr.argument, context, group, tables) == null
55
+ return await evaluateHavingValue({ expr: expr.argument, context, group, tables, functions }) == null
54
56
  }
55
57
  if (expr.op === 'IS NOT NULL') {
56
- return await evaluateHavingValue(expr.argument, context, group, tables) != null
58
+ return await evaluateHavingValue({ expr: expr.argument, context, group, tables, functions }) != null
57
59
  }
58
60
  }
59
61
 
60
62
  // For other expression types, use the context row
61
- return Boolean(await evaluateExpr({ node: expr, row: context, tables }))
63
+ return Boolean(await evaluateExpr({ node: expr, row: context, tables, functions }))
62
64
  }
63
65
 
64
66
  /**
65
67
  * Evaluates a value in a HAVING expression
66
68
  *
67
- * @param {ExprNode} expr
68
- * @param {AsyncRow} context - the context row
69
- * @param {AsyncRow[]} group - the group of rows
70
- * @param {Record<string, AsyncDataSource>} tables
69
+ * @param {Object} options
70
+ * @param {ExprNode} options.expr
71
+ * @param {AsyncRow} options.context - the context row
72
+ * @param {AsyncRow[]} options.group - the group of rows
73
+ * @param {Record<string, AsyncDataSource>} options.tables
74
+ * @param {Record<string, UserDefinedFunction>} [options.functions]
71
75
  * @returns {Promise<SqlPrimitive>} the evaluated value
72
76
  */
73
- function evaluateHavingValue(expr, context, group, tables) {
77
+ function evaluateHavingValue({ expr, context, group, tables, functions }) {
74
78
  if (expr.type === 'function') {
75
79
  const funcName = expr.name.toUpperCase()
76
80
  if (isAggregateFunc(funcName)) {
77
- return evaluateAggregateFunction(funcName, expr.args, group, tables)
81
+ return evaluateAggregateFunction({ funcName, args: expr.args, group, tables, functions })
78
82
  }
79
83
  }
80
84
 
81
85
  // For binary expressions, we need to use evaluateHavingExpr to properly handle aggregates
82
86
  if (expr.type === 'binary' || expr.type === 'unary') {
83
- return evaluateHavingExpr(expr, context, group, tables)
87
+ return evaluateHavingExpr({ expr, row: context, group, tables, functions })
84
88
  }
85
89
 
86
- return evaluateExpr({ node: expr, row: context, tables })
90
+ return evaluateExpr({ node: expr, row: context, tables, functions })
87
91
  }
88
92
 
89
93
  /**
90
94
  * Evaluates an aggregate function on a group
91
95
  *
92
- * @param {AggregateFunc} funcName - aggregate function name
93
- * @param {ExprNode[]} args - function arguments
94
- * @param {AsyncRow[]} group - the group of rows
95
- * @param {Record<string, AsyncDataSource>} tables
96
+ * @param {Object} options
97
+ * @param {AggregateFunc} options.funcName - aggregate function name
98
+ * @param {ExprNode[]} options.args - function arguments
99
+ * @param {AsyncRow[]} options.group - the group of rows
100
+ * @param {Record<string, AsyncDataSource>} options.tables
101
+ * @param {Record<string, UserDefinedFunction>} [options.functions]
96
102
  * @returns {Promise<SqlPrimitive>} the aggregate result
97
103
  */
98
- async function evaluateAggregateFunction(funcName, args, group, tables) {
104
+ async function evaluateAggregateFunction({ funcName, args, group, tables, functions }) {
99
105
  if (funcName === 'COUNT') {
100
106
  if (args.length === 1 && args[0].type === 'identifier' && args[0].name === '*') {
101
107
  return group.length
@@ -103,7 +109,7 @@ async function evaluateAggregateFunction(funcName, args, group, tables) {
103
109
  // COUNT(column) - count non-null values
104
110
  let count = 0
105
111
  for (const row of group) {
106
- const val = await evaluateExpr({ node: args[0], row, tables })
112
+ const val = await evaluateExpr({ node: args[0], row, tables, functions })
107
113
  if (val != null) count++
108
114
  }
109
115
  return count
@@ -112,7 +118,7 @@ async function evaluateAggregateFunction(funcName, args, group, tables) {
112
118
  if (funcName === 'SUM') {
113
119
  let sum = 0
114
120
  for (const row of group) {
115
- const val = await evaluateExpr({ node: args[0], row, tables })
121
+ const val = await evaluateExpr({ node: args[0], row, tables, functions })
116
122
  if (val != null) sum += Number(val)
117
123
  }
118
124
  return sum
@@ -122,7 +128,7 @@ async function evaluateAggregateFunction(funcName, args, group, tables) {
122
128
  let sum = 0
123
129
  let count = 0
124
130
  for (const row of group) {
125
- const val = await evaluateExpr({ node: args[0], row, tables })
131
+ const val = await evaluateExpr({ node: args[0], row, tables, functions })
126
132
  if (val != null) {
127
133
  sum += Number(val)
128
134
  count++
@@ -134,7 +140,7 @@ async function evaluateAggregateFunction(funcName, args, group, tables) {
134
140
  if (funcName === 'MIN') {
135
141
  let min = null
136
142
  for (const row of group) {
137
- const val = await evaluateExpr({ node: args[0], row, tables })
143
+ const val = await evaluateExpr({ node: args[0], row, tables, functions })
138
144
  if (val != null && (min == null || val < min)) {
139
145
  min = val
140
146
  }
@@ -145,7 +151,7 @@ async function evaluateAggregateFunction(funcName, args, group, tables) {
145
151
  if (funcName === 'MAX') {
146
152
  let max = null
147
153
  for (const row of group) {
148
- const val = await evaluateExpr({ node: args[0], row, tables })
154
+ const val = await evaluateExpr({ node: args[0], row, tables, functions })
149
155
  if (val != null && (max == null || val > max)) {
150
156
  max = val
151
157
  }