aetherdialect 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
text2sql/dialect.py ADDED
@@ -0,0 +1,1134 @@
1
+ """Database dialect abstraction for SQL validation and introspection.
2
+
3
+ Provides AST-based and EXPLAIN-based SQL validation, join-pair extraction, CTE body parsing, and enum reflection for PostgreSQL and Databricks.
4
+
5
+ The module conditionally imports pglast for PostgreSQL and exposes a Dialect base class with dialect-specific implementations.
6
+
7
+ A factory function returns the active dialect based on EngineConfig.
8
+ """
9
+
10
+ from __future__ import annotations
11
+
12
+ import re
13
+ from typing import Any
14
+
15
+ from .config import DATABRICKS_TABLE_QUALIFY_SKIP_IDENTIFIERS, EngineConfig
16
+ from .core_utils import canonicalize_sql, debug
17
+
18
+
19
+ def _extract_join_pairs_sqlglot(sql: str) -> set[tuple[str, str]]:
20
+ """Extract equality join pairs from Spark SQL using sqlglot AST."""
21
+ import sqlglot
22
+ from sqlglot import exp
23
+
24
+ pairs: set[tuple[str, str]] = set()
25
+ try:
26
+ parsed = sqlglot.parse_one(sql, dialect="spark")
27
+ except Exception:
28
+ return pairs
29
+
30
+ alias_to_table: dict[str, str] = {}
31
+ for table in parsed.find_all(exp.Table):
32
+ alias = table.alias_or_name
33
+ name = table.name
34
+ alias_to_table[alias] = name
35
+ alias_to_table[name] = name
36
+
37
+ def _resolve(table_or_alias: str) -> str:
38
+ return alias_to_table.get(table_or_alias, table_or_alias)
39
+
40
+ def _add_pair(left: exp.Column, right: exp.Column) -> None:
41
+ lt = left.table or ""
42
+ lc = left.name or ""
43
+ rt = right.table or ""
44
+ rc = right.name or ""
45
+ if lt and lc and rt and rc:
46
+ la = _resolve(lt)
47
+ ra = _resolve(rt)
48
+ pair_a = f"{la}.{lc}"
49
+ pair_b = f"{ra}.{rc}"
50
+ pairs.add(tuple(sorted([pair_a, pair_b])))
51
+
52
+ for join in parsed.find_all(exp.Join):
53
+ on_expr = join.args.get("on")
54
+ if on_expr is None:
55
+ continue
56
+ for eq in on_expr.find_all(exp.EQ):
57
+ left = eq.this
58
+ right = eq.expression
59
+ if isinstance(left, exp.Column) and isinstance(right, exp.Column):
60
+ _add_pair(left, right)
61
+
62
+ return pairs
63
+
64
+
65
+ def _extract_cte_bodies_sqlglot(sql: str) -> dict[str, str]:
66
+ """Extract CTE name to body SQL mapping from Spark SQL using sqlglot."""
67
+ import sqlglot
68
+ from sqlglot import exp
69
+
70
+ cte_bodies: dict[str, str] = {}
71
+ try:
72
+ parsed = sqlglot.parse_one(sql, dialect="spark")
73
+ except Exception:
74
+ return cte_bodies
75
+
76
+ for cte in parsed.find_all(exp.CTE):
77
+ alias = cte.alias
78
+ if alias and cte.this:
79
+ cte_bodies[alias.lower()] = cte.this.sql(dialect="spark")
80
+
81
+ return cte_bodies
82
+
83
+ if EngineConfig.TYPE == "postgresql":
84
+
85
+ def _strip_schema(ident: str) -> str:
86
+ """Strip schema prefix from identifier and return lowercase table name."""
87
+ s = (ident or "").strip().lower()
88
+ if "." in s:
89
+ s = s.split(".")[-1]
90
+ return s
91
+
92
+ def _colref_to_parts(n: Any) -> tuple[str, str] | None:
93
+ """Parse AST ColumnRef node into a (table, column) tuple.
94
+
95
+ Args:
96
+
97
+ n: An AST node, expected to be a ColumnRef with one to three String fields.
98
+
99
+ Returns:
100
+
101
+ Tuple of (table_alias_or_name, column_name), or None if the node is not a valid ColumnRef or contains unsupported field types.
102
+ """
103
+ tag = getattr(n, "__class__", type("x", (), {})).__name__
104
+ if tag != "ColumnRef":
105
+ return None
106
+ fields = getattr(n, "fields", None)
107
+ if not isinstance(fields, list | tuple) or not fields:
108
+ return None
109
+ parts: list[str] = []
110
+ for f in fields:
111
+ ftag = getattr(f, "__class__", type("x", (), {})).__name__
112
+ if ftag != "String":
113
+ return None
114
+ sval = getattr(f, "sval", None)
115
+ if isinstance(sval, str) and sval:
116
+ parts.append(sval)
117
+ if len(parts) == 1:
118
+ return ("", _strip_schema(parts[0]))
119
+ if len(parts) == 2:
120
+ return (_strip_schema(parts[0]), _strip_schema(parts[1]))
121
+ if len(parts) == 3:
122
+ return (_strip_schema(parts[1]), _strip_schema(parts[2]))
123
+ return None
124
+
125
+ def _extract_eq_pairs(expr: Any, alias_to_table: dict[str, str], out: set[tuple[str, str]]) -> bool:
126
+ """Recursively extract equality join pairs from an AST expression.
127
+
128
+ Args:
129
+
130
+ expr: AST expression node (A_Expr or BoolExpr).
131
+ alias_to_table: Mapping from table alias to canonical table name.
132
+ out: Set to accumulate discovered (tableA.col, tableB.col) join pairs.
133
+
134
+ Returns:
135
+
136
+ True if the expression was fully traversed without unsupported constructs, False if a non-equality or disallowed expression was found.
137
+ """
138
+ tag = getattr(expr, "__class__", type("x", (), {})).__name__
139
+ if tag == "BoolExpr":
140
+ if getattr(expr, "boolop", None) != 0:
141
+ return False
142
+ args = getattr(expr, "args", None)
143
+ if not isinstance(args, list):
144
+ return False
145
+ for a in args:
146
+ if not _extract_eq_pairs(a, alias_to_table, out):
147
+ return False
148
+ return True
149
+ if tag != "A_Expr":
150
+ return True
151
+ k = getattr(expr, "kind", None)
152
+ if isinstance(k, int):
153
+ if k != 0:
154
+ return False
155
+ else:
156
+ if getattr(k, "name", "") != "AEXPR_OP":
157
+ return False
158
+ name = getattr(expr, "name", None)
159
+ if not isinstance(name, list | tuple) or len(name) != 1:
160
+ return False
161
+ op = name[0]
162
+ if getattr(op, "__class__", type("x", (), {})).__name__ != "String":
163
+ return False
164
+ op_str = getattr(op, "sval", None)
165
+ if op_str != "=":
166
+ return True
167
+ lhs = _colref_to_parts(getattr(expr, "lexpr", None))
168
+ r = _colref_to_parts(getattr(expr, "rexpr", None))
169
+ if not lhs or not r:
170
+ return True
171
+ lt, lc = lhs
172
+ rt, rc = r
173
+ lt = alias_to_table.get(lt, lt) if lt else ""
174
+ rt = alias_to_table.get(rt, rt) if rt else ""
175
+ if not lt or not rt:
176
+ return True
177
+ if not lc or not rc:
178
+ return True
179
+ a = f"{_strip_schema(lt)}.{lc}"
180
+ b = f"{_strip_schema(rt)}.{rc}"
181
+ out.add(tuple(sorted((a, b))))
182
+ return True
183
+
184
+ def _collect_from_items(
185
+ fr: Any,
186
+ ) -> tuple[bool, dict[str, str], bool, bool, bool, bool, bool]:
187
+ """Collect table aliases and detect forbidden structures from a FROM clause.
188
+
189
+ Args:
190
+
191
+ fr: FROM clause node or list of FROM clause items from the AST.
192
+
193
+ Returns:
194
+
195
+ Tuple of (ok, alias_to_table, has_subquery, has_using, has_cross_join, has_self_join, has_items) where ok is False if an unsupported node type was encountered.
196
+ """
197
+ alias_to_table: dict[str, str] = {}
198
+ has_subquery = False
199
+ has_using = False
200
+ has_cross_join = False
201
+ has_self_join = False
202
+ seen_tables: set[str] = set()
203
+ ok = True
204
+
205
+ def add_alias(relname: str, alias: Any):
206
+ nonlocal alias_to_table, has_self_join, seen_tables
207
+ t = _strip_schema(relname)
208
+ if t in seen_tables:
209
+ has_self_join = True
210
+ seen_tables.add(t)
211
+ if alias is None:
212
+ alias_to_table[t] = t
213
+ return
214
+ an = getattr(alias, "aliasname", None)
215
+ if isinstance(an, str) and an:
216
+ alias_to_table[_strip_schema(an)] = t
217
+ alias_to_table[t] = t
218
+
219
+ def walk(item: Any):
220
+ nonlocal has_subquery, has_using, has_cross_join, ok
221
+ if item is None:
222
+ ok = False
223
+ return False
224
+ tag = getattr(item, "__class__", type("x", (), {})).__name__
225
+ if tag == "RangeVar":
226
+ add_alias(getattr(item, "relname", "") or "", getattr(item, "alias", None))
227
+ return True
228
+ if tag == "JoinExpr":
229
+ if getattr(item, "usingClause", None) is not None or getattr(item, "isNatural", False):
230
+ has_using = True
231
+ join_type = getattr(item, "jointype", None)
232
+ if join_type is not None and str(join_type) == "JoinType.JOIN_INNER":
233
+ quals = getattr(item, "quals", None)
234
+ if quals is None:
235
+ has_cross_join = True
236
+ if not walk(getattr(item, "larg", None)):
237
+ return False
238
+ if not walk(getattr(item, "rarg", None)):
239
+ return False
240
+ return True
241
+ if tag in {
242
+ "RangeSubselect",
243
+ "RangeFunction",
244
+ "RangeTableFunc",
245
+ "RangeTableSample",
246
+ }:
247
+ has_subquery = True
248
+ ok = False
249
+ return False
250
+ ok = False
251
+ return False
252
+
253
+ if fr is None:
254
+ return False, {}, False, False, False, False, False
255
+ for it in fr if isinstance(fr, list | tuple) else [fr]:
256
+ if not walk(it):
257
+ ok = False
258
+ break
259
+ return (
260
+ ok,
261
+ alias_to_table,
262
+ has_subquery,
263
+ has_using,
264
+ has_cross_join,
265
+ has_self_join,
266
+ True,
267
+ )
268
+ return (
269
+ ok,
270
+ alias_to_table,
271
+ has_subquery,
272
+ has_using,
273
+ has_cross_join,
274
+ has_self_join,
275
+ True,
276
+ )
277
+
278
+ def _normalized_join_pairs_from_sql_ast(
279
+ sql: str,
280
+ ) -> tuple[bool, set[tuple[str, str]]]:
281
+ """Extract normalized join pairs from SQL using AST parsing and handle CTEs.
282
+
283
+ Args:
284
+
285
+ sql: SQL query string to parse.
286
+
287
+ Returns:
288
+
289
+ Tuple of (ok, pairs) where ok is False if parsing failed or a forbidden construct was found, and pairs is the set of (tableA.col, tableB.col) tuples sorted so the lexicographically smaller element comes first.
290
+ """
291
+ s = canonicalize_sql(sql)
292
+ if not s:
293
+ debug("[dialect.normalized_join_pairs_from_sql_ast] empty_sql")
294
+ return False, set()
295
+ try:
296
+ import pglast
297
+ except ImportError:
298
+ debug("[dialect.normalized_join_pairs_from_sql_ast] pglast_unavailable")
299
+ return False, set()
300
+ parse_sql = getattr(pglast, "parse_sql", None)
301
+ if parse_sql is None:
302
+ debug("[dialect.normalized_join_pairs_from_sql_ast] pglast_unavailable")
303
+ return False, set()
304
+ try:
305
+ stmts = parse_sql(s)
306
+ except Exception as e:
307
+ debug(f"[dialect.normalized_join_pairs_from_sql_ast] parse_failed: {e}")
308
+ return False, set()
309
+ if not stmts:
310
+ debug("[dialect.normalized_join_pairs_from_sql_ast] no_statements")
311
+ return False, set()
312
+ first = stmts[0]
313
+ root = getattr(first, "stmt", None)
314
+ if root is None:
315
+ debug("[dialect.normalized_join_pairs_from_sql_ast] no_stmt_attr")
316
+ return False, set()
317
+ if getattr(root, "__class__", type("x", (), {})).__name__ != "SelectStmt":
318
+ debug("[dialect.normalized_join_pairs_from_sql_ast] not_select_stmt")
319
+ return False, set()
320
+
321
+ with_clause = getattr(root, "withClause", None)
322
+ cte_names = set()
323
+ all_pairs: set[tuple[str, str]] = set()
324
+
325
+ if with_clause is not None:
326
+ ctes = getattr(with_clause, "ctes", [])
327
+ for cte in ctes:
328
+ cte_name = getattr(getattr(cte, "ctename", None), "sval", None) or getattr(cte, "ctename", "")
329
+ if cte_name:
330
+ cte_names.add(cte_name.lower())
331
+
332
+ cte_query = getattr(cte, "ctequery", None)
333
+ if cte_query is None:
334
+ continue
335
+
336
+ cte_from = getattr(cte_query, "fromClause", None)
337
+ ok_cte_from, cte_alias_map, has_subq, has_using, _, _, _ = _collect_from_items(cte_from)
338
+ if not ok_cte_from or has_subq or has_using:
339
+ return False, set()
340
+
341
+ def walk_cte_joins(item: Any, alias_map: dict[str, str], pairs: set[tuple[str, str]]) -> bool:
342
+ if getattr(item, "__class__", type("x", (), {})).__name__ == "JoinExpr":
343
+ if getattr(item, "quals", None) is None:
344
+ return False
345
+ if not _extract_eq_pairs(getattr(item, "quals", None), alias_map, pairs):
346
+ return False
347
+ if not walk_cte_joins(getattr(item, "larg", None), alias_map, pairs):
348
+ return False
349
+ if not walk_cte_joins(getattr(item, "rarg", None), alias_map, pairs):
350
+ return False
351
+ return True
352
+
353
+ if cte_from:
354
+ for it in cte_from if isinstance(cte_from, list | tuple) else [cte_from]:
355
+ if not walk_cte_joins(it, cte_alias_map, all_pairs):
356
+ return False, set()
357
+
358
+ cte_where = getattr(cte_query, "whereClause", None)
359
+ if cte_where is not None:
360
+ tmp: set[tuple[str, str]] = set()
361
+ if not _extract_eq_pairs(cte_where, cte_alias_map, tmp):
362
+ return False, set()
363
+ all_pairs |= tmp
364
+
365
+ fr = getattr(root, "fromClause", None)
366
+ ok_from, alias_to_table, has_subq, has_using, _, _, _ = _collect_from_items(fr)
367
+ if not ok_from:
368
+ debug("[dialect.normalized_join_pairs_from_sql_ast] collect_from_failed")
369
+ return False, set()
370
+ if has_subq:
371
+ debug("[dialect.normalized_join_pairs_from_sql_ast] has_subquery")
372
+ return False, set()
373
+ if has_using:
374
+ debug("[dialect.normalized_join_pairs_from_sql_ast] has_using")
375
+ return False, set()
376
+
377
+ def walk_join_ons(item: Any) -> bool:
378
+ if getattr(item, "__class__", type("x", (), {})).__name__ == "JoinExpr":
379
+ quals = getattr(item, "quals", None)
380
+ if quals is None:
381
+ debug("[dialect.normalized_join_pairs_from_sql_ast] missing_quals")
382
+ return False
383
+ if not _extract_eq_pairs(quals, alias_to_table, all_pairs):
384
+ debug("[dialect.normalized_join_pairs_from_sql_ast] extract_pairs_failed")
385
+ return False
386
+ if not walk_join_ons(getattr(item, "larg", None)):
387
+ return False
388
+ if not walk_join_ons(getattr(item, "rarg", None)):
389
+ return False
390
+ return True
391
+
392
+ for it in fr if isinstance(fr, list | tuple) else []:
393
+ if not walk_join_ons(it):
394
+ return False, set()
395
+
396
+ where = getattr(root, "whereClause", None)
397
+ if where is not None:
398
+ tmp: set[tuple[str, str]] = set()
399
+ if not _extract_eq_pairs(where, alias_to_table, tmp):
400
+ debug("[dialect.normalized_join_pairs_from_sql_ast] where_extract_failed")
401
+ return False, set()
402
+ all_pairs |= tmp
403
+
404
+ if getattr(root, "groupClause", None) is not None:
405
+ for g in getattr(root, "groupClause", []) or []:
406
+ if getattr(g, "__class__", type("x", (), {})).__name__ == "SelectStmt":
407
+ return False, set()
408
+ if getattr(root, "sortClause", None) is not None:
409
+ for so in getattr(root, "sortClause", []) or []:
410
+ if getattr(so, "__class__", type("x", (), {})).__name__ == "SelectStmt":
411
+ return False, set()
412
+ return True, all_pairs
413
+
414
+ def _validate_cte_bodies(with_clause: Any) -> tuple[bool, str]:
415
+ """Validate CTE bodies against structural restrictions.
416
+
417
+ Forbids recursive CTEs, subqueries, window functions, set operations, and CASE expressions inside any CTE body.
418
+
419
+ Args:
420
+
421
+ with_clause: AST WithClause node or None.
422
+
423
+ Returns:
424
+
425
+ Tuple of (ok, error_code). ok is True when all CTEs pass validation. error_code is an empty string on success or a short string token describing the first violation found.
426
+ """
427
+ if with_clause is None:
428
+ return True, ""
429
+
430
+ if getattr(with_clause, "recursive", False):
431
+ return False, "cte_recursive"
432
+
433
+ ctes = getattr(with_clause, "ctes", [])
434
+ if not ctes:
435
+ return True, ""
436
+
437
+ for cte in ctes:
438
+ cte_query = getattr(cte, "ctequery", None)
439
+ if cte_query is None:
440
+ return False, "cte_malformed"
441
+
442
+ def walk_cte(n):
443
+ tag = getattr(n, "__class__", type("x", (), {})).__name__
444
+ if tag in {"RangeSubselect", "SubLink"}:
445
+ if tag == "SubLink":
446
+ sublink_type = getattr(n, "subLinkType", None)
447
+ if sublink_type is not None and sublink_type == 0:
448
+ return "cte_contains_exists"
449
+ return "cte_contains_subquery"
450
+ if tag == "SetOperationStmt":
451
+ return "cte_contains_set_op"
452
+ if tag == "WindowDef":
453
+ return "cte_contains_window"
454
+ if tag == "CaseExpr":
455
+ return "cte_contains_case"
456
+ try:
457
+ attrs = vars(n)
458
+ except TypeError:
459
+ return None
460
+ for attr in attrs.values():
461
+ if isinstance(attr, list):
462
+ for x in attr:
463
+ if hasattr(x, "__class__"):
464
+ err = walk_cte(x)
465
+ if err:
466
+ return err
467
+ elif hasattr(attr, "__class__"):
468
+ err = walk_cte(attr)
469
+ if err:
470
+ return err
471
+ return None
472
+
473
+ err = walk_cte(cte_query)
474
+ if err:
475
+ return False, err
476
+
477
+ return True, ""
478
+
479
+ def _ast_structural_valid(sql: str) -> tuple[bool, str]:
480
+ """Validate SQL structure using the pglast AST.
481
+
482
+ Checks that the SQL is a single SELECT statement free of subqueries, window functions, CASE expressions, CROSS JOINs, self-joins, USING clauses, EXISTS, LATERAL, and set operations. Also validates any CTE bodies.
483
+
484
+ Args:
485
+
486
+ sql: SQL query string to validate.
487
+
488
+ Returns:
489
+
490
+ Tuple of (ok, error_code). ok is True when the SQL passes all structural checks. error_code is an empty string on success or a short token identifying the first violation.
491
+ """
492
+ try:
493
+ import pglast
494
+ except ImportError:
495
+ return False, "ast_unavailable"
496
+ parse_sql = getattr(pglast, "parse_sql", None)
497
+ if parse_sql is None:
498
+ return False, "ast_unavailable"
499
+
500
+ try:
501
+ stmts = parse_sql(canonicalize_sql(sql))
502
+ except Exception:
503
+ return False, "ast_parse_failed"
504
+
505
+ if not stmts or len(stmts) != 1:
506
+ return False, "multiple_statements"
507
+
508
+ root = getattr(stmts[0], "stmt", None)
509
+ if root is None:
510
+ return False, "no_root"
511
+
512
+ if getattr(root, "__class__", type("x", (), {})).__name__ != "SelectStmt":
513
+ return False, "not_select"
514
+
515
+ with_clause = getattr(root, "withClause", None)
516
+ has_cte = with_clause is not None
517
+
518
+ if has_cte:
519
+ ok, err = _validate_cte_bodies(with_clause)
520
+ if not ok:
521
+ return False, err
522
+
523
+ fr = getattr(root, "fromClause", None)
524
+ if fr is not None:
525
+ _, _, has_subq, has_using, has_cross, has_self, _ = _collect_from_items(fr)
526
+ if has_subq:
527
+ return False, "subquery_not_allowed"
528
+ if has_using:
529
+ return False, "using_not_allowed"
530
+ if has_cross:
531
+ return False, "cross_join_not_allowed"
532
+ if has_self:
533
+ return False, "self_join_not_allowed"
534
+
535
+ has_window = False
536
+ has_case = False
537
+ has_exists = False
538
+ has_lateral = False
539
+
540
+ def walk(n, check_window=True):
541
+ nonlocal has_window, has_case, has_exists, has_lateral
542
+ tag = getattr(n, "__class__", type("x", (), {})).__name__
543
+ if tag in {
544
+ "RangeSubselect",
545
+ "SubLink",
546
+ "SetOperationStmt",
547
+ }:
548
+ if tag == "SubLink":
549
+ sublink_type = getattr(n, "subLinkType", None)
550
+ if sublink_type is not None and sublink_type == 0:
551
+ has_exists = True
552
+ return False
553
+ if tag == "CaseExpr":
554
+ has_case = True
555
+ return False
556
+ if tag == "RangeFunction":
557
+ is_lateral = getattr(n, "lateral", False)
558
+ if is_lateral:
559
+ has_lateral = True
560
+ return False
561
+ if check_window and tag == "WindowDef":
562
+ has_window = True
563
+ return False
564
+ if check_window and tag == "FuncCall":
565
+ over = getattr(n, "over", None)
566
+ if over is not None:
567
+ has_window = True
568
+ return False
569
+
570
+ try:
571
+ attrs = vars(n)
572
+ except TypeError:
573
+ return True
574
+
575
+ for attr in attrs.values():
576
+ if isinstance(attr, list):
577
+ for x in attr:
578
+ if hasattr(x, "__class__") and not walk(x, check_window):
579
+ return False
580
+ elif hasattr(attr, "__class__"):
581
+ if not walk(attr, check_window):
582
+ return False
583
+ return True
584
+
585
+ if not walk(root):
586
+ if has_window:
587
+ return False, "window_not_allowed"
588
+ if has_case:
589
+ return False, "case_when_not_allowed"
590
+ if has_exists:
591
+ return False, "exists_not_allowed"
592
+ if has_lateral:
593
+ return False, "lateral_not_allowed"
594
+ return False, "forbidden_structure"
595
+
596
+ return True, ""
597
+
598
+
599
+ class Dialect:
600
+ """Base dialect interface for database-specific operations."""
601
+
602
+ name: str = "base"
603
+
604
+ def __init__(self, config):
605
+ """Initialize the dialect with a runtime configuration object.
606
+
607
+ Args:
608
+
609
+ config: Runtime config instance, such as PostgresRuntimeConfig or DatabricksRuntimeConfig, providing connection details.
610
+ """
611
+ self.config = config
612
+
613
+ def ast_validate(self, sql: str) -> tuple[bool, str]:
614
+ """Validate SQL structure via AST parsing without a database connection.
615
+
616
+ Args:
617
+
618
+ sql: SQL query string to validate.
619
+
620
+ Returns:
621
+
622
+ Tuple of (ok, error_code). ok is True when the SQL passes structural checks.
623
+ """
624
+ raise NotImplementedError
625
+
626
+ def extract_join_pairs(self, sql: str) -> tuple[bool, set[tuple[str, str]]]:
627
+ """Extract normalized join column pairs from the SQL AST.
628
+
629
+ Args:
630
+
631
+ sql: SQL query string to parse.
632
+
633
+ Returns:
634
+
635
+ Tuple of (ok, pairs) where pairs is a set of sorted (tableA.col, tableB.col) tuples found in ON conditions.
636
+ """
637
+ raise NotImplementedError
638
+
639
+ def extract_cte_bodies(self, sql: str) -> dict[str, str]:
640
+ """Extract the SQL body for each named CTE from a WITH clause.
641
+
642
+ Args:
643
+
644
+ sql: SQL query string that may contain a WITH clause.
645
+
646
+ Returns:
647
+
648
+ Dictionary mapping lowercase CTE name to its inner SELECT SQL string.
649
+ """
650
+ raise NotImplementedError
651
+
652
+ def explain_sql(self, engine: Any, sql: str, params: dict[str, Any] | None = None) -> tuple[bool, str]:
653
+ """Test SQL executability by running EXPLAIN against the database.
654
+
655
+ Args:
656
+
657
+ engine: SQLAlchemy Engine connected to the target database.
658
+ sql: SQL query string to explain.
659
+ params: Optional bind-parameter values keyed by placeholder name.
660
+
661
+ Returns:
662
+
663
+ Tuple of (ok, error_message). ok is True when EXPLAIN succeeds.
664
+ """
665
+ raise NotImplementedError
666
+
667
+ def reflect_enums(self, engine: Any, schema_name: str) -> dict[str, list[str]]:
668
+ """Discover enum types and their allowed values from the database.
669
+
670
+ Args:
671
+
672
+ engine: SQLAlchemy Engine connected to the target database.
673
+ schema_name: Database schema name to introspect.
674
+
675
+ Returns:
676
+
677
+ Dictionary mapping enum type name to an ordered list of its values.
678
+ """
679
+ raise NotImplementedError
680
+
681
+
682
+ class PostgresDialect(Dialect):
683
+ """PostgreSQL-specific dialect implementation."""
684
+
685
+ name: str = "postgresql"
686
+
687
+ def __init__(self, config):
688
+ """Initialize the PostgreSQL dialect and create a SQLAlchemy engine.
689
+
690
+ Args:
691
+
692
+ config: PostgresRuntimeConfig instance providing connection credentials.
693
+ """
694
+ super().__init__(config)
695
+ from sqlalchemy import create_engine
696
+
697
+ self.engine = create_engine(config.db_url())
698
+
699
+ def ast_validate(self, sql: str) -> tuple[bool, str]:
700
+ """Validate SQL structure via AST parsing."""
701
+ return _ast_structural_valid(sql)
702
+
703
+ def extract_join_pairs(self, sql: str) -> tuple[bool, set[tuple[str, str]]]:
704
+ """Extract normalized join pairs from SQL AST."""
705
+ return _normalized_join_pairs_from_sql_ast(sql)
706
+
707
+ def extract_cte_bodies(self, sql: str) -> dict[str, str]:
708
+ """Extract SQL body for each CTE from a WITH clause using the pglast AST."""
709
+ try:
710
+ import pglast
711
+ except ImportError:
712
+ return self._extract_cte_bodies_depth_tracking(sql)
713
+ parse_sql = getattr(pglast, "parse_sql", None)
714
+ if parse_sql is None:
715
+ return self._extract_cte_bodies_depth_tracking(sql)
716
+ cte_bodies: dict[str, str] = {}
717
+ try:
718
+ stmts = parse_sql(canonicalize_sql(sql))
719
+ if not stmts:
720
+ return self._extract_cte_bodies_depth_tracking(sql)
721
+ root = getattr(stmts[0], "stmt", None)
722
+ if root is None:
723
+ return self._extract_cte_bodies_depth_tracking(sql)
724
+ with_clause = getattr(root, "withClause", None)
725
+ if with_clause is None:
726
+ return cte_bodies
727
+ ctes = getattr(with_clause, "ctes", [])
728
+ for cte in ctes:
729
+ cte_name = getattr(getattr(cte, "ctename", None), "sval", None) or getattr(cte, "ctename", "")
730
+ cte_query = getattr(cte, "ctequery", None)
731
+ if cte_name and cte_query:
732
+ try:
733
+ from pglast import prettify
734
+
735
+ cte_sql = prettify(cte_query)
736
+ cte_bodies[cte_name.lower()] = cte_sql
737
+ except Exception as exc:
738
+ debug(f"[dialect.extract_cte_bodies] prettify failed for CTE '{cte_name}': {exc}")
739
+ if cte_bodies:
740
+ debug(
741
+ f"[dialect.extract_cte_bodies] pglast extracted {len(cte_bodies)} CTEs: {list(cte_bodies.keys())}"
742
+ )
743
+ return cte_bodies
744
+ except Exception as e:
745
+ debug(f"[dialect.extract_cte_bodies] pglast failed: {e}, falling back to depth-tracking")
746
+ return self._extract_cte_bodies_depth_tracking(sql)
747
+
748
+ def _extract_cte_bodies_depth_tracking(self, sql: str) -> dict[str, str]:
749
+ """Extract CTE bodies using depth-tracking character parser."""
750
+ cte_bodies: dict[str, str] = {}
751
+ sql_upper = sql.upper()
752
+ if not sql_upper.strip().startswith("WITH"):
753
+ return cte_bodies
754
+ with_match = re.match(r"\s*WITH\s+", sql, re.IGNORECASE)
755
+ if not with_match:
756
+ return cte_bodies
757
+ pos = with_match.end()
758
+ while pos < len(sql):
759
+ while pos < len(sql) and sql[pos] in " \t\n\r,":
760
+ pos += 1
761
+ if pos >= len(sql):
762
+ break
763
+ name_match = re.match(r"(\w+)\s+AS\s*\(", sql[pos:], re.IGNORECASE)
764
+ if not name_match:
765
+ break
766
+ cte_name = name_match.group(1).lower()
767
+ pos += name_match.end()
768
+ start_pos = pos
769
+ depth = 1
770
+ in_string = False
771
+ string_char = None
772
+ while pos < len(sql) and depth > 0:
773
+ c = sql[pos]
774
+ if in_string:
775
+ if c == string_char and (pos + 1 >= len(sql) or sql[pos + 1] != string_char):
776
+ in_string = False
777
+ elif c == string_char and pos + 1 < len(sql) and sql[pos + 1] == string_char:
778
+ pos += 1
779
+ else:
780
+ if c in ("'", '"'):
781
+ in_string = True
782
+ string_char = c
783
+ elif c == "(":
784
+ depth += 1
785
+ elif c == ")":
786
+ depth -= 1
787
+ pos += 1
788
+ if depth == 0:
789
+ cte_sql = sql[start_pos : pos - 1].strip()
790
+ cte_bodies[cte_name] = cte_sql
791
+ rest = sql[pos:].strip().upper()
792
+ if rest.startswith("SELECT") or rest.startswith("(SELECT"):
793
+ break
794
+ debug(
795
+ f"[dialect.extract_cte_bodies] depth-tracking extracted {len(cte_bodies)} CTEs: {list(cte_bodies.keys())}"
796
+ )
797
+ return cte_bodies
798
+
799
+ def explain_sql(self, engine: Any, sql: str, params: dict[str, Any] | None = None) -> tuple[bool, str]:
800
+ """Test SQL executability via PostgreSQL EXPLAIN."""
801
+ from sqlalchemy import text
802
+
803
+ try:
804
+ with engine.connect() as conn:
805
+ conn.execute(text(f"EXPLAIN {sql}"), params or {})
806
+ return True, ""
807
+ except Exception as e:
808
+ return False, str(e)
809
+
810
+ def reflect_enums(self, engine: Any, schema_name: str) -> dict[str, list[str]]:
811
+ """Discover PostgreSQL enum types and their values."""
812
+ if EngineConfig.TYPE != "postgresql":
813
+ return {}
814
+
815
+ from sqlalchemy import text
816
+
817
+ enum_values: dict[str, list[str]] = {}
818
+ try:
819
+ with engine.connect() as conn:
820
+ enum_query = text("""
821
+ SELECT t.typname, e.enumlabel
822
+ FROM pg_type t
823
+ JOIN pg_enum e ON t.oid = e.enumtypid
824
+ JOIN pg_namespace n ON t.typnamespace = n.oid
825
+ WHERE n.nspname = :schema
826
+ ORDER BY t.typname, e.enumsortorder
827
+ """)
828
+ for row in conn.execute(enum_query, {"schema": schema_name}):
829
+ enum_name = row[0]
830
+ if enum_name not in enum_values:
831
+ enum_values[enum_name] = []
832
+ enum_values[enum_name].append(row[1])
833
+ except Exception as exc:
834
+ debug(f"[dialect.reflect_enums] failed to reflect enums: {exc}")
835
+ return enum_values
836
+
837
+
838
+ class DatabricksDialect(Dialect):
839
+ """Databricks Spark dialect with Spark EXPLAIN validation."""
840
+
841
+ name: str = "databricks"
842
+
843
+ def __init__(self, config):
844
+ """Initialize the Databricks dialect with optional native connection.
845
+
846
+ When ``config.has_native_connection()`` is True, creates a connection via
847
+ ``databricks-sql-connector`` and skips SparkSession. Otherwise falls
848
+ back to PySpark.
849
+
850
+ Args:
851
+
852
+ config: DatabricksRuntimeConfig instance providing catalog and schema.
853
+
854
+ Raises:
855
+
856
+ RuntimeError: If neither a native connection nor a SparkSession can be obtained.
857
+ """
858
+ super().__init__(config)
859
+
860
+ self.connection = None
861
+ self.spark = None
862
+
863
+ if config.has_native_connection():
864
+ try:
865
+ from databricks import sql as dbsql
866
+
867
+ self.connection = dbsql.connect(
868
+ server_hostname=config.SERVER_HOSTNAME,
869
+ http_path=config.HTTP_PATH,
870
+ access_token=config.ACCESS_TOKEN,
871
+ )
872
+ except Exception as e:
873
+ raise RuntimeError(f"Failed to connect via databricks-sql-connector: {e}") from e
874
+ else:
875
+ try:
876
+ from pyspark.sql import SparkSession
877
+
878
+ self.spark = SparkSession.builder.getOrCreate()
879
+ except Exception as e:
880
+ raise RuntimeError(f"Failed to get SparkSession: {e}") from e
881
+
882
+ def ast_validate(self, sql: str) -> tuple[bool, str]:
883
+ """Validate SQL syntax via Spark EXPLAIN on native connection or PySpark."""
884
+ qualified = self._qualify_table_references(sql)
885
+
886
+ if self.connection is not None:
887
+ try:
888
+ cursor = self.connection.cursor()
889
+ cursor.execute(f"EXPLAIN {qualified}")
890
+ cursor.fetchall()
891
+ cursor.close()
892
+ except Exception as e:
893
+ err_str = str(e).lower()
894
+ if "syntax" in err_str or "parse" in err_str:
895
+ return False, f"Databricks syntax error: {e}"
896
+ debug(f"[dialect.ast_validate] non-syntax Databricks error (treating as valid): {e}")
897
+ return True, ""
898
+
899
+ try:
900
+ explain_df = self.spark.sql(f"EXPLAIN {qualified}")
901
+ explain_df.collect()
902
+ except Exception as e:
903
+ err_str = str(e).lower()
904
+ if "syntax" in err_str or "parse" in err_str:
905
+ return False, f"Spark syntax error: {e}"
906
+ debug(f"[dialect.ast_validate] non-syntax Spark error (treating as valid): {e}")
907
+
908
+ return True, ""
909
+
910
+ def extract_join_pairs(self, sql: str) -> tuple[bool, set[tuple[str, str]]]:
911
+ """Extract join pairs from Spark SQL using sqlglot."""
912
+ try:
913
+ pairs = _extract_join_pairs_sqlglot(sql)
914
+ return True, pairs
915
+ except Exception as e:
916
+ debug(f"[dialect.extract_join_pairs] failed: {e}")
917
+ return False, set()
918
+
919
+ def extract_cte_bodies(self, sql: str) -> dict[str, str]:
920
+ """Extract SQL body for each CTE from a WITH clause using sqlglot."""
921
+ cte_bodies = _extract_cte_bodies_sqlglot(sql)
922
+ if cte_bodies:
923
+ debug(
924
+ f"[dialect.extract_cte_bodies] sqlglot extracted {len(cte_bodies)} CTEs: {list(cte_bodies.keys())}"
925
+ )
926
+ return cte_bodies
927
+ return self._extract_cte_bodies_depth_tracking(sql)
928
+
929
+ def _extract_cte_bodies_depth_tracking(self, sql: str) -> dict[str, str]:
930
+ """Extract CTE bodies using depth-tracking character parser."""
931
+ cte_bodies: dict[str, str] = {}
932
+ sql_upper = sql.upper()
933
+ if not sql_upper.strip().startswith("WITH"):
934
+ return cte_bodies
935
+ with_match = re.match(r"\s*WITH\s+", sql, re.IGNORECASE)
936
+ if not with_match:
937
+ return cte_bodies
938
+ pos = with_match.end()
939
+ while pos < len(sql):
940
+ while pos < len(sql) and sql[pos] in " \t\n\r,":
941
+ pos += 1
942
+ if pos >= len(sql):
943
+ break
944
+ name_match = re.match(r"(\w+)\s+AS\s*\(", sql[pos:], re.IGNORECASE)
945
+ if not name_match:
946
+ break
947
+ cte_name = name_match.group(1).lower()
948
+ pos += name_match.end()
949
+ start_pos = pos
950
+ depth = 1
951
+ in_string = False
952
+ string_char = None
953
+ while pos < len(sql) and depth > 0:
954
+ c = sql[pos]
955
+ if in_string:
956
+ if c == string_char and (pos + 1 >= len(sql) or sql[pos + 1] != string_char):
957
+ in_string = False
958
+ elif c == string_char and pos + 1 < len(sql) and sql[pos + 1] == string_char:
959
+ pos += 1
960
+ else:
961
+ if c in ("'", '"'):
962
+ in_string = True
963
+ string_char = c
964
+ elif c == "(":
965
+ depth += 1
966
+ elif c == ")":
967
+ depth -= 1
968
+ pos += 1
969
+ if depth == 0:
970
+ cte_sql = sql[start_pos : pos - 1].strip()
971
+ cte_bodies[cte_name] = cte_sql
972
+ rest = sql[pos:].strip().upper()
973
+ if rest.startswith("SELECT") or rest.startswith("(SELECT"):
974
+ break
975
+ debug(
976
+ f"[dialect.extract_cte_bodies] depth-tracking extracted {len(cte_bodies)} CTEs: {list(cte_bodies.keys())}"
977
+ )
978
+ return cte_bodies
979
+
980
+ def explain_sql(self, engine: Any, sql: str, params: dict[str, Any] | None = None) -> tuple[bool, str]:
981
+ """Test SQL executability via Spark EXPLAIN."""
982
+ try:
983
+ qualified = self._qualify_table_references(sql)
984
+ explain_df = self.spark.sql(f"EXPLAIN {qualified}")
985
+ explain_df.collect()
986
+ return True, ""
987
+ except Exception as e:
988
+ return False, str(e)
989
+
990
+ def reflect_enums(self, engine: Any, schema_name: str) -> dict[str, list[str]]:
991
+ """Databricks does not support native enums."""
992
+ return {}
993
+
994
+ def prepare_for_execution(self, sql: str) -> str:
995
+ """Qualify table references with catalog.schema for Spark execution."""
996
+ return self._qualify_table_references(sql)
997
+
998
+ def _qualify_table_references(self, sql: str) -> str:
999
+ """Add catalog.schema prefix to physical table references only."""
1000
+ catalog = self.config.CATALOG
1001
+ schema_name = self.config.SCHEMA
1002
+
1003
+ pattern = r'\b(FROM|JOIN)\s+(["`]?)(\w+)\2\b'
1004
+
1005
+ def replace_table(match: re.Match[str]) -> str:
1006
+ keyword = match.group(1)
1007
+ quote = match.group(2) or "`"
1008
+ table = match.group(3)
1009
+ if table.lower() in DATABRICKS_TABLE_QUALIFY_SKIP_IDENTIFIERS:
1010
+ return match.group(0)
1011
+ return (
1012
+ f"{keyword} {quote}{catalog}{quote}.{quote}{schema_name}{quote}."
1013
+ f"{quote}{table}{quote}"
1014
+ )
1015
+
1016
+ return re.sub(pattern, replace_table, sql, flags=re.IGNORECASE)
1017
+
1018
+ def execute_sql_spark(self, sql: str, sql_already_spark: bool = False) -> list[tuple]:
1019
+ """Execute SQL on Databricks via native connector or Spark."""
1020
+ spark_sql = self._qualify_table_references(sql)
1021
+
1022
+ if self.config.DEBUG:
1023
+ print(f"[DEBUG] Executing on Databricks:\n{spark_sql}")
1024
+
1025
+ if self.connection is not None:
1026
+ cursor = self.connection.cursor()
1027
+ cursor.execute(spark_sql)
1028
+ rows = cursor.fetchall()
1029
+ cursor.close()
1030
+ return [tuple(row) for row in rows]
1031
+
1032
+ df = self.spark.sql(spark_sql)
1033
+ rows = df.collect()
1034
+
1035
+ return [tuple(row) for row in rows]
1036
+
1037
+
1038
+ def render_date_diff_expr(
1039
+ dialect_type: str, left_expr: str, op: str, unit: str, amount: int
1040
+ ) -> str:
1041
+ """Render a date-difference filter expression for dialect.
1042
+
1043
+ Args:
1044
+
1045
+ dialect_type: One of 'postgresql' or 'databricks'.
1046
+ left_expr: SQL expression for date subtraction (e.g. col1 - col2).
1047
+ op: Comparison operator.
1048
+ unit: 'day', 'week', 'month', or 'year'.
1049
+ amount: Numeric amount for the interval.
1050
+
1051
+ Returns:
1052
+
1053
+ SQL predicate string.
1054
+ """
1055
+ if dialect_type == "postgresql":
1056
+ interval_suffix = "s" if amount != 1 else ""
1057
+ return f"({left_expr}) {op} INTERVAL '{amount} {unit}{interval_suffix}'"
1058
+ if unit == "day":
1059
+ return f"({left_expr}) {op} {amount}"
1060
+ if unit == "week":
1061
+ return f"({left_expr}) {op} {amount * 7}"
1062
+ if unit == "month":
1063
+ return f"({left_expr}) {op} {amount * 30}"
1064
+ if unit == "year":
1065
+ return f"({left_expr}) {op} {amount * 365}"
1066
+ return f"({left_expr}) {op} {amount}"
1067
+
1068
+
1069
+ def render_date_window_expr(dialect_type: str, column: str, op: str, unit: str, offset: int) -> str:
1070
+ """Render a dialect-specific date window filter expression.
1071
+
1072
+ Args:
1073
+
1074
+ dialect_type: One of 'postgresql' or 'databricks'.
1075
+ column: Column reference string to compare against the date boundary.
1076
+ op: Comparison operator, for example '>=' or '<'.
1077
+ unit: Date unit: 'day', 'week', 'month', or 'year'.
1078
+ offset: Number of units in the past. 0 means the start of the current unit.
1079
+
1080
+ Returns:
1081
+
1082
+ SQL expression string suitable for use in a WHERE clause.
1083
+ """
1084
+ if offset == 0:
1085
+ if dialect_type == "postgresql":
1086
+ return f"{column} {op} DATE_TRUNC('{unit}', CURRENT_DATE)"
1087
+ return f"{column} {op} date_trunc('{unit}', current_date())"
1088
+ if dialect_type == "postgresql":
1089
+ suffix = "s" if offset != 1 else ""
1090
+ return f"{column} {op} CURRENT_DATE - INTERVAL '{offset} {unit}{suffix}'"
1091
+ if unit == "day":
1092
+ return f"{column} {op} date_sub(current_date(), {offset})"
1093
+ if unit == "week":
1094
+ return f"{column} {op} date_sub(current_date(), {offset * 7})"
1095
+ if unit == "month":
1096
+ return f"{column} {op} add_months(current_date(), -{offset})"
1097
+ if unit == "year":
1098
+ return f"{column} {op} add_months(current_date(), -{offset * 12})"
1099
+ if unit == "hour":
1100
+ return f"{column} {op} (current_timestamp() - INTERVAL '{offset} HOURS')"
1101
+ if unit == "minute":
1102
+ return f"{column} {op} (current_timestamp() - INTERVAL '{offset} MINUTES')"
1103
+ if unit == "second":
1104
+ return f"{column} {op} (current_timestamp() - INTERVAL '{offset} SECONDS')"
1105
+ return f"{column} {op} date_sub(current_date(), {offset})"
1106
+
1107
+
1108
+ def get_dialect(engine_type: str = None, config=None) -> Dialect:
1109
+ """Return a Dialect instance for the specified engine type.
1110
+
1111
+ Args:
1112
+
1113
+ engine_type: One of 'postgresql' or 'databricks'. Defaults to EngineConfig.TYPE.
1114
+ config: Runtime config object. Defaults to EngineConfig.RUNTIME.
1115
+
1116
+ Returns:
1117
+
1118
+ Configured Dialect subclass instance ready for use.
1119
+
1120
+ Raises:
1121
+
1122
+ ValueError: If engine_type is not a supported dialect name.
1123
+ """
1124
+ if engine_type is None:
1125
+ engine_type = EngineConfig.TYPE
1126
+ if config is None:
1127
+ config = EngineConfig.RUNTIME
1128
+
1129
+ if engine_type == "postgresql":
1130
+ return PostgresDialect(config)
1131
+ elif engine_type == "databricks":
1132
+ return DatabricksDialect(config)
1133
+ else:
1134
+ raise ValueError(f"Unsupported dialect: {engine_type}")