prismiq 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,231 @@
1
+ """Calculated field preprocessing for query building.
2
+
3
+ This module provides functions to preprocess queries with calculated fields,
4
+ applying SQL expressions to columns and filters before SQL generation.
5
+
6
+ Usage:
7
+ from prismiq.calculated_field_processor import preprocess_calculated_fields
8
+
9
+ # Apply calculated fields to a query dict
10
+ processed_query = preprocess_calculated_fields(query)
11
+
12
+ # Then build SQL
13
+ sql, params = build_sql_from_dict(processed_query)
14
+ """
15
+
16
+ from __future__ import annotations
17
+
18
+ import re
19
+ from typing import Any
20
+
21
+ from .calculated_fields import resolve_calculated_fields
22
+
23
+
24
+ def _has_special_characters(column_name: str) -> bool:
25
+ """Check if column name contains special characters.
26
+
27
+ Special characters indicate a calculated field name (e.g., "Total
28
+ Revenue %") rather than a regular database column (e.g.,
29
+ "account_id").
30
+ """
31
+ if not column_name:
32
+ return False
33
+ if column_name == "*":
34
+ return False
35
+ # Allow alphanumeric, underscore, and dot (for table.column refs)
36
+ # Everything else is considered special
37
+ return bool(re.search(r"[^a-zA-Z0-9_.]", column_name))
38
+
39
+
40
+ def _apply_calculated_fields_to_columns(
41
+ columns: list[dict[str, Any]],
42
+ calc_field_sql_map: dict[str, tuple[str, bool]],
43
+ ) -> tuple[list[dict[str, Any]], bool]:
44
+ """Replace calculated field references with SQL expressions in columns.
45
+
46
+ This modifies the column definitions to include SQL expressions for calculated fields.
47
+ The SQL builder will then use the sql_expression field instead of building table.column.
48
+
49
+ Args:
50
+ columns: Column definitions from query
51
+ calc_field_sql_map: Mapping of field names to (SQL expression, has_aggregation) tuples
52
+
53
+ Returns:
54
+ Tuple of (modified_columns, uses_window_functions):
55
+ - modified_columns: Column definitions with calculated fields resolved
56
+ - uses_window_functions: True if any column uses window functions (OVER ()),
57
+ indicating GROUP BY should be cleared
58
+ """
59
+ modified_columns = []
60
+
61
+ # First pass: check if any calculated field uses window functions (OVER ())
62
+ # If so, we need to convert all aggregations to window functions to avoid conflicts
63
+ has_window_function = False
64
+ for col in columns:
65
+ column_name = col.get("column", "")
66
+ if column_name in calc_field_sql_map:
67
+ expr, _ = calc_field_sql_map[column_name]
68
+ if " OVER " in expr.upper():
69
+ has_window_function = True
70
+ break
71
+
72
+ for col in columns:
73
+ col_copy = col.copy()
74
+ column_name = col.get("column", "")
75
+ aggregation = col.get("aggregation", "none")
76
+
77
+ # Check if this is a calculated field
78
+ if column_name in calc_field_sql_map:
79
+ expr, has_aggregation = calc_field_sql_map[column_name]
80
+
81
+ # Use sql_expression field (SQL builder will use this instead of building table.column)
82
+ col_copy["sql_expression"] = expr
83
+ col_copy["_has_aggregation"] = has_aggregation
84
+ elif has_window_function and aggregation and aggregation != "none":
85
+ # Convert regular aggregations to window functions to match calculated fields
86
+ # This prevents the "column must appear in GROUP BY" error when mixing
87
+ # window functions with regular aggregates
88
+ # Escape double quotes to prevent SQL injection
89
+ safe_column_name = column_name.replace('"', '""')
90
+ if column_name == "*" and aggregation == "count":
91
+ # COUNT(*) -> COUNT(*) OVER ()
92
+ col_copy["sql_expression"] = "COUNT(*) OVER ()"
93
+ col_copy["_has_aggregation"] = True
94
+ col_copy["aggregation"] = "none" # Don't double-wrap
95
+ elif aggregation == "count":
96
+ col_copy["sql_expression"] = f'COUNT("{safe_column_name}") OVER ()'
97
+ col_copy["_has_aggregation"] = True
98
+ col_copy["aggregation"] = "none"
99
+ elif aggregation == "count_distinct":
100
+ col_copy["sql_expression"] = f'COUNT(DISTINCT "{safe_column_name}") OVER ()'
101
+ col_copy["_has_aggregation"] = True
102
+ col_copy["aggregation"] = "none"
103
+ elif aggregation == "sum":
104
+ col_copy["sql_expression"] = f'SUM("{safe_column_name}") OVER ()'
105
+ col_copy["_has_aggregation"] = True
106
+ col_copy["aggregation"] = "none"
107
+ elif aggregation == "avg":
108
+ col_copy["sql_expression"] = f'AVG("{safe_column_name}") OVER ()'
109
+ col_copy["_has_aggregation"] = True
110
+ col_copy["aggregation"] = "none"
111
+ elif aggregation == "min":
112
+ col_copy["sql_expression"] = f'MIN("{safe_column_name}") OVER ()'
113
+ col_copy["_has_aggregation"] = True
114
+ col_copy["aggregation"] = "none"
115
+ elif aggregation == "max":
116
+ col_copy["sql_expression"] = f'MAX("{safe_column_name}") OVER ()'
117
+ col_copy["_has_aggregation"] = True
118
+ col_copy["aggregation"] = "none"
119
+ elif column_name and _has_special_characters(column_name):
120
+ # Column name has special characters (spaces, %, etc.), which typically indicates
121
+ # a calculated field or custom name. If it's not found in calc_field_sql_map,
122
+ # it might be a calculated field defined elsewhere.
123
+ # Use the column name as a raw expression without table qualification.
124
+ # Escape double quotes to prevent SQL injection
125
+ safe_name = column_name.replace('"', '""')
126
+ col_copy["sql_expression"] = f'"{safe_name}"'
127
+ col_copy["_has_aggregation"] = False
128
+
129
+ modified_columns.append(col_copy)
130
+
131
+ return modified_columns, has_window_function
132
+
133
+
134
+ def _apply_calculated_fields_to_filters(
135
+ filters: list[dict[str, Any]],
136
+ calc_field_sql_map: dict[str, tuple[str, bool]],
137
+ ) -> list[dict[str, Any]]:
138
+ """Replace calculated field references in filters with SQL expressions.
139
+
140
+ When a filter references a calculated field, we need to use the SQL expression
141
+ instead of the field name.
142
+
143
+ Args:
144
+ filters: Filter definitions from query
145
+ calc_field_sql_map: Mapping of field names to (SQL expression, has_aggregation) tuples
146
+
147
+ Returns:
148
+ Modified filter definitions with sql_expression added where needed
149
+ """
150
+ if not filters:
151
+ return filters
152
+
153
+ modified_filters = []
154
+ for f in filters:
155
+ f_copy = f.copy()
156
+ column = f.get("column", "")
157
+
158
+ if column in calc_field_sql_map:
159
+ expr, _ = calc_field_sql_map[column]
160
+ f_copy["sql_expression"] = expr
161
+
162
+ modified_filters.append(f_copy)
163
+
164
+ return modified_filters
165
+
166
+
167
+ def preprocess_calculated_fields(
168
+ query: dict[str, Any],
169
+ base_table_name: str | None = None,
170
+ ) -> dict[str, Any]:
171
+ """Preprocess a query dict to resolve calculated fields.
172
+
173
+ This is the main entry point for calculated field processing. It:
174
+ 1. Resolves calculated field expressions to SQL using resolve_calculated_fields()
175
+ 2. Applies the resolved expressions to columns and filters
176
+ 3. Handles window function conflicts by converting regular aggregations
177
+
178
+ Args:
179
+ query: Query dict with columns, filters, calculated_fields, etc.
180
+ base_table_name: Optional base table name to prefix unqualified column references.
181
+ If not provided, will be extracted from the first table in the query.
182
+
183
+ Returns:
184
+ Modified query dict with calculated fields resolved to sql_expression fields.
185
+ The original query is not mutated.
186
+ """
187
+ calculated_fields = query.get("calculated_fields", [])
188
+
189
+ # Extract base table name from first table if not provided
190
+ if base_table_name is None:
191
+ tables = query.get("tables", [])
192
+ if tables and isinstance(tables[0], dict):
193
+ base_table_name = tables[0].get("table_id") or tables[0].get("name")
194
+ elif tables and isinstance(tables[0], str):
195
+ base_table_name = tables[0]
196
+
197
+ # Resolve calculated field expressions to SQL
198
+ calc_field_sql_map: dict[str, tuple[str, bool]] = {}
199
+ if calculated_fields:
200
+ try:
201
+ calc_field_sql_map = resolve_calculated_fields(
202
+ query_columns=query.get("columns", []),
203
+ calculated_fields=calculated_fields,
204
+ base_table_name=base_table_name,
205
+ )
206
+ except Exception:
207
+ # Continue without calculated fields rather than failing
208
+ # Caller can handle logging if needed
209
+ calc_field_sql_map = {}
210
+
211
+ # Always process columns and filters to handle calculated field references
212
+ # Even if calc_field_sql_map is empty, some columns might have names with spaces
213
+ # that need special handling (e.g., calculated fields not found in current widget)
214
+ result = query.copy()
215
+
216
+ columns, uses_window_functions = _apply_calculated_fields_to_columns(
217
+ query.get("columns", []), calc_field_sql_map
218
+ )
219
+ result["columns"] = columns
220
+
221
+ # If window functions are used, clear GROUP BY to avoid conflicts
222
+ # Window functions (OVER ()) operate on all rows and don't need grouping
223
+ if uses_window_functions:
224
+ result["group_by"] = []
225
+
226
+ # Also apply calculated fields to filters
227
+ result["filters"] = _apply_calculated_fields_to_filters(
228
+ query.get("filters", []), calc_field_sql_map
229
+ )
230
+
231
+ return result