aetherdialect 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- aetherdialect-0.1.0.dist-info/METADATA +197 -0
- aetherdialect-0.1.0.dist-info/RECORD +34 -0
- aetherdialect-0.1.0.dist-info/WHEEL +5 -0
- aetherdialect-0.1.0.dist-info/licenses/LICENSE +7 -0
- aetherdialect-0.1.0.dist-info/top_level.txt +1 -0
- text2sql/__init__.py +7 -0
- text2sql/config.py +1063 -0
- text2sql/contracts_base.py +952 -0
- text2sql/contracts_core.py +1890 -0
- text2sql/core_utils.py +834 -0
- text2sql/dialect.py +1134 -0
- text2sql/expansion_ops.py +1218 -0
- text2sql/expansion_rules.py +496 -0
- text2sql/intent_expr.py +1759 -0
- text2sql/intent_process.py +2133 -0
- text2sql/intent_repair.py +1733 -0
- text2sql/intent_resolve.py +1292 -0
- text2sql/live_testing.py +1117 -0
- text2sql/main_execution.py +799 -0
- text2sql/pipeline.py +1662 -0
- text2sql/qsim_ops.py +1286 -0
- text2sql/qsim_sample.py +609 -0
- text2sql/qsim_struct.py +569 -0
- text2sql/schema.py +973 -0
- text2sql/schema_profiling.py +2075 -0
- text2sql/simulator.py +970 -0
- text2sql/sql_gen.py +1537 -0
- text2sql/templates.py +1037 -0
- text2sql/text2sql.py +726 -0
- text2sql/utils.py +973 -0
- text2sql/validation_agg.py +1033 -0
- text2sql/validation_execute.py +1092 -0
- text2sql/validation_schema.py +1847 -0
- text2sql/validation_semantic.py +2122 -0
text2sql/dialect.py
ADDED
|
@@ -0,0 +1,1134 @@
|
|
|
1
|
+
"""Database dialect abstraction for SQL validation and introspection.
|
|
2
|
+
|
|
3
|
+
Provides AST-based and EXPLAIN-based SQL validation, join-pair extraction, CTE body parsing, and enum reflection for PostgreSQL and Databricks.
|
|
4
|
+
|
|
5
|
+
The module conditionally imports pglast for PostgreSQL and exposes a Dialect base class with dialect-specific implementations.
|
|
6
|
+
|
|
7
|
+
A factory function returns the active dialect based on EngineConfig.
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
from __future__ import annotations
|
|
11
|
+
|
|
12
|
+
import re
|
|
13
|
+
from typing import Any
|
|
14
|
+
|
|
15
|
+
from .config import DATABRICKS_TABLE_QUALIFY_SKIP_IDENTIFIERS, EngineConfig
|
|
16
|
+
from .core_utils import canonicalize_sql, debug
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def _extract_join_pairs_sqlglot(sql: str) -> set[tuple[str, str]]:
|
|
20
|
+
"""Extract equality join pairs from Spark SQL using sqlglot AST."""
|
|
21
|
+
import sqlglot
|
|
22
|
+
from sqlglot import exp
|
|
23
|
+
|
|
24
|
+
pairs: set[tuple[str, str]] = set()
|
|
25
|
+
try:
|
|
26
|
+
parsed = sqlglot.parse_one(sql, dialect="spark")
|
|
27
|
+
except Exception:
|
|
28
|
+
return pairs
|
|
29
|
+
|
|
30
|
+
alias_to_table: dict[str, str] = {}
|
|
31
|
+
for table in parsed.find_all(exp.Table):
|
|
32
|
+
alias = table.alias_or_name
|
|
33
|
+
name = table.name
|
|
34
|
+
alias_to_table[alias] = name
|
|
35
|
+
alias_to_table[name] = name
|
|
36
|
+
|
|
37
|
+
def _resolve(table_or_alias: str) -> str:
|
|
38
|
+
return alias_to_table.get(table_or_alias, table_or_alias)
|
|
39
|
+
|
|
40
|
+
def _add_pair(left: exp.Column, right: exp.Column) -> None:
|
|
41
|
+
lt = left.table or ""
|
|
42
|
+
lc = left.name or ""
|
|
43
|
+
rt = right.table or ""
|
|
44
|
+
rc = right.name or ""
|
|
45
|
+
if lt and lc and rt and rc:
|
|
46
|
+
la = _resolve(lt)
|
|
47
|
+
ra = _resolve(rt)
|
|
48
|
+
pair_a = f"{la}.{lc}"
|
|
49
|
+
pair_b = f"{ra}.{rc}"
|
|
50
|
+
pairs.add(tuple(sorted([pair_a, pair_b])))
|
|
51
|
+
|
|
52
|
+
for join in parsed.find_all(exp.Join):
|
|
53
|
+
on_expr = join.args.get("on")
|
|
54
|
+
if on_expr is None:
|
|
55
|
+
continue
|
|
56
|
+
for eq in on_expr.find_all(exp.EQ):
|
|
57
|
+
left = eq.this
|
|
58
|
+
right = eq.expression
|
|
59
|
+
if isinstance(left, exp.Column) and isinstance(right, exp.Column):
|
|
60
|
+
_add_pair(left, right)
|
|
61
|
+
|
|
62
|
+
return pairs
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
def _extract_cte_bodies_sqlglot(sql: str) -> dict[str, str]:
|
|
66
|
+
"""Extract CTE name to body SQL mapping from Spark SQL using sqlglot."""
|
|
67
|
+
import sqlglot
|
|
68
|
+
from sqlglot import exp
|
|
69
|
+
|
|
70
|
+
cte_bodies: dict[str, str] = {}
|
|
71
|
+
try:
|
|
72
|
+
parsed = sqlglot.parse_one(sql, dialect="spark")
|
|
73
|
+
except Exception:
|
|
74
|
+
return cte_bodies
|
|
75
|
+
|
|
76
|
+
for cte in parsed.find_all(exp.CTE):
|
|
77
|
+
alias = cte.alias
|
|
78
|
+
if alias and cte.this:
|
|
79
|
+
cte_bodies[alias.lower()] = cte.this.sql(dialect="spark")
|
|
80
|
+
|
|
81
|
+
return cte_bodies
|
|
82
|
+
|
|
83
|
+
if EngineConfig.TYPE == "postgresql":
|
|
84
|
+
|
|
85
|
+
def _strip_schema(ident: str) -> str:
|
|
86
|
+
"""Strip schema prefix from identifier and return lowercase table name."""
|
|
87
|
+
s = (ident or "").strip().lower()
|
|
88
|
+
if "." in s:
|
|
89
|
+
s = s.split(".")[-1]
|
|
90
|
+
return s
|
|
91
|
+
|
|
92
|
+
def _colref_to_parts(n: Any) -> tuple[str, str] | None:
|
|
93
|
+
"""Parse AST ColumnRef node into a (table, column) tuple.
|
|
94
|
+
|
|
95
|
+
Args:
|
|
96
|
+
|
|
97
|
+
n: An AST node, expected to be a ColumnRef with one to three String fields.
|
|
98
|
+
|
|
99
|
+
Returns:
|
|
100
|
+
|
|
101
|
+
Tuple of (table_alias_or_name, column_name), or None if the node is not a valid ColumnRef or contains unsupported field types.
|
|
102
|
+
"""
|
|
103
|
+
tag = getattr(n, "__class__", type("x", (), {})).__name__
|
|
104
|
+
if tag != "ColumnRef":
|
|
105
|
+
return None
|
|
106
|
+
fields = getattr(n, "fields", None)
|
|
107
|
+
if not isinstance(fields, list | tuple) or not fields:
|
|
108
|
+
return None
|
|
109
|
+
parts: list[str] = []
|
|
110
|
+
for f in fields:
|
|
111
|
+
ftag = getattr(f, "__class__", type("x", (), {})).__name__
|
|
112
|
+
if ftag != "String":
|
|
113
|
+
return None
|
|
114
|
+
sval = getattr(f, "sval", None)
|
|
115
|
+
if isinstance(sval, str) and sval:
|
|
116
|
+
parts.append(sval)
|
|
117
|
+
if len(parts) == 1:
|
|
118
|
+
return ("", _strip_schema(parts[0]))
|
|
119
|
+
if len(parts) == 2:
|
|
120
|
+
return (_strip_schema(parts[0]), _strip_schema(parts[1]))
|
|
121
|
+
if len(parts) == 3:
|
|
122
|
+
return (_strip_schema(parts[1]), _strip_schema(parts[2]))
|
|
123
|
+
return None
|
|
124
|
+
|
|
125
|
+
def _extract_eq_pairs(expr: Any, alias_to_table: dict[str, str], out: set[tuple[str, str]]) -> bool:
|
|
126
|
+
"""Recursively extract equality join pairs from an AST expression.
|
|
127
|
+
|
|
128
|
+
Args:
|
|
129
|
+
|
|
130
|
+
expr: AST expression node (A_Expr or BoolExpr).
|
|
131
|
+
alias_to_table: Mapping from table alias to canonical table name.
|
|
132
|
+
out: Set to accumulate discovered (tableA.col, tableB.col) join pairs.
|
|
133
|
+
|
|
134
|
+
Returns:
|
|
135
|
+
|
|
136
|
+
True if the expression was fully traversed without unsupported constructs, False if a non-equality or disallowed expression was found.
|
|
137
|
+
"""
|
|
138
|
+
tag = getattr(expr, "__class__", type("x", (), {})).__name__
|
|
139
|
+
if tag == "BoolExpr":
|
|
140
|
+
if getattr(expr, "boolop", None) != 0:
|
|
141
|
+
return False
|
|
142
|
+
args = getattr(expr, "args", None)
|
|
143
|
+
if not isinstance(args, list):
|
|
144
|
+
return False
|
|
145
|
+
for a in args:
|
|
146
|
+
if not _extract_eq_pairs(a, alias_to_table, out):
|
|
147
|
+
return False
|
|
148
|
+
return True
|
|
149
|
+
if tag != "A_Expr":
|
|
150
|
+
return True
|
|
151
|
+
k = getattr(expr, "kind", None)
|
|
152
|
+
if isinstance(k, int):
|
|
153
|
+
if k != 0:
|
|
154
|
+
return False
|
|
155
|
+
else:
|
|
156
|
+
if getattr(k, "name", "") != "AEXPR_OP":
|
|
157
|
+
return False
|
|
158
|
+
name = getattr(expr, "name", None)
|
|
159
|
+
if not isinstance(name, list | tuple) or len(name) != 1:
|
|
160
|
+
return False
|
|
161
|
+
op = name[0]
|
|
162
|
+
if getattr(op, "__class__", type("x", (), {})).__name__ != "String":
|
|
163
|
+
return False
|
|
164
|
+
op_str = getattr(op, "sval", None)
|
|
165
|
+
if op_str != "=":
|
|
166
|
+
return True
|
|
167
|
+
lhs = _colref_to_parts(getattr(expr, "lexpr", None))
|
|
168
|
+
r = _colref_to_parts(getattr(expr, "rexpr", None))
|
|
169
|
+
if not lhs or not r:
|
|
170
|
+
return True
|
|
171
|
+
lt, lc = lhs
|
|
172
|
+
rt, rc = r
|
|
173
|
+
lt = alias_to_table.get(lt, lt) if lt else ""
|
|
174
|
+
rt = alias_to_table.get(rt, rt) if rt else ""
|
|
175
|
+
if not lt or not rt:
|
|
176
|
+
return True
|
|
177
|
+
if not lc or not rc:
|
|
178
|
+
return True
|
|
179
|
+
a = f"{_strip_schema(lt)}.{lc}"
|
|
180
|
+
b = f"{_strip_schema(rt)}.{rc}"
|
|
181
|
+
out.add(tuple(sorted((a, b))))
|
|
182
|
+
return True
|
|
183
|
+
|
|
184
|
+
def _collect_from_items(
|
|
185
|
+
fr: Any,
|
|
186
|
+
) -> tuple[bool, dict[str, str], bool, bool, bool, bool, bool]:
|
|
187
|
+
"""Collect table aliases and detect forbidden structures from a FROM clause.
|
|
188
|
+
|
|
189
|
+
Args:
|
|
190
|
+
|
|
191
|
+
fr: FROM clause node or list of FROM clause items from the AST.
|
|
192
|
+
|
|
193
|
+
Returns:
|
|
194
|
+
|
|
195
|
+
Tuple of (ok, alias_to_table, has_subquery, has_using, has_cross_join, has_self_join, has_items) where ok is False if an unsupported node type was encountered.
|
|
196
|
+
"""
|
|
197
|
+
alias_to_table: dict[str, str] = {}
|
|
198
|
+
has_subquery = False
|
|
199
|
+
has_using = False
|
|
200
|
+
has_cross_join = False
|
|
201
|
+
has_self_join = False
|
|
202
|
+
seen_tables: set[str] = set()
|
|
203
|
+
ok = True
|
|
204
|
+
|
|
205
|
+
def add_alias(relname: str, alias: Any):
|
|
206
|
+
nonlocal alias_to_table, has_self_join, seen_tables
|
|
207
|
+
t = _strip_schema(relname)
|
|
208
|
+
if t in seen_tables:
|
|
209
|
+
has_self_join = True
|
|
210
|
+
seen_tables.add(t)
|
|
211
|
+
if alias is None:
|
|
212
|
+
alias_to_table[t] = t
|
|
213
|
+
return
|
|
214
|
+
an = getattr(alias, "aliasname", None)
|
|
215
|
+
if isinstance(an, str) and an:
|
|
216
|
+
alias_to_table[_strip_schema(an)] = t
|
|
217
|
+
alias_to_table[t] = t
|
|
218
|
+
|
|
219
|
+
def walk(item: Any):
|
|
220
|
+
nonlocal has_subquery, has_using, has_cross_join, ok
|
|
221
|
+
if item is None:
|
|
222
|
+
ok = False
|
|
223
|
+
return False
|
|
224
|
+
tag = getattr(item, "__class__", type("x", (), {})).__name__
|
|
225
|
+
if tag == "RangeVar":
|
|
226
|
+
add_alias(getattr(item, "relname", "") or "", getattr(item, "alias", None))
|
|
227
|
+
return True
|
|
228
|
+
if tag == "JoinExpr":
|
|
229
|
+
if getattr(item, "usingClause", None) is not None or getattr(item, "isNatural", False):
|
|
230
|
+
has_using = True
|
|
231
|
+
join_type = getattr(item, "jointype", None)
|
|
232
|
+
if join_type is not None and str(join_type) == "JoinType.JOIN_INNER":
|
|
233
|
+
quals = getattr(item, "quals", None)
|
|
234
|
+
if quals is None:
|
|
235
|
+
has_cross_join = True
|
|
236
|
+
if not walk(getattr(item, "larg", None)):
|
|
237
|
+
return False
|
|
238
|
+
if not walk(getattr(item, "rarg", None)):
|
|
239
|
+
return False
|
|
240
|
+
return True
|
|
241
|
+
if tag in {
|
|
242
|
+
"RangeSubselect",
|
|
243
|
+
"RangeFunction",
|
|
244
|
+
"RangeTableFunc",
|
|
245
|
+
"RangeTableSample",
|
|
246
|
+
}:
|
|
247
|
+
has_subquery = True
|
|
248
|
+
ok = False
|
|
249
|
+
return False
|
|
250
|
+
ok = False
|
|
251
|
+
return False
|
|
252
|
+
|
|
253
|
+
if fr is None:
|
|
254
|
+
return False, {}, False, False, False, False, False
|
|
255
|
+
for it in fr if isinstance(fr, list | tuple) else [fr]:
|
|
256
|
+
if not walk(it):
|
|
257
|
+
ok = False
|
|
258
|
+
break
|
|
259
|
+
return (
|
|
260
|
+
ok,
|
|
261
|
+
alias_to_table,
|
|
262
|
+
has_subquery,
|
|
263
|
+
has_using,
|
|
264
|
+
has_cross_join,
|
|
265
|
+
has_self_join,
|
|
266
|
+
True,
|
|
267
|
+
)
|
|
268
|
+
return (
|
|
269
|
+
ok,
|
|
270
|
+
alias_to_table,
|
|
271
|
+
has_subquery,
|
|
272
|
+
has_using,
|
|
273
|
+
has_cross_join,
|
|
274
|
+
has_self_join,
|
|
275
|
+
True,
|
|
276
|
+
)
|
|
277
|
+
|
|
278
|
+
def _normalized_join_pairs_from_sql_ast(
|
|
279
|
+
sql: str,
|
|
280
|
+
) -> tuple[bool, set[tuple[str, str]]]:
|
|
281
|
+
"""Extract normalized join pairs from SQL using AST parsing and handle CTEs.
|
|
282
|
+
|
|
283
|
+
Args:
|
|
284
|
+
|
|
285
|
+
sql: SQL query string to parse.
|
|
286
|
+
|
|
287
|
+
Returns:
|
|
288
|
+
|
|
289
|
+
Tuple of (ok, pairs) where ok is False if parsing failed or a forbidden construct was found, and pairs is the set of (tableA.col, tableB.col) tuples sorted so the lexicographically smaller element comes first.
|
|
290
|
+
"""
|
|
291
|
+
s = canonicalize_sql(sql)
|
|
292
|
+
if not s:
|
|
293
|
+
debug("[dialect.normalized_join_pairs_from_sql_ast] empty_sql")
|
|
294
|
+
return False, set()
|
|
295
|
+
try:
|
|
296
|
+
import pglast
|
|
297
|
+
except ImportError:
|
|
298
|
+
debug("[dialect.normalized_join_pairs_from_sql_ast] pglast_unavailable")
|
|
299
|
+
return False, set()
|
|
300
|
+
parse_sql = getattr(pglast, "parse_sql", None)
|
|
301
|
+
if parse_sql is None:
|
|
302
|
+
debug("[dialect.normalized_join_pairs_from_sql_ast] pglast_unavailable")
|
|
303
|
+
return False, set()
|
|
304
|
+
try:
|
|
305
|
+
stmts = parse_sql(s)
|
|
306
|
+
except Exception as e:
|
|
307
|
+
debug(f"[dialect.normalized_join_pairs_from_sql_ast] parse_failed: {e}")
|
|
308
|
+
return False, set()
|
|
309
|
+
if not stmts:
|
|
310
|
+
debug("[dialect.normalized_join_pairs_from_sql_ast] no_statements")
|
|
311
|
+
return False, set()
|
|
312
|
+
first = stmts[0]
|
|
313
|
+
root = getattr(first, "stmt", None)
|
|
314
|
+
if root is None:
|
|
315
|
+
debug("[dialect.normalized_join_pairs_from_sql_ast] no_stmt_attr")
|
|
316
|
+
return False, set()
|
|
317
|
+
if getattr(root, "__class__", type("x", (), {})).__name__ != "SelectStmt":
|
|
318
|
+
debug("[dialect.normalized_join_pairs_from_sql_ast] not_select_stmt")
|
|
319
|
+
return False, set()
|
|
320
|
+
|
|
321
|
+
with_clause = getattr(root, "withClause", None)
|
|
322
|
+
cte_names = set()
|
|
323
|
+
all_pairs: set[tuple[str, str]] = set()
|
|
324
|
+
|
|
325
|
+
if with_clause is not None:
|
|
326
|
+
ctes = getattr(with_clause, "ctes", [])
|
|
327
|
+
for cte in ctes:
|
|
328
|
+
cte_name = getattr(getattr(cte, "ctename", None), "sval", None) or getattr(cte, "ctename", "")
|
|
329
|
+
if cte_name:
|
|
330
|
+
cte_names.add(cte_name.lower())
|
|
331
|
+
|
|
332
|
+
cte_query = getattr(cte, "ctequery", None)
|
|
333
|
+
if cte_query is None:
|
|
334
|
+
continue
|
|
335
|
+
|
|
336
|
+
cte_from = getattr(cte_query, "fromClause", None)
|
|
337
|
+
ok_cte_from, cte_alias_map, has_subq, has_using, _, _, _ = _collect_from_items(cte_from)
|
|
338
|
+
if not ok_cte_from or has_subq or has_using:
|
|
339
|
+
return False, set()
|
|
340
|
+
|
|
341
|
+
def walk_cte_joins(item: Any, alias_map: dict[str, str], pairs: set[tuple[str, str]]) -> bool:
|
|
342
|
+
if getattr(item, "__class__", type("x", (), {})).__name__ == "JoinExpr":
|
|
343
|
+
if getattr(item, "quals", None) is None:
|
|
344
|
+
return False
|
|
345
|
+
if not _extract_eq_pairs(getattr(item, "quals", None), alias_map, pairs):
|
|
346
|
+
return False
|
|
347
|
+
if not walk_cte_joins(getattr(item, "larg", None), alias_map, pairs):
|
|
348
|
+
return False
|
|
349
|
+
if not walk_cte_joins(getattr(item, "rarg", None), alias_map, pairs):
|
|
350
|
+
return False
|
|
351
|
+
return True
|
|
352
|
+
|
|
353
|
+
if cte_from:
|
|
354
|
+
for it in cte_from if isinstance(cte_from, list | tuple) else [cte_from]:
|
|
355
|
+
if not walk_cte_joins(it, cte_alias_map, all_pairs):
|
|
356
|
+
return False, set()
|
|
357
|
+
|
|
358
|
+
cte_where = getattr(cte_query, "whereClause", None)
|
|
359
|
+
if cte_where is not None:
|
|
360
|
+
tmp: set[tuple[str, str]] = set()
|
|
361
|
+
if not _extract_eq_pairs(cte_where, cte_alias_map, tmp):
|
|
362
|
+
return False, set()
|
|
363
|
+
all_pairs |= tmp
|
|
364
|
+
|
|
365
|
+
fr = getattr(root, "fromClause", None)
|
|
366
|
+
ok_from, alias_to_table, has_subq, has_using, _, _, _ = _collect_from_items(fr)
|
|
367
|
+
if not ok_from:
|
|
368
|
+
debug("[dialect.normalized_join_pairs_from_sql_ast] collect_from_failed")
|
|
369
|
+
return False, set()
|
|
370
|
+
if has_subq:
|
|
371
|
+
debug("[dialect.normalized_join_pairs_from_sql_ast] has_subquery")
|
|
372
|
+
return False, set()
|
|
373
|
+
if has_using:
|
|
374
|
+
debug("[dialect.normalized_join_pairs_from_sql_ast] has_using")
|
|
375
|
+
return False, set()
|
|
376
|
+
|
|
377
|
+
def walk_join_ons(item: Any) -> bool:
|
|
378
|
+
if getattr(item, "__class__", type("x", (), {})).__name__ == "JoinExpr":
|
|
379
|
+
quals = getattr(item, "quals", None)
|
|
380
|
+
if quals is None:
|
|
381
|
+
debug("[dialect.normalized_join_pairs_from_sql_ast] missing_quals")
|
|
382
|
+
return False
|
|
383
|
+
if not _extract_eq_pairs(quals, alias_to_table, all_pairs):
|
|
384
|
+
debug("[dialect.normalized_join_pairs_from_sql_ast] extract_pairs_failed")
|
|
385
|
+
return False
|
|
386
|
+
if not walk_join_ons(getattr(item, "larg", None)):
|
|
387
|
+
return False
|
|
388
|
+
if not walk_join_ons(getattr(item, "rarg", None)):
|
|
389
|
+
return False
|
|
390
|
+
return True
|
|
391
|
+
|
|
392
|
+
for it in fr if isinstance(fr, list | tuple) else []:
|
|
393
|
+
if not walk_join_ons(it):
|
|
394
|
+
return False, set()
|
|
395
|
+
|
|
396
|
+
where = getattr(root, "whereClause", None)
|
|
397
|
+
if where is not None:
|
|
398
|
+
tmp: set[tuple[str, str]] = set()
|
|
399
|
+
if not _extract_eq_pairs(where, alias_to_table, tmp):
|
|
400
|
+
debug("[dialect.normalized_join_pairs_from_sql_ast] where_extract_failed")
|
|
401
|
+
return False, set()
|
|
402
|
+
all_pairs |= tmp
|
|
403
|
+
|
|
404
|
+
if getattr(root, "groupClause", None) is not None:
|
|
405
|
+
for g in getattr(root, "groupClause", []) or []:
|
|
406
|
+
if getattr(g, "__class__", type("x", (), {})).__name__ == "SelectStmt":
|
|
407
|
+
return False, set()
|
|
408
|
+
if getattr(root, "sortClause", None) is not None:
|
|
409
|
+
for so in getattr(root, "sortClause", []) or []:
|
|
410
|
+
if getattr(so, "__class__", type("x", (), {})).__name__ == "SelectStmt":
|
|
411
|
+
return False, set()
|
|
412
|
+
return True, all_pairs
|
|
413
|
+
|
|
414
|
+
def _validate_cte_bodies(with_clause: Any) -> tuple[bool, str]:
|
|
415
|
+
"""Validate CTE bodies against structural restrictions.
|
|
416
|
+
|
|
417
|
+
Forbids recursive CTEs, subqueries, window functions, set operations, and CASE expressions inside any CTE body.
|
|
418
|
+
|
|
419
|
+
Args:
|
|
420
|
+
|
|
421
|
+
with_clause: AST WithClause node or None.
|
|
422
|
+
|
|
423
|
+
Returns:
|
|
424
|
+
|
|
425
|
+
Tuple of (ok, error_code). ok is True when all CTEs pass validation. error_code is an empty string on success or a short string token describing the first violation found.
|
|
426
|
+
"""
|
|
427
|
+
if with_clause is None:
|
|
428
|
+
return True, ""
|
|
429
|
+
|
|
430
|
+
if getattr(with_clause, "recursive", False):
|
|
431
|
+
return False, "cte_recursive"
|
|
432
|
+
|
|
433
|
+
ctes = getattr(with_clause, "ctes", [])
|
|
434
|
+
if not ctes:
|
|
435
|
+
return True, ""
|
|
436
|
+
|
|
437
|
+
for cte in ctes:
|
|
438
|
+
cte_query = getattr(cte, "ctequery", None)
|
|
439
|
+
if cte_query is None:
|
|
440
|
+
return False, "cte_malformed"
|
|
441
|
+
|
|
442
|
+
def walk_cte(n):
|
|
443
|
+
tag = getattr(n, "__class__", type("x", (), {})).__name__
|
|
444
|
+
if tag in {"RangeSubselect", "SubLink"}:
|
|
445
|
+
if tag == "SubLink":
|
|
446
|
+
sublink_type = getattr(n, "subLinkType", None)
|
|
447
|
+
if sublink_type is not None and sublink_type == 0:
|
|
448
|
+
return "cte_contains_exists"
|
|
449
|
+
return "cte_contains_subquery"
|
|
450
|
+
if tag == "SetOperationStmt":
|
|
451
|
+
return "cte_contains_set_op"
|
|
452
|
+
if tag == "WindowDef":
|
|
453
|
+
return "cte_contains_window"
|
|
454
|
+
if tag == "CaseExpr":
|
|
455
|
+
return "cte_contains_case"
|
|
456
|
+
try:
|
|
457
|
+
attrs = vars(n)
|
|
458
|
+
except TypeError:
|
|
459
|
+
return None
|
|
460
|
+
for attr in attrs.values():
|
|
461
|
+
if isinstance(attr, list):
|
|
462
|
+
for x in attr:
|
|
463
|
+
if hasattr(x, "__class__"):
|
|
464
|
+
err = walk_cte(x)
|
|
465
|
+
if err:
|
|
466
|
+
return err
|
|
467
|
+
elif hasattr(attr, "__class__"):
|
|
468
|
+
err = walk_cte(attr)
|
|
469
|
+
if err:
|
|
470
|
+
return err
|
|
471
|
+
return None
|
|
472
|
+
|
|
473
|
+
err = walk_cte(cte_query)
|
|
474
|
+
if err:
|
|
475
|
+
return False, err
|
|
476
|
+
|
|
477
|
+
return True, ""
|
|
478
|
+
|
|
479
|
+
def _ast_structural_valid(sql: str) -> tuple[bool, str]:
|
|
480
|
+
"""Validate SQL structure using the pglast AST.
|
|
481
|
+
|
|
482
|
+
Checks that the SQL is a single SELECT statement free of subqueries, window functions, CASE expressions, CROSS JOINs, self-joins, USING clauses, EXISTS, LATERAL, and set operations. Also validates any CTE bodies.
|
|
483
|
+
|
|
484
|
+
Args:
|
|
485
|
+
|
|
486
|
+
sql: SQL query string to validate.
|
|
487
|
+
|
|
488
|
+
Returns:
|
|
489
|
+
|
|
490
|
+
Tuple of (ok, error_code). ok is True when the SQL passes all structural checks. error_code is an empty string on success or a short token identifying the first violation.
|
|
491
|
+
"""
|
|
492
|
+
try:
|
|
493
|
+
import pglast
|
|
494
|
+
except ImportError:
|
|
495
|
+
return False, "ast_unavailable"
|
|
496
|
+
parse_sql = getattr(pglast, "parse_sql", None)
|
|
497
|
+
if parse_sql is None:
|
|
498
|
+
return False, "ast_unavailable"
|
|
499
|
+
|
|
500
|
+
try:
|
|
501
|
+
stmts = parse_sql(canonicalize_sql(sql))
|
|
502
|
+
except Exception:
|
|
503
|
+
return False, "ast_parse_failed"
|
|
504
|
+
|
|
505
|
+
if not stmts or len(stmts) != 1:
|
|
506
|
+
return False, "multiple_statements"
|
|
507
|
+
|
|
508
|
+
root = getattr(stmts[0], "stmt", None)
|
|
509
|
+
if root is None:
|
|
510
|
+
return False, "no_root"
|
|
511
|
+
|
|
512
|
+
if getattr(root, "__class__", type("x", (), {})).__name__ != "SelectStmt":
|
|
513
|
+
return False, "not_select"
|
|
514
|
+
|
|
515
|
+
with_clause = getattr(root, "withClause", None)
|
|
516
|
+
has_cte = with_clause is not None
|
|
517
|
+
|
|
518
|
+
if has_cte:
|
|
519
|
+
ok, err = _validate_cte_bodies(with_clause)
|
|
520
|
+
if not ok:
|
|
521
|
+
return False, err
|
|
522
|
+
|
|
523
|
+
fr = getattr(root, "fromClause", None)
|
|
524
|
+
if fr is not None:
|
|
525
|
+
_, _, has_subq, has_using, has_cross, has_self, _ = _collect_from_items(fr)
|
|
526
|
+
if has_subq:
|
|
527
|
+
return False, "subquery_not_allowed"
|
|
528
|
+
if has_using:
|
|
529
|
+
return False, "using_not_allowed"
|
|
530
|
+
if has_cross:
|
|
531
|
+
return False, "cross_join_not_allowed"
|
|
532
|
+
if has_self:
|
|
533
|
+
return False, "self_join_not_allowed"
|
|
534
|
+
|
|
535
|
+
has_window = False
|
|
536
|
+
has_case = False
|
|
537
|
+
has_exists = False
|
|
538
|
+
has_lateral = False
|
|
539
|
+
|
|
540
|
+
def walk(n, check_window=True):
|
|
541
|
+
nonlocal has_window, has_case, has_exists, has_lateral
|
|
542
|
+
tag = getattr(n, "__class__", type("x", (), {})).__name__
|
|
543
|
+
if tag in {
|
|
544
|
+
"RangeSubselect",
|
|
545
|
+
"SubLink",
|
|
546
|
+
"SetOperationStmt",
|
|
547
|
+
}:
|
|
548
|
+
if tag == "SubLink":
|
|
549
|
+
sublink_type = getattr(n, "subLinkType", None)
|
|
550
|
+
if sublink_type is not None and sublink_type == 0:
|
|
551
|
+
has_exists = True
|
|
552
|
+
return False
|
|
553
|
+
if tag == "CaseExpr":
|
|
554
|
+
has_case = True
|
|
555
|
+
return False
|
|
556
|
+
if tag == "RangeFunction":
|
|
557
|
+
is_lateral = getattr(n, "lateral", False)
|
|
558
|
+
if is_lateral:
|
|
559
|
+
has_lateral = True
|
|
560
|
+
return False
|
|
561
|
+
if check_window and tag == "WindowDef":
|
|
562
|
+
has_window = True
|
|
563
|
+
return False
|
|
564
|
+
if check_window and tag == "FuncCall":
|
|
565
|
+
over = getattr(n, "over", None)
|
|
566
|
+
if over is not None:
|
|
567
|
+
has_window = True
|
|
568
|
+
return False
|
|
569
|
+
|
|
570
|
+
try:
|
|
571
|
+
attrs = vars(n)
|
|
572
|
+
except TypeError:
|
|
573
|
+
return True
|
|
574
|
+
|
|
575
|
+
for attr in attrs.values():
|
|
576
|
+
if isinstance(attr, list):
|
|
577
|
+
for x in attr:
|
|
578
|
+
if hasattr(x, "__class__") and not walk(x, check_window):
|
|
579
|
+
return False
|
|
580
|
+
elif hasattr(attr, "__class__"):
|
|
581
|
+
if not walk(attr, check_window):
|
|
582
|
+
return False
|
|
583
|
+
return True
|
|
584
|
+
|
|
585
|
+
if not walk(root):
|
|
586
|
+
if has_window:
|
|
587
|
+
return False, "window_not_allowed"
|
|
588
|
+
if has_case:
|
|
589
|
+
return False, "case_when_not_allowed"
|
|
590
|
+
if has_exists:
|
|
591
|
+
return False, "exists_not_allowed"
|
|
592
|
+
if has_lateral:
|
|
593
|
+
return False, "lateral_not_allowed"
|
|
594
|
+
return False, "forbidden_structure"
|
|
595
|
+
|
|
596
|
+
return True, ""
|
|
597
|
+
|
|
598
|
+
|
|
599
|
+
class Dialect:
|
|
600
|
+
"""Base dialect interface for database-specific operations."""
|
|
601
|
+
|
|
602
|
+
name: str = "base"
|
|
603
|
+
|
|
604
|
+
def __init__(self, config):
|
|
605
|
+
"""Initialize the dialect with a runtime configuration object.
|
|
606
|
+
|
|
607
|
+
Args:
|
|
608
|
+
|
|
609
|
+
config: Runtime config instance, such as PostgresRuntimeConfig or DatabricksRuntimeConfig, providing connection details.
|
|
610
|
+
"""
|
|
611
|
+
self.config = config
|
|
612
|
+
|
|
613
|
+
def ast_validate(self, sql: str) -> tuple[bool, str]:
|
|
614
|
+
"""Validate SQL structure via AST parsing without a database connection.
|
|
615
|
+
|
|
616
|
+
Args:
|
|
617
|
+
|
|
618
|
+
sql: SQL query string to validate.
|
|
619
|
+
|
|
620
|
+
Returns:
|
|
621
|
+
|
|
622
|
+
Tuple of (ok, error_code). ok is True when the SQL passes structural checks.
|
|
623
|
+
"""
|
|
624
|
+
raise NotImplementedError
|
|
625
|
+
|
|
626
|
+
def extract_join_pairs(self, sql: str) -> tuple[bool, set[tuple[str, str]]]:
|
|
627
|
+
"""Extract normalized join column pairs from the SQL AST.
|
|
628
|
+
|
|
629
|
+
Args:
|
|
630
|
+
|
|
631
|
+
sql: SQL query string to parse.
|
|
632
|
+
|
|
633
|
+
Returns:
|
|
634
|
+
|
|
635
|
+
Tuple of (ok, pairs) where pairs is a set of sorted (tableA.col, tableB.col) tuples found in ON conditions.
|
|
636
|
+
"""
|
|
637
|
+
raise NotImplementedError
|
|
638
|
+
|
|
639
|
+
def extract_cte_bodies(self, sql: str) -> dict[str, str]:
|
|
640
|
+
"""Extract the SQL body for each named CTE from a WITH clause.
|
|
641
|
+
|
|
642
|
+
Args:
|
|
643
|
+
|
|
644
|
+
sql: SQL query string that may contain a WITH clause.
|
|
645
|
+
|
|
646
|
+
Returns:
|
|
647
|
+
|
|
648
|
+
Dictionary mapping lowercase CTE name to its inner SELECT SQL string.
|
|
649
|
+
"""
|
|
650
|
+
raise NotImplementedError
|
|
651
|
+
|
|
652
|
+
def explain_sql(self, engine: Any, sql: str, params: dict[str, Any] | None = None) -> tuple[bool, str]:
|
|
653
|
+
"""Test SQL executability by running EXPLAIN against the database.
|
|
654
|
+
|
|
655
|
+
Args:
|
|
656
|
+
|
|
657
|
+
engine: SQLAlchemy Engine connected to the target database.
|
|
658
|
+
sql: SQL query string to explain.
|
|
659
|
+
params: Optional bind-parameter values keyed by placeholder name.
|
|
660
|
+
|
|
661
|
+
Returns:
|
|
662
|
+
|
|
663
|
+
Tuple of (ok, error_message). ok is True when EXPLAIN succeeds.
|
|
664
|
+
"""
|
|
665
|
+
raise NotImplementedError
|
|
666
|
+
|
|
667
|
+
def reflect_enums(self, engine: Any, schema_name: str) -> dict[str, list[str]]:
|
|
668
|
+
"""Discover enum types and their allowed values from the database.
|
|
669
|
+
|
|
670
|
+
Args:
|
|
671
|
+
|
|
672
|
+
engine: SQLAlchemy Engine connected to the target database.
|
|
673
|
+
schema_name: Database schema name to introspect.
|
|
674
|
+
|
|
675
|
+
Returns:
|
|
676
|
+
|
|
677
|
+
Dictionary mapping enum type name to an ordered list of its values.
|
|
678
|
+
"""
|
|
679
|
+
raise NotImplementedError
|
|
680
|
+
|
|
681
|
+
|
|
682
|
+
class PostgresDialect(Dialect):
|
|
683
|
+
"""PostgreSQL-specific dialect implementation."""
|
|
684
|
+
|
|
685
|
+
name: str = "postgresql"
|
|
686
|
+
|
|
687
|
+
def __init__(self, config):
|
|
688
|
+
"""Initialize the PostgreSQL dialect and create a SQLAlchemy engine.
|
|
689
|
+
|
|
690
|
+
Args:
|
|
691
|
+
|
|
692
|
+
config: PostgresRuntimeConfig instance providing connection credentials.
|
|
693
|
+
"""
|
|
694
|
+
super().__init__(config)
|
|
695
|
+
from sqlalchemy import create_engine
|
|
696
|
+
|
|
697
|
+
self.engine = create_engine(config.db_url())
|
|
698
|
+
|
|
699
|
+
def ast_validate(self, sql: str) -> tuple[bool, str]:
|
|
700
|
+
"""Validate SQL structure via AST parsing."""
|
|
701
|
+
return _ast_structural_valid(sql)
|
|
702
|
+
|
|
703
|
+
def extract_join_pairs(self, sql: str) -> tuple[bool, set[tuple[str, str]]]:
|
|
704
|
+
"""Extract normalized join pairs from SQL AST."""
|
|
705
|
+
return _normalized_join_pairs_from_sql_ast(sql)
|
|
706
|
+
|
|
707
|
+
def extract_cte_bodies(self, sql: str) -> dict[str, str]:
|
|
708
|
+
"""Extract SQL body for each CTE from a WITH clause using the pglast AST."""
|
|
709
|
+
try:
|
|
710
|
+
import pglast
|
|
711
|
+
except ImportError:
|
|
712
|
+
return self._extract_cte_bodies_depth_tracking(sql)
|
|
713
|
+
parse_sql = getattr(pglast, "parse_sql", None)
|
|
714
|
+
if parse_sql is None:
|
|
715
|
+
return self._extract_cte_bodies_depth_tracking(sql)
|
|
716
|
+
cte_bodies: dict[str, str] = {}
|
|
717
|
+
try:
|
|
718
|
+
stmts = parse_sql(canonicalize_sql(sql))
|
|
719
|
+
if not stmts:
|
|
720
|
+
return self._extract_cte_bodies_depth_tracking(sql)
|
|
721
|
+
root = getattr(stmts[0], "stmt", None)
|
|
722
|
+
if root is None:
|
|
723
|
+
return self._extract_cte_bodies_depth_tracking(sql)
|
|
724
|
+
with_clause = getattr(root, "withClause", None)
|
|
725
|
+
if with_clause is None:
|
|
726
|
+
return cte_bodies
|
|
727
|
+
ctes = getattr(with_clause, "ctes", [])
|
|
728
|
+
for cte in ctes:
|
|
729
|
+
cte_name = getattr(getattr(cte, "ctename", None), "sval", None) or getattr(cte, "ctename", "")
|
|
730
|
+
cte_query = getattr(cte, "ctequery", None)
|
|
731
|
+
if cte_name and cte_query:
|
|
732
|
+
try:
|
|
733
|
+
from pglast import prettify
|
|
734
|
+
|
|
735
|
+
cte_sql = prettify(cte_query)
|
|
736
|
+
cte_bodies[cte_name.lower()] = cte_sql
|
|
737
|
+
except Exception as exc:
|
|
738
|
+
debug(f"[dialect.extract_cte_bodies] prettify failed for CTE '{cte_name}': {exc}")
|
|
739
|
+
if cte_bodies:
|
|
740
|
+
debug(
|
|
741
|
+
f"[dialect.extract_cte_bodies] pglast extracted {len(cte_bodies)} CTEs: {list(cte_bodies.keys())}"
|
|
742
|
+
)
|
|
743
|
+
return cte_bodies
|
|
744
|
+
except Exception as e:
|
|
745
|
+
debug(f"[dialect.extract_cte_bodies] pglast failed: {e}, falling back to depth-tracking")
|
|
746
|
+
return self._extract_cte_bodies_depth_tracking(sql)
|
|
747
|
+
|
|
748
|
+
def _extract_cte_bodies_depth_tracking(self, sql: str) -> dict[str, str]:
|
|
749
|
+
"""Extract CTE bodies using depth-tracking character parser."""
|
|
750
|
+
cte_bodies: dict[str, str] = {}
|
|
751
|
+
sql_upper = sql.upper()
|
|
752
|
+
if not sql_upper.strip().startswith("WITH"):
|
|
753
|
+
return cte_bodies
|
|
754
|
+
with_match = re.match(r"\s*WITH\s+", sql, re.IGNORECASE)
|
|
755
|
+
if not with_match:
|
|
756
|
+
return cte_bodies
|
|
757
|
+
pos = with_match.end()
|
|
758
|
+
while pos < len(sql):
|
|
759
|
+
while pos < len(sql) and sql[pos] in " \t\n\r,":
|
|
760
|
+
pos += 1
|
|
761
|
+
if pos >= len(sql):
|
|
762
|
+
break
|
|
763
|
+
name_match = re.match(r"(\w+)\s+AS\s*\(", sql[pos:], re.IGNORECASE)
|
|
764
|
+
if not name_match:
|
|
765
|
+
break
|
|
766
|
+
cte_name = name_match.group(1).lower()
|
|
767
|
+
pos += name_match.end()
|
|
768
|
+
start_pos = pos
|
|
769
|
+
depth = 1
|
|
770
|
+
in_string = False
|
|
771
|
+
string_char = None
|
|
772
|
+
while pos < len(sql) and depth > 0:
|
|
773
|
+
c = sql[pos]
|
|
774
|
+
if in_string:
|
|
775
|
+
if c == string_char and (pos + 1 >= len(sql) or sql[pos + 1] != string_char):
|
|
776
|
+
in_string = False
|
|
777
|
+
elif c == string_char and pos + 1 < len(sql) and sql[pos + 1] == string_char:
|
|
778
|
+
pos += 1
|
|
779
|
+
else:
|
|
780
|
+
if c in ("'", '"'):
|
|
781
|
+
in_string = True
|
|
782
|
+
string_char = c
|
|
783
|
+
elif c == "(":
|
|
784
|
+
depth += 1
|
|
785
|
+
elif c == ")":
|
|
786
|
+
depth -= 1
|
|
787
|
+
pos += 1
|
|
788
|
+
if depth == 0:
|
|
789
|
+
cte_sql = sql[start_pos : pos - 1].strip()
|
|
790
|
+
cte_bodies[cte_name] = cte_sql
|
|
791
|
+
rest = sql[pos:].strip().upper()
|
|
792
|
+
if rest.startswith("SELECT") or rest.startswith("(SELECT"):
|
|
793
|
+
break
|
|
794
|
+
debug(
|
|
795
|
+
f"[dialect.extract_cte_bodies] depth-tracking extracted {len(cte_bodies)} CTEs: {list(cte_bodies.keys())}"
|
|
796
|
+
)
|
|
797
|
+
return cte_bodies
|
|
798
|
+
|
|
799
|
+
def explain_sql(self, engine: Any, sql: str, params: dict[str, Any] | None = None) -> tuple[bool, str]:
|
|
800
|
+
"""Test SQL executability via PostgreSQL EXPLAIN."""
|
|
801
|
+
from sqlalchemy import text
|
|
802
|
+
|
|
803
|
+
try:
|
|
804
|
+
with engine.connect() as conn:
|
|
805
|
+
conn.execute(text(f"EXPLAIN {sql}"), params or {})
|
|
806
|
+
return True, ""
|
|
807
|
+
except Exception as e:
|
|
808
|
+
return False, str(e)
|
|
809
|
+
|
|
810
|
+
def reflect_enums(self, engine: Any, schema_name: str) -> dict[str, list[str]]:
|
|
811
|
+
"""Discover PostgreSQL enum types and their values."""
|
|
812
|
+
if EngineConfig.TYPE != "postgresql":
|
|
813
|
+
return {}
|
|
814
|
+
|
|
815
|
+
from sqlalchemy import text
|
|
816
|
+
|
|
817
|
+
enum_values: dict[str, list[str]] = {}
|
|
818
|
+
try:
|
|
819
|
+
with engine.connect() as conn:
|
|
820
|
+
enum_query = text("""
|
|
821
|
+
SELECT t.typname, e.enumlabel
|
|
822
|
+
FROM pg_type t
|
|
823
|
+
JOIN pg_enum e ON t.oid = e.enumtypid
|
|
824
|
+
JOIN pg_namespace n ON t.typnamespace = n.oid
|
|
825
|
+
WHERE n.nspname = :schema
|
|
826
|
+
ORDER BY t.typname, e.enumsortorder
|
|
827
|
+
""")
|
|
828
|
+
for row in conn.execute(enum_query, {"schema": schema_name}):
|
|
829
|
+
enum_name = row[0]
|
|
830
|
+
if enum_name not in enum_values:
|
|
831
|
+
enum_values[enum_name] = []
|
|
832
|
+
enum_values[enum_name].append(row[1])
|
|
833
|
+
except Exception as exc:
|
|
834
|
+
debug(f"[dialect.reflect_enums] failed to reflect enums: {exc}")
|
|
835
|
+
return enum_values
|
|
836
|
+
|
|
837
|
+
|
|
838
|
+
class DatabricksDialect(Dialect):
|
|
839
|
+
"""Databricks Spark dialect with Spark EXPLAIN validation."""
|
|
840
|
+
|
|
841
|
+
name: str = "databricks"
|
|
842
|
+
|
|
843
|
+
def __init__(self, config):
|
|
844
|
+
"""Initialize the Databricks dialect with optional native connection.
|
|
845
|
+
|
|
846
|
+
When ``config.has_native_connection()`` is True, creates a connection via
|
|
847
|
+
``databricks-sql-connector`` and skips SparkSession. Otherwise falls
|
|
848
|
+
back to PySpark.
|
|
849
|
+
|
|
850
|
+
Args:
|
|
851
|
+
|
|
852
|
+
config: DatabricksRuntimeConfig instance providing catalog and schema.
|
|
853
|
+
|
|
854
|
+
Raises:
|
|
855
|
+
|
|
856
|
+
RuntimeError: If neither a native connection nor a SparkSession can be obtained.
|
|
857
|
+
"""
|
|
858
|
+
super().__init__(config)
|
|
859
|
+
|
|
860
|
+
self.connection = None
|
|
861
|
+
self.spark = None
|
|
862
|
+
|
|
863
|
+
if config.has_native_connection():
|
|
864
|
+
try:
|
|
865
|
+
from databricks import sql as dbsql
|
|
866
|
+
|
|
867
|
+
self.connection = dbsql.connect(
|
|
868
|
+
server_hostname=config.SERVER_HOSTNAME,
|
|
869
|
+
http_path=config.HTTP_PATH,
|
|
870
|
+
access_token=config.ACCESS_TOKEN,
|
|
871
|
+
)
|
|
872
|
+
except Exception as e:
|
|
873
|
+
raise RuntimeError(f"Failed to connect via databricks-sql-connector: {e}") from e
|
|
874
|
+
else:
|
|
875
|
+
try:
|
|
876
|
+
from pyspark.sql import SparkSession
|
|
877
|
+
|
|
878
|
+
self.spark = SparkSession.builder.getOrCreate()
|
|
879
|
+
except Exception as e:
|
|
880
|
+
raise RuntimeError(f"Failed to get SparkSession: {e}") from e
|
|
881
|
+
|
|
882
|
+
def ast_validate(self, sql: str) -> tuple[bool, str]:
|
|
883
|
+
"""Validate SQL syntax via Spark EXPLAIN on native connection or PySpark."""
|
|
884
|
+
qualified = self._qualify_table_references(sql)
|
|
885
|
+
|
|
886
|
+
if self.connection is not None:
|
|
887
|
+
try:
|
|
888
|
+
cursor = self.connection.cursor()
|
|
889
|
+
cursor.execute(f"EXPLAIN {qualified}")
|
|
890
|
+
cursor.fetchall()
|
|
891
|
+
cursor.close()
|
|
892
|
+
except Exception as e:
|
|
893
|
+
err_str = str(e).lower()
|
|
894
|
+
if "syntax" in err_str or "parse" in err_str:
|
|
895
|
+
return False, f"Databricks syntax error: {e}"
|
|
896
|
+
debug(f"[dialect.ast_validate] non-syntax Databricks error (treating as valid): {e}")
|
|
897
|
+
return True, ""
|
|
898
|
+
|
|
899
|
+
try:
|
|
900
|
+
explain_df = self.spark.sql(f"EXPLAIN {qualified}")
|
|
901
|
+
explain_df.collect()
|
|
902
|
+
except Exception as e:
|
|
903
|
+
err_str = str(e).lower()
|
|
904
|
+
if "syntax" in err_str or "parse" in err_str:
|
|
905
|
+
return False, f"Spark syntax error: {e}"
|
|
906
|
+
debug(f"[dialect.ast_validate] non-syntax Spark error (treating as valid): {e}")
|
|
907
|
+
|
|
908
|
+
return True, ""
|
|
909
|
+
|
|
910
|
+
def extract_join_pairs(self, sql: str) -> tuple[bool, set[tuple[str, str]]]:
|
|
911
|
+
"""Extract join pairs from Spark SQL using sqlglot."""
|
|
912
|
+
try:
|
|
913
|
+
pairs = _extract_join_pairs_sqlglot(sql)
|
|
914
|
+
return True, pairs
|
|
915
|
+
except Exception as e:
|
|
916
|
+
debug(f"[dialect.extract_join_pairs] failed: {e}")
|
|
917
|
+
return False, set()
|
|
918
|
+
|
|
919
|
+
def extract_cte_bodies(self, sql: str) -> dict[str, str]:
|
|
920
|
+
"""Extract SQL body for each CTE from a WITH clause using sqlglot."""
|
|
921
|
+
cte_bodies = _extract_cte_bodies_sqlglot(sql)
|
|
922
|
+
if cte_bodies:
|
|
923
|
+
debug(
|
|
924
|
+
f"[dialect.extract_cte_bodies] sqlglot extracted {len(cte_bodies)} CTEs: {list(cte_bodies.keys())}"
|
|
925
|
+
)
|
|
926
|
+
return cte_bodies
|
|
927
|
+
return self._extract_cte_bodies_depth_tracking(sql)
|
|
928
|
+
|
|
929
|
+
def _extract_cte_bodies_depth_tracking(self, sql: str) -> dict[str, str]:
|
|
930
|
+
"""Extract CTE bodies using depth-tracking character parser."""
|
|
931
|
+
cte_bodies: dict[str, str] = {}
|
|
932
|
+
sql_upper = sql.upper()
|
|
933
|
+
if not sql_upper.strip().startswith("WITH"):
|
|
934
|
+
return cte_bodies
|
|
935
|
+
with_match = re.match(r"\s*WITH\s+", sql, re.IGNORECASE)
|
|
936
|
+
if not with_match:
|
|
937
|
+
return cte_bodies
|
|
938
|
+
pos = with_match.end()
|
|
939
|
+
while pos < len(sql):
|
|
940
|
+
while pos < len(sql) and sql[pos] in " \t\n\r,":
|
|
941
|
+
pos += 1
|
|
942
|
+
if pos >= len(sql):
|
|
943
|
+
break
|
|
944
|
+
name_match = re.match(r"(\w+)\s+AS\s*\(", sql[pos:], re.IGNORECASE)
|
|
945
|
+
if not name_match:
|
|
946
|
+
break
|
|
947
|
+
cte_name = name_match.group(1).lower()
|
|
948
|
+
pos += name_match.end()
|
|
949
|
+
start_pos = pos
|
|
950
|
+
depth = 1
|
|
951
|
+
in_string = False
|
|
952
|
+
string_char = None
|
|
953
|
+
while pos < len(sql) and depth > 0:
|
|
954
|
+
c = sql[pos]
|
|
955
|
+
if in_string:
|
|
956
|
+
if c == string_char and (pos + 1 >= len(sql) or sql[pos + 1] != string_char):
|
|
957
|
+
in_string = False
|
|
958
|
+
elif c == string_char and pos + 1 < len(sql) and sql[pos + 1] == string_char:
|
|
959
|
+
pos += 1
|
|
960
|
+
else:
|
|
961
|
+
if c in ("'", '"'):
|
|
962
|
+
in_string = True
|
|
963
|
+
string_char = c
|
|
964
|
+
elif c == "(":
|
|
965
|
+
depth += 1
|
|
966
|
+
elif c == ")":
|
|
967
|
+
depth -= 1
|
|
968
|
+
pos += 1
|
|
969
|
+
if depth == 0:
|
|
970
|
+
cte_sql = sql[start_pos : pos - 1].strip()
|
|
971
|
+
cte_bodies[cte_name] = cte_sql
|
|
972
|
+
rest = sql[pos:].strip().upper()
|
|
973
|
+
if rest.startswith("SELECT") or rest.startswith("(SELECT"):
|
|
974
|
+
break
|
|
975
|
+
debug(
|
|
976
|
+
f"[dialect.extract_cte_bodies] depth-tracking extracted {len(cte_bodies)} CTEs: {list(cte_bodies.keys())}"
|
|
977
|
+
)
|
|
978
|
+
return cte_bodies
|
|
979
|
+
|
|
980
|
+
def explain_sql(self, engine: Any, sql: str, params: dict[str, Any] | None = None) -> tuple[bool, str]:
|
|
981
|
+
"""Test SQL executability via Spark EXPLAIN."""
|
|
982
|
+
try:
|
|
983
|
+
qualified = self._qualify_table_references(sql)
|
|
984
|
+
explain_df = self.spark.sql(f"EXPLAIN {qualified}")
|
|
985
|
+
explain_df.collect()
|
|
986
|
+
return True, ""
|
|
987
|
+
except Exception as e:
|
|
988
|
+
return False, str(e)
|
|
989
|
+
|
|
990
|
+
def reflect_enums(self, engine: Any, schema_name: str) -> dict[str, list[str]]:
|
|
991
|
+
"""Databricks does not support native enums."""
|
|
992
|
+
return {}
|
|
993
|
+
|
|
994
|
+
def prepare_for_execution(self, sql: str) -> str:
|
|
995
|
+
"""Qualify table references with catalog.schema for Spark execution."""
|
|
996
|
+
return self._qualify_table_references(sql)
|
|
997
|
+
|
|
998
|
+
def _qualify_table_references(self, sql: str) -> str:
|
|
999
|
+
"""Add catalog.schema prefix to physical table references only."""
|
|
1000
|
+
catalog = self.config.CATALOG
|
|
1001
|
+
schema_name = self.config.SCHEMA
|
|
1002
|
+
|
|
1003
|
+
pattern = r'\b(FROM|JOIN)\s+(["`]?)(\w+)\2\b'
|
|
1004
|
+
|
|
1005
|
+
def replace_table(match: re.Match[str]) -> str:
|
|
1006
|
+
keyword = match.group(1)
|
|
1007
|
+
quote = match.group(2) or "`"
|
|
1008
|
+
table = match.group(3)
|
|
1009
|
+
if table.lower() in DATABRICKS_TABLE_QUALIFY_SKIP_IDENTIFIERS:
|
|
1010
|
+
return match.group(0)
|
|
1011
|
+
return (
|
|
1012
|
+
f"{keyword} {quote}{catalog}{quote}.{quote}{schema_name}{quote}."
|
|
1013
|
+
f"{quote}{table}{quote}"
|
|
1014
|
+
)
|
|
1015
|
+
|
|
1016
|
+
return re.sub(pattern, replace_table, sql, flags=re.IGNORECASE)
|
|
1017
|
+
|
|
1018
|
+
def execute_sql_spark(self, sql: str, sql_already_spark: bool = False) -> list[tuple]:
|
|
1019
|
+
"""Execute SQL on Databricks via native connector or Spark."""
|
|
1020
|
+
spark_sql = self._qualify_table_references(sql)
|
|
1021
|
+
|
|
1022
|
+
if self.config.DEBUG:
|
|
1023
|
+
print(f"[DEBUG] Executing on Databricks:\n{spark_sql}")
|
|
1024
|
+
|
|
1025
|
+
if self.connection is not None:
|
|
1026
|
+
cursor = self.connection.cursor()
|
|
1027
|
+
cursor.execute(spark_sql)
|
|
1028
|
+
rows = cursor.fetchall()
|
|
1029
|
+
cursor.close()
|
|
1030
|
+
return [tuple(row) for row in rows]
|
|
1031
|
+
|
|
1032
|
+
df = self.spark.sql(spark_sql)
|
|
1033
|
+
rows = df.collect()
|
|
1034
|
+
|
|
1035
|
+
return [tuple(row) for row in rows]
|
|
1036
|
+
|
|
1037
|
+
|
|
1038
|
+
def render_date_diff_expr(
|
|
1039
|
+
dialect_type: str, left_expr: str, op: str, unit: str, amount: int
|
|
1040
|
+
) -> str:
|
|
1041
|
+
"""Render a date-difference filter expression for dialect.
|
|
1042
|
+
|
|
1043
|
+
Args:
|
|
1044
|
+
|
|
1045
|
+
dialect_type: One of 'postgresql' or 'databricks'.
|
|
1046
|
+
left_expr: SQL expression for date subtraction (e.g. col1 - col2).
|
|
1047
|
+
op: Comparison operator.
|
|
1048
|
+
unit: 'day', 'week', 'month', or 'year'.
|
|
1049
|
+
amount: Numeric amount for the interval.
|
|
1050
|
+
|
|
1051
|
+
Returns:
|
|
1052
|
+
|
|
1053
|
+
SQL predicate string.
|
|
1054
|
+
"""
|
|
1055
|
+
if dialect_type == "postgresql":
|
|
1056
|
+
interval_suffix = "s" if amount != 1 else ""
|
|
1057
|
+
return f"({left_expr}) {op} INTERVAL '{amount} {unit}{interval_suffix}'"
|
|
1058
|
+
if unit == "day":
|
|
1059
|
+
return f"({left_expr}) {op} {amount}"
|
|
1060
|
+
if unit == "week":
|
|
1061
|
+
return f"({left_expr}) {op} {amount * 7}"
|
|
1062
|
+
if unit == "month":
|
|
1063
|
+
return f"({left_expr}) {op} {amount * 30}"
|
|
1064
|
+
if unit == "year":
|
|
1065
|
+
return f"({left_expr}) {op} {amount * 365}"
|
|
1066
|
+
return f"({left_expr}) {op} {amount}"
|
|
1067
|
+
|
|
1068
|
+
|
|
1069
|
+
def render_date_window_expr(dialect_type: str, column: str, op: str, unit: str, offset: int) -> str:
|
|
1070
|
+
"""Render a dialect-specific date window filter expression.
|
|
1071
|
+
|
|
1072
|
+
Args:
|
|
1073
|
+
|
|
1074
|
+
dialect_type: One of 'postgresql' or 'databricks'.
|
|
1075
|
+
column: Column reference string to compare against the date boundary.
|
|
1076
|
+
op: Comparison operator, for example '>=' or '<'.
|
|
1077
|
+
unit: Date unit: 'day', 'week', 'month', or 'year'.
|
|
1078
|
+
offset: Number of units in the past. 0 means the start of the current unit.
|
|
1079
|
+
|
|
1080
|
+
Returns:
|
|
1081
|
+
|
|
1082
|
+
SQL expression string suitable for use in a WHERE clause.
|
|
1083
|
+
"""
|
|
1084
|
+
if offset == 0:
|
|
1085
|
+
if dialect_type == "postgresql":
|
|
1086
|
+
return f"{column} {op} DATE_TRUNC('{unit}', CURRENT_DATE)"
|
|
1087
|
+
return f"{column} {op} date_trunc('{unit}', current_date())"
|
|
1088
|
+
if dialect_type == "postgresql":
|
|
1089
|
+
suffix = "s" if offset != 1 else ""
|
|
1090
|
+
return f"{column} {op} CURRENT_DATE - INTERVAL '{offset} {unit}{suffix}'"
|
|
1091
|
+
if unit == "day":
|
|
1092
|
+
return f"{column} {op} date_sub(current_date(), {offset})"
|
|
1093
|
+
if unit == "week":
|
|
1094
|
+
return f"{column} {op} date_sub(current_date(), {offset * 7})"
|
|
1095
|
+
if unit == "month":
|
|
1096
|
+
return f"{column} {op} add_months(current_date(), -{offset})"
|
|
1097
|
+
if unit == "year":
|
|
1098
|
+
return f"{column} {op} add_months(current_date(), -{offset * 12})"
|
|
1099
|
+
if unit == "hour":
|
|
1100
|
+
return f"{column} {op} (current_timestamp() - INTERVAL '{offset} HOURS')"
|
|
1101
|
+
if unit == "minute":
|
|
1102
|
+
return f"{column} {op} (current_timestamp() - INTERVAL '{offset} MINUTES')"
|
|
1103
|
+
if unit == "second":
|
|
1104
|
+
return f"{column} {op} (current_timestamp() - INTERVAL '{offset} SECONDS')"
|
|
1105
|
+
return f"{column} {op} date_sub(current_date(), {offset})"
|
|
1106
|
+
|
|
1107
|
+
|
|
1108
|
+
def get_dialect(engine_type: str = None, config=None) -> Dialect:
|
|
1109
|
+
"""Return a Dialect instance for the specified engine type.
|
|
1110
|
+
|
|
1111
|
+
Args:
|
|
1112
|
+
|
|
1113
|
+
engine_type: One of 'postgresql' or 'databricks'. Defaults to EngineConfig.TYPE.
|
|
1114
|
+
config: Runtime config object. Defaults to EngineConfig.RUNTIME.
|
|
1115
|
+
|
|
1116
|
+
Returns:
|
|
1117
|
+
|
|
1118
|
+
Configured Dialect subclass instance ready for use.
|
|
1119
|
+
|
|
1120
|
+
Raises:
|
|
1121
|
+
|
|
1122
|
+
ValueError: If engine_type is not a supported dialect name.
|
|
1123
|
+
"""
|
|
1124
|
+
if engine_type is None:
|
|
1125
|
+
engine_type = EngineConfig.TYPE
|
|
1126
|
+
if config is None:
|
|
1127
|
+
config = EngineConfig.RUNTIME
|
|
1128
|
+
|
|
1129
|
+
if engine_type == "postgresql":
|
|
1130
|
+
return PostgresDialect(config)
|
|
1131
|
+
elif engine_type == "databricks":
|
|
1132
|
+
return DatabricksDialect(config)
|
|
1133
|
+
else:
|
|
1134
|
+
raise ValueError(f"Unsupported dialect: {engine_type}")
|