prismiq 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,819 @@
1
+ """Calculated field expression parser and SQL generator.
2
+
3
+ Parses bracket notation expression syntax (e.g., "if([is_won]==1, [amount], 0)")
4
+ and converts to PostgreSQL SQL expressions.
5
+
6
+ Supported functions:
7
+ - if(condition, true_val, false_val)
8
+ - sum(expr), count(expr), avg(expr), min(expr), max(expr)
9
+ - find(substring, text)
10
+ - date(year, month, day, hour, min, sec)
11
+ - year(date), month(date), day(date)
12
+ - datediff(date1, date2, interval) - interval: 'd'/'day', 'm'/'month', 'y'/'year', 'h'/'hour', 'mi'/'minute', 's'/'second'
13
+ - today()
14
+ - concatenate(arg1, arg2, ...)
15
+ - Operators: +, -, *, /, ==, !=, >, <, >=, <=
16
+ - Field references: [field_name] or [Table.field]
17
+ """
18
+
19
+ from __future__ import annotations
20
+
21
+ import re
22
+ from typing import Any
23
+
24
+ # ============================================================================
25
+ # Expression AST Nodes
26
+ # ============================================================================
27
+
28
+
29
+ class ExprNode:
30
+ """Base class for expression AST nodes."""
31
+
32
+ def to_sql(
33
+ self,
34
+ field_mapping: dict[str, str],
35
+ use_window_functions: bool = False,
36
+ default_table_ref: str | None = None,
37
+ ) -> str:
38
+ """Convert to PostgreSQL SQL.
39
+
40
+ Args:
41
+ field_mapping: Map of calculated field names to their SQL expressions
42
+ use_window_functions: If True, use window functions (OVER ()) for aggregations
43
+ default_table_ref: Optional table reference to qualify unqualified column names.
44
+ When there are JOINs, this prevents "ambiguous column" errors.
45
+ Example: "account_view" -> "account_view"."column_name"
46
+
47
+ Returns:
48
+ PostgreSQL SQL expression
49
+ """
50
+ raise NotImplementedError
51
+
52
+
53
+ def _quote_column_ref(name: str, default_table_ref: str | None = None) -> str:
54
+ """Quote a column reference with proper PostgreSQL identifier quoting.
55
+
56
+ Handles three cases:
57
+ 1. Qualified "alias.column" -> "alias"."column"
58
+ 2. Unqualified with default_table_ref -> "default_table_ref"."name"
59
+ 3. Unqualified without default -> "name"
60
+
61
+ Args:
62
+ name: Column name, possibly qualified with table alias (e.g., "A.date")
63
+ default_table_ref: Table reference to use for unqualified columns
64
+
65
+ Returns:
66
+ Properly quoted PostgreSQL column reference
67
+ """
68
+ if "." in name:
69
+ alias, column = name.split(".", 1)
70
+ return f'"{alias}"."{column}"'
71
+ if default_table_ref:
72
+ return f'"{default_table_ref}"."{name}"'
73
+ return f'"{name}"'
74
+
75
+
76
+ class FieldRef(ExprNode):
77
+ """Field reference: [field_name] or [Table.field]"""
78
+
79
+ # Patterns for aggregation references like [Sum of X], [Count Distinct of Y]
80
+ AGG_PATTERNS = { # noqa: RUF012
81
+ "Sum of ": "SUM",
82
+ "Average of ": "AVG",
83
+ "Count of ": "COUNT",
84
+ "Count Distinct of ": "COUNT_DISTINCT",
85
+ "Min of ": "MIN",
86
+ "Max of ": "MAX",
87
+ }
88
+
89
+ def __init__(self, name: str):
90
+ self.name = name
91
+
92
+ def to_sql(
93
+ self,
94
+ field_mapping: dict[str, str],
95
+ use_window_functions: bool = False,
96
+ default_table_ref: str | None = None,
97
+ ) -> str:
98
+ # Check if it's a calculated field that needs substitution
99
+ if self.name in field_mapping:
100
+ return f"({field_mapping[self.name]})"
101
+
102
+ # Handle aggregation references like [Sum of pageview_cms]
103
+ # These are post-aggregation references used in calculated fields
104
+ for prefix, agg_func in self.AGG_PATTERNS.items():
105
+ if self.name.startswith(prefix):
106
+ inner_field = self.name[len(prefix) :]
107
+ # Check if the inner field is also a calculated field
108
+ if inner_field in field_mapping:
109
+ inner_sql = f"({field_mapping[inner_field]})"
110
+ else:
111
+ inner_sql = _quote_column_ref(inner_field, default_table_ref)
112
+ # Generate the aggregation SQL
113
+ if agg_func == "COUNT_DISTINCT":
114
+ return f"COUNT(DISTINCT {inner_sql})"
115
+ return f"{agg_func}({inner_sql})"
116
+
117
+ # Regular column reference (possibly qualified with alias)
118
+ return _quote_column_ref(self.name, default_table_ref)
119
+
120
+
121
+ class FunctionCall(ExprNode):
122
+ """Function call: func(arg1, arg2, ...)"""
123
+
124
+ def __init__(self, name: str, args: list[ExprNode]):
125
+ self.name = name
126
+ self.args = args
127
+
128
+ def to_sql(
129
+ self,
130
+ field_mapping: dict[str, str],
131
+ use_window_functions: bool = False,
132
+ default_table_ref: str | None = None,
133
+ ) -> str:
134
+ if self.name == "if":
135
+ # if(condition, true_val, false_val) -> CASE WHEN condition THEN true_val ELSE false_val END
136
+ cond = self.args[0].to_sql(field_mapping, use_window_functions, default_table_ref)
137
+ true_val = self.args[1].to_sql(field_mapping, use_window_functions, default_table_ref)
138
+ false_val = self.args[2].to_sql(field_mapping, use_window_functions, default_table_ref)
139
+ return f"CASE WHEN {cond} THEN {true_val} ELSE {false_val} END"
140
+
141
+ elif self.name == "sum":
142
+ arg = self.args[0].to_sql(field_mapping, use_window_functions, default_table_ref)
143
+ if use_window_functions:
144
+ # Use a scalar subquery placeholder instead of OVER ()
145
+ # main.py will replace __SCALAR_SUM_<column>__ with actual subquery
146
+ if isinstance(self.args[0], FieldRef):
147
+ column_name = self.args[0].name
148
+ return f"__SCALAR_SUM_{column_name}__"
149
+ return f"SUM({arg}) OVER ()"
150
+ return f"SUM({arg})"
151
+
152
+ elif self.name == "avg":
153
+ arg = self.args[0].to_sql(field_mapping, use_window_functions, default_table_ref)
154
+ if use_window_functions:
155
+ return f"AVG({arg}) OVER ()"
156
+ return f"AVG({arg})"
157
+
158
+ elif self.name == "count":
159
+ arg = self.args[0].to_sql(field_mapping, use_window_functions, default_table_ref)
160
+ if use_window_functions:
161
+ return f"COUNT({arg}) OVER ()"
162
+ return f"COUNT({arg})"
163
+
164
+ elif self.name == "min":
165
+ arg = self.args[0].to_sql(field_mapping, use_window_functions, default_table_ref)
166
+ if use_window_functions:
167
+ return f"MIN({arg}) OVER ()"
168
+ return f"MIN({arg})"
169
+
170
+ elif self.name == "max":
171
+ arg = self.args[0].to_sql(field_mapping, use_window_functions, default_table_ref)
172
+ if use_window_functions:
173
+ return f"MAX({arg}) OVER ()"
174
+ return f"MAX({arg})"
175
+
176
+ elif self.name == "find":
177
+ # find(substring, text) -> POSITION(substring IN text)
178
+ substring = self.args[0].to_sql(field_mapping, use_window_functions, default_table_ref)
179
+ text = self.args[1].to_sql(field_mapping, use_window_functions, default_table_ref)
180
+ return f"POSITION({substring} IN {text})"
181
+
182
+ elif self.name == "today":
183
+ return "CURRENT_DATE"
184
+
185
+ elif self.name == "year":
186
+ arg = self.args[0].to_sql(field_mapping, use_window_functions, default_table_ref)
187
+ return f"EXTRACT(YEAR FROM {arg})"
188
+
189
+ elif self.name == "month":
190
+ arg = self.args[0].to_sql(field_mapping, use_window_functions, default_table_ref)
191
+ return f"EXTRACT(MONTH FROM {arg})"
192
+
193
+ elif self.name == "day":
194
+ arg = self.args[0].to_sql(field_mapping, use_window_functions, default_table_ref)
195
+ return f"EXTRACT(DAY FROM {arg})"
196
+
197
+ elif self.name == "date":
198
+ # date(year, month, day, hour, min, sec) -> MAKE_TIMESTAMP
199
+ args_sql = [
200
+ arg.to_sql(field_mapping, use_window_functions, default_table_ref)
201
+ for arg in self.args
202
+ ]
203
+ # MAKE_TIMESTAMP expects: year, month, day, hour, minute, second (all as INTEGER)
204
+ # Cast each arg to INTEGER since EXTRACT() returns NUMERIC
205
+ args_cast = [f"({a})::INTEGER" for a in args_sql]
206
+ return f"MAKE_TIMESTAMP({', '.join(args_cast)})"
207
+
208
+ elif self.name == "concatenate":
209
+ # Concatenate all arguments with ||
210
+ args_sql = [
211
+ arg.to_sql(field_mapping, use_window_functions, default_table_ref)
212
+ for arg in self.args
213
+ ]
214
+ return " || ".join(args_sql)
215
+
216
+ elif self.name == "datediff":
217
+ # DATEDIFF(start_date, end_date, interval) -> PostgreSQL date arithmetic
218
+ # Syntax: datediff(start, end, interval) returns (end - start)
219
+ # Example: datediff(today(), [date], 'd') = [date] - today()
220
+ # If date is Jan 17 and today is Jan 18, result = -1 (yesterday is -1 day from today)
221
+ # PostgreSQL equivalent: (date2 - date1) to match expected semantics
222
+ if len(self.args) >= 2:
223
+ date1 = self.args[0].to_sql(field_mapping, use_window_functions, default_table_ref)
224
+ date2 = self.args[1].to_sql(field_mapping, use_window_functions, default_table_ref)
225
+ # Get interval type if specified (3rd arg)
226
+ interval = "d" # default to days
227
+ if len(self.args) >= 3 and isinstance(self.args[2], Literal):
228
+ interval = str(self.args[2].value).lower()
229
+
230
+ # Note: datediff returns (end - start), so we use (date2 - date1)
231
+ if interval in ("d", "day", "days"):
232
+ # Day difference: cast to date and subtract (end - start)
233
+ return f"(({date2})::date - ({date1})::date)"
234
+ elif interval in ("m", "month", "months"):
235
+ # Month difference: use age function (end - start)
236
+ return f"(EXTRACT(YEAR FROM AGE({date2}::date, {date1}::date)) * 12 + EXTRACT(MONTH FROM AGE({date2}::date, {date1}::date)))::int"
237
+ elif interval in ("y", "year", "years"):
238
+ # Year difference: use age function (end - start)
239
+ return f"EXTRACT(YEAR FROM AGE({date2}::date, {date1}::date))::int"
240
+ elif interval in ("h", "hour", "hours"):
241
+ # Hour difference (end - start)
242
+ return f"EXTRACT(EPOCH FROM ({date2}::timestamp - {date1}::timestamp)) / 3600"
243
+ elif interval in ("mi", "minute", "minutes"):
244
+ # Minute difference (end - start)
245
+ return f"EXTRACT(EPOCH FROM ({date2}::timestamp - {date1}::timestamp)) / 60"
246
+ elif interval in ("s", "second", "seconds"):
247
+ # Second difference (end - start)
248
+ return f"EXTRACT(EPOCH FROM ({date2}::timestamp - {date1}::timestamp))"
249
+ else:
250
+ # Default to days (end - start)
251
+ return f"(({date2})::date - ({date1})::date)"
252
+ else:
253
+ # Not enough arguments, return raw
254
+ args_sql = [
255
+ arg.to_sql(field_mapping, use_window_functions, default_table_ref)
256
+ for arg in self.args
257
+ ]
258
+ return f"DATEDIFF({', '.join(args_sql)})"
259
+
260
+ else:
261
+ # Unknown function - pass through
262
+ args_sql = [
263
+ arg.to_sql(field_mapping, use_window_functions, default_table_ref)
264
+ for arg in self.args
265
+ ]
266
+ return f"{self.name.upper()}({', '.join(args_sql)})"
267
+
268
+
269
+ class MethodCall(ExprNode):
270
+ """Method call on an object: obj.method(args)"""
271
+
272
+ def __init__(self, obj: ExprNode, method: str, args: list[ExprNode]):
273
+ self.obj = obj
274
+ self.method = method
275
+ self.args = args
276
+
277
+ def to_sql(
278
+ self,
279
+ field_mapping: dict[str, str],
280
+ use_window_functions: bool = False,
281
+ default_table_ref: str | None = None,
282
+ ) -> str:
283
+ obj_sql = self.obj.to_sql(field_mapping, use_window_functions, default_table_ref)
284
+
285
+ if self.method == "concatenate":
286
+ # [field].concatenate() -> COALESCE([field], '')
287
+ # When concatenating multiple values, PostgreSQL || handles it
288
+ return f"COALESCE({obj_sql}, '')"
289
+
290
+ else:
291
+ # Unknown method
292
+ return f"{obj_sql}.{self.method}()"
293
+
294
+
295
+ class BinaryOp(ExprNode):
296
+ """Binary operation: left op right."""
297
+
298
+ def __init__(self, op: str, left: ExprNode, right: ExprNode):
299
+ self.op = op
300
+ self.left = left
301
+ self.right = right
302
+
303
+ def to_sql(
304
+ self,
305
+ field_mapping: dict[str, str],
306
+ use_window_functions: bool = False,
307
+ default_table_ref: str | None = None,
308
+ ) -> str:
309
+ left_sql = self.left.to_sql(field_mapping, use_window_functions, default_table_ref)
310
+ right_sql = self.right.to_sql(field_mapping, use_window_functions, default_table_ref)
311
+
312
+ # Map expression operators to SQL
313
+ op_map = {"==": "=", "!=": "<>"}
314
+ sql_op = op_map.get(self.op, self.op)
315
+
316
+ # Handle boolean comparison with integer (e.g., [is_won] == 1)
317
+ # Cast boolean fields to integer when comparing with 0 or 1
318
+ if sql_op in ["=", "<>"] and isinstance(self.right, Literal):
319
+ if self.right.value in (0, 1) and isinstance(self.left, FieldRef):
320
+ # Cast the field to integer for comparison
321
+ left_sql = f"({left_sql})::int"
322
+ elif sql_op in ["=", "<>"] and isinstance(self.left, Literal): # noqa: SIM102
323
+ if self.left.value in (0, 1) and isinstance(self.right, FieldRef):
324
+ # Cast the field to integer for comparison
325
+ right_sql = f"({right_sql})::int"
326
+
327
+ return f"({left_sql} {sql_op} {right_sql})"
328
+
329
+
330
+ class Literal(ExprNode):
331
+ """Literal value: number, string."""
332
+
333
+ def __init__(self, value: Any):
334
+ self.value = value
335
+
336
+ def to_sql(
337
+ self,
338
+ field_mapping: dict[str, str],
339
+ use_window_functions: bool = False,
340
+ default_table_ref: str | None = None,
341
+ ) -> str:
342
+ if isinstance(self.value, str):
343
+ # Escape single quotes by doubling them
344
+ escaped = self.value.replace("'", "''")
345
+ return f"'{escaped}'"
346
+ elif self.value is None:
347
+ return "NULL"
348
+ return str(self.value)
349
+
350
+
351
+ # ============================================================================
352
+ # Expression Parser
353
+ # ============================================================================
354
+
355
+
356
+ class ExpressionParser:
357
+ """Parser for bracket notation expression syntax.
358
+
359
+ Implements a recursive descent parser for the expression language.
360
+ """
361
+
362
+ def parse(self, expression: str) -> ExprNode:
363
+ """Parse expression string into AST.
364
+
365
+ Args:
366
+ expression: Expression string using bracket notation
367
+
368
+ Returns:
369
+ Root AST node
370
+
371
+ Raises:
372
+ ValueError: If expression syntax is invalid
373
+ """
374
+ # Tokenize
375
+ tokens = self._tokenize(expression)
376
+
377
+ if not tokens:
378
+ raise ValueError("Empty expression")
379
+
380
+ # Parse tokens into AST
381
+ ast, pos = self._parse_expr(tokens, 0)
382
+
383
+ # Check that we consumed all tokens
384
+ if pos < len(tokens):
385
+ raise ValueError(f"Unexpected tokens after expression: {tokens[pos:]}")
386
+
387
+ return ast
388
+
389
+ def _tokenize(self, expr: str) -> list[str]:
390
+ """Tokenize expression into list of tokens.
391
+
392
+ Args:
393
+ expr: Expression string
394
+
395
+ Returns:
396
+ List of tokens
397
+ """
398
+ # Regex to match:
399
+ # - Field refs [name]
400
+ # - Numbers (including decimals)
401
+ # - Strings in quotes
402
+ # - Identifiers (function names)
403
+ # - Operators: ==, !=, >=, <=, >, <, +, -, *, /, = (single = as alias for ==)
404
+ # - Delimiters: ( ) , .
405
+ # Note: Order matters - must match == before = to avoid partial match
406
+ pattern = r'\[([^\]]+)\]|(\d+\.?\d*)|("(?:[^"\\]|\\.)*")|([a-zA-Z_]\w*)|(\(|\)|,|\.)|(<=|>=|==|!=|>|<|=|[\+\-*/])'
407
+
408
+ tokens = []
409
+ pos = 0
410
+
411
+ for match in re.finditer(pattern, expr):
412
+ # Skip whitespace between tokens
413
+ if match.start() > pos:
414
+ ws = expr[pos : match.start()]
415
+ if not ws.isspace():
416
+ raise ValueError(f"Invalid character at position {pos}: {ws}")
417
+
418
+ if match.group(1): # Field reference
419
+ tokens.append(f"FIELD:{match.group(1)}")
420
+ elif match.group(2): # Number
421
+ tokens.append(f"NUMBER:{match.group(2)}")
422
+ elif match.group(3): # String
423
+ # Remove quotes and unescape
424
+ s = match.group(3)[1:-1].replace('\\"', '"')
425
+ tokens.append(f"STRING:{s}")
426
+ elif match.group(4): # Identifier
427
+ tokens.append(f"ID:{match.group(4)}")
428
+ elif match.group(5): # Delimiter
429
+ tokens.append(match.group(5))
430
+ elif match.group(6): # Operator
431
+ tokens.append(match.group(6))
432
+
433
+ pos = match.end()
434
+
435
+ # Check for trailing non-whitespace
436
+ if pos < len(expr) and not expr[pos:].isspace():
437
+ raise ValueError(f"Invalid character at position {pos}: {expr[pos:]}")
438
+
439
+ return tokens
440
+
441
+ def _parse_expr(self, tokens: list[str], pos: int) -> tuple[ExprNode, int]:
442
+ """Parse a complete expression (handles all operators).
443
+
444
+ Uses precedence climbing for operator precedence.
445
+
446
+ Args:
447
+ tokens: Token list
448
+ pos: Current position in token list
449
+
450
+ Returns:
451
+ (AST node, next position)
452
+ """
453
+ # Parse comparison expressions (lowest precedence)
454
+ return self._parse_comparison(tokens, pos)
455
+
456
+ def _parse_comparison(self, tokens: list[str], pos: int) -> tuple[ExprNode, int]:
457
+ """Parse comparison operators: ==, =, !=, >, <, >=, <="""
458
+ left, pos = self._parse_additive(tokens, pos)
459
+
460
+ # Note: "=" is an alias for equality operator "=="
461
+ while pos < len(tokens) and tokens[pos] in [
462
+ "==",
463
+ "=",
464
+ "!=",
465
+ ">",
466
+ "<",
467
+ ">=",
468
+ "<=",
469
+ ]:
470
+ op = tokens[pos]
471
+ pos += 1
472
+ right, pos = self._parse_additive(tokens, pos)
473
+ left = BinaryOp(op, left, right)
474
+
475
+ return left, pos
476
+
477
+ def _parse_additive(self, tokens: list[str], pos: int) -> tuple[ExprNode, int]:
478
+ """Parse additive operators: +, -"""
479
+ left, pos = self._parse_multiplicative(tokens, pos)
480
+
481
+ while pos < len(tokens) and tokens[pos] in ["+", "-"]:
482
+ op = tokens[pos]
483
+ pos += 1
484
+ right, pos = self._parse_multiplicative(tokens, pos)
485
+ left = BinaryOp(op, left, right)
486
+
487
+ return left, pos
488
+
489
+ def _parse_multiplicative(self, tokens: list[str], pos: int) -> tuple[ExprNode, int]:
490
+ """Parse multiplicative operators: *, /"""
491
+ left, pos = self._parse_primary(tokens, pos)
492
+
493
+ while pos < len(tokens) and tokens[pos] in ["*", "/"]:
494
+ op = tokens[pos]
495
+ pos += 1
496
+ right, pos = self._parse_primary(tokens, pos)
497
+ left = BinaryOp(op, left, right)
498
+
499
+ return left, pos
500
+
501
+ def _parse_primary(self, tokens: list[str], pos: int) -> tuple[ExprNode, int]:
502
+ """Parse primary expressions: literals, field refs, function calls,
503
+ parentheses."""
504
+ if pos >= len(tokens):
505
+ raise ValueError("Unexpected end of expression")
506
+
507
+ token = tokens[pos]
508
+
509
+ # Field reference
510
+ if token.startswith("FIELD:"):
511
+ field_name = token[6:] # Remove "FIELD:" prefix
512
+ pos += 1
513
+
514
+ # Check for method call: [field].method()
515
+ if pos < len(tokens) and tokens[pos] == ".":
516
+ pos += 1
517
+ if pos >= len(tokens) or not tokens[pos].startswith("ID:"):
518
+ raise ValueError("Expected method name after '.'")
519
+ method_name = tokens[pos][3:] # Remove "ID:" prefix
520
+ pos += 1
521
+
522
+ # Expect ()
523
+ if pos >= len(tokens) or tokens[pos] != "(":
524
+ raise ValueError(f"Expected '(' after method name '{method_name}'")
525
+ pos += 1
526
+ if pos >= len(tokens) or tokens[pos] != ")":
527
+ raise ValueError(f"Expected ')' after '{method_name}('")
528
+ pos += 1
529
+
530
+ return MethodCall(FieldRef(field_name), method_name, []), pos
531
+
532
+ return FieldRef(field_name), pos
533
+
534
+ # Number literal
535
+ elif token.startswith("NUMBER:"):
536
+ num_str = token[7:] # Remove "NUMBER:" prefix
537
+ if "." in num_str:
538
+ num_value: float | int = float(num_str)
539
+ else:
540
+ num_value = int(num_str)
541
+ return Literal(num_value), pos + 1
542
+
543
+ # String literal
544
+ elif token.startswith("STRING:"):
545
+ str_value = token[7:] # Remove "STRING:" prefix
546
+ return Literal(str_value), pos + 1
547
+
548
+ # Function call or identifier
549
+ elif token.startswith("ID:"):
550
+ func_name = token[3:] # Remove "ID:" prefix
551
+ pos += 1
552
+
553
+ # Check for function call: name(args)
554
+ if pos < len(tokens) and tokens[pos] == "(":
555
+ pos += 1 # Skip '('
556
+
557
+ # Parse arguments
558
+ args = []
559
+ while pos < len(tokens) and tokens[pos] != ")":
560
+ arg, pos = self._parse_expr(tokens, pos)
561
+ args.append(arg)
562
+
563
+ if pos < len(tokens) and tokens[pos] == ",":
564
+ pos += 1 # Skip ','
565
+
566
+ if pos >= len(tokens):
567
+ raise ValueError(f"Expected ')' after function arguments for '{func_name}'")
568
+
569
+ pos += 1 # Skip ')'
570
+
571
+ return FunctionCall(func_name, args), pos
572
+
573
+ else:
574
+ # Just an identifier (shouldn't happen in valid expressions)
575
+ raise ValueError(f"Unexpected identifier: {func_name}")
576
+
577
+ # Parenthesized expression
578
+ elif token == "(": # noqa: S105
579
+ pos += 1
580
+ expr, pos = self._parse_expr(tokens, pos)
581
+
582
+ if pos >= len(tokens) or tokens[pos] != ")":
583
+ raise ValueError("Expected ')' after expression")
584
+
585
+ return expr, pos + 1
586
+
587
+ else:
588
+ raise ValueError(f"Unexpected token: {token}")
589
+
590
+
591
+ # ============================================================================
592
+ # Dependency Resolution
593
+ # ============================================================================
594
+
595
+
596
+ def has_aggregation(expression: str) -> bool:
597
+ """Check if an expression contains aggregation functions.
598
+
599
+ Args:
600
+ expression: Expression string to check
601
+
602
+ Returns:
603
+ True if expression contains sum, avg, count, min, max, etc.
604
+ """
605
+ # Standard aggregation function syntax: sum(, avg(, count(, etc.
606
+ agg_funcs = ["sum(", "avg(", "count(", "min(", "max("]
607
+ expr_lower = expression.lower()
608
+ if any(func in expr_lower for func in agg_funcs):
609
+ return True
610
+
611
+ # Aggregation reference syntax: [Sum of X], [Count Distinct of Y], etc.
612
+ # These are field references that represent aggregated values
613
+ agg_patterns = [
614
+ "[sum of ",
615
+ "[average of ",
616
+ "[count of ",
617
+ "[count distinct of ",
618
+ "[min of ",
619
+ "[max of ",
620
+ ]
621
+ return any(pattern in expr_lower for pattern in agg_patterns)
622
+
623
+
624
+ def resolve_calculated_fields(
625
+ query_columns: list[dict[str, Any]],
626
+ calculated_fields: list[dict[str, Any]],
627
+ base_table_name: str | None = None,
628
+ ) -> dict[str, tuple[str, bool]]:
629
+ """Resolve calculated field dependencies and generate SQL expressions.
630
+
631
+ Args:
632
+ query_columns: Column selections from query (may reference calculated fields)
633
+ calculated_fields: List of {name, expression} dicts
634
+ base_table_name: Optional base table name to prefix unqualified column references.
635
+ When there are JOINs, this prevents "ambiguous column" errors.
636
+ Example: "account_custom_fields_view" -> "account_custom_fields_view"."column"
637
+
638
+ Returns:
639
+ Dict mapping field name to (sql_expression, has_aggregation) tuple
640
+ """
641
+ # Build map of calculated field definitions
642
+ calc_field_map = {cf["name"]: cf["expression"] for cf in calculated_fields}
643
+
644
+ if not calc_field_map:
645
+ return {}
646
+
647
+ # Build map of which calculated fields will have outer aggregations applied
648
+ outer_agg_map = {}
649
+ for col in query_columns:
650
+ col_name = col.get("column")
651
+ col_agg = col.get("aggregation", "none")
652
+ if col_name in calc_field_map and col_agg and col_agg != "none":
653
+ outer_agg_map[col_name] = col_agg
654
+
655
+ # Parse all expressions
656
+ parser = ExpressionParser()
657
+ parsed = {}
658
+ for name, expr in calc_field_map.items():
659
+ try:
660
+ parsed[name] = parser.parse(expr)
661
+ except ValueError as e:
662
+ raise ValueError(f"Failed to parse calculated field '{name}': {e}") from e
663
+
664
+ # Topological sort to resolve dependencies
665
+ resolved: dict[str, tuple[str, bool]] = {} # name -> (SQL expression, has_aggregation)
666
+ visiting: set[str] = set() # For cycle detection
667
+
668
+ def resolve(name: str) -> tuple[str, bool]:
669
+ """Resolve a calculated field and its dependencies."""
670
+ if name in resolved:
671
+ return resolved[name]
672
+
673
+ if name not in parsed:
674
+ # Not a calculated field, return as-is (will be handled as column reference)
675
+ return (f'"{name}"', False)
676
+
677
+ if name in visiting:
678
+ raise ValueError(f"Circular dependency detected in calculated field: {name}")
679
+
680
+ visiting.add(name)
681
+
682
+ # Get AST for this field
683
+ ast = parsed[name]
684
+
685
+ # Extract field dependencies
686
+ deps = _extract_field_refs(ast)
687
+
688
+ # Check if original expression has aggregation
689
+ original_expr = calc_field_map[name]
690
+ has_agg = has_aggregation(original_expr)
691
+
692
+ # Check if this field will have an outer aggregation applied
693
+ will_have_outer_agg = name in outer_agg_map
694
+
695
+ # Helper to check for and extract aggregation reference patterns
696
+ # e.g., "Sum of pageview_cms" -> ("SUM", "pageview_cms")
697
+ def parse_agg_reference(field_name: str) -> tuple[str, str] | None:
698
+ """Check if field_name is an aggregation reference like 'Sum of
699
+ X'."""
700
+ agg_patterns = {
701
+ "Sum of ": "SUM",
702
+ "Average of ": "AVG",
703
+ "Count of ": "COUNT",
704
+ "Count Distinct of ": "COUNT_DISTINCT",
705
+ "Min of ": "MIN",
706
+ "Max of ": "MAX",
707
+ }
708
+ for prefix, agg_func in agg_patterns.items():
709
+ if field_name.startswith(prefix):
710
+ return (agg_func, field_name[len(prefix) :])
711
+ return None
712
+
713
+ # Resolve all dependencies first
714
+ dep_sql_map = {}
715
+ for dep in deps:
716
+ # Check if this is an aggregation reference like "Sum of pageview_cms"
717
+ agg_ref = parse_agg_reference(dep)
718
+ if agg_ref:
719
+ agg_func, inner_field = agg_ref
720
+ # Resolve the inner field
721
+ if inner_field in calc_field_map:
722
+ inner_sql, _ = resolve(inner_field)
723
+ elif "." in inner_field:
724
+ parts = inner_field.split(".", 1)
725
+ inner_sql = f'"{parts[0]}"."{parts[1]}"'
726
+ elif base_table_name:
727
+ inner_sql = f'"{base_table_name}"."{inner_field}"'
728
+ else:
729
+ inner_sql = f'"{inner_field}"'
730
+ # Build the aggregation SQL
731
+ if agg_func == "COUNT_DISTINCT":
732
+ dep_sql_map[dep] = f"COUNT(DISTINCT {inner_sql})"
733
+ else:
734
+ dep_sql_map[dep] = f"{agg_func}({inner_sql})"
735
+ elif dep in calc_field_map:
736
+ dep_sql, dep_has_agg = resolve(dep)
737
+ # If this expression has aggregation and the dependency doesn't,
738
+ # wrap the dependency in SUM so it works with GROUP BY
739
+ if has_agg and not dep_has_agg:
740
+ dep_sql = f"SUM({dep_sql})"
741
+ dep_sql_map[dep] = dep_sql
742
+ else:
743
+ # Not a calculated field - format as column reference
744
+ # Handle alias.column syntax (e.g., "A.date" from joined tables)
745
+ if "." in dep:
746
+ parts = dep.split(".", 1)
747
+ if len(parts) == 2:
748
+ alias, column = parts
749
+ dep_sql_map[dep] = f'"{alias}"."{column}"'
750
+ else:
751
+ dep_sql_map[dep] = f'"{dep}"'
752
+ elif base_table_name:
753
+ # Only qualify with base table name if this looks like a real database column.
754
+ # Database columns typically use snake_case without spaces.
755
+ # If the dependency name contains spaces or special chars, it's likely
756
+ # a reference to another calculated field that wasn't found in calc_field_map
757
+ # (possibly defined in another widget). Don't apply table prefix in that case.
758
+ looks_like_db_column = " " not in dep and not any(
759
+ c in dep for c in ["(", ")", "+", "-", "*", "/"]
760
+ )
761
+ if looks_like_db_column:
762
+ # Qualify with base table name to avoid ambiguity in JOINs
763
+ dep_sql_map[dep] = f'"{base_table_name}"."{dep}"'
764
+ else:
765
+ # Likely a calculated field reference - don't qualify
766
+ dep_sql_map[dep] = f'"{dep}"'
767
+ else:
768
+ dep_sql_map[dep] = f'"{dep}"'
769
+
770
+ # Convert to SQL with resolved dependencies
771
+ # Use window functions if:
772
+ # 1. This calculated field has aggregation in its expression
773
+ # 2. AND an outer aggregation will be applied (from VisualizationDataSpec)
774
+ use_window_functions = has_agg and will_have_outer_agg
775
+ # Pass base_table_name as default_table_ref to qualify unqualified columns
776
+ sql = ast.to_sql(
777
+ dep_sql_map,
778
+ use_window_functions=use_window_functions,
779
+ default_table_ref=base_table_name,
780
+ )
781
+
782
+ resolved[name] = (sql, has_agg)
783
+
784
+ visiting.remove(name)
785
+ return (sql, has_agg)
786
+
787
+ # Resolve all calculated fields (even if not directly used in query)
788
+ # This ensures all dependencies are resolved
789
+ for name in calc_field_map:
790
+ resolve(name)
791
+
792
+ return resolved
793
+
794
+
795
+ def _extract_field_refs(node: ExprNode) -> list[str]:
796
+ """Extract all field references from AST.
797
+
798
+ Args:
799
+ node: AST node to extract from
800
+
801
+ Returns:
802
+ List of field names referenced
803
+ """
804
+ if isinstance(node, FieldRef):
805
+ return [node.name]
806
+ elif isinstance(node, FunctionCall):
807
+ refs = []
808
+ for arg in node.args:
809
+ refs.extend(_extract_field_refs(arg))
810
+ return refs
811
+ elif isinstance(node, MethodCall):
812
+ refs = _extract_field_refs(node.obj)
813
+ for arg in node.args:
814
+ refs.extend(_extract_field_refs(arg))
815
+ return refs
816
+ elif isinstance(node, BinaryOp):
817
+ return _extract_field_refs(node.left) + _extract_field_refs(node.right)
818
+ else:
819
+ return []