sql-glider 0.1.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,1183 @@
1
+ """Core lineage analysis using SQLGlot."""
2
+
3
+ from enum import Enum
4
+ from typing import Callable, Iterator, List, Optional, Set, Tuple
5
+
6
+ from pydantic import BaseModel, Field
7
+ from sqlglot import exp, parse
8
+ from sqlglot.errors import ParseError
9
+ from sqlglot.lineage import Node, lineage
10
+
11
+ from sqlglider.global_models import AnalysisLevel
12
+
13
+
14
+ class TableUsage(str, Enum):
15
+ """How a table is used in a query."""
16
+
17
+ INPUT = "INPUT"
18
+ OUTPUT = "OUTPUT"
19
+ BOTH = "BOTH"
20
+
21
+
22
+ class ObjectType(str, Enum):
23
+ """Type of database object."""
24
+
25
+ TABLE = "TABLE"
26
+ VIEW = "VIEW"
27
+ CTE = "CTE"
28
+ UNKNOWN = "UNKNOWN"
29
+
30
+
31
+ class TableInfo(BaseModel):
32
+ """Information about a table referenced in a query."""
33
+
34
+ name: str = Field(..., description="Fully qualified table name")
35
+ usage: TableUsage = Field(
36
+ ..., description="How the table is used (INPUT, OUTPUT, BOTH)"
37
+ )
38
+ object_type: ObjectType = Field(
39
+ ..., description="Type of object (TABLE, VIEW, CTE, UNKNOWN)"
40
+ )
41
+
42
+
43
+ class QueryTablesResult(BaseModel):
44
+ """Result of table analysis for a single query."""
45
+
46
+ metadata: "QueryMetadata"
47
+ tables: List[TableInfo] = Field(default_factory=list)
48
+
49
+
50
+ class LineageItem(BaseModel):
51
+ """Represents a single lineage relationship (output -> source)."""
52
+
53
+ output_name: str = Field(..., description="Output column/table name")
54
+ source_name: str = Field(..., description="Source column/table name")
55
+
56
+
57
+ class QueryMetadata(BaseModel):
58
+ """Query execution context."""
59
+
60
+ query_index: int = Field(..., description="0-based query index")
61
+ query_preview: str = Field(..., description="First 100 chars of query")
62
+
63
+
64
+ class QueryLineageResult(BaseModel):
65
+ """Complete lineage result for a single query."""
66
+
67
+ metadata: QueryMetadata
68
+ lineage_items: List[LineageItem] = Field(default_factory=list)
69
+ level: AnalysisLevel
70
+
71
+
72
+ class SkippedQuery(BaseModel):
73
+ """Information about a query that was skipped during analysis."""
74
+
75
+ query_index: int = Field(..., description="0-based query index")
76
+ statement_type: str = Field(..., description="Type of SQL statement (e.g., CREATE)")
77
+ reason: str = Field(..., description="Reason for skipping")
78
+ query_preview: str = Field(..., description="First 100 chars of query")
79
+
80
+
81
+ # Type alias for warning callback function
82
+ WarningCallback = Callable[[str], None]
83
+
84
+
85
+ class LineageAnalyzer:
86
+ """Analyze column and table lineage for SQL queries."""
87
+
88
+ def __init__(self, sql: str, dialect: str = "spark"):
89
+ """
90
+ Initialize the lineage analyzer.
91
+
92
+ Args:
93
+ sql: SQL query string to analyze (can contain multiple statements)
94
+ dialect: SQL dialect (default: spark)
95
+
96
+ Raises:
97
+ ParseError: If the SQL cannot be parsed
98
+ """
99
+ self.sql = sql
100
+ self.dialect = dialect
101
+ self._skipped_queries: List[SkippedQuery] = []
102
+
103
+ try:
104
+ # Parse all statements in the SQL string
105
+ parsed = parse(sql, dialect=dialect)
106
+
107
+ # Filter out None values (can happen with empty statements or comments)
108
+ self.expressions: List[exp.Expression] = [
109
+ expr for expr in parsed if expr is not None
110
+ ]
111
+
112
+ if not self.expressions:
113
+ raise ParseError("No valid SQL statements found")
114
+
115
+ # For backward compatibility, store first expression as self.expr
116
+ self.expr = self.expressions[0]
117
+
118
+ except ParseError as e:
119
+ raise ParseError(f"Invalid SQL syntax: {e}") from e
120
+
121
+ @property
122
+ def skipped_queries(self) -> List[SkippedQuery]:
123
+ """Get list of queries that were skipped during analysis."""
124
+ return self._skipped_queries.copy()
125
+
126
+ def get_output_columns(self) -> List[str]:
127
+ """
128
+ Extract all output column names from the query with full qualification.
129
+
130
+ For DML/DDL statements (INSERT, UPDATE, MERGE, CREATE TABLE AS, etc.),
131
+ returns the target table columns. For DQL (SELECT), returns the selected columns.
132
+
133
+ Returns:
134
+ List of fully qualified output column names (table.column or database.table.column)
135
+
136
+ Raises:
137
+ ValueError: If the statement type is not supported for lineage analysis
138
+ """
139
+ columns = []
140
+
141
+ # Build mapping for qualified names
142
+ self._column_mapping = {} # Maps qualified name -> lineage column name
143
+
144
+ # Check if this is a DML/DDL statement
145
+ result = self._get_target_and_select()
146
+ if result is None:
147
+ # Unsupported statement type
148
+ stmt_type = self._get_statement_type()
149
+ raise ValueError(
150
+ f"Statement type '{stmt_type}' does not support lineage analysis"
151
+ )
152
+
153
+ target_table, select_node = result
154
+
155
+ if target_table:
156
+ # DML/DDL: Use target table for output column qualification
157
+ # The columns are from the SELECT, but qualified with the target table
158
+ for projection in select_node.expressions:
159
+ # Get the underlying expression (unwrap alias if present)
160
+ if isinstance(projection, exp.Alias):
161
+ # For aliased columns, use the alias as the column name
162
+ column_name = projection.alias
163
+ lineage_name = column_name # SQLGlot lineage uses the alias
164
+ else:
165
+ source_expr = projection
166
+ if isinstance(source_expr, exp.Column):
167
+ column_name = source_expr.name
168
+ lineage_name = column_name
169
+ else:
170
+ # For expressions, use the SQL representation
171
+ column_name = source_expr.sql(dialect=self.dialect)
172
+ lineage_name = column_name
173
+
174
+ # Qualify with target table
175
+ qualified_name = f"{target_table}.{column_name}"
176
+ columns.append(qualified_name)
177
+ self._column_mapping[qualified_name] = lineage_name
178
+
179
+ else:
180
+ # DQL (pure SELECT): Use the SELECT columns as output
181
+ for projection in select_node.expressions:
182
+ # Get the underlying expression (unwrap alias if present)
183
+ if isinstance(projection, exp.Alias):
184
+ source_expr = projection.this
185
+ column_name = projection.alias
186
+ lineage_name = column_name # SQLGlot lineage uses the alias
187
+ else:
188
+ source_expr = projection
189
+ column_name = None
190
+ lineage_name = None
191
+
192
+ # Try to extract fully qualified name
193
+ if isinstance(source_expr, exp.Column):
194
+ # Get table and column parts
195
+ table_name = source_expr.table
196
+ col_name = column_name or source_expr.name
197
+
198
+ if table_name:
199
+ # Resolve table reference (could be table, CTE, or subquery alias)
200
+ # This works at any nesting level because we're only looking at the immediate context
201
+ resolved_table = self._resolve_table_reference(
202
+ table_name, select_node
203
+ )
204
+ qualified_name = f"{resolved_table}.{col_name}"
205
+ columns.append(qualified_name)
206
+ # Map qualified name to what lineage expects
207
+ self._column_mapping[qualified_name] = lineage_name or col_name
208
+ else:
209
+ # No table qualifier - try to infer from FROM clause
210
+ # This handles "SELECT col FROM single_source" cases
211
+ inferred_table = self._infer_single_table_source(select_node)
212
+ if inferred_table:
213
+ qualified_name = f"{inferred_table}.{col_name}"
214
+ columns.append(qualified_name)
215
+ self._column_mapping[qualified_name] = (
216
+ lineage_name or col_name
217
+ )
218
+ else:
219
+ # Can't infer table, just use column name
220
+ columns.append(col_name)
221
+ self._column_mapping[col_name] = lineage_name or col_name
222
+ else:
223
+ # For other expressions (literals, functions, etc.)
224
+ # Use the alias if available, otherwise the SQL representation
225
+ if column_name:
226
+ columns.append(column_name)
227
+ self._column_mapping[column_name] = column_name
228
+ else:
229
+ expr_str = source_expr.sql(dialect=self.dialect)
230
+ columns.append(expr_str)
231
+ self._column_mapping[expr_str] = expr_str
232
+
233
+ return columns
234
+
235
+ def analyze_queries(
236
+ self,
237
+ level: AnalysisLevel = AnalysisLevel.COLUMN,
238
+ column: Optional[str] = None,
239
+ source_column: Optional[str] = None,
240
+ table_filter: Optional[str] = None,
241
+ ) -> List[QueryLineageResult]:
242
+ """
243
+ Unified lineage analysis for single or multi-query files.
244
+
245
+ This method replaces all previous analysis methods (analyze_column_lineage,
246
+ analyze_reverse_lineage, analyze_table_lineage, analyze_all_queries, etc.)
247
+ with a single unified interface.
248
+
249
+ Args:
250
+ level: Analysis level ("column" or "table")
251
+ column: Target output column for forward lineage
252
+ source_column: Source column for reverse lineage (impact analysis)
253
+ table_filter: Filter queries to those referencing this table
254
+
255
+ Returns:
256
+ List of QueryLineageResult objects (one per query that matches filters)
257
+
258
+ Raises:
259
+ ValueError: If column or source_column is specified but not found
260
+
261
+ Examples:
262
+ # Forward lineage for all columns
263
+ results = analyzer.analyze_queries(level="column")
264
+
265
+ # Forward lineage for specific column
266
+ results = analyzer.analyze_queries(level="column", column="customers.id")
267
+
268
+ # Reverse lineage (impact analysis)
269
+ results = analyzer.analyze_queries(level="column", source_column="orders.customer_id")
270
+
271
+ # Table-level lineage
272
+ results = analyzer.analyze_queries(level="table")
273
+
274
+ # Filter by table (multi-query files)
275
+ results = analyzer.analyze_queries(table_filter="customers")
276
+ """
277
+ results = []
278
+ self._skipped_queries = [] # Reset skipped queries for this analysis
279
+
280
+ for query_index, expr, preview in self._iterate_queries(table_filter):
281
+ # Temporarily swap self.expr to analyze this query
282
+ original_expr = self.expr
283
+ self.expr = expr
284
+
285
+ try:
286
+ lineage_items: List[LineageItem] = []
287
+
288
+ if level == AnalysisLevel.COLUMN:
289
+ if source_column:
290
+ # Reverse lineage (impact analysis)
291
+ lineage_items = self._analyze_reverse_lineage_internal(
292
+ source_column
293
+ )
294
+ if not lineage_items:
295
+ # Source column not found in this query - skip it
296
+ continue
297
+ else:
298
+ # Forward lineage
299
+ lineage_items = self._analyze_column_lineage_internal(column)
300
+ if not lineage_items:
301
+ # Column not found in this query (if column was specified) - skip it
302
+ if column:
303
+ continue
304
+ else: # table
305
+ lineage_items = self._analyze_table_lineage_internal()
306
+
307
+ # Create query result
308
+ results.append(
309
+ QueryLineageResult(
310
+ metadata=QueryMetadata(
311
+ query_index=query_index,
312
+ query_preview=preview,
313
+ ),
314
+ lineage_items=lineage_items,
315
+ level=level,
316
+ )
317
+ )
318
+ except ValueError as e:
319
+ # Unsupported statement type - track it and continue
320
+ stmt_type = self._get_statement_type(expr)
321
+ self._skipped_queries.append(
322
+ SkippedQuery(
323
+ query_index=query_index,
324
+ statement_type=stmt_type,
325
+ reason=str(e),
326
+ query_preview=preview,
327
+ )
328
+ )
329
+ finally:
330
+ # Restore original expression
331
+ self.expr = original_expr
332
+
333
+ # Validate: if a specific column or source_column was specified and we got no results,
334
+ # raise ValueError to preserve backward compatibility
335
+ if not results:
336
+ if column:
337
+ raise ValueError(
338
+ f"Column '{column}' not found in any query. "
339
+ "Please check the column name and try again."
340
+ )
341
+ elif source_column:
342
+ raise ValueError(
343
+ f"Source column '{source_column}' not found in any query. "
344
+ "Please check the column name and try again."
345
+ )
346
+
347
+ return results
348
+
349
+ def analyze_tables(
350
+ self,
351
+ table_filter: Optional[str] = None,
352
+ ) -> List[QueryTablesResult]:
353
+ """
354
+ Analyze all tables involved in SQL queries.
355
+
356
+ This method extracts information about all tables referenced in the SQL,
357
+ including their usage (INPUT, OUTPUT, or BOTH) and object type (TABLE, VIEW,
358
+ CTE, or UNKNOWN).
359
+
360
+ Args:
361
+ table_filter: Filter queries to those referencing this table
362
+
363
+ Returns:
364
+ List of QueryTablesResult objects (one per query that matches filters)
365
+
366
+ Examples:
367
+ # Get all tables from SQL
368
+ results = analyzer.analyze_tables()
369
+
370
+ # Filter by table (multi-query files)
371
+ results = analyzer.analyze_tables(table_filter="customers")
372
+ """
373
+ results = []
374
+
375
+ for query_index, expr, preview in self._iterate_queries(table_filter):
376
+ # Temporarily swap self.expr to analyze this query
377
+ original_expr = self.expr
378
+ self.expr = expr
379
+
380
+ try:
381
+ tables = self._extract_tables_from_query()
382
+
383
+ # Create query result
384
+ results.append(
385
+ QueryTablesResult(
386
+ metadata=QueryMetadata(
387
+ query_index=query_index,
388
+ query_preview=preview,
389
+ ),
390
+ tables=tables,
391
+ )
392
+ )
393
+ finally:
394
+ # Restore original expression
395
+ self.expr = original_expr
396
+
397
+ return results
398
+
399
+ def _extract_tables_from_query(self) -> List[TableInfo]:
400
+ """
401
+ Extract all tables from the current query with usage and type information.
402
+
403
+ Returns:
404
+ List of TableInfo objects for all tables in the query.
405
+ """
406
+ # Track tables by name to consolidate INPUT/OUTPUT into BOTH
407
+ tables_dict: dict[str, TableInfo] = {}
408
+
409
+ # Extract CTEs first (they're INPUT only)
410
+ cte_names = self._extract_cte_names()
411
+ for cte_name in cte_names:
412
+ tables_dict[cte_name] = TableInfo(
413
+ name=cte_name,
414
+ usage=TableUsage.INPUT,
415
+ object_type=ObjectType.CTE,
416
+ )
417
+
418
+ # Determine target table and its type based on statement type
419
+ target_table, target_type = self._get_target_table_info()
420
+
421
+ # Get all table references in the query (except CTEs)
422
+ input_tables = self._get_all_input_tables(cte_names)
423
+
424
+ # Add target table as OUTPUT
425
+ if target_table:
426
+ if target_table in tables_dict:
427
+ # Table is both input and output (e.g., UPDATE with self-reference)
428
+ tables_dict[target_table] = TableInfo(
429
+ name=target_table,
430
+ usage=TableUsage.BOTH,
431
+ object_type=target_type,
432
+ )
433
+ else:
434
+ tables_dict[target_table] = TableInfo(
435
+ name=target_table,
436
+ usage=TableUsage.OUTPUT,
437
+ object_type=target_type,
438
+ )
439
+
440
+ # Add input tables
441
+ for table_name in input_tables:
442
+ if table_name in tables_dict:
443
+ # Already exists - might need to upgrade to BOTH
444
+ existing = tables_dict[table_name]
445
+ if existing.usage == TableUsage.OUTPUT:
446
+ tables_dict[table_name] = TableInfo(
447
+ name=table_name,
448
+ usage=TableUsage.BOTH,
449
+ object_type=existing.object_type,
450
+ )
451
+ # If INPUT or BOTH, keep as-is
452
+ else:
453
+ tables_dict[table_name] = TableInfo(
454
+ name=table_name,
455
+ usage=TableUsage.INPUT,
456
+ object_type=ObjectType.UNKNOWN,
457
+ )
458
+
459
+ # Return sorted list by name for consistent output
460
+ return sorted(tables_dict.values(), key=lambda t: t.name.lower())
461
+
462
+ def _extract_cte_names(self) -> Set[str]:
463
+ """
464
+ Extract all CTE (Common Table Expression) names from the query.
465
+
466
+ Returns:
467
+ Set of CTE names defined in the WITH clause.
468
+ """
469
+ cte_names: Set[str] = set()
470
+
471
+ # Look for WITH clause
472
+ if hasattr(self.expr, "args") and self.expr.args.get("with"):
473
+ with_clause = self.expr.args["with"]
474
+ for cte in with_clause.expressions:
475
+ if isinstance(cte, exp.CTE) and cte.alias:
476
+ cte_names.add(cte.alias)
477
+
478
+ return cte_names
479
+
480
+ def _get_target_table_info(self) -> Tuple[Optional[str], ObjectType]:
481
+ """
482
+ Get the target table name and its object type for DML/DDL statements.
483
+
484
+ Returns:
485
+ Tuple of (target_table_name, object_type) or (None, UNKNOWN) for SELECT.
486
+ """
487
+ # INSERT INTO table
488
+ if isinstance(self.expr, exp.Insert):
489
+ target = self.expr.this
490
+ if isinstance(target, exp.Table):
491
+ return (self._get_qualified_table_name(target), ObjectType.UNKNOWN)
492
+
493
+ # CREATE TABLE / CREATE VIEW
494
+ elif isinstance(self.expr, exp.Create):
495
+ kind = getattr(self.expr, "kind", "").upper()
496
+ target = self.expr.this
497
+
498
+ # Handle Schema wrapper (CREATE TABLE with columns)
499
+ if isinstance(target, exp.Schema):
500
+ target = target.this
501
+
502
+ if isinstance(target, exp.Table):
503
+ table_name = self._get_qualified_table_name(target)
504
+ if kind == "VIEW":
505
+ return (table_name, ObjectType.VIEW)
506
+ elif kind == "TABLE":
507
+ return (table_name, ObjectType.TABLE)
508
+ else:
509
+ return (table_name, ObjectType.UNKNOWN)
510
+
511
+ # UPDATE table
512
+ elif isinstance(self.expr, exp.Update):
513
+ target = self.expr.this
514
+ if isinstance(target, exp.Table):
515
+ return (self._get_qualified_table_name(target), ObjectType.UNKNOWN)
516
+
517
+ # MERGE INTO table
518
+ elif isinstance(self.expr, exp.Merge):
519
+ target = self.expr.this
520
+ if isinstance(target, exp.Table):
521
+ return (self._get_qualified_table_name(target), ObjectType.UNKNOWN)
522
+
523
+ # DELETE FROM table
524
+ elif isinstance(self.expr, exp.Delete):
525
+ target = self.expr.this
526
+ if isinstance(target, exp.Table):
527
+ return (self._get_qualified_table_name(target), ObjectType.UNKNOWN)
528
+
529
+ # DROP TABLE / DROP VIEW
530
+ elif isinstance(self.expr, exp.Drop):
531
+ kind = getattr(self.expr, "kind", "").upper()
532
+ target = self.expr.this
533
+ if isinstance(target, exp.Table):
534
+ table_name = self._get_qualified_table_name(target)
535
+ if kind == "VIEW":
536
+ return (table_name, ObjectType.VIEW)
537
+ elif kind == "TABLE":
538
+ return (table_name, ObjectType.TABLE)
539
+ else:
540
+ return (table_name, ObjectType.UNKNOWN)
541
+
542
+ # SELECT (no target table)
543
+ return (None, ObjectType.UNKNOWN)
544
+
545
+ def _get_all_input_tables(self, exclude_ctes: Set[str]) -> Set[str]:
546
+ """
547
+ Get all tables used as input (FROM, JOIN, subqueries, etc.).
548
+
549
+ Args:
550
+ exclude_ctes: Set of CTE names to exclude from results.
551
+
552
+ Returns:
553
+ Set of fully qualified table names that are used as input.
554
+ """
555
+ input_tables: Set[str] = set()
556
+
557
+ # Find all Table nodes in the expression tree
558
+ for table_node in self.expr.find_all(exp.Table):
559
+ table_name = self._get_qualified_table_name(table_node)
560
+
561
+ # Skip CTEs (they're tracked separately)
562
+ if table_name in exclude_ctes:
563
+ continue
564
+
565
+ # Skip the target table for certain statement types
566
+ # (it will be added separately as OUTPUT)
567
+ if self._is_target_table(table_node):
568
+ continue
569
+
570
+ input_tables.add(table_name)
571
+
572
+ return input_tables
573
+
574
+ def _is_target_table(self, table_node: exp.Table) -> bool:
575
+ """
576
+ Check if a table node is the target of a DML/DDL statement.
577
+
578
+ This helps distinguish the target table (OUTPUT) from source tables (INPUT)
579
+ in statements like INSERT, UPDATE, MERGE, DELETE.
580
+
581
+ Args:
582
+ table_node: The table node to check.
583
+
584
+ Returns:
585
+ True if this is the target table, False otherwise.
586
+ """
587
+ # For INSERT, the target is self.expr.this
588
+ if isinstance(self.expr, exp.Insert):
589
+ return table_node is self.expr.this
590
+
591
+ # For UPDATE, the target is self.expr.this
592
+ elif isinstance(self.expr, exp.Update):
593
+ return table_node is self.expr.this
594
+
595
+ # For MERGE, the target is self.expr.this
596
+ elif isinstance(self.expr, exp.Merge):
597
+ return table_node is self.expr.this
598
+
599
+ # For DELETE, the target is self.expr.this
600
+ elif isinstance(self.expr, exp.Delete):
601
+ return table_node is self.expr.this
602
+
603
+ # For CREATE TABLE/VIEW, check if it's in the schema
604
+ elif isinstance(self.expr, exp.Create):
605
+ target = self.expr.this
606
+ if isinstance(target, exp.Schema):
607
+ return table_node is target.this
608
+ return table_node is target
609
+
610
+ # For DROP, the target is self.expr.this
611
+ elif isinstance(self.expr, exp.Drop):
612
+ return table_node is self.expr.this
613
+
614
+ return False
615
+
616
+ def _analyze_column_lineage_internal(
617
+ self, column: Optional[str] = None
618
+ ) -> List[LineageItem]:
619
+ """
620
+ Internal method for analyzing column lineage. Returns flat list of LineageItem.
621
+
622
+ Args:
623
+ column: Optional specific column to analyze. If None, analyzes all columns.
624
+
625
+ Returns:
626
+ List of LineageItem objects (one per output-source relationship)
627
+ """
628
+ output_columns = self.get_output_columns()
629
+
630
+ if column:
631
+ # Analyze only the specified column (case-insensitive matching)
632
+ matched_column = None
633
+ column_lower = column.lower()
634
+ for output_col in output_columns:
635
+ if output_col.lower() == column_lower:
636
+ matched_column = output_col
637
+ break
638
+
639
+ if matched_column is None:
640
+ # Column not found - return empty list (caller will skip this query)
641
+ return []
642
+ columns_to_analyze = [matched_column]
643
+ else:
644
+ # Analyze all columns
645
+ columns_to_analyze = output_columns
646
+
647
+ lineage_items = []
648
+ # Get SQL for current expression only (not full multi-query SQL)
649
+ current_query_sql = self.expr.sql(dialect=self.dialect)
650
+
651
+ for col in columns_to_analyze:
652
+ try:
653
+ # Get the column name that lineage expects
654
+ lineage_col = self._column_mapping.get(col, col)
655
+
656
+ # Get lineage tree for this column using current query SQL only
657
+ node = lineage(lineage_col, current_query_sql, dialect=self.dialect)
658
+
659
+ # Collect all source columns
660
+ sources: Set[str] = set()
661
+ self._collect_source_columns(node, sources)
662
+
663
+ # Convert to flat LineageItem list (one item per source)
664
+ for source in sorted(sources):
665
+ lineage_items.append(
666
+ LineageItem(output_name=col, source_name=source)
667
+ )
668
+
669
+ # If no sources found, add single item with empty source
670
+ if not sources:
671
+ lineage_items.append(LineageItem(output_name=col, source_name=""))
672
+ except Exception:
673
+ # If lineage fails for a column, add item with empty source
674
+ lineage_items.append(LineageItem(output_name=col, source_name=""))
675
+
676
+ return lineage_items
677
+
678
+ def _analyze_table_lineage_internal(self) -> List[LineageItem]:
679
+ """
680
+ Internal method for analyzing table lineage. Returns flat list of LineageItem.
681
+
682
+ Returns:
683
+ List of LineageItem objects (one per output-source table relationship)
684
+ """
685
+ source_tables: Set[str] = set()
686
+
687
+ # Find all Table nodes in the AST
688
+ for table_node in self.expr.find_all(exp.Table):
689
+ # Get fully qualified table name
690
+ table_name = table_node.sql(dialect=self.dialect)
691
+ source_tables.add(table_name)
692
+
693
+ # The output table would typically be defined in INSERT/CREATE statements
694
+ # For SELECT statements, we use a placeholder
695
+ output_table = "query_result"
696
+
697
+ # Convert to flat LineageItem list (one item per source table)
698
+ lineage_items = []
699
+ for source in sorted(source_tables):
700
+ lineage_items.append(
701
+ LineageItem(output_name=output_table, source_name=source)
702
+ )
703
+
704
+ return lineage_items
705
+
706
+ def _analyze_reverse_lineage_internal(
707
+ self, source_column: str
708
+ ) -> List[LineageItem]:
709
+ """
710
+ Internal method for analyzing reverse lineage. Returns flat list of LineageItem.
711
+
712
+ Args:
713
+ source_column: Source column to analyze (e.g., "orders.customer_id")
714
+
715
+ Returns:
716
+ List of LineageItem objects (source column -> affected outputs)
717
+ """
718
+ # Step 1: Run forward lineage on all output columns
719
+ forward_items = self._analyze_column_lineage_internal(column=None)
720
+
721
+ # Step 2: Build reverse mapping (source -> [affected outputs])
722
+ reverse_map: dict[str, set[str]] = {}
723
+ all_outputs = set()
724
+
725
+ for item in forward_items:
726
+ all_outputs.add(item.output_name)
727
+ if item.source_name: # Skip empty sources
728
+ if item.source_name not in reverse_map:
729
+ reverse_map[item.source_name] = set()
730
+ reverse_map[item.source_name].add(item.output_name)
731
+
732
+ # Step 3: Find matching source (case-insensitive)
733
+ matched_source = None
734
+ affected_outputs = set()
735
+ source_column_lower = source_column.lower()
736
+
737
+ # First check if it's in reverse_map (derived columns)
738
+ for source in reverse_map.keys():
739
+ if source.lower() == source_column_lower:
740
+ matched_source = source
741
+ affected_outputs = reverse_map[matched_source]
742
+ break
743
+
744
+ # If not found, check if it's an output column (base table column)
745
+ if matched_source is None:
746
+ for output in all_outputs:
747
+ if output.lower() == source_column_lower:
748
+ matched_source = output
749
+ affected_outputs = {output} # It affects itself
750
+ break
751
+
752
+ if matched_source is None:
753
+ # Source column not found - return empty list (caller will skip this query)
754
+ return []
755
+
756
+ # Step 4: Return with semantic swap (source as output, affected as sources)
757
+ # This maintains the LineageItem structure where output_name is what we're looking at
758
+ # and source_name is what it affects
759
+ lineage_items = []
760
+ for affected in sorted(affected_outputs):
761
+ lineage_items.append(
762
+ LineageItem(output_name=matched_source, source_name=affected)
763
+ )
764
+
765
+ return lineage_items
766
+
767
+ def _get_statement_type(self, expr: Optional[exp.Expression] = None) -> str:
768
+ """
769
+ Get a human-readable name for the SQL statement type.
770
+
771
+ Args:
772
+ expr: Expression to check (uses self.expr if not provided)
773
+
774
+ Returns:
775
+ Statement type name (e.g., "CREATE FUNCTION", "SELECT", "DELETE")
776
+ """
777
+ target_expr = expr if expr is not None else self.expr
778
+ expr_type = type(target_expr).__name__
779
+
780
+ # Map common expression types to more readable names
781
+ type_map = {
782
+ "Select": "SELECT",
783
+ "Insert": "INSERT",
784
+ "Update": "UPDATE",
785
+ "Delete": "DELETE",
786
+ "Merge": "MERGE",
787
+ "Create": f"CREATE {getattr(target_expr, 'kind', '')}".strip(),
788
+ "Drop": f"DROP {getattr(target_expr, 'kind', '')}".strip(),
789
+ "Alter": "ALTER",
790
+ "Truncate": "TRUNCATE",
791
+ "Command": "COMMAND",
792
+ }
793
+
794
+ return type_map.get(expr_type, expr_type.upper())
795
+
796
+ def _get_target_and_select(
797
+ self,
798
+ ) -> Optional[tuple[Optional[str], exp.Select]]:
799
+ """
800
+ Detect if this is a DML/DDL statement and extract the target table and SELECT node.
801
+
802
+ Returns:
803
+ Tuple of (target_table_name, select_node) where:
804
+ - target_table_name is the fully qualified target table for DML/DDL, or None for pure SELECT
805
+ - select_node is the SELECT statement that provides the data
806
+ - Returns None if the statement type doesn't contain a SELECT (e.g., CREATE FUNCTION)
807
+
808
+ Handles:
809
+ - INSERT INTO table SELECT ...
810
+ - CREATE TABLE table AS SELECT ...
811
+ - MERGE INTO table ...
812
+ - UPDATE table SET ... FROM (SELECT ...)
813
+ - Pure SELECT (returns None as target)
814
+ """
815
+ # Check for INSERT statement
816
+ if isinstance(self.expr, exp.Insert):
817
+ target = self.expr.this
818
+ if isinstance(target, exp.Table):
819
+ target_name = self._get_qualified_table_name(target)
820
+ # Find the SELECT within the INSERT
821
+ select_node = self.expr.expression
822
+ if isinstance(select_node, exp.Select):
823
+ return (target_name, select_node)
824
+
825
+ # Check for CREATE TABLE AS SELECT (CTAS) or CREATE VIEW AS SELECT
826
+ elif isinstance(self.expr, exp.Create):
827
+ if self.expr.kind in ("TABLE", "VIEW"):
828
+ target = self.expr.this
829
+ if isinstance(target, exp.Schema):
830
+ # Get the table from schema
831
+ target = target.this
832
+ if isinstance(target, exp.Table):
833
+ target_name = self._get_qualified_table_name(target)
834
+ # Find the SELECT in the expression
835
+ select_node = self.expr.expression
836
+ if isinstance(select_node, exp.Select):
837
+ return (target_name, select_node)
838
+
839
+ # Check for MERGE statement
840
+ elif isinstance(self.expr, exp.Merge):
841
+ target = self.expr.this
842
+ if isinstance(target, exp.Table):
843
+ target_name = self._get_qualified_table_name(target)
844
+ # For MERGE, we need to find the SELECT in the USING clause
845
+ # This is more complex, for now treat it as a SELECT
846
+ select_nodes = list(self.expr.find_all(exp.Select))
847
+ if select_nodes:
848
+ return (target_name, select_nodes[0])
849
+
850
+ # Check for UPDATE with subquery
851
+ elif isinstance(self.expr, exp.Update):
852
+ target = self.expr.this
853
+ if isinstance(target, exp.Table):
854
+ target_name = self._get_qualified_table_name(target)
855
+ # For UPDATE, find the SELECT if there is one
856
+ select_nodes = list(self.expr.find_all(exp.Select))
857
+ if select_nodes:
858
+ return (target_name, select_nodes[0])
859
+
860
+ # Default: Pure SELECT (DQL)
861
+ select_nodes = list(self.expr.find_all(exp.Select))
862
+ if select_nodes:
863
+ return (None, select_nodes[0])
864
+
865
+ # Fallback: return the expression as-is if it's a SELECT
866
+ if isinstance(self.expr, exp.Select):
867
+ return (None, self.expr)
868
+
869
+ # No SELECT found - return None to indicate unsupported statement
870
+ return None
871
+
872
+ def _get_qualified_table_name(self, table: exp.Table) -> str:
873
+ """
874
+ Get the fully qualified name for a table.
875
+
876
+ Args:
877
+ table: SQLGlot Table expression
878
+
879
+ Returns:
880
+ Fully qualified table name (database.table or catalog.database.table)
881
+ """
882
+ parts = []
883
+ if table.catalog:
884
+ parts.append(table.catalog)
885
+ if table.db:
886
+ parts.append(table.db)
887
+ parts.append(table.name)
888
+ return ".".join(parts)
889
+
890
+ def _resolve_table_reference(self, ref: str, select_node: exp.Select) -> str:
891
+ """
892
+ Resolve a table reference (alias, CTE name, or actual table) to its canonical name.
893
+
894
+ This works at any nesting level by only looking at the immediate SELECT context.
895
+ For CTEs and subqueries, returns their alias name (which is the "table name" in that context).
896
+ For actual tables with aliases, returns the actual table name.
897
+
898
+ Args:
899
+ ref: The table reference to resolve (could be alias, CTE name, or table name)
900
+ select_node: The SELECT node containing the FROM/JOIN clauses
901
+
902
+ Returns:
903
+ The canonical table name (actual table for real tables, alias for CTEs/subqueries)
904
+ """
905
+ # Check if this is a CTE reference first
906
+ # CTEs are defined in the WITH clause and referenced by their alias
907
+ parent = select_node
908
+ while parent:
909
+ if isinstance(parent, (exp.Select, exp.Union)) and parent.args.get("with"):
910
+ cte_node = parent.args["with"]
911
+ for cte in cte_node.expressions:
912
+ if isinstance(cte, exp.CTE) and cte.alias == ref:
913
+ # This is a CTE - return the CTE alias as the "table name"
914
+ return ref
915
+ parent = parent.parent if hasattr(parent, "parent") else None
916
+
917
+ # Look for table references in FROM and JOIN clauses
918
+ for table_ref in select_node.find_all(exp.Table):
919
+ # Check if this table has the matching alias
920
+ if table_ref.alias == ref:
921
+ # Return the qualified table name
922
+ parts = []
923
+ if table_ref.db:
924
+ parts.append(table_ref.db)
925
+ if table_ref.catalog:
926
+ parts.insert(0, table_ref.catalog)
927
+ parts.append(table_ref.name)
928
+ return ".".join(parts)
929
+ # Also check if ref matches the table name directly (no alias case)
930
+ elif table_ref.name == ref and not table_ref.alias:
931
+ parts = []
932
+ if table_ref.db:
933
+ parts.append(table_ref.db)
934
+ if table_ref.catalog:
935
+ parts.insert(0, table_ref.catalog)
936
+ parts.append(table_ref.name)
937
+ return ".".join(parts)
938
+
939
+ # Check for subquery aliases in FROM clause
940
+ if select_node.args.get("from"):
941
+ from_clause = select_node.args["from"]
942
+ if isinstance(from_clause, exp.From):
943
+ source = from_clause.this
944
+ # Check if it's a subquery with matching alias
945
+ if isinstance(source, exp.Subquery) and source.alias == ref:
946
+ # Return the subquery alias as the "table name"
947
+ return ref
948
+ # Check if it's a table with matching alias
949
+ elif isinstance(source, exp.Table) and source.alias == ref:
950
+ parts = []
951
+ if source.db:
952
+ parts.append(source.db)
953
+ if source.catalog:
954
+ parts.insert(0, source.catalog)
955
+ parts.append(source.name)
956
+ return ".".join(parts)
957
+
958
+ # Check JOIN clauses for subqueries
959
+ for join in select_node.find_all(exp.Join):
960
+ if isinstance(join.this, exp.Subquery) and join.this.alias == ref:
961
+ return ref
962
+ elif isinstance(join.this, exp.Table) and join.this.alias == ref:
963
+ parts = []
964
+ if join.this.db:
965
+ parts.append(join.this.db)
966
+ if join.this.catalog:
967
+ parts.insert(0, join.this.catalog)
968
+ parts.append(join.this.name)
969
+ return ".".join(parts)
970
+
971
+ # If we can't resolve, return the reference as-is
972
+ return ref
973
+
974
+ def _infer_single_table_source(self, select_node: exp.Select) -> Optional[str]:
975
+ """
976
+ Infer the table name when there's only one table in FROM clause.
977
+
978
+ This handles cases like "SELECT col FROM table" where col has no table prefix.
979
+
980
+ Args:
981
+ select_node: The SELECT node
982
+
983
+ Returns:
984
+ The table name if there's exactly one source, None otherwise
985
+ """
986
+ if not select_node.args.get("from"):
987
+ return None
988
+
989
+ from_clause = select_node.args["from"]
990
+ if not isinstance(from_clause, exp.From):
991
+ return None
992
+
993
+ source = from_clause.this
994
+
995
+ # Check for JOINs - if there are joins, we can't infer
996
+ if list(select_node.find_all(exp.Join)):
997
+ return None
998
+
999
+ # Single table or CTE/subquery
1000
+ if isinstance(source, exp.Table):
1001
+ parts = []
1002
+ if source.db:
1003
+ parts.append(source.db)
1004
+ if source.catalog:
1005
+ parts.insert(0, source.catalog)
1006
+ if source.alias:
1007
+ # If the table has an alias, use the alias
1008
+ return source.alias
1009
+ parts.append(source.name)
1010
+ return ".".join(parts)
1011
+ elif isinstance(source, (exp.Subquery, exp.CTE)):
1012
+ # Return the subquery/CTE alias
1013
+ return source.alias if source.alias else None
1014
+
1015
+ return None
1016
+
1017
+ def _collect_source_columns(self, node: Node, sources: Set[str]) -> None:
1018
+ """
1019
+ Recursively collect all source columns from a lineage tree.
1020
+
1021
+ This traverses the lineage tree depth-first, collecting leaf nodes
1022
+ which represent the actual source columns.
1023
+
1024
+ Args:
1025
+ node: The current lineage node
1026
+ sources: Set to accumulate source column names
1027
+ """
1028
+ if not node.downstream:
1029
+ # Leaf node - this is a source column
1030
+ # Check if this is a literal value (SQLGlot uses position numbers for literals)
1031
+ if node.name.isdigit():
1032
+ # This is a literal - extract the actual value from the expression
1033
+ literal_repr = self._extract_literal_representation(node)
1034
+ sources.add(literal_repr)
1035
+ else:
1036
+ # SQLGlot's lineage provides qualified names, but may use aliases
1037
+ # Need to resolve aliases to actual table names
1038
+ qualified_name = self._resolve_source_column_alias(node.name)
1039
+ sources.add(qualified_name)
1040
+ else:
1041
+ # Traverse deeper into the tree
1042
+ for child in node.downstream:
1043
+ self._collect_source_columns(child, sources)
1044
+
1045
+ def _extract_literal_representation(self, node: Node) -> str:
1046
+ """
1047
+ Extract a human-readable representation of a literal value from a lineage node.
1048
+
1049
+ When SQLGlot encounters a literal value in a UNION branch, it returns the
1050
+ column position as the node name. This method extracts the actual literal
1051
+ value from the node's expression.
1052
+
1053
+ Args:
1054
+ node: A lineage node where node.name is a digit (position number)
1055
+
1056
+ Returns:
1057
+ A string like "<literal: NULL>" or "<literal: 'value'>" or "<literal: 0>"
1058
+ """
1059
+ try:
1060
+ expr = node.expression
1061
+ # The expression is typically an Alias wrapping the actual value
1062
+ if isinstance(expr, exp.Alias):
1063
+ literal_expr = expr.this
1064
+ literal_sql = literal_expr.sql(dialect=self.dialect)
1065
+ return f"<literal: {literal_sql}>"
1066
+ else:
1067
+ # Fallback: use the expression's SQL representation
1068
+ return f"<literal: {expr.sql(dialect=self.dialect)}>"
1069
+ except Exception:
1070
+ # If extraction fails, return a generic literal marker
1071
+ return "<literal>"
1072
+
1073
+ def _get_query_tables(self) -> List[str]:
1074
+ """
1075
+ Get all table names referenced in the current query.
1076
+
1077
+ Returns:
1078
+ List of fully qualified table names used in the query
1079
+ """
1080
+ tables = []
1081
+ for table_node in self.expr.find_all(exp.Table):
1082
+ table_name = self._get_qualified_table_name(table_node)
1083
+ tables.append(table_name)
1084
+ return tables
1085
+
1086
+ def _resolve_source_column_alias(self, column_name: str) -> str:
1087
+ """
1088
+ Resolve table aliases in source column names.
1089
+
1090
+ This searches through ALL SELECT nodes in the query (including nested ones)
1091
+ to find and resolve table aliases, CTEs, and subqueries.
1092
+
1093
+ Args:
1094
+ column_name: Column name like "alias.column" or "table.column"
1095
+
1096
+ Returns:
1097
+ Fully qualified column name with actual table name
1098
+ """
1099
+ # Parse the column name (format: table.column or db.table.column)
1100
+ parts = column_name.split(".")
1101
+
1102
+ if len(parts) < 2:
1103
+ # No table qualifier, return as-is
1104
+ return column_name
1105
+
1106
+ # The table part might be an alias, CTE name, or actual table
1107
+ table_part = parts[0] if len(parts) == 2 else parts[-2]
1108
+ column_part = parts[-1]
1109
+
1110
+ # Try to resolve by searching through ALL SELECT nodes (including nested)
1111
+ # This handles cases where the alias is defined deep in a subquery/CTE
1112
+ for select_node in self.expr.find_all(exp.Select):
1113
+ resolved = self._resolve_table_reference(table_part, select_node)
1114
+ # If resolution changed the name, we found it
1115
+ if resolved != table_part:
1116
+ # Reconstruct with resolved table name
1117
+ if len(parts) == 2:
1118
+ return f"{resolved}.{column_part}"
1119
+ else:
1120
+ # Has database part
1121
+ return f"{parts[0]}.{resolved}.{column_part}"
1122
+
1123
+ # If we couldn't resolve in any SELECT, return as-is
1124
+ return column_name
1125
+
1126
+ def _generate_query_preview(self, expr: exp.Expression) -> str:
1127
+ """
1128
+ Generate a preview string for a query (first 100 chars, normalized).
1129
+
1130
+ Args:
1131
+ expr: The SQL expression to generate a preview for
1132
+
1133
+ Returns:
1134
+ Preview string (first 100 chars with "..." if truncated)
1135
+ """
1136
+ query_text = expr.sql(dialect=self.dialect)
1137
+ preview = " ".join(query_text.split())[:100]
1138
+ if len(" ".join(query_text.split())) > 100:
1139
+ preview += "..."
1140
+ return preview
1141
+
1142
+ def _filter_by_table(self, expr: exp.Expression, table_filter: str) -> bool:
1143
+ """
1144
+ Check if a query references a specific table.
1145
+
1146
+ Args:
1147
+ expr: The SQL expression to check
1148
+ table_filter: Table name to filter by (case-insensitive partial match)
1149
+
1150
+ Returns:
1151
+ True if the query references the table, False otherwise
1152
+ """
1153
+ # Temporarily swap self.expr to analyze this expression
1154
+ original_expr = self.expr
1155
+ self.expr = expr
1156
+ try:
1157
+ query_tables = self._get_query_tables()
1158
+ table_filter_lower = table_filter.lower()
1159
+ return any(table_filter_lower in table.lower() for table in query_tables)
1160
+ finally:
1161
+ self.expr = original_expr
1162
+
1163
+ def _iterate_queries(
1164
+ self, table_filter: Optional[str] = None
1165
+ ) -> Iterator[Tuple[int, exp.Expression, str]]:
1166
+ """
1167
+ Iterate over queries with filtering and preview generation.
1168
+
1169
+ Args:
1170
+ table_filter: Optional table name to filter queries by
1171
+
1172
+ Yields:
1173
+ Tuple of (query_index, expression, query_preview)
1174
+ """
1175
+ for idx, expr in enumerate(self.expressions):
1176
+ # Apply table filter
1177
+ if table_filter and not self._filter_by_table(expr, table_filter):
1178
+ continue
1179
+
1180
+ # Generate preview
1181
+ preview = self._generate_query_preview(expr)
1182
+
1183
+ yield idx, expr, preview