sql-glider 0.1.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,1631 @@
1
+ """Core lineage analysis using SQLGlot."""
2
+
3
+ from enum import Enum
4
+ from typing import Callable, Dict, Iterator, List, Optional, Set, Tuple, Union
5
+
6
+ from pydantic import BaseModel, Field
7
+ from sqlglot import exp, parse
8
+ from sqlglot.errors import ParseError
9
+ from sqlglot.lineage import Node, lineage
10
+
11
+ from sqlglider.global_models import AnalysisLevel
12
+
13
+
14
+ class TableUsage(str, Enum):
15
+ """How a table is used in a query."""
16
+
17
+ INPUT = "INPUT"
18
+ OUTPUT = "OUTPUT"
19
+ BOTH = "BOTH"
20
+
21
+
22
+ class ObjectType(str, Enum):
23
+ """Type of database object."""
24
+
25
+ TABLE = "TABLE"
26
+ VIEW = "VIEW"
27
+ CTE = "CTE"
28
+ UNKNOWN = "UNKNOWN"
29
+
30
+
31
+ class TableInfo(BaseModel):
32
+ """Information about a table referenced in a query."""
33
+
34
+ name: str = Field(..., description="Fully qualified table name")
35
+ usage: TableUsage = Field(
36
+ ..., description="How the table is used (INPUT, OUTPUT, BOTH)"
37
+ )
38
+ object_type: ObjectType = Field(
39
+ ..., description="Type of object (TABLE, VIEW, CTE, UNKNOWN)"
40
+ )
41
+
42
+
43
+ class QueryTablesResult(BaseModel):
44
+ """Result of table analysis for a single query."""
45
+
46
+ metadata: "QueryMetadata"
47
+ tables: List[TableInfo] = Field(default_factory=list)
48
+
49
+
50
+ class LineageItem(BaseModel):
51
+ """Represents a single lineage relationship (output -> source)."""
52
+
53
+ output_name: str = Field(..., description="Output column/table name")
54
+ source_name: str = Field(..., description="Source column/table name")
55
+
56
+
57
+ class QueryMetadata(BaseModel):
58
+ """Query execution context."""
59
+
60
+ query_index: int = Field(..., description="0-based query index")
61
+ query_preview: str = Field(..., description="First 100 chars of query")
62
+
63
+
64
+ class QueryLineageResult(BaseModel):
65
+ """Complete lineage result for a single query."""
66
+
67
+ metadata: QueryMetadata
68
+ lineage_items: List[LineageItem] = Field(default_factory=list)
69
+ level: AnalysisLevel
70
+
71
+
72
+ class SkippedQuery(BaseModel):
73
+ """Information about a query that was skipped during analysis."""
74
+
75
+ query_index: int = Field(..., description="0-based query index")
76
+ statement_type: str = Field(..., description="Type of SQL statement (e.g., CREATE)")
77
+ reason: str = Field(..., description="Reason for skipping")
78
+ query_preview: str = Field(..., description="First 100 chars of query")
79
+
80
+
81
+ # Type alias for warning callback function
82
+ WarningCallback = Callable[[str], None]
83
+
84
+
85
+ class LineageAnalyzer:
86
+ """Analyze column and table lineage for SQL queries."""
87
+
88
+ def __init__(self, sql: str, dialect: str = "spark"):
89
+ """
90
+ Initialize the lineage analyzer.
91
+
92
+ Args:
93
+ sql: SQL query string to analyze (can contain multiple statements)
94
+ dialect: SQL dialect (default: spark)
95
+
96
+ Raises:
97
+ ParseError: If the SQL cannot be parsed
98
+ """
99
+ self.sql = sql
100
+ self.dialect = dialect
101
+ self._skipped_queries: List[SkippedQuery] = []
102
+ # File-scoped schema context for cross-statement lineage
103
+ # Maps table/view names to their column definitions
104
+ self._file_schema: Dict[str, Dict[str, str]] = {}
105
+
106
+ try:
107
+ # Parse all statements in the SQL string
108
+ parsed = parse(sql, dialect=dialect)
109
+
110
+ # Filter out None values (can happen with empty statements or comments)
111
+ self.expressions: List[exp.Expression] = [
112
+ expr for expr in parsed if expr is not None
113
+ ]
114
+
115
+ if not self.expressions:
116
+ raise ParseError("No valid SQL statements found")
117
+
118
+ # For backward compatibility, store first expression as self.expr
119
+ self.expr = self.expressions[0]
120
+
121
+ except ParseError as e:
122
+ raise ParseError(f"Invalid SQL syntax: {e}") from e
123
+
124
+ @property
125
+ def skipped_queries(self) -> List[SkippedQuery]:
126
+ """Get list of queries that were skipped during analysis."""
127
+ return self._skipped_queries.copy()
128
+
129
+ def get_output_columns(self) -> List[str]:
130
+ """
131
+ Extract all output column names from the query with full qualification.
132
+
133
+ For DML/DDL statements (INSERT, UPDATE, MERGE, CREATE TABLE AS, etc.),
134
+ returns the target table columns. For DQL (SELECT), returns the selected columns.
135
+
136
+ Returns:
137
+ List of fully qualified output column names (table.column or database.table.column)
138
+
139
+ Raises:
140
+ ValueError: If the statement type is not supported for lineage analysis
141
+ """
142
+ columns = []
143
+
144
+ # Build mapping for qualified names
145
+ self._column_mapping = {} # Maps qualified name -> lineage column name
146
+
147
+ # Check if this is a DML/DDL statement
148
+ result = self._get_target_and_select()
149
+ if result is None:
150
+ # Unsupported statement type
151
+ stmt_type = self._get_statement_type()
152
+ raise ValueError(
153
+ f"Statement type '{stmt_type}' does not support lineage analysis"
154
+ )
155
+
156
+ target_table, select_node = result
157
+
158
+ if target_table:
159
+ # DML/DDL: Use target table for output column qualification
160
+ # The columns are from the SELECT, but qualified with the target table
161
+ projections = self._get_select_projections(select_node)
162
+ first_select = self._get_first_select(select_node)
163
+
164
+ for projection in projections:
165
+ # Handle SELECT * by resolving from file schema
166
+ if isinstance(projection, exp.Star):
167
+ if first_select:
168
+ star_columns = self._resolve_star_columns(first_select)
169
+ for star_col in star_columns:
170
+ qualified_name = f"{target_table}.{star_col}"
171
+ columns.append(qualified_name)
172
+ self._column_mapping[qualified_name] = star_col
173
+ if not columns:
174
+ # Fallback: can't resolve *, use * as column name
175
+ qualified_name = f"{target_table}.*"
176
+ columns.append(qualified_name)
177
+ self._column_mapping[qualified_name] = "*"
178
+ continue
179
+
180
+ # Get the underlying expression (unwrap alias if present)
181
+ if isinstance(projection, exp.Alias):
182
+ # For aliased columns, use the alias as the column name
183
+ column_name = projection.alias
184
+ lineage_name = column_name # SQLGlot lineage uses the alias
185
+ # Qualify with target table
186
+ qualified_name = f"{target_table}.{column_name}"
187
+ columns.append(qualified_name)
188
+ self._column_mapping[qualified_name] = lineage_name
189
+ elif isinstance(projection, exp.Column):
190
+ # Check if this is a table-qualified star (e.g., t.*)
191
+ if isinstance(projection.this, exp.Star):
192
+ source_table = projection.table
193
+ qualified_star_cols: List[str] = []
194
+ if source_table and first_select:
195
+ qualified_star_cols = self._resolve_qualified_star(
196
+ source_table, first_select
197
+ )
198
+ for col in qualified_star_cols:
199
+ qualified_name = f"{target_table}.{col}"
200
+ columns.append(qualified_name)
201
+ self._column_mapping[qualified_name] = col
202
+ if not qualified_star_cols:
203
+ # Fallback: can't resolve t.*, use * as column name
204
+ qualified_name = f"{target_table}.*"
205
+ columns.append(qualified_name)
206
+ self._column_mapping[qualified_name] = "*"
207
+ else:
208
+ column_name = projection.name
209
+ lineage_name = column_name
210
+ # Qualify with target table
211
+ qualified_name = f"{target_table}.{column_name}"
212
+ columns.append(qualified_name)
213
+ self._column_mapping[qualified_name] = lineage_name
214
+ else:
215
+ # For expressions, use the SQL representation
216
+ column_name = projection.sql(dialect=self.dialect)
217
+ lineage_name = column_name
218
+ # Qualify with target table
219
+ qualified_name = f"{target_table}.{column_name}"
220
+ columns.append(qualified_name)
221
+ self._column_mapping[qualified_name] = lineage_name
222
+
223
+ else:
224
+ # DQL (pure SELECT): Use the SELECT columns as output
225
+ projections = self._get_select_projections(select_node)
226
+ # Get the first SELECT for table resolution (handles UNION case)
227
+ first_select = self._get_first_select(select_node)
228
+ for projection in projections:
229
+ # Get the underlying expression (unwrap alias if present)
230
+ if isinstance(projection, exp.Alias):
231
+ source_expr = projection.this
232
+ column_name = projection.alias
233
+ lineage_name = column_name # SQLGlot lineage uses the alias
234
+ else:
235
+ source_expr = projection
236
+ column_name = None
237
+ lineage_name = None
238
+
239
+ # Try to extract fully qualified name
240
+ if isinstance(source_expr, exp.Column):
241
+ # Get table and column parts
242
+ table_name = source_expr.table
243
+ col_name = column_name or source_expr.name
244
+
245
+ if table_name and first_select:
246
+ # Resolve table reference (could be table, CTE, or subquery alias)
247
+ # This works at any nesting level because we're only looking at the immediate context
248
+ resolved_table = self._resolve_table_reference(
249
+ table_name, first_select
250
+ )
251
+ qualified_name = f"{resolved_table}.{col_name}"
252
+ columns.append(qualified_name)
253
+ # Map qualified name to what lineage expects
254
+ self._column_mapping[qualified_name] = lineage_name or col_name
255
+ elif first_select:
256
+ # No table qualifier - try to infer from FROM clause
257
+ # This handles "SELECT col FROM single_source" cases
258
+ inferred_table = self._infer_single_table_source(first_select)
259
+ if inferred_table:
260
+ qualified_name = f"{inferred_table}.{col_name}"
261
+ columns.append(qualified_name)
262
+ self._column_mapping[qualified_name] = (
263
+ lineage_name or col_name
264
+ )
265
+ else:
266
+ # Can't infer table, just use column name
267
+ columns.append(col_name)
268
+ self._column_mapping[col_name] = lineage_name or col_name
269
+ else:
270
+ # No SELECT found, just use column name
271
+ columns.append(col_name)
272
+ self._column_mapping[col_name] = lineage_name or col_name
273
+ else:
274
+ # For other expressions (literals, functions, etc.)
275
+ # Use the alias if available, otherwise the SQL representation
276
+ if column_name:
277
+ columns.append(column_name)
278
+ self._column_mapping[column_name] = column_name
279
+ else:
280
+ expr_str = source_expr.sql(dialect=self.dialect)
281
+ columns.append(expr_str)
282
+ self._column_mapping[expr_str] = expr_str
283
+
284
+ return columns
285
+
286
+ def _get_select_projections(self, node: exp.Expression) -> List[exp.Expression]:
287
+ """
288
+ Get the SELECT projections from a SELECT or set operation node.
289
+
290
+ For set operations (UNION, INTERSECT, EXCEPT), returns projections from
291
+ the first branch since all branches must have the same number of columns
292
+ with compatible types.
293
+
294
+ Args:
295
+ node: A SELECT or set operation (UNION/INTERSECT/EXCEPT) expression
296
+
297
+ Returns:
298
+ List of projection expressions from the SELECT clause
299
+ """
300
+ if isinstance(node, exp.Select):
301
+ return list(node.expressions)
302
+ elif isinstance(node, (exp.Union, exp.Intersect, exp.Except)):
303
+ # Recursively get from the left branch (could be nested set operations)
304
+ return self._get_select_projections(node.left)
305
+ return []
306
+
307
+ def _get_first_select(self, node: exp.Expression) -> Optional[exp.Select]:
308
+ """
309
+ Get the first SELECT node from a SELECT or set operation expression.
310
+
311
+ For set operations (UNION, INTERSECT, EXCEPT), returns the leftmost
312
+ SELECT branch.
313
+
314
+ Args:
315
+ node: A SELECT or set operation (UNION/INTERSECT/EXCEPT) expression
316
+
317
+ Returns:
318
+ The first SELECT node, or None if not found
319
+ """
320
+ if isinstance(node, exp.Select):
321
+ return node
322
+ elif isinstance(node, (exp.Union, exp.Intersect, exp.Except)):
323
+ return self._get_first_select(node.left)
324
+ return None
325
+
326
+ def analyze_queries(
327
+ self,
328
+ level: AnalysisLevel = AnalysisLevel.COLUMN,
329
+ column: Optional[str] = None,
330
+ source_column: Optional[str] = None,
331
+ table_filter: Optional[str] = None,
332
+ ) -> List[QueryLineageResult]:
333
+ """
334
+ Unified lineage analysis for single or multi-query files.
335
+
336
+ This method replaces all previous analysis methods (analyze_column_lineage,
337
+ analyze_reverse_lineage, analyze_table_lineage, analyze_all_queries, etc.)
338
+ with a single unified interface.
339
+
340
+ Args:
341
+ level: Analysis level ("column" or "table")
342
+ column: Target output column for forward lineage
343
+ source_column: Source column for reverse lineage (impact analysis)
344
+ table_filter: Filter queries to those referencing this table
345
+
346
+ Returns:
347
+ List of QueryLineageResult objects (one per query that matches filters)
348
+
349
+ Raises:
350
+ ValueError: If column or source_column is specified but not found
351
+
352
+ Examples:
353
+ # Forward lineage for all columns
354
+ results = analyzer.analyze_queries(level="column")
355
+
356
+ # Forward lineage for specific column
357
+ results = analyzer.analyze_queries(level="column", column="customers.id")
358
+
359
+ # Reverse lineage (impact analysis)
360
+ results = analyzer.analyze_queries(level="column", source_column="orders.customer_id")
361
+
362
+ # Table-level lineage
363
+ results = analyzer.analyze_queries(level="table")
364
+
365
+ # Filter by table (multi-query files)
366
+ results = analyzer.analyze_queries(table_filter="customers")
367
+ """
368
+ results = []
369
+ self._skipped_queries = [] # Reset skipped queries for this analysis
370
+ self._file_schema = {} # Reset file schema for this analysis run
371
+
372
+ for query_index, expr, preview in self._iterate_queries(table_filter):
373
+ # Temporarily swap self.expr to analyze this query
374
+ original_expr = self.expr
375
+ self.expr = expr
376
+
377
+ try:
378
+ lineage_items: List[LineageItem] = []
379
+
380
+ if level == AnalysisLevel.COLUMN:
381
+ if source_column:
382
+ # Reverse lineage (impact analysis)
383
+ lineage_items = self._analyze_reverse_lineage_internal(
384
+ source_column
385
+ )
386
+ if not lineage_items:
387
+ # Source column not found in this query - skip it
388
+ continue
389
+ else:
390
+ # Forward lineage
391
+ lineage_items = self._analyze_column_lineage_internal(column)
392
+ if not lineage_items:
393
+ # Column not found in this query (if column was specified) - skip it
394
+ if column:
395
+ continue
396
+ else: # table
397
+ lineage_items = self._analyze_table_lineage_internal()
398
+
399
+ # Create query result
400
+ results.append(
401
+ QueryLineageResult(
402
+ metadata=QueryMetadata(
403
+ query_index=query_index,
404
+ query_preview=preview,
405
+ ),
406
+ lineage_items=lineage_items,
407
+ level=level,
408
+ )
409
+ )
410
+ except ValueError as e:
411
+ # Unsupported statement type - track it and continue
412
+ stmt_type = self._get_statement_type(expr)
413
+ self._skipped_queries.append(
414
+ SkippedQuery(
415
+ query_index=query_index,
416
+ statement_type=stmt_type,
417
+ reason=str(e),
418
+ query_preview=preview,
419
+ )
420
+ )
421
+ finally:
422
+ # Extract schema from this statement AFTER analysis
423
+ # This builds up context for subsequent statements to use
424
+ self._extract_schema_from_statement(expr)
425
+ # Restore original expression
426
+ self.expr = original_expr
427
+
428
+ # Validate: if a specific column or source_column was specified and we got no results,
429
+ # raise ValueError to preserve backward compatibility
430
+ if not results:
431
+ if column:
432
+ raise ValueError(
433
+ f"Column '{column}' not found in any query. "
434
+ "Please check the column name and try again."
435
+ )
436
+ elif source_column:
437
+ raise ValueError(
438
+ f"Source column '{source_column}' not found in any query. "
439
+ "Please check the column name and try again."
440
+ )
441
+
442
+ return results
443
+
444
+ def analyze_tables(
445
+ self,
446
+ table_filter: Optional[str] = None,
447
+ ) -> List[QueryTablesResult]:
448
+ """
449
+ Analyze all tables involved in SQL queries.
450
+
451
+ This method extracts information about all tables referenced in the SQL,
452
+ including their usage (INPUT, OUTPUT, or BOTH) and object type (TABLE, VIEW,
453
+ CTE, or UNKNOWN).
454
+
455
+ Args:
456
+ table_filter: Filter queries to those referencing this table
457
+
458
+ Returns:
459
+ List of QueryTablesResult objects (one per query that matches filters)
460
+
461
+ Examples:
462
+ # Get all tables from SQL
463
+ results = analyzer.analyze_tables()
464
+
465
+ # Filter by table (multi-query files)
466
+ results = analyzer.analyze_tables(table_filter="customers")
467
+ """
468
+ results = []
469
+
470
+ for query_index, expr, preview in self._iterate_queries(table_filter):
471
+ # Temporarily swap self.expr to analyze this query
472
+ original_expr = self.expr
473
+ self.expr = expr
474
+
475
+ try:
476
+ tables = self._extract_tables_from_query()
477
+
478
+ # Create query result
479
+ results.append(
480
+ QueryTablesResult(
481
+ metadata=QueryMetadata(
482
+ query_index=query_index,
483
+ query_preview=preview,
484
+ ),
485
+ tables=tables,
486
+ )
487
+ )
488
+ finally:
489
+ # Restore original expression
490
+ self.expr = original_expr
491
+
492
+ return results
493
+
494
+ def _extract_tables_from_query(self) -> List[TableInfo]:
495
+ """
496
+ Extract all tables from the current query with usage and type information.
497
+
498
+ Returns:
499
+ List of TableInfo objects for all tables in the query.
500
+ """
501
+ # Track tables by name to consolidate INPUT/OUTPUT into BOTH
502
+ tables_dict: dict[str, TableInfo] = {}
503
+
504
+ # Extract CTEs first (they're INPUT only)
505
+ cte_names = self._extract_cte_names()
506
+ for cte_name in cte_names:
507
+ tables_dict[cte_name] = TableInfo(
508
+ name=cte_name,
509
+ usage=TableUsage.INPUT,
510
+ object_type=ObjectType.CTE,
511
+ )
512
+
513
+ # Determine target table and its type based on statement type
514
+ target_table, target_type = self._get_target_table_info()
515
+
516
+ # Get all table references in the query (except CTEs)
517
+ input_tables = self._get_all_input_tables(cte_names)
518
+
519
+ # Add target table as OUTPUT
520
+ if target_table:
521
+ if target_table in tables_dict:
522
+ # Table is both input and output (e.g., UPDATE with self-reference)
523
+ tables_dict[target_table] = TableInfo(
524
+ name=target_table,
525
+ usage=TableUsage.BOTH,
526
+ object_type=target_type,
527
+ )
528
+ else:
529
+ tables_dict[target_table] = TableInfo(
530
+ name=target_table,
531
+ usage=TableUsage.OUTPUT,
532
+ object_type=target_type,
533
+ )
534
+
535
+ # Add input tables
536
+ for table_name in input_tables:
537
+ if table_name in tables_dict:
538
+ # Already exists - might need to upgrade to BOTH
539
+ existing = tables_dict[table_name]
540
+ if existing.usage == TableUsage.OUTPUT:
541
+ tables_dict[table_name] = TableInfo(
542
+ name=table_name,
543
+ usage=TableUsage.BOTH,
544
+ object_type=existing.object_type,
545
+ )
546
+ # If INPUT or BOTH, keep as-is
547
+ else:
548
+ tables_dict[table_name] = TableInfo(
549
+ name=table_name,
550
+ usage=TableUsage.INPUT,
551
+ object_type=ObjectType.UNKNOWN,
552
+ )
553
+
554
+ # Return sorted list by name for consistent output
555
+ return sorted(tables_dict.values(), key=lambda t: t.name.lower())
556
+
557
+ def _extract_cte_names(self) -> Set[str]:
558
+ """
559
+ Extract all CTE (Common Table Expression) names from the query.
560
+
561
+ Returns:
562
+ Set of CTE names defined in the WITH clause.
563
+ """
564
+ cte_names: Set[str] = set()
565
+
566
+ # Look for WITH clause
567
+ if hasattr(self.expr, "args") and self.expr.args.get("with"):
568
+ with_clause = self.expr.args["with"]
569
+ for cte in with_clause.expressions:
570
+ if isinstance(cte, exp.CTE) and cte.alias:
571
+ cte_names.add(cte.alias)
572
+
573
+ return cte_names
574
+
575
+ def _get_target_table_info(self) -> Tuple[Optional[str], ObjectType]:
576
+ """
577
+ Get the target table name and its object type for DML/DDL statements.
578
+
579
+ Returns:
580
+ Tuple of (target_table_name, object_type) or (None, UNKNOWN) for SELECT.
581
+ """
582
+ # INSERT INTO table
583
+ if isinstance(self.expr, exp.Insert):
584
+ target = self.expr.this
585
+ if isinstance(target, exp.Table):
586
+ return (self._get_qualified_table_name(target), ObjectType.UNKNOWN)
587
+
588
+ # CREATE TABLE / CREATE VIEW
589
+ elif isinstance(self.expr, exp.Create):
590
+ kind = getattr(self.expr, "kind", "").upper()
591
+ target = self.expr.this
592
+
593
+ # Handle Schema wrapper (CREATE TABLE with columns)
594
+ if isinstance(target, exp.Schema):
595
+ target = target.this
596
+
597
+ if isinstance(target, exp.Table):
598
+ table_name = self._get_qualified_table_name(target)
599
+ if kind == "VIEW":
600
+ return (table_name, ObjectType.VIEW)
601
+ elif kind == "TABLE":
602
+ return (table_name, ObjectType.TABLE)
603
+ else:
604
+ return (table_name, ObjectType.UNKNOWN)
605
+
606
+ # UPDATE table
607
+ elif isinstance(self.expr, exp.Update):
608
+ target = self.expr.this
609
+ if isinstance(target, exp.Table):
610
+ return (self._get_qualified_table_name(target), ObjectType.UNKNOWN)
611
+
612
+ # MERGE INTO table
613
+ elif isinstance(self.expr, exp.Merge):
614
+ target = self.expr.this
615
+ if isinstance(target, exp.Table):
616
+ return (self._get_qualified_table_name(target), ObjectType.UNKNOWN)
617
+
618
+ # DELETE FROM table
619
+ elif isinstance(self.expr, exp.Delete):
620
+ target = self.expr.this
621
+ if isinstance(target, exp.Table):
622
+ return (self._get_qualified_table_name(target), ObjectType.UNKNOWN)
623
+
624
+ # DROP TABLE / DROP VIEW
625
+ elif isinstance(self.expr, exp.Drop):
626
+ kind = getattr(self.expr, "kind", "").upper()
627
+ target = self.expr.this
628
+ if isinstance(target, exp.Table):
629
+ table_name = self._get_qualified_table_name(target)
630
+ if kind == "VIEW":
631
+ return (table_name, ObjectType.VIEW)
632
+ elif kind == "TABLE":
633
+ return (table_name, ObjectType.TABLE)
634
+ else:
635
+ return (table_name, ObjectType.UNKNOWN)
636
+
637
+ # SELECT (no target table)
638
+ return (None, ObjectType.UNKNOWN)
639
+
640
+ def _get_all_input_tables(self, exclude_ctes: Set[str]) -> Set[str]:
641
+ """
642
+ Get all tables used as input (FROM, JOIN, subqueries, etc.).
643
+
644
+ Args:
645
+ exclude_ctes: Set of CTE names to exclude from results.
646
+
647
+ Returns:
648
+ Set of fully qualified table names that are used as input.
649
+ """
650
+ input_tables: Set[str] = set()
651
+
652
+ # Find all Table nodes in the expression tree
653
+ for table_node in self.expr.find_all(exp.Table):
654
+ table_name = self._get_qualified_table_name(table_node)
655
+
656
+ # Skip CTEs (they're tracked separately)
657
+ if table_name in exclude_ctes:
658
+ continue
659
+
660
+ # Skip the target table for certain statement types
661
+ # (it will be added separately as OUTPUT)
662
+ if self._is_target_table(table_node):
663
+ continue
664
+
665
+ input_tables.add(table_name)
666
+
667
+ return input_tables
668
+
669
+ def _is_target_table(self, table_node: exp.Table) -> bool:
670
+ """
671
+ Check if a table node is the target of a DML/DDL statement.
672
+
673
+ This helps distinguish the target table (OUTPUT) from source tables (INPUT)
674
+ in statements like INSERT, UPDATE, MERGE, DELETE.
675
+
676
+ Args:
677
+ table_node: The table node to check.
678
+
679
+ Returns:
680
+ True if this is the target table, False otherwise.
681
+ """
682
+ # For INSERT, the target is self.expr.this
683
+ if isinstance(self.expr, exp.Insert):
684
+ return table_node is self.expr.this
685
+
686
+ # For UPDATE, the target is self.expr.this
687
+ elif isinstance(self.expr, exp.Update):
688
+ return table_node is self.expr.this
689
+
690
+ # For MERGE, the target is self.expr.this
691
+ elif isinstance(self.expr, exp.Merge):
692
+ return table_node is self.expr.this
693
+
694
+ # For DELETE, the target is self.expr.this
695
+ elif isinstance(self.expr, exp.Delete):
696
+ return table_node is self.expr.this
697
+
698
+ # For CREATE TABLE/VIEW, check if it's in the schema
699
+ elif isinstance(self.expr, exp.Create):
700
+ target = self.expr.this
701
+ if isinstance(target, exp.Schema):
702
+ return table_node is target.this
703
+ return table_node is target
704
+
705
+ # For DROP, the target is self.expr.this
706
+ elif isinstance(self.expr, exp.Drop):
707
+ return table_node is self.expr.this
708
+
709
+ return False
710
+
711
+ def _analyze_column_lineage_internal(
712
+ self, column: Optional[str] = None
713
+ ) -> List[LineageItem]:
714
+ """
715
+ Internal method for analyzing column lineage. Returns flat list of LineageItem.
716
+
717
+ Args:
718
+ column: Optional specific column to analyze. If None, analyzes all columns.
719
+
720
+ Returns:
721
+ List of LineageItem objects (one per output-source relationship)
722
+ """
723
+ output_columns = self.get_output_columns()
724
+
725
+ if column:
726
+ # Analyze only the specified column (case-insensitive matching)
727
+ matched_column = None
728
+ column_lower = column.lower()
729
+ for output_col in output_columns:
730
+ if output_col.lower() == column_lower:
731
+ matched_column = output_col
732
+ break
733
+
734
+ if matched_column is None:
735
+ # Column not found - return empty list (caller will skip this query)
736
+ return []
737
+ columns_to_analyze = [matched_column]
738
+ else:
739
+ # Analyze all columns
740
+ columns_to_analyze = output_columns
741
+
742
+ lineage_items = []
743
+ # Get SQL for current expression only (not full multi-query SQL)
744
+ current_query_sql = self.expr.sql(dialect=self.dialect)
745
+
746
+ for col in columns_to_analyze:
747
+ try:
748
+ # Get the column name that lineage expects
749
+ lineage_col = self._column_mapping.get(col, col)
750
+
751
+ # Get lineage tree for this column using current query SQL only
752
+ # Pass file schema to enable SELECT * expansion for known tables/views
753
+ node = lineage(
754
+ lineage_col,
755
+ current_query_sql,
756
+ dialect=self.dialect,
757
+ schema=self._file_schema if self._file_schema else None,
758
+ )
759
+
760
+ # Collect all source columns
761
+ sources: Set[str] = set()
762
+ self._collect_source_columns(node, sources)
763
+
764
+ # Convert to flat LineageItem list (one item per source)
765
+ for source in sorted(sources):
766
+ lineage_items.append(
767
+ LineageItem(output_name=col, source_name=source)
768
+ )
769
+
770
+ # If no sources found, add single item with empty source
771
+ if not sources:
772
+ lineage_items.append(LineageItem(output_name=col, source_name=""))
773
+ except Exception:
774
+ # If lineage fails for a column, add item with empty source
775
+ lineage_items.append(LineageItem(output_name=col, source_name=""))
776
+
777
+ return lineage_items
778
+
779
+ def _analyze_table_lineage_internal(self) -> List[LineageItem]:
780
+ """
781
+ Internal method for analyzing table lineage. Returns flat list of LineageItem.
782
+
783
+ Returns:
784
+ List of LineageItem objects (one per output-source table relationship)
785
+ """
786
+ source_tables: Set[str] = set()
787
+
788
+ # Find all Table nodes in the AST
789
+ for table_node in self.expr.find_all(exp.Table):
790
+ # Get fully qualified table name
791
+ table_name = table_node.sql(dialect=self.dialect)
792
+ source_tables.add(table_name)
793
+
794
+ # The output table would typically be defined in INSERT/CREATE statements
795
+ # For SELECT statements, we use a placeholder
796
+ output_table = "query_result"
797
+
798
+ # Convert to flat LineageItem list (one item per source table)
799
+ lineage_items = []
800
+ for source in sorted(source_tables):
801
+ lineage_items.append(
802
+ LineageItem(output_name=output_table, source_name=source)
803
+ )
804
+
805
+ return lineage_items
806
+
807
+ def _analyze_reverse_lineage_internal(
808
+ self, source_column: str
809
+ ) -> List[LineageItem]:
810
+ """
811
+ Internal method for analyzing reverse lineage. Returns flat list of LineageItem.
812
+
813
+ Args:
814
+ source_column: Source column to analyze (e.g., "orders.customer_id")
815
+
816
+ Returns:
817
+ List of LineageItem objects (source column -> affected outputs)
818
+ """
819
+ # Step 1: Run forward lineage on all output columns
820
+ forward_items = self._analyze_column_lineage_internal(column=None)
821
+
822
+ # Step 2: Build reverse mapping (source -> [affected outputs])
823
+ reverse_map: dict[str, set[str]] = {}
824
+ all_outputs = set()
825
+
826
+ for item in forward_items:
827
+ all_outputs.add(item.output_name)
828
+ if item.source_name: # Skip empty sources
829
+ if item.source_name not in reverse_map:
830
+ reverse_map[item.source_name] = set()
831
+ reverse_map[item.source_name].add(item.output_name)
832
+
833
+ # Step 3: Find matching source (case-insensitive)
834
+ matched_source = None
835
+ affected_outputs = set()
836
+ source_column_lower = source_column.lower()
837
+
838
+ # First check if it's in reverse_map (derived columns)
839
+ for source in reverse_map.keys():
840
+ if source.lower() == source_column_lower:
841
+ matched_source = source
842
+ affected_outputs = reverse_map[matched_source]
843
+ break
844
+
845
+ # If not found, check if it's an output column (base table column)
846
+ if matched_source is None:
847
+ for output in all_outputs:
848
+ if output.lower() == source_column_lower:
849
+ matched_source = output
850
+ affected_outputs = {output} # It affects itself
851
+ break
852
+
853
+ if matched_source is None:
854
+ # Source column not found - return empty list (caller will skip this query)
855
+ return []
856
+
857
+ # Step 4: Return with semantic swap (source as output, affected as sources)
858
+ # This maintains the LineageItem structure where output_name is what we're looking at
859
+ # and source_name is what it affects
860
+ lineage_items = []
861
+ for affected in sorted(affected_outputs):
862
+ lineage_items.append(
863
+ LineageItem(output_name=matched_source, source_name=affected)
864
+ )
865
+
866
+ return lineage_items
867
+
868
+ def _get_statement_type(self, expr: Optional[exp.Expression] = None) -> str:
869
+ """
870
+ Get a human-readable name for the SQL statement type.
871
+
872
+ Args:
873
+ expr: Expression to check (uses self.expr if not provided)
874
+
875
+ Returns:
876
+ Statement type name (e.g., "CREATE FUNCTION", "SELECT", "DELETE")
877
+ """
878
+ target_expr = expr if expr is not None else self.expr
879
+ expr_type = type(target_expr).__name__
880
+
881
+ # Map common expression types to more readable names
882
+ type_map = {
883
+ "Select": "SELECT",
884
+ "Insert": "INSERT",
885
+ "Update": "UPDATE",
886
+ "Delete": "DELETE",
887
+ "Merge": "MERGE",
888
+ "Create": f"CREATE {getattr(target_expr, 'kind', '')}".strip(),
889
+ "Drop": f"DROP {getattr(target_expr, 'kind', '')}".strip(),
890
+ "Alter": "ALTER",
891
+ "Truncate": "TRUNCATE",
892
+ "Command": "COMMAND",
893
+ }
894
+
895
+ return type_map.get(expr_type, expr_type.upper())
896
+
897
+ def _get_target_and_select(
898
+ self,
899
+ ) -> Optional[
900
+ tuple[Optional[str], Union[exp.Select, exp.Union, exp.Intersect, exp.Except]]
901
+ ]:
902
+ """
903
+ Detect if this is a DML/DDL statement and extract the target table and SELECT node.
904
+
905
+ Returns:
906
+ Tuple of (target_table_name, select_node) where:
907
+ - target_table_name is the fully qualified target table for DML/DDL, or None for pure SELECT
908
+ - select_node is the SELECT statement that provides the data
909
+ - Returns None if the statement type doesn't contain a SELECT (e.g., CREATE FUNCTION)
910
+
911
+ Handles:
912
+ - INSERT INTO table SELECT ...
913
+ - CREATE TABLE table AS SELECT ...
914
+ - MERGE INTO table ...
915
+ - UPDATE table SET ... FROM (SELECT ...)
916
+ - Pure SELECT (returns None as target)
917
+ """
918
+ # Check for INSERT statement
919
+ if isinstance(self.expr, exp.Insert):
920
+ target = self.expr.this
921
+ if isinstance(target, exp.Table):
922
+ target_name = self._get_qualified_table_name(target)
923
+ # Find the SELECT within the INSERT (may be a set operation)
924
+ select_node = self.expr.expression
925
+ if isinstance(
926
+ select_node, (exp.Select, exp.Union, exp.Intersect, exp.Except)
927
+ ):
928
+ return (target_name, select_node)
929
+
930
+ # Check for CREATE TABLE AS SELECT (CTAS) or CREATE VIEW AS SELECT
931
+ elif isinstance(self.expr, exp.Create):
932
+ if self.expr.kind in ("TABLE", "VIEW"):
933
+ target = self.expr.this
934
+ if isinstance(target, exp.Schema):
935
+ # Get the table from schema
936
+ target = target.this
937
+ if isinstance(target, exp.Table):
938
+ target_name = self._get_qualified_table_name(target)
939
+ # Find the SELECT in the expression (may be a set operation)
940
+ select_node = self.expr.expression
941
+ if isinstance(
942
+ select_node, (exp.Select, exp.Union, exp.Intersect, exp.Except)
943
+ ):
944
+ return (target_name, select_node)
945
+
946
+ # Check for MERGE statement
947
+ elif isinstance(self.expr, exp.Merge):
948
+ target = self.expr.this
949
+ if isinstance(target, exp.Table):
950
+ target_name = self._get_qualified_table_name(target)
951
+ # For MERGE, we need to find the SELECT in the USING clause
952
+ # This is more complex, for now treat it as a SELECT
953
+ select_nodes = list(self.expr.find_all(exp.Select))
954
+ if select_nodes:
955
+ return (target_name, select_nodes[0])
956
+
957
+ # Check for UPDATE with subquery
958
+ elif isinstance(self.expr, exp.Update):
959
+ target = self.expr.this
960
+ if isinstance(target, exp.Table):
961
+ target_name = self._get_qualified_table_name(target)
962
+ # For UPDATE, find the SELECT if there is one
963
+ select_nodes = list(self.expr.find_all(exp.Select))
964
+ if select_nodes:
965
+ return (target_name, select_nodes[0])
966
+
967
+ # Default: Pure SELECT (DQL)
968
+ select_nodes = list(self.expr.find_all(exp.Select))
969
+ if select_nodes:
970
+ return (None, select_nodes[0])
971
+
972
+ # Fallback: return the expression as-is if it's a SELECT
973
+ if isinstance(self.expr, exp.Select):
974
+ return (None, self.expr)
975
+
976
+ # No SELECT found - return None to indicate unsupported statement
977
+ return None
978
+
979
+ def _get_qualified_table_name(self, table: exp.Table) -> str:
980
+ """
981
+ Get the fully qualified name for a table.
982
+
983
+ Args:
984
+ table: SQLGlot Table expression
985
+
986
+ Returns:
987
+ Fully qualified table name (database.table or catalog.database.table)
988
+ """
989
+ parts = []
990
+ if table.catalog:
991
+ parts.append(table.catalog)
992
+ if table.db:
993
+ parts.append(table.db)
994
+ parts.append(table.name)
995
+ return ".".join(parts)
996
+
997
+ def _resolve_table_reference(self, ref: str, select_node: exp.Select) -> str:
998
+ """
999
+ Resolve a table reference (alias, CTE name, or actual table) to its canonical name.
1000
+
1001
+ This works at any nesting level by only looking at the immediate SELECT context.
1002
+ For CTEs and subqueries, returns their alias name (which is the "table name" in that context).
1003
+ For actual tables with aliases, returns the actual table name.
1004
+
1005
+ Args:
1006
+ ref: The table reference to resolve (could be alias, CTE name, or table name)
1007
+ select_node: The SELECT node containing the FROM/JOIN clauses
1008
+
1009
+ Returns:
1010
+ The canonical table name (actual table for real tables, alias for CTEs/subqueries)
1011
+ """
1012
+ # Check if this is a CTE reference first
1013
+ # CTEs are defined in the WITH clause and referenced by their alias
1014
+ parent = select_node
1015
+ while parent:
1016
+ if isinstance(parent, (exp.Select, exp.Union)) and parent.args.get("with"):
1017
+ cte_node = parent.args["with"]
1018
+ for cte in cte_node.expressions:
1019
+ if isinstance(cte, exp.CTE) and cte.alias == ref:
1020
+ # This is a CTE - return the CTE alias as the "table name"
1021
+ return ref
1022
+ parent = parent.parent if hasattr(parent, "parent") else None
1023
+
1024
+ # Look for table references in FROM and JOIN clauses
1025
+ for table_ref in select_node.find_all(exp.Table):
1026
+ # Check if this table has the matching alias
1027
+ if table_ref.alias == ref:
1028
+ # Return the qualified table name
1029
+ parts = []
1030
+ if table_ref.db:
1031
+ parts.append(table_ref.db)
1032
+ if table_ref.catalog:
1033
+ parts.insert(0, table_ref.catalog)
1034
+ parts.append(table_ref.name)
1035
+ return ".".join(parts)
1036
+ # Also check if ref matches the table name directly (no alias case)
1037
+ elif table_ref.name == ref and not table_ref.alias:
1038
+ parts = []
1039
+ if table_ref.db:
1040
+ parts.append(table_ref.db)
1041
+ if table_ref.catalog:
1042
+ parts.insert(0, table_ref.catalog)
1043
+ parts.append(table_ref.name)
1044
+ return ".".join(parts)
1045
+
1046
+ # Check for subquery aliases in FROM clause
1047
+ if select_node.args.get("from"):
1048
+ from_clause = select_node.args["from"]
1049
+ if isinstance(from_clause, exp.From):
1050
+ source = from_clause.this
1051
+ # Check if it's a subquery with matching alias
1052
+ if isinstance(source, exp.Subquery) and source.alias == ref:
1053
+ # Return the subquery alias as the "table name"
1054
+ return ref
1055
+ # Check if it's a table with matching alias
1056
+ elif isinstance(source, exp.Table) and source.alias == ref:
1057
+ parts = []
1058
+ if source.db:
1059
+ parts.append(source.db)
1060
+ if source.catalog:
1061
+ parts.insert(0, source.catalog)
1062
+ parts.append(source.name)
1063
+ return ".".join(parts)
1064
+
1065
+ # Check JOIN clauses for subqueries
1066
+ for join in select_node.find_all(exp.Join):
1067
+ if isinstance(join.this, exp.Subquery) and join.this.alias == ref:
1068
+ return ref
1069
+ elif isinstance(join.this, exp.Table) and join.this.alias == ref:
1070
+ parts = []
1071
+ if join.this.db:
1072
+ parts.append(join.this.db)
1073
+ if join.this.catalog:
1074
+ parts.insert(0, join.this.catalog)
1075
+ parts.append(join.this.name)
1076
+ return ".".join(parts)
1077
+
1078
+ # If we can't resolve, return the reference as-is
1079
+ return ref
1080
+
1081
+ def _infer_single_table_source(self, select_node: exp.Select) -> Optional[str]:
1082
+ """
1083
+ Infer the table name when there's only one table in FROM clause.
1084
+
1085
+ This handles cases like "SELECT col FROM table" where col has no table prefix.
1086
+
1087
+ Args:
1088
+ select_node: The SELECT node
1089
+
1090
+ Returns:
1091
+ The table name if there's exactly one source, None otherwise
1092
+ """
1093
+ if not select_node.args.get("from"):
1094
+ return None
1095
+
1096
+ from_clause = select_node.args["from"]
1097
+ if not isinstance(from_clause, exp.From):
1098
+ return None
1099
+
1100
+ source = from_clause.this
1101
+
1102
+ # Check for JOINs - if there are joins, we can't infer
1103
+ if list(select_node.find_all(exp.Join)):
1104
+ return None
1105
+
1106
+ # Single table or CTE/subquery
1107
+ if isinstance(source, exp.Table):
1108
+ parts = []
1109
+ if source.db:
1110
+ parts.append(source.db)
1111
+ if source.catalog:
1112
+ parts.insert(0, source.catalog)
1113
+ if source.alias:
1114
+ # If the table has an alias, use the alias
1115
+ return source.alias
1116
+ parts.append(source.name)
1117
+ return ".".join(parts)
1118
+ elif isinstance(source, (exp.Subquery, exp.CTE)):
1119
+ # Return the subquery/CTE alias
1120
+ return source.alias if source.alias else None
1121
+
1122
+ return None
1123
+
1124
+ def _collect_source_columns(self, node: Node, sources: Set[str]) -> None:
1125
+ """
1126
+ Recursively collect all source columns from a lineage tree.
1127
+
1128
+ This traverses the lineage tree depth-first, collecting leaf nodes
1129
+ which represent the actual source columns.
1130
+
1131
+ Args:
1132
+ node: The current lineage node
1133
+ sources: Set to accumulate source column names
1134
+ """
1135
+ if not node.downstream:
1136
+ # Leaf node - this is a source column
1137
+ # Check if this is a literal value (SQLGlot uses position numbers for literals)
1138
+ if node.name.isdigit():
1139
+ # This is a literal - extract the actual value from the expression
1140
+ literal_repr = self._extract_literal_representation(node)
1141
+ sources.add(literal_repr)
1142
+ else:
1143
+ # SQLGlot's lineage provides qualified names, but may use aliases
1144
+ # Need to resolve aliases to actual table names
1145
+ qualified_name = self._resolve_source_column_alias(node.name)
1146
+ sources.add(qualified_name)
1147
+ else:
1148
+ # Traverse deeper into the tree
1149
+ for child in node.downstream:
1150
+ self._collect_source_columns(child, sources)
1151
+
1152
+ def _extract_literal_representation(self, node: Node) -> str:
1153
+ """
1154
+ Extract a human-readable representation of a literal value from a lineage node.
1155
+
1156
+ When SQLGlot encounters a literal value in a UNION branch, it returns the
1157
+ column position as the node name. This method extracts the actual literal
1158
+ value from the node's expression.
1159
+
1160
+ Args:
1161
+ node: A lineage node where node.name is a digit (position number)
1162
+
1163
+ Returns:
1164
+ A string like "<literal: NULL>" or "<literal: 'value'>" or "<literal: 0>"
1165
+ """
1166
+ try:
1167
+ expr = node.expression
1168
+ # The expression is typically an Alias wrapping the actual value
1169
+ if isinstance(expr, exp.Alias):
1170
+ literal_expr = expr.this
1171
+ literal_sql = literal_expr.sql(dialect=self.dialect)
1172
+ return f"<literal: {literal_sql}>"
1173
+ else:
1174
+ # Fallback: use the expression's SQL representation
1175
+ return f"<literal: {expr.sql(dialect=self.dialect)}>"
1176
+ except Exception:
1177
+ # If extraction fails, return a generic literal marker
1178
+ return "<literal>"
1179
+
1180
+ def _get_query_tables(self) -> List[str]:
1181
+ """
1182
+ Get all table names referenced in the current query.
1183
+
1184
+ Returns:
1185
+ List of fully qualified table names used in the query
1186
+ """
1187
+ tables = []
1188
+ for table_node in self.expr.find_all(exp.Table):
1189
+ table_name = self._get_qualified_table_name(table_node)
1190
+ tables.append(table_name)
1191
+ return tables
1192
+
1193
+ def _resolve_source_column_alias(self, column_name: str) -> str:
1194
+ """
1195
+ Resolve table aliases in source column names.
1196
+
1197
+ This searches through ALL SELECT nodes in the query (including nested ones)
1198
+ to find and resolve table aliases, CTEs, and subqueries.
1199
+
1200
+ Args:
1201
+ column_name: Column name like "alias.column" or "table.column"
1202
+
1203
+ Returns:
1204
+ Fully qualified column name with actual table name
1205
+ """
1206
+ # Parse the column name (format: table.column or db.table.column)
1207
+ parts = column_name.split(".")
1208
+
1209
+ if len(parts) < 2:
1210
+ # No table qualifier, return as-is
1211
+ return column_name
1212
+
1213
+ # The table part might be an alias, CTE name, or actual table
1214
+ table_part = parts[0] if len(parts) == 2 else parts[-2]
1215
+ column_part = parts[-1]
1216
+
1217
+ # Try to resolve by searching through ALL SELECT nodes (including nested)
1218
+ # This handles cases where the alias is defined deep in a subquery/CTE
1219
+ for select_node in self.expr.find_all(exp.Select):
1220
+ resolved = self._resolve_table_reference(table_part, select_node)
1221
+ # If resolution changed the name, we found it
1222
+ if resolved != table_part:
1223
+ # Reconstruct with resolved table name
1224
+ if len(parts) == 2:
1225
+ return f"{resolved}.{column_part}"
1226
+ else:
1227
+ # Has database part
1228
+ return f"{parts[0]}.{resolved}.{column_part}"
1229
+
1230
+ # If we couldn't resolve in any SELECT, return as-is
1231
+ return column_name
1232
+
1233
+ def _generate_query_preview(self, expr: exp.Expression) -> str:
1234
+ """
1235
+ Generate a preview string for a query (first 100 chars, normalized).
1236
+
1237
+ Args:
1238
+ expr: The SQL expression to generate a preview for
1239
+
1240
+ Returns:
1241
+ Preview string (first 100 chars with "..." if truncated)
1242
+ """
1243
+ query_text = expr.sql(dialect=self.dialect)
1244
+ preview = " ".join(query_text.split())[:100]
1245
+ if len(" ".join(query_text.split())) > 100:
1246
+ preview += "..."
1247
+ return preview
1248
+
1249
+ def _filter_by_table(self, expr: exp.Expression, table_filter: str) -> bool:
1250
+ """
1251
+ Check if a query references a specific table.
1252
+
1253
+ Args:
1254
+ expr: The SQL expression to check
1255
+ table_filter: Table name to filter by (case-insensitive partial match)
1256
+
1257
+ Returns:
1258
+ True if the query references the table, False otherwise
1259
+ """
1260
+ # Temporarily swap self.expr to analyze this expression
1261
+ original_expr = self.expr
1262
+ self.expr = expr
1263
+ try:
1264
+ query_tables = self._get_query_tables()
1265
+ table_filter_lower = table_filter.lower()
1266
+ return any(table_filter_lower in table.lower() for table in query_tables)
1267
+ finally:
1268
+ self.expr = original_expr
1269
+
1270
+ def _iterate_queries(
1271
+ self, table_filter: Optional[str] = None
1272
+ ) -> Iterator[Tuple[int, exp.Expression, str]]:
1273
+ """
1274
+ Iterate over queries with filtering and preview generation.
1275
+
1276
+ Args:
1277
+ table_filter: Optional table name to filter queries by
1278
+
1279
+ Yields:
1280
+ Tuple of (query_index, expression, query_preview)
1281
+ """
1282
+ for idx, expr in enumerate(self.expressions):
1283
+ # Apply table filter
1284
+ if table_filter and not self._filter_by_table(expr, table_filter):
1285
+ continue
1286
+
1287
+ # Generate preview
1288
+ preview = self._generate_query_preview(expr)
1289
+
1290
+ yield idx, expr, preview
1291
+
1292
+ # -------------------------------------------------------------------------
1293
+ # File-scoped schema context methods
1294
+ # -------------------------------------------------------------------------
1295
+
1296
+ def _extract_schema_from_statement(self, expr: exp.Expression) -> None:
1297
+ """
1298
+ Extract column definitions from CREATE VIEW/TABLE AS SELECT statements.
1299
+
1300
+ This method builds up file-scoped schema context as statements are processed,
1301
+ enabling SQLGlot to correctly expand SELECT * and trace cross-statement references.
1302
+
1303
+ Args:
1304
+ expr: The SQL expression to extract schema from
1305
+ """
1306
+ # Only handle CREATE VIEW or CREATE TABLE (AS SELECT)
1307
+ if not isinstance(expr, exp.Create):
1308
+ return
1309
+ if expr.kind not in ("VIEW", "TABLE"):
1310
+ return
1311
+
1312
+ # Get target table/view name
1313
+ target = expr.this
1314
+ if isinstance(target, exp.Schema):
1315
+ target = target.this
1316
+ if not isinstance(target, exp.Table):
1317
+ return
1318
+
1319
+ target_name = self._get_qualified_table_name(target)
1320
+
1321
+ # Get the SELECT node from the CREATE statement
1322
+ select_node = expr.expression
1323
+ if select_node is None:
1324
+ return
1325
+
1326
+ # Handle Subquery wrapper (e.g., CREATE VIEW AS (SELECT ...))
1327
+ if isinstance(select_node, exp.Subquery):
1328
+ select_node = select_node.this
1329
+
1330
+ if not isinstance(
1331
+ select_node, (exp.Select, exp.Union, exp.Intersect, exp.Except)
1332
+ ):
1333
+ return
1334
+
1335
+ # Extract column names from the SELECT
1336
+ columns = self._extract_columns_from_select(select_node)
1337
+
1338
+ if columns:
1339
+ # Store with UNKNOWN type - SQLGlot only needs column names for expansion
1340
+ self._file_schema[target_name] = {col: "UNKNOWN" for col in columns}
1341
+
1342
+ def _extract_columns_from_select(
1343
+ self, select_node: Union[exp.Select, exp.Union, exp.Intersect, exp.Except]
1344
+ ) -> List[str]:
1345
+ """
1346
+ Extract column names from a SELECT statement.
1347
+
1348
+ Handles aliases, direct column references, and SELECT * by resolving
1349
+ against the known file schema.
1350
+
1351
+ Args:
1352
+ select_node: The SELECT or set operation expression
1353
+
1354
+ Returns:
1355
+ List of column names
1356
+ """
1357
+ columns: List[str] = []
1358
+
1359
+ # Get projections (for UNION, use first branch)
1360
+ projections = self._get_select_projections(select_node)
1361
+ first_select = self._get_first_select(select_node)
1362
+
1363
+ for projection in projections:
1364
+ if isinstance(projection, exp.Alias):
1365
+ # Use the alias name as the column name
1366
+ columns.append(projection.alias)
1367
+ elif isinstance(projection, exp.Column):
1368
+ # Check if this is a table-qualified star (e.g., t.*)
1369
+ if isinstance(projection.this, exp.Star):
1370
+ # Resolve table-qualified star from known schema
1371
+ table_name = projection.table
1372
+ if table_name and first_select:
1373
+ qualified_star_cols = self._resolve_qualified_star(
1374
+ table_name, first_select
1375
+ )
1376
+ columns.extend(qualified_star_cols)
1377
+ else:
1378
+ # Use the column name
1379
+ columns.append(projection.name)
1380
+ elif isinstance(projection, exp.Star):
1381
+ # Resolve SELECT * from known schema
1382
+ if first_select:
1383
+ star_columns = self._resolve_star_columns(first_select)
1384
+ columns.extend(star_columns)
1385
+ else:
1386
+ # For expressions without alias, use SQL representation
1387
+ col_sql = projection.sql(dialect=self.dialect)
1388
+ columns.append(col_sql)
1389
+
1390
+ return columns
1391
+
1392
+ def _resolve_star_columns(self, select_node: exp.Select) -> List[str]:
1393
+ """
1394
+ Resolve SELECT * to actual column names from known file schema or CTEs.
1395
+
1396
+ Args:
1397
+ select_node: The SELECT node containing the * reference
1398
+
1399
+ Returns:
1400
+ List of column names if source is known, empty list otherwise
1401
+ """
1402
+ columns: List[str] = []
1403
+
1404
+ # Get the source table(s) from FROM clause
1405
+ from_clause = select_node.args.get("from")
1406
+ if not from_clause or not isinstance(from_clause, exp.From):
1407
+ return columns
1408
+
1409
+ source = from_clause.this
1410
+
1411
+ # Handle table reference from FROM clause
1412
+ columns.extend(self._resolve_source_columns(source, select_node))
1413
+
1414
+ # Handle JOIN clauses - collect columns from all joined tables
1415
+ # EXCEPT for SEMI and ANTI joins which only return left table columns
1416
+ joins = select_node.args.get("joins")
1417
+ if joins:
1418
+ for join in joins:
1419
+ if isinstance(join, exp.Join):
1420
+ # SEMI and ANTI joins don't include right table columns in SELECT *
1421
+ join_kind = join.kind
1422
+ if join_kind in ("SEMI", "ANTI"):
1423
+ # Skip right table columns for SEMI/ANTI joins
1424
+ continue
1425
+ join_source = join.this
1426
+ columns.extend(
1427
+ self._resolve_source_columns(join_source, select_node)
1428
+ )
1429
+
1430
+ # Handle LATERAL VIEW clauses - collect generated columns
1431
+ laterals = select_node.args.get("laterals")
1432
+ if laterals:
1433
+ for lateral in laterals:
1434
+ if isinstance(lateral, exp.Lateral):
1435
+ lateral_cols = self._resolve_lateral_columns(lateral)
1436
+ columns.extend(lateral_cols)
1437
+
1438
+ return columns
1439
+
1440
+ def _resolve_lateral_columns(self, lateral: exp.Lateral) -> List[str]:
1441
+ """
1442
+ Extract generated column names from a LATERAL VIEW clause.
1443
+
1444
+ Args:
1445
+ lateral: The Lateral expression node
1446
+
1447
+ Returns:
1448
+ List of generated column names (e.g., ['elem'] for explode,
1449
+ ['pos', 'elem'] for posexplode)
1450
+ """
1451
+ # Use SQLGlot's built-in property to get alias column names
1452
+ return lateral.alias_column_names or []
1453
+
1454
+ def _resolve_source_columns(
1455
+ self, source: exp.Expression, select_node: exp.Select
1456
+ ) -> List[str]:
1457
+ """
1458
+ Resolve columns from a single source (table, subquery, etc.).
1459
+
1460
+ Args:
1461
+ source: The source expression (Table, Subquery, etc.)
1462
+ select_node: The containing SELECT node for CTE resolution
1463
+
1464
+ Returns:
1465
+ List of column names from the source
1466
+ """
1467
+ columns: List[str] = []
1468
+
1469
+ # Handle table reference
1470
+ if isinstance(source, exp.Table):
1471
+ source_name = self._get_qualified_table_name(source)
1472
+
1473
+ # First check file schema (views/tables from previous statements)
1474
+ if source_name in self._file_schema:
1475
+ columns.extend(self._file_schema[source_name].keys())
1476
+ else:
1477
+ # Check if this is a CTE reference within the same statement
1478
+ cte_columns = self._resolve_cte_columns(source_name, select_node)
1479
+ columns.extend(cte_columns)
1480
+
1481
+ # Handle subquery with alias
1482
+ elif isinstance(source, exp.Subquery):
1483
+ # First check if this subquery alias is in file schema
1484
+ if source.alias and source.alias in self._file_schema:
1485
+ columns.extend(self._file_schema[source.alias].keys())
1486
+ else:
1487
+ # Extract columns from the subquery's SELECT
1488
+ inner_select = source.this
1489
+ if isinstance(inner_select, exp.Select):
1490
+ subquery_cols = self._extract_subquery_columns(inner_select)
1491
+ columns.extend(subquery_cols)
1492
+
1493
+ return columns
1494
+
1495
+ def _resolve_qualified_star(
1496
+ self, table_name: str, select_node: exp.Select
1497
+ ) -> List[str]:
1498
+ """
1499
+ Resolve a table-qualified star (e.g., t.*) to actual column names.
1500
+
1501
+ Args:
1502
+ table_name: The table/alias name qualifying the star
1503
+ select_node: The SELECT node for context
1504
+
1505
+ Returns:
1506
+ List of column names from the specified table
1507
+ """
1508
+ # First check file schema
1509
+ if table_name in self._file_schema:
1510
+ return list(self._file_schema[table_name].keys())
1511
+
1512
+ # Check if it's a CTE reference
1513
+ cte_columns = self._resolve_cte_columns(table_name, select_node)
1514
+ if cte_columns:
1515
+ return cte_columns
1516
+
1517
+ # Check if the table name is an alias - need to resolve the actual table
1518
+ from_clause = select_node.args.get("from")
1519
+ if from_clause and isinstance(from_clause, exp.From):
1520
+ source = from_clause.this
1521
+ if isinstance(source, exp.Table) and source.alias == table_name:
1522
+ actual_name = self._get_qualified_table_name(source)
1523
+ if actual_name in self._file_schema:
1524
+ return list(self._file_schema[actual_name].keys())
1525
+
1526
+ # Check JOIN clauses for aliased tables
1527
+ joins = select_node.args.get("joins")
1528
+ if joins:
1529
+ for join in joins:
1530
+ if isinstance(join, exp.Join):
1531
+ join_source = join.this
1532
+ if (
1533
+ isinstance(join_source, exp.Table)
1534
+ and join_source.alias == table_name
1535
+ ):
1536
+ actual_name = self._get_qualified_table_name(join_source)
1537
+ if actual_name in self._file_schema:
1538
+ return list(self._file_schema[actual_name].keys())
1539
+
1540
+ return []
1541
+
1542
+ def _extract_subquery_columns(self, subquery_select: exp.Select) -> List[str]:
1543
+ """
1544
+ Extract column names from a subquery's SELECT statement.
1545
+
1546
+ Args:
1547
+ subquery_select: The SELECT expression within the subquery
1548
+
1549
+ Returns:
1550
+ List of column names
1551
+ """
1552
+ columns: List[str] = []
1553
+
1554
+ for projection in subquery_select.expressions:
1555
+ if isinstance(projection, exp.Alias):
1556
+ columns.append(projection.alias)
1557
+ elif isinstance(projection, exp.Column):
1558
+ # Check for table-qualified star (t.*)
1559
+ if isinstance(projection.this, exp.Star):
1560
+ table_name = projection.table
1561
+ if table_name:
1562
+ qualified_cols = self._resolve_qualified_star(
1563
+ table_name, subquery_select
1564
+ )
1565
+ columns.extend(qualified_cols)
1566
+ else:
1567
+ columns.append(projection.name)
1568
+ elif isinstance(projection, exp.Star):
1569
+ # Resolve SELECT * in subquery
1570
+ star_columns = self._resolve_star_columns(subquery_select)
1571
+ columns.extend(star_columns)
1572
+ else:
1573
+ col_sql = projection.sql(dialect=self.dialect)
1574
+ columns.append(col_sql)
1575
+
1576
+ return columns
1577
+
1578
+ def _resolve_cte_columns(self, cte_name: str, select_node: exp.Select) -> List[str]:
1579
+ """
1580
+ Resolve columns from a CTE definition within the same statement.
1581
+
1582
+ Args:
1583
+ cte_name: Name of the CTE to resolve
1584
+ select_node: The SELECT node that references the CTE
1585
+
1586
+ Returns:
1587
+ List of column names from the CTE, empty if CTE not found
1588
+ """
1589
+ # Walk up the tree to find the WITH clause containing this CTE
1590
+ parent = select_node
1591
+ while parent:
1592
+ if hasattr(parent, "args") and parent.args.get("with"):
1593
+ with_clause = parent.args["with"]
1594
+ for cte in with_clause.expressions:
1595
+ if isinstance(cte, exp.CTE) and cte.alias == cte_name:
1596
+ # Found the CTE - extract its columns
1597
+ cte_select = cte.this
1598
+ if isinstance(cte_select, exp.Select):
1599
+ return self._extract_cte_select_columns(cte_select)
1600
+ parent = parent.parent if hasattr(parent, "parent") else None
1601
+
1602
+ return []
1603
+
1604
+ def _extract_cte_select_columns(self, cte_select: exp.Select) -> List[str]:
1605
+ """
1606
+ Extract column names from a CTE's SELECT statement.
1607
+
1608
+ This handles SELECT * within the CTE by resolving against file schema.
1609
+
1610
+ Args:
1611
+ cte_select: The SELECT expression within the CTE
1612
+
1613
+ Returns:
1614
+ List of column names
1615
+ """
1616
+ columns: List[str] = []
1617
+
1618
+ for projection in cte_select.expressions:
1619
+ if isinstance(projection, exp.Alias):
1620
+ columns.append(projection.alias)
1621
+ elif isinstance(projection, exp.Column):
1622
+ columns.append(projection.name)
1623
+ elif isinstance(projection, exp.Star):
1624
+ # Resolve SELECT * in CTE from file schema
1625
+ star_columns = self._resolve_star_columns(cte_select)
1626
+ columns.extend(star_columns)
1627
+ else:
1628
+ col_sql = projection.sql(dialect=self.dialect)
1629
+ columns.append(col_sql)
1630
+
1631
+ return columns