sqlprism 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,1031 @@
1
+ """SQL parser using sqlglot.
2
+
3
+ This is the richest parser in the system. sqlglot provides semantic analysis
4
+ beyond what tree-sitter can offer for SQL: CTE scope tracking, column-level
5
+ lineage via the Scope module, multi-dialect awareness, and proper resolution
6
+ of aliased references.
7
+
8
+ CTEs are tracked as first-class nodes, not flattened into the parent query.
9
+ """
10
+
11
+ from pathlib import Path
12
+
13
+ import sqlglot
14
+ from sqlglot import exp
15
+ from sqlglot.lineage import lineage as sqlglot_lineage
16
+ from sqlglot.optimizer.qualify_columns import qualify_columns
17
+ from sqlglot.optimizer.scope import build_scope
18
+
19
+ from sqlprism.types import (
20
+ ColumnLineageResult,
21
+ ColumnUsageResult,
22
+ EdgeResult,
23
+ LineageHop,
24
+ NodeResult,
25
+ ParseResult,
26
+ )
27
+
28
+
29
+ class SqlParser:
30
+ """Parses SQL files into nodes, edges, column usage, and column lineage using sqlglot.
31
+
32
+ Handles multi-statement files, CTE extraction, column-level scope analysis,
33
+ transform detection, and end-to-end column lineage tracing. Dialect-aware
34
+ identifier normalisation ensures consistent casing across Postgres, Snowflake,
35
+ DuckDB, and other engines.
36
+ """
37
+
38
+ # Dialects that fold unquoted identifiers to lowercase
39
+ _LOWERCASE_DIALECTS = frozenset({"postgres", "postgresql", "redshift", "duckdb"})
40
+ # Dialects that fold unquoted identifiers to uppercase
41
+ _UPPERCASE_DIALECTS = frozenset({"snowflake", "oracle", "db2"})
42
+
43
+ def __init__(self, dialect: str | None = None):
44
+ """Initialise with an optional SQL dialect.
45
+
46
+ Args:
47
+ dialect: sqlglot dialect string (e.g., 'postgres', 'mysql', 'duckdb').
48
+ None for auto-detection.
49
+ """
50
+ self.dialect = dialect
51
+
52
+ def parse(self, file_path: str, file_content: str, schema: dict | None = None) -> ParseResult:
53
+ """Parse a SQL file into nodes, edges, column usage, and column lineage.
54
+
55
+ Handles multiple statements per file. Each statement is parsed
56
+ independently. Errors in one statement don't prevent parsing others.
57
+
58
+ Args:
59
+ file_path: Path to the SQL file (used for naming nodes).
60
+ file_content: Raw SQL content.
61
+ schema: Optional schema catalog ``{table: {col: type}}`` for
62
+ expanding ``SELECT *`` in lineage tracing via
63
+ ``qualify_columns``.
64
+
65
+ Returns:
66
+ A ``ParseResult`` containing all extracted nodes, edges,
67
+ column usage records, column lineage chains, and any
68
+ non-fatal parse errors.
69
+ """
70
+ nodes: list[NodeResult] = []
71
+ edges: list[EdgeResult] = []
72
+ column_usage: list[ColumnUsageResult] = []
73
+ column_lineage: list[ColumnLineageResult] = []
74
+ errors: list[str] = []
75
+
76
+ file_stem = Path(file_path).stem
77
+
78
+ try:
79
+ statements = sqlglot.parse(file_content, dialect=self.dialect)
80
+ except (sqlglot.errors.ParseError, sqlglot.errors.TokenError) as e:
81
+ return ParseResult(language="sql", errors=[f"Parse error: {e}"])
82
+
83
+ # Persistent dedup sets across all statements in this file
84
+ seen_nodes: set[tuple[str, str, str | None]] = set()
85
+ seen_ctes: set[str] = set()
86
+
87
+ for stmt_idx, stmt in enumerate(statements):
88
+ if stmt is None:
89
+ continue
90
+
91
+ try:
92
+ self._process_statement(
93
+ stmt,
94
+ file_stem,
95
+ file_path,
96
+ nodes,
97
+ edges,
98
+ column_usage,
99
+ seen_nodes=seen_nodes,
100
+ seen_ctes=seen_ctes,
101
+ )
102
+ except Exception as e:
103
+ errors.append(f"Statement {stmt_idx}: {type(e).__name__}: {e}")
104
+ continue
105
+
106
+ # Column lineage via sqlglot.lineage — separate pass
107
+ try:
108
+ self._extract_column_lineage(stmt, file_stem, file_content, column_lineage, schema=schema)
109
+ except Exception as e:
110
+ errors.append(f"Lineage stmt {stmt_idx}: {type(e).__name__}: {e}")
111
+ continue
112
+
113
+ return ParseResult(
114
+ language="sql",
115
+ nodes=nodes,
116
+ edges=edges,
117
+ column_usage=column_usage,
118
+ column_lineage=column_lineage,
119
+ errors=errors,
120
+ )
121
+
122
+ def _process_statement(
123
+ self,
124
+ stmt: exp.Expression,
125
+ file_stem: str,
126
+ file_path: str,
127
+ nodes: list[NodeResult],
128
+ edges: list[EdgeResult],
129
+ column_usage: list[ColumnUsageResult],
130
+ seen_nodes: set[tuple[str, str, str | None]] | None = None,
131
+ seen_ctes: set[str] | None = None,
132
+ ) -> None:
133
+ """Process a single SQL statement."""
134
+ # Use persistent dedup sets across statements, or create fresh ones
135
+ if seen_nodes is None:
136
+ seen_nodes = {(n.name, n.kind, (n.metadata or {}).get("schema")) for n in nodes}
137
+ seen_edges: set[tuple[str, str, str]] = set()
138
+
139
+ # CREATE TABLE / CREATE VIEW
140
+ if isinstance(stmt, exp.Create):
141
+ self._process_create(stmt, file_stem, nodes, edges)
142
+
143
+ # Extract table references from any statement type
144
+ self._extract_table_references(stmt, file_stem, nodes, edges, seen_nodes, seen_edges)
145
+
146
+ # Extract CTEs as first-class nodes
147
+ self._extract_ctes(stmt, file_stem, nodes, edges, seen_ctes=seen_ctes)
148
+
149
+ # Column-level lineage via sqlglot's scope analysis
150
+ self._extract_column_usage(stmt, file_stem, nodes, column_usage)
151
+
152
+ # INSERT...SELECT column mapping
153
+ if isinstance(stmt, exp.Insert):
154
+ self._extract_insert_select_mapping(stmt, file_stem, column_usage)
155
+
156
+ def _process_create(
157
+ self,
158
+ stmt: exp.Create,
159
+ file_stem: str,
160
+ nodes: list[NodeResult],
161
+ edges: list[EdgeResult],
162
+ ) -> None:
163
+ """Handle CREATE TABLE / CREATE VIEW statements."""
164
+ kind_expr = stmt.args.get("kind")
165
+ if not kind_expr:
166
+ return
167
+
168
+ kind_str = kind_expr.upper() if isinstance(kind_expr, str) else str(kind_expr).upper()
169
+
170
+ table_expr = stmt.this
171
+ if not isinstance(table_expr, exp.Table):
172
+ # Could be a Schema wrapping a Table
173
+ if isinstance(table_expr, exp.Schema):
174
+ table_expr = table_expr.this
175
+ if not isinstance(table_expr, exp.Table):
176
+ return
177
+
178
+ name = self._normalize_identifier(table_expr.name, self._is_quoted_identifier(table_expr))
179
+ if not name:
180
+ return
181
+
182
+ node_kind = "view" if "VIEW" in kind_str else "table"
183
+ metadata = self._build_table_metadata(table_expr)
184
+ metadata["dialect"] = self.dialect
185
+ metadata["create_type"] = kind_str
186
+
187
+ nodes.append(
188
+ NodeResult(
189
+ kind=node_kind,
190
+ name=name,
191
+ line_start=None, # sqlglot doesn't track line numbers reliably
192
+ metadata=metadata,
193
+ )
194
+ )
195
+ edges.append(
196
+ EdgeResult(
197
+ source_name=file_stem,
198
+ source_kind="query",
199
+ target_name=name,
200
+ target_kind=node_kind,
201
+ relationship="defines",
202
+ context="CREATE statement",
203
+ )
204
+ )
205
+
206
+ def _extract_table_references(
207
+ self,
208
+ stmt: exp.Expression,
209
+ file_stem: str,
210
+ nodes: list[NodeResult],
211
+ edges: list[EdgeResult],
212
+ seen_nodes: set[tuple[str, str, str | None]] | None = None,
213
+ seen_edges: set[tuple[str, str, str]] | None = None,
214
+ ) -> None:
215
+ """Extract all table references from a statement."""
216
+ if seen_nodes is None:
217
+ seen_nodes = {(n.name, n.kind, (n.metadata or {}).get("schema")) for n in nodes}
218
+ if seen_edges is None:
219
+ seen_edges = set()
220
+
221
+ # Identify the CREATE target so we don't double-count it as a reference
222
+ create_target: str | None = None
223
+ if isinstance(stmt, exp.Create):
224
+ target_expr = stmt.this
225
+ if isinstance(target_expr, exp.Schema):
226
+ target_expr = target_expr.this
227
+ if isinstance(target_expr, exp.Table) and target_expr.name:
228
+ create_target = self._normalize_identifier(
229
+ target_expr.name,
230
+ self._is_quoted_identifier(target_expr),
231
+ )
232
+
233
+ for table in stmt.find_all(exp.Table):
234
+ name = self._normalize_identifier(table.name, self._is_quoted_identifier(table))
235
+ if not name:
236
+ continue
237
+
238
+ # Skip the CREATE target — it's already handled by _process_create
239
+ if name == create_target:
240
+ # Check if this is the actual CREATE target (direct child of Create/Schema)
241
+ parent = table.parent
242
+ if isinstance(parent, (exp.Create, exp.Schema)):
243
+ continue
244
+
245
+ # Avoid duplicating nodes for the same table+schema within one file (O(1) check)
246
+ metadata = self._build_table_metadata(table)
247
+ table_schema = metadata.get("schema")
248
+ node_key = (name, "table", table_schema)
249
+ if node_key not in seen_nodes:
250
+ seen_nodes.add(node_key)
251
+ nodes.append(NodeResult(kind="table", name=name, metadata=metadata or None))
252
+
253
+ # Determine context from parent expression
254
+ context = self._get_table_context(table)
255
+
256
+ relationship = "inserts_into" if isinstance(stmt, exp.Insert) else "references"
257
+
258
+ # Skip duplicate edges with the same (source, target, context)
259
+ edge_key = (file_stem, name, context)
260
+ if edge_key in seen_edges:
261
+ continue
262
+ seen_edges.add(edge_key)
263
+
264
+ edges.append(
265
+ EdgeResult(
266
+ source_name=file_stem,
267
+ source_kind="query",
268
+ target_name=name,
269
+ target_kind="table",
270
+ relationship=relationship,
271
+ context=context,
272
+ )
273
+ )
274
+
275
+ def _extract_ctes(
276
+ self,
277
+ stmt: exp.Expression,
278
+ file_stem: str,
279
+ nodes: list[NodeResult],
280
+ edges: list[EdgeResult],
281
+ seen_ctes: set[str] | None = None,
282
+ ) -> None:
283
+ """Extract CTEs as first-class nodes with their own edges.
284
+
285
+ When a CTE references another CTE from the same statement, the edge
286
+ uses target_kind='cte' so trace queries follow CTE chains correctly.
287
+
288
+ Args:
289
+ seen_ctes: Set of CTE names already added across statements in this file.
290
+ Used to deduplicate CTEs with the same name across statements.
291
+ """
292
+ if seen_ctes is None:
293
+ seen_ctes = set()
294
+
295
+ # Collect all CTE names in this statement first
296
+ cte_names: set[str] = set()
297
+ for cte in stmt.find_all(exp.CTE):
298
+ if cte.alias:
299
+ alias_node = cte.args.get("alias")
300
+ quoted = self._is_quoted_identifier(alias_node) if alias_node else False
301
+ cte_names.add(self._normalize_identifier(cte.alias, quoted))
302
+
303
+ for cte in stmt.find_all(exp.CTE):
304
+ alias_node = cte.args.get("alias")
305
+ cte_quoted = self._is_quoted_identifier(alias_node) if alias_node else False
306
+ cte_name = self._normalize_identifier(cte.alias, cte_quoted) if cte.alias else None
307
+ if not cte_name:
308
+ continue
309
+
310
+ # Deduplicate CTEs across statements in the same file
311
+ if cte_name in seen_ctes:
312
+ continue
313
+ seen_ctes.add(cte_name)
314
+
315
+ nodes.append(
316
+ NodeResult(
317
+ kind="cte",
318
+ name=cte_name,
319
+ metadata={"parent_query": file_stem},
320
+ )
321
+ )
322
+
323
+ # Find tables referenced within this CTE
324
+ for table in cte.find_all(exp.Table):
325
+ table_name = self._normalize_identifier(
326
+ table.name,
327
+ self._is_quoted_identifier(table),
328
+ )
329
+ if not table_name or table_name == cte_name:
330
+ continue
331
+
332
+ # If the reference is to another CTE, use target_kind='cte'
333
+ target_kind = "cte" if table_name in cte_names else "table"
334
+
335
+ edges.append(
336
+ EdgeResult(
337
+ source_name=cte_name,
338
+ source_kind="cte",
339
+ target_name=table_name,
340
+ target_kind=target_kind,
341
+ relationship="cte_references",
342
+ context=self._get_table_context(table),
343
+ )
344
+ )
345
+
346
+ def _extract_column_usage(
347
+ self,
348
+ stmt: exp.Expression,
349
+ file_stem: str,
350
+ nodes: list[NodeResult],
351
+ column_usage: list[ColumnUsageResult],
352
+ ) -> None:
353
+ """Extract column-level usage via sqlglot's scope analysis.
354
+
355
+ This is where sqlglot's investment pays off — scope-aware column
356
+ resolution that understands aliases, CTEs, and subqueries.
357
+ Also captures wrapping transforms (CAST, COALESCE, etc.) and
358
+ extracts WHERE clause filters as node metadata.
359
+ """
360
+ # Only works on SELECT-like statements
361
+ select = stmt
362
+ if not isinstance(stmt, (exp.Select, exp.Union)):
363
+ select = stmt.find(exp.Select)
364
+ if select is None:
365
+ return
366
+
367
+ try:
368
+ root_scope = build_scope(select)
369
+ except Exception:
370
+ return
371
+
372
+ if root_scope is None:
373
+ return
374
+
375
+ seen_scopes = set()
376
+ for scope in [root_scope] + list(root_scope.traverse()):
377
+ scope_id = id(scope)
378
+ if scope_id in seen_scopes:
379
+ continue
380
+ seen_scopes.add(scope_id)
381
+
382
+ # Determine scope name
383
+ scope_name = file_stem
384
+ scope_kind = "query"
385
+ parent_expr = scope.expression.parent
386
+ if scope.is_cte:
387
+ # Extract CTE name from the expression's parent
388
+ if isinstance(parent_expr, exp.CTE) and parent_expr.alias:
389
+ alias_node = parent_expr.args.get("alias")
390
+ quoted = self._is_quoted_identifier(alias_node) if alias_node else False
391
+ scope_name = self._normalize_identifier(parent_expr.alias, quoted)
392
+ scope_kind = "cte"
393
+ elif isinstance(parent_expr, exp.Subquery) and parent_expr.alias:
394
+ # Derived table (subquery in FROM/JOIN)
395
+ alias_node = parent_expr.args.get("alias")
396
+ quoted = self._is_quoted_identifier(alias_node) if alias_node else False
397
+ scope_name = self._normalize_identifier(parent_expr.alias, quoted)
398
+ scope_kind = "subquery"
399
+ # Create a node for the subquery alias so column_usage can resolve
400
+ nodes.append(
401
+ NodeResult(
402
+ kind="subquery",
403
+ name=scope_name,
404
+ metadata={"parent_query": file_stem},
405
+ )
406
+ )
407
+ elif isinstance(parent_expr, exp.Create):
408
+ # Root scope inside CREATE TABLE/VIEW — use the table name
409
+ table_expr = parent_expr.this
410
+ if isinstance(table_expr, exp.Schema):
411
+ table_expr = table_expr.this
412
+ if isinstance(table_expr, exp.Table) and table_expr.name:
413
+ scope_name = self._normalize_identifier(
414
+ table_expr.name,
415
+ self._is_quoted_identifier(table_expr),
416
+ )
417
+ elif scope_kind == "query" and (scope_name, "query", None) not in {
418
+ (n.name, n.kind, (n.metadata or {}).get("schema")) for n in nodes
419
+ }:
420
+ # Bare SELECT root scope — create a query node so column_usage resolves
421
+ nodes.append(NodeResult(kind="query", name=scope_name, metadata={"bare_query": True}))
422
+
423
+ # Build alias → real table name mapping
424
+ alias_map: dict[str, str] = {}
425
+ for source_name, source in scope.sources.items():
426
+ if isinstance(source, exp.Table):
427
+ alias_map[source_name] = self._normalize_identifier(
428
+ source.name,
429
+ self._is_quoted_identifier(source),
430
+ )
431
+
432
+ # When there's exactly one source and no table qualifier, infer the table
433
+ single_table = ""
434
+ if len(alias_map) == 1:
435
+ single_table = next(iter(alias_map.values()))
436
+
437
+ for col in scope.columns:
438
+ if not isinstance(col, exp.Column):
439
+ continue
440
+ col_name = self._normalize_identifier(col.name, self._is_quoted_identifier(col))
441
+ if not col_name:
442
+ continue
443
+
444
+ # Resolve alias to real table name
445
+ table_alias = col.table or ""
446
+ table_name = alias_map.get(table_alias, table_alias)
447
+ if not table_name and single_table:
448
+ table_name = single_table
449
+
450
+ usage_type = self._classify_column_context(col)
451
+ transform = self._extract_transform(col)
452
+ alias = self._extract_alias(col)
453
+
454
+ column_usage.append(
455
+ ColumnUsageResult(
456
+ node_name=scope_name,
457
+ node_kind=scope_kind,
458
+ table_name=table_name,
459
+ column_name=col_name,
460
+ usage_type=usage_type,
461
+ alias=alias,
462
+ transform=transform,
463
+ )
464
+ )
465
+
466
+ # Handle SELECT * — emit usage for each source table
467
+ select_expr = scope.expression
468
+ if isinstance(select_expr, exp.Select):
469
+ for expr in select_expr.expressions:
470
+ if isinstance(expr, exp.Star):
471
+ # Unqualified * — emit for each source table
472
+ for source_name, source in scope.sources.items():
473
+ table_name = source.name if isinstance(source, exp.Table) else source_name
474
+ column_usage.append(
475
+ ColumnUsageResult(
476
+ node_name=scope_name,
477
+ node_kind=scope_kind,
478
+ table_name=table_name,
479
+ column_name="*",
480
+ usage_type="select",
481
+ )
482
+ )
483
+ elif isinstance(expr, exp.Column) and isinstance(expr.this, exp.Star):
484
+ # Qualified table.* — emit for that specific table
485
+ table_alias = expr.table or ""
486
+ table_name = alias_map.get(table_alias, table_alias)
487
+ column_usage.append(
488
+ ColumnUsageResult(
489
+ node_name=scope_name,
490
+ node_kind=scope_kind,
491
+ table_name=table_name,
492
+ column_name="*",
493
+ usage_type="select",
494
+ )
495
+ )
496
+
497
+ # Extract WHERE filters as metadata on the scope's node
498
+ self._extract_where_filters(scope, scope_name, scope_kind, nodes)
499
+
500
+ def _classify_column_context(self, col: exp.Column) -> str:
501
+ """Determine how a column is used based on its AST position.
502
+
503
+ Distinguishes window function sub-clauses (PARTITION BY, window ORDER BY)
504
+ from regular usage types.
505
+ """
506
+ parent = col.parent
507
+
508
+ while parent:
509
+ # Window function sub-clauses — check before general Order
510
+ if isinstance(parent, exp.Window):
511
+ # Determine if column is in PARTITION BY or ORDER BY within window
512
+ return self._classify_window_position(col, parent)
513
+ if isinstance(parent, exp.Where):
514
+ return "where"
515
+ if isinstance(parent, exp.Join):
516
+ return "join_on"
517
+ if isinstance(parent, exp.Group):
518
+ return "group_by"
519
+ if isinstance(parent, exp.Order):
520
+ # Check if this Order is inside a Window (window ORDER BY)
521
+ order_parent = parent.parent
522
+ if isinstance(order_parent, exp.Window):
523
+ return "window_order"
524
+ return "order_by"
525
+ if isinstance(parent, exp.Having):
526
+ return "having"
527
+ if isinstance(parent, exp.Qualify):
528
+ return "qualify"
529
+ if isinstance(parent, exp.Select):
530
+ return "select"
531
+ parent = parent.parent
532
+
533
+ return "unknown"
534
+
535
+ def _classify_window_position(self, col: exp.Column, window: exp.Window) -> str:
536
+ """Classify a column's position within a window function."""
537
+ # Walk from column up to the window, checking if we pass through
538
+ # partition_by or order clause
539
+ parent = col.parent
540
+ while parent and parent is not window:
541
+ if isinstance(parent, exp.Order):
542
+ return "window_order"
543
+ parent = parent.parent
544
+
545
+ # Check if column is in the partition_by list
546
+ partition_by = window.args.get("partition_by")
547
+ if partition_by:
548
+ for partition_col in partition_by:
549
+ if col in partition_col.walk():
550
+ return "partition_by"
551
+
552
+ return "select" # fallback — column is in the aggregate part of the window
553
+
554
+ def _extract_transform(self, col: exp.Column) -> str | None:
555
+ """Extract the wrapping transform expression around a column.
556
+
557
+ Walks up from the Column node to find wrapping functions like
558
+ CAST, COALESCE, IF, CASE, arithmetic, etc. Returns the SQL string
559
+ of the outermost meaningful wrapper, or None if the column is bare.
560
+ """
561
+ # Wrapping expression types that constitute a "transform"
562
+ transform_types = (
563
+ exp.Cast,
564
+ exp.TryCast,
565
+ exp.Coalesce,
566
+ exp.If,
567
+ exp.Case,
568
+ exp.Anonymous, # function calls like NVL, IFNULL, etc.
569
+ exp.Func, # base class for all functions (UPPER, LOWER, etc.)
570
+ exp.Add,
571
+ exp.Sub,
572
+ exp.Mul,
573
+ exp.Div,
574
+ exp.Mod,
575
+ exp.Concat,
576
+ exp.DPipe, # || concat operator
577
+ exp.Substring,
578
+ exp.Trim,
579
+ exp.Extract, # EXTRACT(YEAR FROM ...)
580
+ exp.DateAdd,
581
+ exp.DateSub,
582
+ exp.DateDiff,
583
+ exp.Between,
584
+ exp.In,
585
+ exp.Like,
586
+ exp.Neg, # unary minus
587
+ )
588
+
589
+ # Comparison types — include as transforms but don't traverse past
590
+ comparison_types = (
591
+ exp.EQ,
592
+ exp.NEQ,
593
+ exp.GT,
594
+ exp.GTE,
595
+ exp.LT,
596
+ exp.LTE,
597
+ exp.Is,
598
+ exp.Not,
599
+ )
600
+
601
+ parent = col.parent
602
+ outermost = None
603
+
604
+ while parent:
605
+ if isinstance(parent, transform_types):
606
+ outermost = parent
607
+ elif isinstance(parent, comparison_types):
608
+ outermost = parent
609
+ break # comparisons are the natural boundary for WHERE/JOIN
610
+ elif isinstance(parent, (exp.And, exp.Or)):
611
+ break # don't capture the full AND/OR chain
612
+ elif isinstance(
613
+ parent,
614
+ (
615
+ exp.Select,
616
+ exp.Where,
617
+ exp.Group,
618
+ exp.Order,
619
+ exp.Having,
620
+ exp.Join,
621
+ exp.From,
622
+ exp.Subquery,
623
+ exp.CTE,
624
+ ),
625
+ ):
626
+ # Stop at clause boundaries
627
+ break
628
+ parent = parent.parent
629
+
630
+ if outermost is None:
631
+ return None
632
+
633
+ try:
634
+ sql = outermost.sql(dialect=self.dialect)
635
+ # Skip if the transform is just the column itself
636
+ col_sql = col.sql(dialect=self.dialect)
637
+ if sql == col_sql:
638
+ return None
639
+ return sql
640
+ except Exception:
641
+ return None
642
+
643
+ def _extract_alias(self, col: exp.Column) -> str | None:
644
+ """Extract the output alias for a column (AS name)."""
645
+ parent = col.parent
646
+ while parent:
647
+ if isinstance(parent, exp.Alias):
648
+ return parent.alias
649
+ if isinstance(parent, (exp.Select, exp.Where, exp.Group, exp.Order, exp.Having)):
650
+ break
651
+ parent = parent.parent
652
+ return None
653
+
654
+ def _extract_where_filters(
655
+ self,
656
+ scope,
657
+ scope_name: str,
658
+ scope_kind: str,
659
+ nodes: list[NodeResult],
660
+ ) -> None:
661
+ """Extract WHERE clause conditions and attach as metadata to the scope's node.
662
+
663
+ Finds the WHERE clause in the scope expression and extracts each
664
+ top-level condition as a string. These are stored as node metadata
665
+ so they're searchable in the graph.
666
+ """
667
+ try:
668
+ # Use .args["where"] to get only the direct WHERE, not from subqueries
669
+ where = scope.expression.args.get("where")
670
+ except Exception:
671
+ return
672
+
673
+ if not where:
674
+ return
675
+
676
+ filters = []
677
+ # Split AND conditions into individual filters
678
+ conditions = self._split_conditions(where.this)
679
+ for cond in conditions:
680
+ try:
681
+ sql = cond.sql(dialect=self.dialect)
682
+ if sql and len(sql) < 500: # skip absurdly long conditions
683
+ filters.append(sql)
684
+ except Exception:
685
+ continue
686
+
687
+ if not filters:
688
+ return
689
+
690
+ # Find the matching node and update its metadata
691
+ # Try exact match first, then match by name only (handles query→table/view mapping)
692
+ # Use enumerate to avoid O(N) nodes.index() and wrong-match-on-duplicates bug
693
+ for idx, node in enumerate(nodes):
694
+ if node.name == scope_name and (node.kind == scope_kind or node.kind in ("table", "view", "cte")):
695
+ existing_meta = dict(node.metadata) if node.metadata else {}
696
+ existing_meta["filters"] = filters
697
+ # NodeResult is frozen, so we need to replace it
698
+ nodes[idx] = NodeResult(
699
+ kind=node.kind,
700
+ name=node.name,
701
+ line_start=node.line_start,
702
+ line_end=node.line_end,
703
+ metadata=existing_meta,
704
+ )
705
+ return
706
+
707
+ def _split_conditions(self, expr: exp.Expression) -> list[exp.Expression]:
708
+ """Split an AND chain into individual conditions."""
709
+ if isinstance(expr, exp.And):
710
+ return self._split_conditions(expr.left) + self._split_conditions(expr.right)
711
+ return [expr]
712
+
713
+ def _extract_column_lineage(
714
+ self,
715
+ stmt: exp.Expression,
716
+ file_stem: str,
717
+ file_content: str,
718
+ column_lineage: list[ColumnLineageResult],
719
+ schema: dict | None = None,
720
+ ) -> None:
721
+ """Extract end-to-end column lineage using sqlglot.lineage.lineage().
722
+
723
+ Traces each output column through CTEs and subqueries back to source tables.
724
+ If a schema catalog is provided, it's passed to sqlglot_lineage to help
725
+ resolve SELECT * and improve lineage accuracy.
726
+ """
727
+ # Find the output SELECT to get column names
728
+ select = stmt
729
+ output_name = file_stem
730
+
731
+ if isinstance(stmt, exp.Create):
732
+ # Get the CREATE target name
733
+ table_expr = stmt.this
734
+ if isinstance(table_expr, exp.Schema):
735
+ table_expr = table_expr.this
736
+ if isinstance(table_expr, exp.Table) and table_expr.name:
737
+ output_name = table_expr.name
738
+ select = stmt.find(exp.Select)
739
+ elif not isinstance(stmt, (exp.Select, exp.Union)):
740
+ select = stmt.find(exp.Select)
741
+
742
+ if select is None:
743
+ return
744
+
745
+ # If schema available, try qualify_columns to expand SELECT *
746
+ qualified_stmt = stmt
747
+ if schema:
748
+ try:
749
+ qualified_stmt = qualify_columns(stmt.copy(), schema=schema, dialect=self.dialect)
750
+ # Re-find the select from the qualified version
751
+ if isinstance(qualified_stmt, exp.Create):
752
+ select = qualified_stmt.find(exp.Select)
753
+ elif isinstance(qualified_stmt, (exp.Select, exp.Union)):
754
+ select = qualified_stmt
755
+ else:
756
+ select = qualified_stmt.find(exp.Select)
757
+ if select is None:
758
+ return
759
+ except Exception:
760
+ pass # fall back to unqualified
761
+
762
+ # Get output column names from the SELECT
763
+ # For UNION, enumerate output columns from ALL branches
764
+ if isinstance(select, exp.Union):
765
+ output_cols = []
766
+ seen_cols: set[str] = set()
767
+ for branch_select in select.find_all(exp.Select):
768
+ for expr in branch_select.expressions:
769
+ col_name = None
770
+ if isinstance(expr, exp.Alias):
771
+ col_name = expr.alias
772
+ elif isinstance(expr, exp.Column):
773
+ col_name = expr.name
774
+ elif isinstance(expr, exp.Star):
775
+ col_name = "*"
776
+ if col_name and col_name not in seen_cols:
777
+ seen_cols.add(col_name)
778
+ output_cols.append(col_name)
779
+ elif isinstance(select, exp.Select):
780
+ output_cols = []
781
+ for expr in select.expressions:
782
+ if isinstance(expr, exp.Alias):
783
+ output_cols.append(expr.alias)
784
+ elif isinstance(expr, exp.Column):
785
+ output_cols.append(expr.name)
786
+ elif isinstance(expr, exp.Star):
787
+ # SELECT * — can't trace individual columns without schema
788
+ output_cols.append("*")
789
+ else:
790
+ # Complex expression without alias — skip
791
+ continue
792
+ else:
793
+ return
794
+
795
+ # Trace each output column — pass AST directly to avoid re-serializing
796
+ for col_name in output_cols:
797
+ if col_name == "*":
798
+ # Can't trace SELECT * without schema catalog
799
+ continue
800
+ try:
801
+ root = sqlglot_lineage(
802
+ col_name,
803
+ qualified_stmt,
804
+ dialect=self.dialect,
805
+ schema=schema,
806
+ )
807
+ except Exception:
808
+ continue
809
+
810
+ # Walk the lineage tree to build hop chains
811
+ chains = self._walk_lineage_tree(root, [])
812
+ for chain in chains:
813
+ if chain: # skip empty chains
814
+ column_lineage.append(
815
+ ColumnLineageResult(
816
+ output_column=col_name,
817
+ output_node=output_name,
818
+ chain=chain,
819
+ )
820
+ )
821
+
822
+ def _walk_lineage_tree(
823
+ self,
824
+ node,
825
+ current_chain: list[LineageHop],
826
+ max_depth: int = 50,
827
+ max_chains: int = 1000,
828
+ _chain_count: list | None = None,
829
+ ) -> list[list[LineageHop]]:
830
+ """Recursively walk a sqlglot lineage node tree into flat chains.
831
+
832
+ Each leaf produces one complete chain from output to source.
833
+
834
+ Args:
835
+ node: Current lineage node.
836
+ current_chain: Chain built so far.
837
+ max_depth: Maximum recursion depth before treating node as leaf.
838
+ max_chains: Maximum total chains to collect before stopping early.
839
+ _chain_count: Mutable counter shared across recursion to track total chains.
840
+ """
841
+ if _chain_count is None:
842
+ _chain_count = [0]
843
+
844
+ # Stop if depth or chain limit exceeded — treat current node as leaf
845
+ if len(current_chain) >= max_depth or _chain_count[0] >= max_chains:
846
+ return [current_chain] if current_chain else []
847
+
848
+ # Extract info from this node
849
+ name = node.name if hasattr(node, "name") else ""
850
+ source = node.source.sql() if hasattr(node, "source") and node.source else ""
851
+ expr_str = (
852
+ node.expression.sql(dialect=self.dialect) if hasattr(node, "expression") and node.expression else None
853
+ )
854
+
855
+ # Parse column and table from the node name (format: "table.column" or just "column")
856
+ parts = name.split(".") if name else []
857
+ hop_column = parts[-1] if parts else name
858
+ hop_table = parts[-2] if len(parts) >= 2 else ""
859
+
860
+ # If no table from name, try to extract from source
861
+ if not hop_table and source:
862
+ # Source often looks like "table AS alias" or just "table"
863
+ source_parts = source.strip().split()
864
+ if source_parts:
865
+ hop_table = source_parts[0].strip('"').strip("'")
866
+
867
+ hop = LineageHop(
868
+ column=hop_column,
869
+ table=hop_table,
870
+ expression=expr_str if expr_str and expr_str != hop_column else None,
871
+ )
872
+
873
+ new_chain = current_chain + [hop]
874
+
875
+ downstream = node.downstream if hasattr(node, "downstream") else []
876
+ if not downstream:
877
+ # Leaf node — return the completed chain
878
+ _chain_count[0] += 1
879
+ return [new_chain]
880
+
881
+ # Recurse into downstream nodes
882
+ all_chains = []
883
+ for child in downstream:
884
+ if _chain_count[0] >= max_chains:
885
+ break
886
+ all_chains.extend(self._walk_lineage_tree(child, new_chain, max_depth, max_chains, _chain_count))
887
+ return all_chains
888
+
889
+ def _extract_insert_select_mapping(
890
+ self,
891
+ stmt: exp.Insert,
892
+ file_stem: str,
893
+ column_usage: list[ColumnUsageResult],
894
+ ) -> None:
895
+ """Extract positional column mapping from INSERT...SELECT.
896
+
897
+ When INSERT INTO target (a, b) SELECT x, y FROM source,
898
+ maps source column x -> target column a, y -> b by position.
899
+ """
900
+ # Get the target table name
901
+ target_table = stmt.this
902
+ if isinstance(target_table, exp.Schema):
903
+ # INSERT INTO table (col1, col2) — columns are Identifier nodes
904
+ target_cols = [col.name for col in target_table.expressions if hasattr(col, "name")]
905
+ target_table = target_table.this
906
+ else:
907
+ target_cols = []
908
+
909
+ if not isinstance(target_table, exp.Table) or not target_table.name:
910
+ return
911
+
912
+ target_name = target_table.name
913
+
914
+ # Get the SELECT statement
915
+ select = stmt.expression
916
+ if not isinstance(select, exp.Select):
917
+ return
918
+
919
+ # Get SELECT expressions (output columns)
920
+ select_exprs = select.expressions
921
+ if not select_exprs:
922
+ return
923
+
924
+ # Build alias → real table name mapping from the SELECT's FROM/JOIN sources
925
+ alias_map: dict[str, str] = {}
926
+ for table_ref in select.find_all(exp.Table):
927
+ tbl_name = self._normalize_identifier(
928
+ table_ref.name,
929
+ self._is_quoted_identifier(table_ref),
930
+ )
931
+ if tbl_name:
932
+ alias_map[tbl_name] = tbl_name
933
+ if table_ref.alias:
934
+ alias_map[table_ref.alias] = tbl_name
935
+
936
+ # Map each SELECT expression to its target column by position
937
+ for i, select_expr in enumerate(select_exprs):
938
+ target_col = target_cols[i] if i < len(target_cols) else None
939
+
940
+ # Find the source column in this expression
941
+ source_cols = list(select_expr.find_all(exp.Column))
942
+ for src_col in source_cols:
943
+ if not src_col.name:
944
+ continue
945
+
946
+ # Resolve table alias to real table name
947
+ table_alias = src_col.table or ""
948
+ source_table = alias_map.get(table_alias, table_alias)
949
+
950
+ column_usage.append(
951
+ ColumnUsageResult(
952
+ node_name=file_stem,
953
+ node_kind="query",
954
+ table_name=source_table or target_name,
955
+ column_name=src_col.name,
956
+ usage_type="insert",
957
+ transform=self._extract_transform(src_col),
958
+ alias=target_col,
959
+ )
960
+ )
961
+
962
+ def _normalize_identifier(self, name: str, quoted: bool = False) -> str:
963
+ """Normalize an identifier based on the SQL dialect's case folding rules.
964
+
965
+ Unquoted identifiers are folded: lowercase for Postgres/Redshift/DuckDB,
966
+ uppercase for Snowflake/Oracle. Other dialects preserve case.
967
+
968
+ Quoted identifiers are never folded — they preserve the exact case
969
+ the user wrote.
970
+ """
971
+ if not name or not self.dialect or quoted:
972
+ return name
973
+ d = self.dialect.lower()
974
+ if d in self._LOWERCASE_DIALECTS:
975
+ return name.lower()
976
+ if d in self._UPPERCASE_DIALECTS:
977
+ return name.upper()
978
+ return name
979
+
980
+ @staticmethod
981
+ def _is_quoted_identifier(node: exp.Expression) -> bool:
982
+ """Check whether a sqlglot expression's name identifier is quoted.
983
+
984
+ Works for Table (node.this is Identifier), Column (node.this is Identifier),
985
+ CTE/Subquery aliases (via TableAlias wrapping an Identifier), etc.
986
+ """
987
+ ident = node.this if hasattr(node, "this") else None
988
+ if isinstance(ident, exp.Identifier):
989
+ return bool(ident.quoted)
990
+ return False
991
+
992
+ def _build_table_metadata(self, table: exp.Table) -> dict:
993
+ """Build metadata dict with catalog/schema from a qualified table reference.
994
+
995
+ Catalog and schema values are normalized using the same dialect-aware
996
+ case folding as table/column names. Quoted identifiers keep their
997
+ original case.
998
+ """
999
+ meta: dict = {}
1000
+ if table.catalog:
1001
+ catalog_node = table.args.get("catalog")
1002
+ quoted = isinstance(catalog_node, exp.Identifier) and bool(catalog_node.quoted)
1003
+ meta["catalog"] = self._normalize_identifier(table.catalog, quoted)
1004
+ if table.db:
1005
+ db_node = table.args.get("db")
1006
+ quoted = isinstance(db_node, exp.Identifier) and bool(db_node.quoted)
1007
+ meta["schema"] = self._normalize_identifier(table.db, quoted)
1008
+ return meta
1009
+
1010
+ def _get_table_context(self, table: exp.Table) -> str:
1011
+ """Determine context of a table reference from its AST position."""
1012
+ parent = table.parent
1013
+
1014
+ while parent:
1015
+ if isinstance(parent, exp.Join):
1016
+ return "JOIN clause"
1017
+ if isinstance(parent, exp.From):
1018
+ return "FROM clause"
1019
+ if isinstance(parent, exp.Subquery):
1020
+ return "subquery"
1021
+ if isinstance(parent, exp.Insert):
1022
+ return "INSERT INTO"
1023
+ if isinstance(parent, exp.Merge):
1024
+ return "MERGE target"
1025
+ if isinstance(parent, exp.Update):
1026
+ return "UPDATE target"
1027
+ if isinstance(parent, exp.Lateral):
1028
+ return "LATERAL subquery"
1029
+ parent = parent.parent
1030
+
1031
+ return "FROM clause"