sql-glider 0.1.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sql_glider-0.1.8.dist-info/METADATA +893 -0
- sql_glider-0.1.8.dist-info/RECORD +34 -0
- sql_glider-0.1.8.dist-info/WHEEL +4 -0
- sql_glider-0.1.8.dist-info/entry_points.txt +9 -0
- sql_glider-0.1.8.dist-info/licenses/LICENSE +201 -0
- sqlglider/__init__.py +3 -0
- sqlglider/_version.py +34 -0
- sqlglider/catalog/__init__.py +30 -0
- sqlglider/catalog/base.py +99 -0
- sqlglider/catalog/databricks.py +255 -0
- sqlglider/catalog/registry.py +121 -0
- sqlglider/cli.py +1589 -0
- sqlglider/dissection/__init__.py +17 -0
- sqlglider/dissection/analyzer.py +767 -0
- sqlglider/dissection/formatters.py +222 -0
- sqlglider/dissection/models.py +112 -0
- sqlglider/global_models.py +17 -0
- sqlglider/graph/__init__.py +42 -0
- sqlglider/graph/builder.py +349 -0
- sqlglider/graph/merge.py +136 -0
- sqlglider/graph/models.py +289 -0
- sqlglider/graph/query.py +287 -0
- sqlglider/graph/serialization.py +107 -0
- sqlglider/lineage/__init__.py +10 -0
- sqlglider/lineage/analyzer.py +1631 -0
- sqlglider/lineage/formatters.py +335 -0
- sqlglider/templating/__init__.py +51 -0
- sqlglider/templating/base.py +103 -0
- sqlglider/templating/jinja.py +163 -0
- sqlglider/templating/registry.py +124 -0
- sqlglider/templating/variables.py +295 -0
- sqlglider/utils/__init__.py +11 -0
- sqlglider/utils/config.py +155 -0
- sqlglider/utils/file_utils.py +38 -0
|
@@ -0,0 +1,1631 @@
|
|
|
1
|
+
"""Core lineage analysis using SQLGlot."""
|
|
2
|
+
|
|
3
|
+
from enum import Enum
|
|
4
|
+
from typing import Callable, Dict, Iterator, List, Optional, Set, Tuple, Union
|
|
5
|
+
|
|
6
|
+
from pydantic import BaseModel, Field
|
|
7
|
+
from sqlglot import exp, parse
|
|
8
|
+
from sqlglot.errors import ParseError
|
|
9
|
+
from sqlglot.lineage import Node, lineage
|
|
10
|
+
|
|
11
|
+
from sqlglider.global_models import AnalysisLevel
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class TableUsage(str, Enum):
|
|
15
|
+
"""How a table is used in a query."""
|
|
16
|
+
|
|
17
|
+
INPUT = "INPUT"
|
|
18
|
+
OUTPUT = "OUTPUT"
|
|
19
|
+
BOTH = "BOTH"
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class ObjectType(str, Enum):
|
|
23
|
+
"""Type of database object."""
|
|
24
|
+
|
|
25
|
+
TABLE = "TABLE"
|
|
26
|
+
VIEW = "VIEW"
|
|
27
|
+
CTE = "CTE"
|
|
28
|
+
UNKNOWN = "UNKNOWN"
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class TableInfo(BaseModel):
|
|
32
|
+
"""Information about a table referenced in a query."""
|
|
33
|
+
|
|
34
|
+
name: str = Field(..., description="Fully qualified table name")
|
|
35
|
+
usage: TableUsage = Field(
|
|
36
|
+
..., description="How the table is used (INPUT, OUTPUT, BOTH)"
|
|
37
|
+
)
|
|
38
|
+
object_type: ObjectType = Field(
|
|
39
|
+
..., description="Type of object (TABLE, VIEW, CTE, UNKNOWN)"
|
|
40
|
+
)
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
class QueryTablesResult(BaseModel):
|
|
44
|
+
"""Result of table analysis for a single query."""
|
|
45
|
+
|
|
46
|
+
metadata: "QueryMetadata"
|
|
47
|
+
tables: List[TableInfo] = Field(default_factory=list)
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
class LineageItem(BaseModel):
|
|
51
|
+
"""Represents a single lineage relationship (output -> source)."""
|
|
52
|
+
|
|
53
|
+
output_name: str = Field(..., description="Output column/table name")
|
|
54
|
+
source_name: str = Field(..., description="Source column/table name")
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
class QueryMetadata(BaseModel):
|
|
58
|
+
"""Query execution context."""
|
|
59
|
+
|
|
60
|
+
query_index: int = Field(..., description="0-based query index")
|
|
61
|
+
query_preview: str = Field(..., description="First 100 chars of query")
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
class QueryLineageResult(BaseModel):
|
|
65
|
+
"""Complete lineage result for a single query."""
|
|
66
|
+
|
|
67
|
+
metadata: QueryMetadata
|
|
68
|
+
lineage_items: List[LineageItem] = Field(default_factory=list)
|
|
69
|
+
level: AnalysisLevel
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
class SkippedQuery(BaseModel):
|
|
73
|
+
"""Information about a query that was skipped during analysis."""
|
|
74
|
+
|
|
75
|
+
query_index: int = Field(..., description="0-based query index")
|
|
76
|
+
statement_type: str = Field(..., description="Type of SQL statement (e.g., CREATE)")
|
|
77
|
+
reason: str = Field(..., description="Reason for skipping")
|
|
78
|
+
query_preview: str = Field(..., description="First 100 chars of query")
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
# Type alias for warning callback function
|
|
82
|
+
WarningCallback = Callable[[str], None]
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
class LineageAnalyzer:
|
|
86
|
+
"""Analyze column and table lineage for SQL queries."""
|
|
87
|
+
|
|
88
|
+
def __init__(self, sql: str, dialect: str = "spark"):
|
|
89
|
+
"""
|
|
90
|
+
Initialize the lineage analyzer.
|
|
91
|
+
|
|
92
|
+
Args:
|
|
93
|
+
sql: SQL query string to analyze (can contain multiple statements)
|
|
94
|
+
dialect: SQL dialect (default: spark)
|
|
95
|
+
|
|
96
|
+
Raises:
|
|
97
|
+
ParseError: If the SQL cannot be parsed
|
|
98
|
+
"""
|
|
99
|
+
self.sql = sql
|
|
100
|
+
self.dialect = dialect
|
|
101
|
+
self._skipped_queries: List[SkippedQuery] = []
|
|
102
|
+
# File-scoped schema context for cross-statement lineage
|
|
103
|
+
# Maps table/view names to their column definitions
|
|
104
|
+
self._file_schema: Dict[str, Dict[str, str]] = {}
|
|
105
|
+
|
|
106
|
+
try:
|
|
107
|
+
# Parse all statements in the SQL string
|
|
108
|
+
parsed = parse(sql, dialect=dialect)
|
|
109
|
+
|
|
110
|
+
# Filter out None values (can happen with empty statements or comments)
|
|
111
|
+
self.expressions: List[exp.Expression] = [
|
|
112
|
+
expr for expr in parsed if expr is not None
|
|
113
|
+
]
|
|
114
|
+
|
|
115
|
+
if not self.expressions:
|
|
116
|
+
raise ParseError("No valid SQL statements found")
|
|
117
|
+
|
|
118
|
+
# For backward compatibility, store first expression as self.expr
|
|
119
|
+
self.expr = self.expressions[0]
|
|
120
|
+
|
|
121
|
+
except ParseError as e:
|
|
122
|
+
raise ParseError(f"Invalid SQL syntax: {e}") from e
|
|
123
|
+
|
|
124
|
+
@property
|
|
125
|
+
def skipped_queries(self) -> List[SkippedQuery]:
|
|
126
|
+
"""Get list of queries that were skipped during analysis."""
|
|
127
|
+
return self._skipped_queries.copy()
|
|
128
|
+
|
|
129
|
+
def get_output_columns(self) -> List[str]:
|
|
130
|
+
"""
|
|
131
|
+
Extract all output column names from the query with full qualification.
|
|
132
|
+
|
|
133
|
+
For DML/DDL statements (INSERT, UPDATE, MERGE, CREATE TABLE AS, etc.),
|
|
134
|
+
returns the target table columns. For DQL (SELECT), returns the selected columns.
|
|
135
|
+
|
|
136
|
+
Returns:
|
|
137
|
+
List of fully qualified output column names (table.column or database.table.column)
|
|
138
|
+
|
|
139
|
+
Raises:
|
|
140
|
+
ValueError: If the statement type is not supported for lineage analysis
|
|
141
|
+
"""
|
|
142
|
+
columns = []
|
|
143
|
+
|
|
144
|
+
# Build mapping for qualified names
|
|
145
|
+
self._column_mapping = {} # Maps qualified name -> lineage column name
|
|
146
|
+
|
|
147
|
+
# Check if this is a DML/DDL statement
|
|
148
|
+
result = self._get_target_and_select()
|
|
149
|
+
if result is None:
|
|
150
|
+
# Unsupported statement type
|
|
151
|
+
stmt_type = self._get_statement_type()
|
|
152
|
+
raise ValueError(
|
|
153
|
+
f"Statement type '{stmt_type}' does not support lineage analysis"
|
|
154
|
+
)
|
|
155
|
+
|
|
156
|
+
target_table, select_node = result
|
|
157
|
+
|
|
158
|
+
if target_table:
|
|
159
|
+
# DML/DDL: Use target table for output column qualification
|
|
160
|
+
# The columns are from the SELECT, but qualified with the target table
|
|
161
|
+
projections = self._get_select_projections(select_node)
|
|
162
|
+
first_select = self._get_first_select(select_node)
|
|
163
|
+
|
|
164
|
+
for projection in projections:
|
|
165
|
+
# Handle SELECT * by resolving from file schema
|
|
166
|
+
if isinstance(projection, exp.Star):
|
|
167
|
+
if first_select:
|
|
168
|
+
star_columns = self._resolve_star_columns(first_select)
|
|
169
|
+
for star_col in star_columns:
|
|
170
|
+
qualified_name = f"{target_table}.{star_col}"
|
|
171
|
+
columns.append(qualified_name)
|
|
172
|
+
self._column_mapping[qualified_name] = star_col
|
|
173
|
+
if not columns:
|
|
174
|
+
# Fallback: can't resolve *, use * as column name
|
|
175
|
+
qualified_name = f"{target_table}.*"
|
|
176
|
+
columns.append(qualified_name)
|
|
177
|
+
self._column_mapping[qualified_name] = "*"
|
|
178
|
+
continue
|
|
179
|
+
|
|
180
|
+
# Get the underlying expression (unwrap alias if present)
|
|
181
|
+
if isinstance(projection, exp.Alias):
|
|
182
|
+
# For aliased columns, use the alias as the column name
|
|
183
|
+
column_name = projection.alias
|
|
184
|
+
lineage_name = column_name # SQLGlot lineage uses the alias
|
|
185
|
+
# Qualify with target table
|
|
186
|
+
qualified_name = f"{target_table}.{column_name}"
|
|
187
|
+
columns.append(qualified_name)
|
|
188
|
+
self._column_mapping[qualified_name] = lineage_name
|
|
189
|
+
elif isinstance(projection, exp.Column):
|
|
190
|
+
# Check if this is a table-qualified star (e.g., t.*)
|
|
191
|
+
if isinstance(projection.this, exp.Star):
|
|
192
|
+
source_table = projection.table
|
|
193
|
+
qualified_star_cols: List[str] = []
|
|
194
|
+
if source_table and first_select:
|
|
195
|
+
qualified_star_cols = self._resolve_qualified_star(
|
|
196
|
+
source_table, first_select
|
|
197
|
+
)
|
|
198
|
+
for col in qualified_star_cols:
|
|
199
|
+
qualified_name = f"{target_table}.{col}"
|
|
200
|
+
columns.append(qualified_name)
|
|
201
|
+
self._column_mapping[qualified_name] = col
|
|
202
|
+
if not qualified_star_cols:
|
|
203
|
+
# Fallback: can't resolve t.*, use * as column name
|
|
204
|
+
qualified_name = f"{target_table}.*"
|
|
205
|
+
columns.append(qualified_name)
|
|
206
|
+
self._column_mapping[qualified_name] = "*"
|
|
207
|
+
else:
|
|
208
|
+
column_name = projection.name
|
|
209
|
+
lineage_name = column_name
|
|
210
|
+
# Qualify with target table
|
|
211
|
+
qualified_name = f"{target_table}.{column_name}"
|
|
212
|
+
columns.append(qualified_name)
|
|
213
|
+
self._column_mapping[qualified_name] = lineage_name
|
|
214
|
+
else:
|
|
215
|
+
# For expressions, use the SQL representation
|
|
216
|
+
column_name = projection.sql(dialect=self.dialect)
|
|
217
|
+
lineage_name = column_name
|
|
218
|
+
# Qualify with target table
|
|
219
|
+
qualified_name = f"{target_table}.{column_name}"
|
|
220
|
+
columns.append(qualified_name)
|
|
221
|
+
self._column_mapping[qualified_name] = lineage_name
|
|
222
|
+
|
|
223
|
+
else:
|
|
224
|
+
# DQL (pure SELECT): Use the SELECT columns as output
|
|
225
|
+
projections = self._get_select_projections(select_node)
|
|
226
|
+
# Get the first SELECT for table resolution (handles UNION case)
|
|
227
|
+
first_select = self._get_first_select(select_node)
|
|
228
|
+
for projection in projections:
|
|
229
|
+
# Get the underlying expression (unwrap alias if present)
|
|
230
|
+
if isinstance(projection, exp.Alias):
|
|
231
|
+
source_expr = projection.this
|
|
232
|
+
column_name = projection.alias
|
|
233
|
+
lineage_name = column_name # SQLGlot lineage uses the alias
|
|
234
|
+
else:
|
|
235
|
+
source_expr = projection
|
|
236
|
+
column_name = None
|
|
237
|
+
lineage_name = None
|
|
238
|
+
|
|
239
|
+
# Try to extract fully qualified name
|
|
240
|
+
if isinstance(source_expr, exp.Column):
|
|
241
|
+
# Get table and column parts
|
|
242
|
+
table_name = source_expr.table
|
|
243
|
+
col_name = column_name or source_expr.name
|
|
244
|
+
|
|
245
|
+
if table_name and first_select:
|
|
246
|
+
# Resolve table reference (could be table, CTE, or subquery alias)
|
|
247
|
+
# This works at any nesting level because we're only looking at the immediate context
|
|
248
|
+
resolved_table = self._resolve_table_reference(
|
|
249
|
+
table_name, first_select
|
|
250
|
+
)
|
|
251
|
+
qualified_name = f"{resolved_table}.{col_name}"
|
|
252
|
+
columns.append(qualified_name)
|
|
253
|
+
# Map qualified name to what lineage expects
|
|
254
|
+
self._column_mapping[qualified_name] = lineage_name or col_name
|
|
255
|
+
elif first_select:
|
|
256
|
+
# No table qualifier - try to infer from FROM clause
|
|
257
|
+
# This handles "SELECT col FROM single_source" cases
|
|
258
|
+
inferred_table = self._infer_single_table_source(first_select)
|
|
259
|
+
if inferred_table:
|
|
260
|
+
qualified_name = f"{inferred_table}.{col_name}"
|
|
261
|
+
columns.append(qualified_name)
|
|
262
|
+
self._column_mapping[qualified_name] = (
|
|
263
|
+
lineage_name or col_name
|
|
264
|
+
)
|
|
265
|
+
else:
|
|
266
|
+
# Can't infer table, just use column name
|
|
267
|
+
columns.append(col_name)
|
|
268
|
+
self._column_mapping[col_name] = lineage_name or col_name
|
|
269
|
+
else:
|
|
270
|
+
# No SELECT found, just use column name
|
|
271
|
+
columns.append(col_name)
|
|
272
|
+
self._column_mapping[col_name] = lineage_name or col_name
|
|
273
|
+
else:
|
|
274
|
+
# For other expressions (literals, functions, etc.)
|
|
275
|
+
# Use the alias if available, otherwise the SQL representation
|
|
276
|
+
if column_name:
|
|
277
|
+
columns.append(column_name)
|
|
278
|
+
self._column_mapping[column_name] = column_name
|
|
279
|
+
else:
|
|
280
|
+
expr_str = source_expr.sql(dialect=self.dialect)
|
|
281
|
+
columns.append(expr_str)
|
|
282
|
+
self._column_mapping[expr_str] = expr_str
|
|
283
|
+
|
|
284
|
+
return columns
|
|
285
|
+
|
|
286
|
+
def _get_select_projections(self, node: exp.Expression) -> List[exp.Expression]:
|
|
287
|
+
"""
|
|
288
|
+
Get the SELECT projections from a SELECT or set operation node.
|
|
289
|
+
|
|
290
|
+
For set operations (UNION, INTERSECT, EXCEPT), returns projections from
|
|
291
|
+
the first branch since all branches must have the same number of columns
|
|
292
|
+
with compatible types.
|
|
293
|
+
|
|
294
|
+
Args:
|
|
295
|
+
node: A SELECT or set operation (UNION/INTERSECT/EXCEPT) expression
|
|
296
|
+
|
|
297
|
+
Returns:
|
|
298
|
+
List of projection expressions from the SELECT clause
|
|
299
|
+
"""
|
|
300
|
+
if isinstance(node, exp.Select):
|
|
301
|
+
return list(node.expressions)
|
|
302
|
+
elif isinstance(node, (exp.Union, exp.Intersect, exp.Except)):
|
|
303
|
+
# Recursively get from the left branch (could be nested set operations)
|
|
304
|
+
return self._get_select_projections(node.left)
|
|
305
|
+
return []
|
|
306
|
+
|
|
307
|
+
def _get_first_select(self, node: exp.Expression) -> Optional[exp.Select]:
|
|
308
|
+
"""
|
|
309
|
+
Get the first SELECT node from a SELECT or set operation expression.
|
|
310
|
+
|
|
311
|
+
For set operations (UNION, INTERSECT, EXCEPT), returns the leftmost
|
|
312
|
+
SELECT branch.
|
|
313
|
+
|
|
314
|
+
Args:
|
|
315
|
+
node: A SELECT or set operation (UNION/INTERSECT/EXCEPT) expression
|
|
316
|
+
|
|
317
|
+
Returns:
|
|
318
|
+
The first SELECT node, or None if not found
|
|
319
|
+
"""
|
|
320
|
+
if isinstance(node, exp.Select):
|
|
321
|
+
return node
|
|
322
|
+
elif isinstance(node, (exp.Union, exp.Intersect, exp.Except)):
|
|
323
|
+
return self._get_first_select(node.left)
|
|
324
|
+
return None
|
|
325
|
+
|
|
326
|
+
def analyze_queries(
|
|
327
|
+
self,
|
|
328
|
+
level: AnalysisLevel = AnalysisLevel.COLUMN,
|
|
329
|
+
column: Optional[str] = None,
|
|
330
|
+
source_column: Optional[str] = None,
|
|
331
|
+
table_filter: Optional[str] = None,
|
|
332
|
+
) -> List[QueryLineageResult]:
|
|
333
|
+
"""
|
|
334
|
+
Unified lineage analysis for single or multi-query files.
|
|
335
|
+
|
|
336
|
+
This method replaces all previous analysis methods (analyze_column_lineage,
|
|
337
|
+
analyze_reverse_lineage, analyze_table_lineage, analyze_all_queries, etc.)
|
|
338
|
+
with a single unified interface.
|
|
339
|
+
|
|
340
|
+
Args:
|
|
341
|
+
level: Analysis level ("column" or "table")
|
|
342
|
+
column: Target output column for forward lineage
|
|
343
|
+
source_column: Source column for reverse lineage (impact analysis)
|
|
344
|
+
table_filter: Filter queries to those referencing this table
|
|
345
|
+
|
|
346
|
+
Returns:
|
|
347
|
+
List of QueryLineageResult objects (one per query that matches filters)
|
|
348
|
+
|
|
349
|
+
Raises:
|
|
350
|
+
ValueError: If column or source_column is specified but not found
|
|
351
|
+
|
|
352
|
+
Examples:
|
|
353
|
+
# Forward lineage for all columns
|
|
354
|
+
results = analyzer.analyze_queries(level="column")
|
|
355
|
+
|
|
356
|
+
# Forward lineage for specific column
|
|
357
|
+
results = analyzer.analyze_queries(level="column", column="customers.id")
|
|
358
|
+
|
|
359
|
+
# Reverse lineage (impact analysis)
|
|
360
|
+
results = analyzer.analyze_queries(level="column", source_column="orders.customer_id")
|
|
361
|
+
|
|
362
|
+
# Table-level lineage
|
|
363
|
+
results = analyzer.analyze_queries(level="table")
|
|
364
|
+
|
|
365
|
+
# Filter by table (multi-query files)
|
|
366
|
+
results = analyzer.analyze_queries(table_filter="customers")
|
|
367
|
+
"""
|
|
368
|
+
results = []
|
|
369
|
+
self._skipped_queries = [] # Reset skipped queries for this analysis
|
|
370
|
+
self._file_schema = {} # Reset file schema for this analysis run
|
|
371
|
+
|
|
372
|
+
for query_index, expr, preview in self._iterate_queries(table_filter):
|
|
373
|
+
# Temporarily swap self.expr to analyze this query
|
|
374
|
+
original_expr = self.expr
|
|
375
|
+
self.expr = expr
|
|
376
|
+
|
|
377
|
+
try:
|
|
378
|
+
lineage_items: List[LineageItem] = []
|
|
379
|
+
|
|
380
|
+
if level == AnalysisLevel.COLUMN:
|
|
381
|
+
if source_column:
|
|
382
|
+
# Reverse lineage (impact analysis)
|
|
383
|
+
lineage_items = self._analyze_reverse_lineage_internal(
|
|
384
|
+
source_column
|
|
385
|
+
)
|
|
386
|
+
if not lineage_items:
|
|
387
|
+
# Source column not found in this query - skip it
|
|
388
|
+
continue
|
|
389
|
+
else:
|
|
390
|
+
# Forward lineage
|
|
391
|
+
lineage_items = self._analyze_column_lineage_internal(column)
|
|
392
|
+
if not lineage_items:
|
|
393
|
+
# Column not found in this query (if column was specified) - skip it
|
|
394
|
+
if column:
|
|
395
|
+
continue
|
|
396
|
+
else: # table
|
|
397
|
+
lineage_items = self._analyze_table_lineage_internal()
|
|
398
|
+
|
|
399
|
+
# Create query result
|
|
400
|
+
results.append(
|
|
401
|
+
QueryLineageResult(
|
|
402
|
+
metadata=QueryMetadata(
|
|
403
|
+
query_index=query_index,
|
|
404
|
+
query_preview=preview,
|
|
405
|
+
),
|
|
406
|
+
lineage_items=lineage_items,
|
|
407
|
+
level=level,
|
|
408
|
+
)
|
|
409
|
+
)
|
|
410
|
+
except ValueError as e:
|
|
411
|
+
# Unsupported statement type - track it and continue
|
|
412
|
+
stmt_type = self._get_statement_type(expr)
|
|
413
|
+
self._skipped_queries.append(
|
|
414
|
+
SkippedQuery(
|
|
415
|
+
query_index=query_index,
|
|
416
|
+
statement_type=stmt_type,
|
|
417
|
+
reason=str(e),
|
|
418
|
+
query_preview=preview,
|
|
419
|
+
)
|
|
420
|
+
)
|
|
421
|
+
finally:
|
|
422
|
+
# Extract schema from this statement AFTER analysis
|
|
423
|
+
# This builds up context for subsequent statements to use
|
|
424
|
+
self._extract_schema_from_statement(expr)
|
|
425
|
+
# Restore original expression
|
|
426
|
+
self.expr = original_expr
|
|
427
|
+
|
|
428
|
+
# Validate: if a specific column or source_column was specified and we got no results,
|
|
429
|
+
# raise ValueError to preserve backward compatibility
|
|
430
|
+
if not results:
|
|
431
|
+
if column:
|
|
432
|
+
raise ValueError(
|
|
433
|
+
f"Column '{column}' not found in any query. "
|
|
434
|
+
"Please check the column name and try again."
|
|
435
|
+
)
|
|
436
|
+
elif source_column:
|
|
437
|
+
raise ValueError(
|
|
438
|
+
f"Source column '{source_column}' not found in any query. "
|
|
439
|
+
"Please check the column name and try again."
|
|
440
|
+
)
|
|
441
|
+
|
|
442
|
+
return results
|
|
443
|
+
|
|
444
|
+
def analyze_tables(
|
|
445
|
+
self,
|
|
446
|
+
table_filter: Optional[str] = None,
|
|
447
|
+
) -> List[QueryTablesResult]:
|
|
448
|
+
"""
|
|
449
|
+
Analyze all tables involved in SQL queries.
|
|
450
|
+
|
|
451
|
+
This method extracts information about all tables referenced in the SQL,
|
|
452
|
+
including their usage (INPUT, OUTPUT, or BOTH) and object type (TABLE, VIEW,
|
|
453
|
+
CTE, or UNKNOWN).
|
|
454
|
+
|
|
455
|
+
Args:
|
|
456
|
+
table_filter: Filter queries to those referencing this table
|
|
457
|
+
|
|
458
|
+
Returns:
|
|
459
|
+
List of QueryTablesResult objects (one per query that matches filters)
|
|
460
|
+
|
|
461
|
+
Examples:
|
|
462
|
+
# Get all tables from SQL
|
|
463
|
+
results = analyzer.analyze_tables()
|
|
464
|
+
|
|
465
|
+
# Filter by table (multi-query files)
|
|
466
|
+
results = analyzer.analyze_tables(table_filter="customers")
|
|
467
|
+
"""
|
|
468
|
+
results = []
|
|
469
|
+
|
|
470
|
+
for query_index, expr, preview in self._iterate_queries(table_filter):
|
|
471
|
+
# Temporarily swap self.expr to analyze this query
|
|
472
|
+
original_expr = self.expr
|
|
473
|
+
self.expr = expr
|
|
474
|
+
|
|
475
|
+
try:
|
|
476
|
+
tables = self._extract_tables_from_query()
|
|
477
|
+
|
|
478
|
+
# Create query result
|
|
479
|
+
results.append(
|
|
480
|
+
QueryTablesResult(
|
|
481
|
+
metadata=QueryMetadata(
|
|
482
|
+
query_index=query_index,
|
|
483
|
+
query_preview=preview,
|
|
484
|
+
),
|
|
485
|
+
tables=tables,
|
|
486
|
+
)
|
|
487
|
+
)
|
|
488
|
+
finally:
|
|
489
|
+
# Restore original expression
|
|
490
|
+
self.expr = original_expr
|
|
491
|
+
|
|
492
|
+
return results
|
|
493
|
+
|
|
494
|
+
def _extract_tables_from_query(self) -> List[TableInfo]:
|
|
495
|
+
"""
|
|
496
|
+
Extract all tables from the current query with usage and type information.
|
|
497
|
+
|
|
498
|
+
Returns:
|
|
499
|
+
List of TableInfo objects for all tables in the query.
|
|
500
|
+
"""
|
|
501
|
+
# Track tables by name to consolidate INPUT/OUTPUT into BOTH
|
|
502
|
+
tables_dict: dict[str, TableInfo] = {}
|
|
503
|
+
|
|
504
|
+
# Extract CTEs first (they're INPUT only)
|
|
505
|
+
cte_names = self._extract_cte_names()
|
|
506
|
+
for cte_name in cte_names:
|
|
507
|
+
tables_dict[cte_name] = TableInfo(
|
|
508
|
+
name=cte_name,
|
|
509
|
+
usage=TableUsage.INPUT,
|
|
510
|
+
object_type=ObjectType.CTE,
|
|
511
|
+
)
|
|
512
|
+
|
|
513
|
+
# Determine target table and its type based on statement type
|
|
514
|
+
target_table, target_type = self._get_target_table_info()
|
|
515
|
+
|
|
516
|
+
# Get all table references in the query (except CTEs)
|
|
517
|
+
input_tables = self._get_all_input_tables(cte_names)
|
|
518
|
+
|
|
519
|
+
# Add target table as OUTPUT
|
|
520
|
+
if target_table:
|
|
521
|
+
if target_table in tables_dict:
|
|
522
|
+
# Table is both input and output (e.g., UPDATE with self-reference)
|
|
523
|
+
tables_dict[target_table] = TableInfo(
|
|
524
|
+
name=target_table,
|
|
525
|
+
usage=TableUsage.BOTH,
|
|
526
|
+
object_type=target_type,
|
|
527
|
+
)
|
|
528
|
+
else:
|
|
529
|
+
tables_dict[target_table] = TableInfo(
|
|
530
|
+
name=target_table,
|
|
531
|
+
usage=TableUsage.OUTPUT,
|
|
532
|
+
object_type=target_type,
|
|
533
|
+
)
|
|
534
|
+
|
|
535
|
+
# Add input tables
|
|
536
|
+
for table_name in input_tables:
|
|
537
|
+
if table_name in tables_dict:
|
|
538
|
+
# Already exists - might need to upgrade to BOTH
|
|
539
|
+
existing = tables_dict[table_name]
|
|
540
|
+
if existing.usage == TableUsage.OUTPUT:
|
|
541
|
+
tables_dict[table_name] = TableInfo(
|
|
542
|
+
name=table_name,
|
|
543
|
+
usage=TableUsage.BOTH,
|
|
544
|
+
object_type=existing.object_type,
|
|
545
|
+
)
|
|
546
|
+
# If INPUT or BOTH, keep as-is
|
|
547
|
+
else:
|
|
548
|
+
tables_dict[table_name] = TableInfo(
|
|
549
|
+
name=table_name,
|
|
550
|
+
usage=TableUsage.INPUT,
|
|
551
|
+
object_type=ObjectType.UNKNOWN,
|
|
552
|
+
)
|
|
553
|
+
|
|
554
|
+
# Return sorted list by name for consistent output
|
|
555
|
+
return sorted(tables_dict.values(), key=lambda t: t.name.lower())
|
|
556
|
+
|
|
557
|
+
def _extract_cte_names(self) -> Set[str]:
|
|
558
|
+
"""
|
|
559
|
+
Extract all CTE (Common Table Expression) names from the query.
|
|
560
|
+
|
|
561
|
+
Returns:
|
|
562
|
+
Set of CTE names defined in the WITH clause.
|
|
563
|
+
"""
|
|
564
|
+
cte_names: Set[str] = set()
|
|
565
|
+
|
|
566
|
+
# Look for WITH clause
|
|
567
|
+
if hasattr(self.expr, "args") and self.expr.args.get("with"):
|
|
568
|
+
with_clause = self.expr.args["with"]
|
|
569
|
+
for cte in with_clause.expressions:
|
|
570
|
+
if isinstance(cte, exp.CTE) and cte.alias:
|
|
571
|
+
cte_names.add(cte.alias)
|
|
572
|
+
|
|
573
|
+
return cte_names
|
|
574
|
+
|
|
575
|
+
def _get_target_table_info(self) -> Tuple[Optional[str], ObjectType]:
|
|
576
|
+
"""
|
|
577
|
+
Get the target table name and its object type for DML/DDL statements.
|
|
578
|
+
|
|
579
|
+
Returns:
|
|
580
|
+
Tuple of (target_table_name, object_type) or (None, UNKNOWN) for SELECT.
|
|
581
|
+
"""
|
|
582
|
+
# INSERT INTO table
|
|
583
|
+
if isinstance(self.expr, exp.Insert):
|
|
584
|
+
target = self.expr.this
|
|
585
|
+
if isinstance(target, exp.Table):
|
|
586
|
+
return (self._get_qualified_table_name(target), ObjectType.UNKNOWN)
|
|
587
|
+
|
|
588
|
+
# CREATE TABLE / CREATE VIEW
|
|
589
|
+
elif isinstance(self.expr, exp.Create):
|
|
590
|
+
kind = getattr(self.expr, "kind", "").upper()
|
|
591
|
+
target = self.expr.this
|
|
592
|
+
|
|
593
|
+
# Handle Schema wrapper (CREATE TABLE with columns)
|
|
594
|
+
if isinstance(target, exp.Schema):
|
|
595
|
+
target = target.this
|
|
596
|
+
|
|
597
|
+
if isinstance(target, exp.Table):
|
|
598
|
+
table_name = self._get_qualified_table_name(target)
|
|
599
|
+
if kind == "VIEW":
|
|
600
|
+
return (table_name, ObjectType.VIEW)
|
|
601
|
+
elif kind == "TABLE":
|
|
602
|
+
return (table_name, ObjectType.TABLE)
|
|
603
|
+
else:
|
|
604
|
+
return (table_name, ObjectType.UNKNOWN)
|
|
605
|
+
|
|
606
|
+
# UPDATE table
|
|
607
|
+
elif isinstance(self.expr, exp.Update):
|
|
608
|
+
target = self.expr.this
|
|
609
|
+
if isinstance(target, exp.Table):
|
|
610
|
+
return (self._get_qualified_table_name(target), ObjectType.UNKNOWN)
|
|
611
|
+
|
|
612
|
+
# MERGE INTO table
|
|
613
|
+
elif isinstance(self.expr, exp.Merge):
|
|
614
|
+
target = self.expr.this
|
|
615
|
+
if isinstance(target, exp.Table):
|
|
616
|
+
return (self._get_qualified_table_name(target), ObjectType.UNKNOWN)
|
|
617
|
+
|
|
618
|
+
# DELETE FROM table
|
|
619
|
+
elif isinstance(self.expr, exp.Delete):
|
|
620
|
+
target = self.expr.this
|
|
621
|
+
if isinstance(target, exp.Table):
|
|
622
|
+
return (self._get_qualified_table_name(target), ObjectType.UNKNOWN)
|
|
623
|
+
|
|
624
|
+
# DROP TABLE / DROP VIEW
|
|
625
|
+
elif isinstance(self.expr, exp.Drop):
|
|
626
|
+
kind = getattr(self.expr, "kind", "").upper()
|
|
627
|
+
target = self.expr.this
|
|
628
|
+
if isinstance(target, exp.Table):
|
|
629
|
+
table_name = self._get_qualified_table_name(target)
|
|
630
|
+
if kind == "VIEW":
|
|
631
|
+
return (table_name, ObjectType.VIEW)
|
|
632
|
+
elif kind == "TABLE":
|
|
633
|
+
return (table_name, ObjectType.TABLE)
|
|
634
|
+
else:
|
|
635
|
+
return (table_name, ObjectType.UNKNOWN)
|
|
636
|
+
|
|
637
|
+
# SELECT (no target table)
|
|
638
|
+
return (None, ObjectType.UNKNOWN)
|
|
639
|
+
|
|
640
|
+
def _get_all_input_tables(self, exclude_ctes: Set[str]) -> Set[str]:
|
|
641
|
+
"""
|
|
642
|
+
Get all tables used as input (FROM, JOIN, subqueries, etc.).
|
|
643
|
+
|
|
644
|
+
Args:
|
|
645
|
+
exclude_ctes: Set of CTE names to exclude from results.
|
|
646
|
+
|
|
647
|
+
Returns:
|
|
648
|
+
Set of fully qualified table names that are used as input.
|
|
649
|
+
"""
|
|
650
|
+
input_tables: Set[str] = set()
|
|
651
|
+
|
|
652
|
+
# Find all Table nodes in the expression tree
|
|
653
|
+
for table_node in self.expr.find_all(exp.Table):
|
|
654
|
+
table_name = self._get_qualified_table_name(table_node)
|
|
655
|
+
|
|
656
|
+
# Skip CTEs (they're tracked separately)
|
|
657
|
+
if table_name in exclude_ctes:
|
|
658
|
+
continue
|
|
659
|
+
|
|
660
|
+
# Skip the target table for certain statement types
|
|
661
|
+
# (it will be added separately as OUTPUT)
|
|
662
|
+
if self._is_target_table(table_node):
|
|
663
|
+
continue
|
|
664
|
+
|
|
665
|
+
input_tables.add(table_name)
|
|
666
|
+
|
|
667
|
+
return input_tables
|
|
668
|
+
|
|
669
|
+
def _is_target_table(self, table_node: exp.Table) -> bool:
|
|
670
|
+
"""
|
|
671
|
+
Check if a table node is the target of a DML/DDL statement.
|
|
672
|
+
|
|
673
|
+
This helps distinguish the target table (OUTPUT) from source tables (INPUT)
|
|
674
|
+
in statements like INSERT, UPDATE, MERGE, DELETE.
|
|
675
|
+
|
|
676
|
+
Args:
|
|
677
|
+
table_node: The table node to check.
|
|
678
|
+
|
|
679
|
+
Returns:
|
|
680
|
+
True if this is the target table, False otherwise.
|
|
681
|
+
"""
|
|
682
|
+
# For INSERT, the target is self.expr.this
|
|
683
|
+
if isinstance(self.expr, exp.Insert):
|
|
684
|
+
return table_node is self.expr.this
|
|
685
|
+
|
|
686
|
+
# For UPDATE, the target is self.expr.this
|
|
687
|
+
elif isinstance(self.expr, exp.Update):
|
|
688
|
+
return table_node is self.expr.this
|
|
689
|
+
|
|
690
|
+
# For MERGE, the target is self.expr.this
|
|
691
|
+
elif isinstance(self.expr, exp.Merge):
|
|
692
|
+
return table_node is self.expr.this
|
|
693
|
+
|
|
694
|
+
# For DELETE, the target is self.expr.this
|
|
695
|
+
elif isinstance(self.expr, exp.Delete):
|
|
696
|
+
return table_node is self.expr.this
|
|
697
|
+
|
|
698
|
+
# For CREATE TABLE/VIEW, check if it's in the schema
|
|
699
|
+
elif isinstance(self.expr, exp.Create):
|
|
700
|
+
target = self.expr.this
|
|
701
|
+
if isinstance(target, exp.Schema):
|
|
702
|
+
return table_node is target.this
|
|
703
|
+
return table_node is target
|
|
704
|
+
|
|
705
|
+
# For DROP, the target is self.expr.this
|
|
706
|
+
elif isinstance(self.expr, exp.Drop):
|
|
707
|
+
return table_node is self.expr.this
|
|
708
|
+
|
|
709
|
+
return False
|
|
710
|
+
|
|
711
|
+
def _analyze_column_lineage_internal(
|
|
712
|
+
self, column: Optional[str] = None
|
|
713
|
+
) -> List[LineageItem]:
|
|
714
|
+
"""
|
|
715
|
+
Internal method for analyzing column lineage. Returns flat list of LineageItem.
|
|
716
|
+
|
|
717
|
+
Args:
|
|
718
|
+
column: Optional specific column to analyze. If None, analyzes all columns.
|
|
719
|
+
|
|
720
|
+
Returns:
|
|
721
|
+
List of LineageItem objects (one per output-source relationship)
|
|
722
|
+
"""
|
|
723
|
+
output_columns = self.get_output_columns()
|
|
724
|
+
|
|
725
|
+
if column:
|
|
726
|
+
# Analyze only the specified column (case-insensitive matching)
|
|
727
|
+
matched_column = None
|
|
728
|
+
column_lower = column.lower()
|
|
729
|
+
for output_col in output_columns:
|
|
730
|
+
if output_col.lower() == column_lower:
|
|
731
|
+
matched_column = output_col
|
|
732
|
+
break
|
|
733
|
+
|
|
734
|
+
if matched_column is None:
|
|
735
|
+
# Column not found - return empty list (caller will skip this query)
|
|
736
|
+
return []
|
|
737
|
+
columns_to_analyze = [matched_column]
|
|
738
|
+
else:
|
|
739
|
+
# Analyze all columns
|
|
740
|
+
columns_to_analyze = output_columns
|
|
741
|
+
|
|
742
|
+
lineage_items = []
|
|
743
|
+
# Get SQL for current expression only (not full multi-query SQL)
|
|
744
|
+
current_query_sql = self.expr.sql(dialect=self.dialect)
|
|
745
|
+
|
|
746
|
+
for col in columns_to_analyze:
|
|
747
|
+
try:
|
|
748
|
+
# Get the column name that lineage expects
|
|
749
|
+
lineage_col = self._column_mapping.get(col, col)
|
|
750
|
+
|
|
751
|
+
# Get lineage tree for this column using current query SQL only
|
|
752
|
+
# Pass file schema to enable SELECT * expansion for known tables/views
|
|
753
|
+
node = lineage(
|
|
754
|
+
lineage_col,
|
|
755
|
+
current_query_sql,
|
|
756
|
+
dialect=self.dialect,
|
|
757
|
+
schema=self._file_schema if self._file_schema else None,
|
|
758
|
+
)
|
|
759
|
+
|
|
760
|
+
# Collect all source columns
|
|
761
|
+
sources: Set[str] = set()
|
|
762
|
+
self._collect_source_columns(node, sources)
|
|
763
|
+
|
|
764
|
+
# Convert to flat LineageItem list (one item per source)
|
|
765
|
+
for source in sorted(sources):
|
|
766
|
+
lineage_items.append(
|
|
767
|
+
LineageItem(output_name=col, source_name=source)
|
|
768
|
+
)
|
|
769
|
+
|
|
770
|
+
# If no sources found, add single item with empty source
|
|
771
|
+
if not sources:
|
|
772
|
+
lineage_items.append(LineageItem(output_name=col, source_name=""))
|
|
773
|
+
except Exception:
|
|
774
|
+
# If lineage fails for a column, add item with empty source
|
|
775
|
+
lineage_items.append(LineageItem(output_name=col, source_name=""))
|
|
776
|
+
|
|
777
|
+
return lineage_items
|
|
778
|
+
|
|
779
|
+
def _analyze_table_lineage_internal(self) -> List[LineageItem]:
|
|
780
|
+
"""
|
|
781
|
+
Internal method for analyzing table lineage. Returns flat list of LineageItem.
|
|
782
|
+
|
|
783
|
+
Returns:
|
|
784
|
+
List of LineageItem objects (one per output-source table relationship)
|
|
785
|
+
"""
|
|
786
|
+
source_tables: Set[str] = set()
|
|
787
|
+
|
|
788
|
+
# Find all Table nodes in the AST
|
|
789
|
+
for table_node in self.expr.find_all(exp.Table):
|
|
790
|
+
# Get fully qualified table name
|
|
791
|
+
table_name = table_node.sql(dialect=self.dialect)
|
|
792
|
+
source_tables.add(table_name)
|
|
793
|
+
|
|
794
|
+
# The output table would typically be defined in INSERT/CREATE statements
|
|
795
|
+
# For SELECT statements, we use a placeholder
|
|
796
|
+
output_table = "query_result"
|
|
797
|
+
|
|
798
|
+
# Convert to flat LineageItem list (one item per source table)
|
|
799
|
+
lineage_items = []
|
|
800
|
+
for source in sorted(source_tables):
|
|
801
|
+
lineage_items.append(
|
|
802
|
+
LineageItem(output_name=output_table, source_name=source)
|
|
803
|
+
)
|
|
804
|
+
|
|
805
|
+
return lineage_items
|
|
806
|
+
|
|
807
|
+
def _analyze_reverse_lineage_internal(
|
|
808
|
+
self, source_column: str
|
|
809
|
+
) -> List[LineageItem]:
|
|
810
|
+
"""
|
|
811
|
+
Internal method for analyzing reverse lineage. Returns flat list of LineageItem.
|
|
812
|
+
|
|
813
|
+
Args:
|
|
814
|
+
source_column: Source column to analyze (e.g., "orders.customer_id")
|
|
815
|
+
|
|
816
|
+
Returns:
|
|
817
|
+
List of LineageItem objects (source column -> affected outputs)
|
|
818
|
+
"""
|
|
819
|
+
# Step 1: Run forward lineage on all output columns
|
|
820
|
+
forward_items = self._analyze_column_lineage_internal(column=None)
|
|
821
|
+
|
|
822
|
+
# Step 2: Build reverse mapping (source -> [affected outputs])
|
|
823
|
+
reverse_map: dict[str, set[str]] = {}
|
|
824
|
+
all_outputs = set()
|
|
825
|
+
|
|
826
|
+
for item in forward_items:
|
|
827
|
+
all_outputs.add(item.output_name)
|
|
828
|
+
if item.source_name: # Skip empty sources
|
|
829
|
+
if item.source_name not in reverse_map:
|
|
830
|
+
reverse_map[item.source_name] = set()
|
|
831
|
+
reverse_map[item.source_name].add(item.output_name)
|
|
832
|
+
|
|
833
|
+
# Step 3: Find matching source (case-insensitive)
|
|
834
|
+
matched_source = None
|
|
835
|
+
affected_outputs = set()
|
|
836
|
+
source_column_lower = source_column.lower()
|
|
837
|
+
|
|
838
|
+
# First check if it's in reverse_map (derived columns)
|
|
839
|
+
for source in reverse_map.keys():
|
|
840
|
+
if source.lower() == source_column_lower:
|
|
841
|
+
matched_source = source
|
|
842
|
+
affected_outputs = reverse_map[matched_source]
|
|
843
|
+
break
|
|
844
|
+
|
|
845
|
+
# If not found, check if it's an output column (base table column)
|
|
846
|
+
if matched_source is None:
|
|
847
|
+
for output in all_outputs:
|
|
848
|
+
if output.lower() == source_column_lower:
|
|
849
|
+
matched_source = output
|
|
850
|
+
affected_outputs = {output} # It affects itself
|
|
851
|
+
break
|
|
852
|
+
|
|
853
|
+
if matched_source is None:
|
|
854
|
+
# Source column not found - return empty list (caller will skip this query)
|
|
855
|
+
return []
|
|
856
|
+
|
|
857
|
+
# Step 4: Return with semantic swap (source as output, affected as sources)
|
|
858
|
+
# This maintains the LineageItem structure where output_name is what we're looking at
|
|
859
|
+
# and source_name is what it affects
|
|
860
|
+
lineage_items = []
|
|
861
|
+
for affected in sorted(affected_outputs):
|
|
862
|
+
lineage_items.append(
|
|
863
|
+
LineageItem(output_name=matched_source, source_name=affected)
|
|
864
|
+
)
|
|
865
|
+
|
|
866
|
+
return lineage_items
|
|
867
|
+
|
|
868
|
+
def _get_statement_type(self, expr: Optional[exp.Expression] = None) -> str:
|
|
869
|
+
"""
|
|
870
|
+
Get a human-readable name for the SQL statement type.
|
|
871
|
+
|
|
872
|
+
Args:
|
|
873
|
+
expr: Expression to check (uses self.expr if not provided)
|
|
874
|
+
|
|
875
|
+
Returns:
|
|
876
|
+
Statement type name (e.g., "CREATE FUNCTION", "SELECT", "DELETE")
|
|
877
|
+
"""
|
|
878
|
+
target_expr = expr if expr is not None else self.expr
|
|
879
|
+
expr_type = type(target_expr).__name__
|
|
880
|
+
|
|
881
|
+
# Map common expression types to more readable names
|
|
882
|
+
type_map = {
|
|
883
|
+
"Select": "SELECT",
|
|
884
|
+
"Insert": "INSERT",
|
|
885
|
+
"Update": "UPDATE",
|
|
886
|
+
"Delete": "DELETE",
|
|
887
|
+
"Merge": "MERGE",
|
|
888
|
+
"Create": f"CREATE {getattr(target_expr, 'kind', '')}".strip(),
|
|
889
|
+
"Drop": f"DROP {getattr(target_expr, 'kind', '')}".strip(),
|
|
890
|
+
"Alter": "ALTER",
|
|
891
|
+
"Truncate": "TRUNCATE",
|
|
892
|
+
"Command": "COMMAND",
|
|
893
|
+
}
|
|
894
|
+
|
|
895
|
+
return type_map.get(expr_type, expr_type.upper())
|
|
896
|
+
|
|
897
|
+
def _get_target_and_select(
|
|
898
|
+
self,
|
|
899
|
+
) -> Optional[
|
|
900
|
+
tuple[Optional[str], Union[exp.Select, exp.Union, exp.Intersect, exp.Except]]
|
|
901
|
+
]:
|
|
902
|
+
"""
|
|
903
|
+
Detect if this is a DML/DDL statement and extract the target table and SELECT node.
|
|
904
|
+
|
|
905
|
+
Returns:
|
|
906
|
+
Tuple of (target_table_name, select_node) where:
|
|
907
|
+
- target_table_name is the fully qualified target table for DML/DDL, or None for pure SELECT
|
|
908
|
+
- select_node is the SELECT statement that provides the data
|
|
909
|
+
- Returns None if the statement type doesn't contain a SELECT (e.g., CREATE FUNCTION)
|
|
910
|
+
|
|
911
|
+
Handles:
|
|
912
|
+
- INSERT INTO table SELECT ...
|
|
913
|
+
- CREATE TABLE table AS SELECT ...
|
|
914
|
+
- MERGE INTO table ...
|
|
915
|
+
- UPDATE table SET ... FROM (SELECT ...)
|
|
916
|
+
- Pure SELECT (returns None as target)
|
|
917
|
+
"""
|
|
918
|
+
# Check for INSERT statement
|
|
919
|
+
if isinstance(self.expr, exp.Insert):
|
|
920
|
+
target = self.expr.this
|
|
921
|
+
if isinstance(target, exp.Table):
|
|
922
|
+
target_name = self._get_qualified_table_name(target)
|
|
923
|
+
# Find the SELECT within the INSERT (may be a set operation)
|
|
924
|
+
select_node = self.expr.expression
|
|
925
|
+
if isinstance(
|
|
926
|
+
select_node, (exp.Select, exp.Union, exp.Intersect, exp.Except)
|
|
927
|
+
):
|
|
928
|
+
return (target_name, select_node)
|
|
929
|
+
|
|
930
|
+
# Check for CREATE TABLE AS SELECT (CTAS) or CREATE VIEW AS SELECT
|
|
931
|
+
elif isinstance(self.expr, exp.Create):
|
|
932
|
+
if self.expr.kind in ("TABLE", "VIEW"):
|
|
933
|
+
target = self.expr.this
|
|
934
|
+
if isinstance(target, exp.Schema):
|
|
935
|
+
# Get the table from schema
|
|
936
|
+
target = target.this
|
|
937
|
+
if isinstance(target, exp.Table):
|
|
938
|
+
target_name = self._get_qualified_table_name(target)
|
|
939
|
+
# Find the SELECT in the expression (may be a set operation)
|
|
940
|
+
select_node = self.expr.expression
|
|
941
|
+
if isinstance(
|
|
942
|
+
select_node, (exp.Select, exp.Union, exp.Intersect, exp.Except)
|
|
943
|
+
):
|
|
944
|
+
return (target_name, select_node)
|
|
945
|
+
|
|
946
|
+
# Check for MERGE statement
|
|
947
|
+
elif isinstance(self.expr, exp.Merge):
|
|
948
|
+
target = self.expr.this
|
|
949
|
+
if isinstance(target, exp.Table):
|
|
950
|
+
target_name = self._get_qualified_table_name(target)
|
|
951
|
+
# For MERGE, we need to find the SELECT in the USING clause
|
|
952
|
+
# This is more complex, for now treat it as a SELECT
|
|
953
|
+
select_nodes = list(self.expr.find_all(exp.Select))
|
|
954
|
+
if select_nodes:
|
|
955
|
+
return (target_name, select_nodes[0])
|
|
956
|
+
|
|
957
|
+
# Check for UPDATE with subquery
|
|
958
|
+
elif isinstance(self.expr, exp.Update):
|
|
959
|
+
target = self.expr.this
|
|
960
|
+
if isinstance(target, exp.Table):
|
|
961
|
+
target_name = self._get_qualified_table_name(target)
|
|
962
|
+
# For UPDATE, find the SELECT if there is one
|
|
963
|
+
select_nodes = list(self.expr.find_all(exp.Select))
|
|
964
|
+
if select_nodes:
|
|
965
|
+
return (target_name, select_nodes[0])
|
|
966
|
+
|
|
967
|
+
# Default: Pure SELECT (DQL)
|
|
968
|
+
select_nodes = list(self.expr.find_all(exp.Select))
|
|
969
|
+
if select_nodes:
|
|
970
|
+
return (None, select_nodes[0])
|
|
971
|
+
|
|
972
|
+
# Fallback: return the expression as-is if it's a SELECT
|
|
973
|
+
if isinstance(self.expr, exp.Select):
|
|
974
|
+
return (None, self.expr)
|
|
975
|
+
|
|
976
|
+
# No SELECT found - return None to indicate unsupported statement
|
|
977
|
+
return None
|
|
978
|
+
|
|
979
|
+
def _get_qualified_table_name(self, table: exp.Table) -> str:
|
|
980
|
+
"""
|
|
981
|
+
Get the fully qualified name for a table.
|
|
982
|
+
|
|
983
|
+
Args:
|
|
984
|
+
table: SQLGlot Table expression
|
|
985
|
+
|
|
986
|
+
Returns:
|
|
987
|
+
Fully qualified table name (database.table or catalog.database.table)
|
|
988
|
+
"""
|
|
989
|
+
parts = []
|
|
990
|
+
if table.catalog:
|
|
991
|
+
parts.append(table.catalog)
|
|
992
|
+
if table.db:
|
|
993
|
+
parts.append(table.db)
|
|
994
|
+
parts.append(table.name)
|
|
995
|
+
return ".".join(parts)
|
|
996
|
+
|
|
997
|
+
def _resolve_table_reference(self, ref: str, select_node: exp.Select) -> str:
|
|
998
|
+
"""
|
|
999
|
+
Resolve a table reference (alias, CTE name, or actual table) to its canonical name.
|
|
1000
|
+
|
|
1001
|
+
This works at any nesting level by only looking at the immediate SELECT context.
|
|
1002
|
+
For CTEs and subqueries, returns their alias name (which is the "table name" in that context).
|
|
1003
|
+
For actual tables with aliases, returns the actual table name.
|
|
1004
|
+
|
|
1005
|
+
Args:
|
|
1006
|
+
ref: The table reference to resolve (could be alias, CTE name, or table name)
|
|
1007
|
+
select_node: The SELECT node containing the FROM/JOIN clauses
|
|
1008
|
+
|
|
1009
|
+
Returns:
|
|
1010
|
+
The canonical table name (actual table for real tables, alias for CTEs/subqueries)
|
|
1011
|
+
"""
|
|
1012
|
+
# Check if this is a CTE reference first
|
|
1013
|
+
# CTEs are defined in the WITH clause and referenced by their alias
|
|
1014
|
+
parent = select_node
|
|
1015
|
+
while parent:
|
|
1016
|
+
if isinstance(parent, (exp.Select, exp.Union)) and parent.args.get("with"):
|
|
1017
|
+
cte_node = parent.args["with"]
|
|
1018
|
+
for cte in cte_node.expressions:
|
|
1019
|
+
if isinstance(cte, exp.CTE) and cte.alias == ref:
|
|
1020
|
+
# This is a CTE - return the CTE alias as the "table name"
|
|
1021
|
+
return ref
|
|
1022
|
+
parent = parent.parent if hasattr(parent, "parent") else None
|
|
1023
|
+
|
|
1024
|
+
# Look for table references in FROM and JOIN clauses
|
|
1025
|
+
for table_ref in select_node.find_all(exp.Table):
|
|
1026
|
+
# Check if this table has the matching alias
|
|
1027
|
+
if table_ref.alias == ref:
|
|
1028
|
+
# Return the qualified table name
|
|
1029
|
+
parts = []
|
|
1030
|
+
if table_ref.db:
|
|
1031
|
+
parts.append(table_ref.db)
|
|
1032
|
+
if table_ref.catalog:
|
|
1033
|
+
parts.insert(0, table_ref.catalog)
|
|
1034
|
+
parts.append(table_ref.name)
|
|
1035
|
+
return ".".join(parts)
|
|
1036
|
+
# Also check if ref matches the table name directly (no alias case)
|
|
1037
|
+
elif table_ref.name == ref and not table_ref.alias:
|
|
1038
|
+
parts = []
|
|
1039
|
+
if table_ref.db:
|
|
1040
|
+
parts.append(table_ref.db)
|
|
1041
|
+
if table_ref.catalog:
|
|
1042
|
+
parts.insert(0, table_ref.catalog)
|
|
1043
|
+
parts.append(table_ref.name)
|
|
1044
|
+
return ".".join(parts)
|
|
1045
|
+
|
|
1046
|
+
# Check for subquery aliases in FROM clause
|
|
1047
|
+
if select_node.args.get("from"):
|
|
1048
|
+
from_clause = select_node.args["from"]
|
|
1049
|
+
if isinstance(from_clause, exp.From):
|
|
1050
|
+
source = from_clause.this
|
|
1051
|
+
# Check if it's a subquery with matching alias
|
|
1052
|
+
if isinstance(source, exp.Subquery) and source.alias == ref:
|
|
1053
|
+
# Return the subquery alias as the "table name"
|
|
1054
|
+
return ref
|
|
1055
|
+
# Check if it's a table with matching alias
|
|
1056
|
+
elif isinstance(source, exp.Table) and source.alias == ref:
|
|
1057
|
+
parts = []
|
|
1058
|
+
if source.db:
|
|
1059
|
+
parts.append(source.db)
|
|
1060
|
+
if source.catalog:
|
|
1061
|
+
parts.insert(0, source.catalog)
|
|
1062
|
+
parts.append(source.name)
|
|
1063
|
+
return ".".join(parts)
|
|
1064
|
+
|
|
1065
|
+
# Check JOIN clauses for subqueries
|
|
1066
|
+
for join in select_node.find_all(exp.Join):
|
|
1067
|
+
if isinstance(join.this, exp.Subquery) and join.this.alias == ref:
|
|
1068
|
+
return ref
|
|
1069
|
+
elif isinstance(join.this, exp.Table) and join.this.alias == ref:
|
|
1070
|
+
parts = []
|
|
1071
|
+
if join.this.db:
|
|
1072
|
+
parts.append(join.this.db)
|
|
1073
|
+
if join.this.catalog:
|
|
1074
|
+
parts.insert(0, join.this.catalog)
|
|
1075
|
+
parts.append(join.this.name)
|
|
1076
|
+
return ".".join(parts)
|
|
1077
|
+
|
|
1078
|
+
# If we can't resolve, return the reference as-is
|
|
1079
|
+
return ref
|
|
1080
|
+
|
|
1081
|
+
def _infer_single_table_source(self, select_node: exp.Select) -> Optional[str]:
|
|
1082
|
+
"""
|
|
1083
|
+
Infer the table name when there's only one table in FROM clause.
|
|
1084
|
+
|
|
1085
|
+
This handles cases like "SELECT col FROM table" where col has no table prefix.
|
|
1086
|
+
|
|
1087
|
+
Args:
|
|
1088
|
+
select_node: The SELECT node
|
|
1089
|
+
|
|
1090
|
+
Returns:
|
|
1091
|
+
The table name if there's exactly one source, None otherwise
|
|
1092
|
+
"""
|
|
1093
|
+
if not select_node.args.get("from"):
|
|
1094
|
+
return None
|
|
1095
|
+
|
|
1096
|
+
from_clause = select_node.args["from"]
|
|
1097
|
+
if not isinstance(from_clause, exp.From):
|
|
1098
|
+
return None
|
|
1099
|
+
|
|
1100
|
+
source = from_clause.this
|
|
1101
|
+
|
|
1102
|
+
# Check for JOINs - if there are joins, we can't infer
|
|
1103
|
+
if list(select_node.find_all(exp.Join)):
|
|
1104
|
+
return None
|
|
1105
|
+
|
|
1106
|
+
# Single table or CTE/subquery
|
|
1107
|
+
if isinstance(source, exp.Table):
|
|
1108
|
+
parts = []
|
|
1109
|
+
if source.db:
|
|
1110
|
+
parts.append(source.db)
|
|
1111
|
+
if source.catalog:
|
|
1112
|
+
parts.insert(0, source.catalog)
|
|
1113
|
+
if source.alias:
|
|
1114
|
+
# If the table has an alias, use the alias
|
|
1115
|
+
return source.alias
|
|
1116
|
+
parts.append(source.name)
|
|
1117
|
+
return ".".join(parts)
|
|
1118
|
+
elif isinstance(source, (exp.Subquery, exp.CTE)):
|
|
1119
|
+
# Return the subquery/CTE alias
|
|
1120
|
+
return source.alias if source.alias else None
|
|
1121
|
+
|
|
1122
|
+
return None
|
|
1123
|
+
|
|
1124
|
+
def _collect_source_columns(self, node: Node, sources: Set[str]) -> None:
|
|
1125
|
+
"""
|
|
1126
|
+
Recursively collect all source columns from a lineage tree.
|
|
1127
|
+
|
|
1128
|
+
This traverses the lineage tree depth-first, collecting leaf nodes
|
|
1129
|
+
which represent the actual source columns.
|
|
1130
|
+
|
|
1131
|
+
Args:
|
|
1132
|
+
node: The current lineage node
|
|
1133
|
+
sources: Set to accumulate source column names
|
|
1134
|
+
"""
|
|
1135
|
+
if not node.downstream:
|
|
1136
|
+
# Leaf node - this is a source column
|
|
1137
|
+
# Check if this is a literal value (SQLGlot uses position numbers for literals)
|
|
1138
|
+
if node.name.isdigit():
|
|
1139
|
+
# This is a literal - extract the actual value from the expression
|
|
1140
|
+
literal_repr = self._extract_literal_representation(node)
|
|
1141
|
+
sources.add(literal_repr)
|
|
1142
|
+
else:
|
|
1143
|
+
# SQLGlot's lineage provides qualified names, but may use aliases
|
|
1144
|
+
# Need to resolve aliases to actual table names
|
|
1145
|
+
qualified_name = self._resolve_source_column_alias(node.name)
|
|
1146
|
+
sources.add(qualified_name)
|
|
1147
|
+
else:
|
|
1148
|
+
# Traverse deeper into the tree
|
|
1149
|
+
for child in node.downstream:
|
|
1150
|
+
self._collect_source_columns(child, sources)
|
|
1151
|
+
|
|
1152
|
+
def _extract_literal_representation(self, node: Node) -> str:
|
|
1153
|
+
"""
|
|
1154
|
+
Extract a human-readable representation of a literal value from a lineage node.
|
|
1155
|
+
|
|
1156
|
+
When SQLGlot encounters a literal value in a UNION branch, it returns the
|
|
1157
|
+
column position as the node name. This method extracts the actual literal
|
|
1158
|
+
value from the node's expression.
|
|
1159
|
+
|
|
1160
|
+
Args:
|
|
1161
|
+
node: A lineage node where node.name is a digit (position number)
|
|
1162
|
+
|
|
1163
|
+
Returns:
|
|
1164
|
+
A string like "<literal: NULL>" or "<literal: 'value'>" or "<literal: 0>"
|
|
1165
|
+
"""
|
|
1166
|
+
try:
|
|
1167
|
+
expr = node.expression
|
|
1168
|
+
# The expression is typically an Alias wrapping the actual value
|
|
1169
|
+
if isinstance(expr, exp.Alias):
|
|
1170
|
+
literal_expr = expr.this
|
|
1171
|
+
literal_sql = literal_expr.sql(dialect=self.dialect)
|
|
1172
|
+
return f"<literal: {literal_sql}>"
|
|
1173
|
+
else:
|
|
1174
|
+
# Fallback: use the expression's SQL representation
|
|
1175
|
+
return f"<literal: {expr.sql(dialect=self.dialect)}>"
|
|
1176
|
+
except Exception:
|
|
1177
|
+
# If extraction fails, return a generic literal marker
|
|
1178
|
+
return "<literal>"
|
|
1179
|
+
|
|
1180
|
+
def _get_query_tables(self) -> List[str]:
|
|
1181
|
+
"""
|
|
1182
|
+
Get all table names referenced in the current query.
|
|
1183
|
+
|
|
1184
|
+
Returns:
|
|
1185
|
+
List of fully qualified table names used in the query
|
|
1186
|
+
"""
|
|
1187
|
+
tables = []
|
|
1188
|
+
for table_node in self.expr.find_all(exp.Table):
|
|
1189
|
+
table_name = self._get_qualified_table_name(table_node)
|
|
1190
|
+
tables.append(table_name)
|
|
1191
|
+
return tables
|
|
1192
|
+
|
|
1193
|
+
def _resolve_source_column_alias(self, column_name: str) -> str:
|
|
1194
|
+
"""
|
|
1195
|
+
Resolve table aliases in source column names.
|
|
1196
|
+
|
|
1197
|
+
This searches through ALL SELECT nodes in the query (including nested ones)
|
|
1198
|
+
to find and resolve table aliases, CTEs, and subqueries.
|
|
1199
|
+
|
|
1200
|
+
Args:
|
|
1201
|
+
column_name: Column name like "alias.column" or "table.column"
|
|
1202
|
+
|
|
1203
|
+
Returns:
|
|
1204
|
+
Fully qualified column name with actual table name
|
|
1205
|
+
"""
|
|
1206
|
+
# Parse the column name (format: table.column or db.table.column)
|
|
1207
|
+
parts = column_name.split(".")
|
|
1208
|
+
|
|
1209
|
+
if len(parts) < 2:
|
|
1210
|
+
# No table qualifier, return as-is
|
|
1211
|
+
return column_name
|
|
1212
|
+
|
|
1213
|
+
# The table part might be an alias, CTE name, or actual table
|
|
1214
|
+
table_part = parts[0] if len(parts) == 2 else parts[-2]
|
|
1215
|
+
column_part = parts[-1]
|
|
1216
|
+
|
|
1217
|
+
# Try to resolve by searching through ALL SELECT nodes (including nested)
|
|
1218
|
+
# This handles cases where the alias is defined deep in a subquery/CTE
|
|
1219
|
+
for select_node in self.expr.find_all(exp.Select):
|
|
1220
|
+
resolved = self._resolve_table_reference(table_part, select_node)
|
|
1221
|
+
# If resolution changed the name, we found it
|
|
1222
|
+
if resolved != table_part:
|
|
1223
|
+
# Reconstruct with resolved table name
|
|
1224
|
+
if len(parts) == 2:
|
|
1225
|
+
return f"{resolved}.{column_part}"
|
|
1226
|
+
else:
|
|
1227
|
+
# Has database part
|
|
1228
|
+
return f"{parts[0]}.{resolved}.{column_part}"
|
|
1229
|
+
|
|
1230
|
+
# If we couldn't resolve in any SELECT, return as-is
|
|
1231
|
+
return column_name
|
|
1232
|
+
|
|
1233
|
+
def _generate_query_preview(self, expr: exp.Expression) -> str:
|
|
1234
|
+
"""
|
|
1235
|
+
Generate a preview string for a query (first 100 chars, normalized).
|
|
1236
|
+
|
|
1237
|
+
Args:
|
|
1238
|
+
expr: The SQL expression to generate a preview for
|
|
1239
|
+
|
|
1240
|
+
Returns:
|
|
1241
|
+
Preview string (first 100 chars with "..." if truncated)
|
|
1242
|
+
"""
|
|
1243
|
+
query_text = expr.sql(dialect=self.dialect)
|
|
1244
|
+
preview = " ".join(query_text.split())[:100]
|
|
1245
|
+
if len(" ".join(query_text.split())) > 100:
|
|
1246
|
+
preview += "..."
|
|
1247
|
+
return preview
|
|
1248
|
+
|
|
1249
|
+
def _filter_by_table(self, expr: exp.Expression, table_filter: str) -> bool:
|
|
1250
|
+
"""
|
|
1251
|
+
Check if a query references a specific table.
|
|
1252
|
+
|
|
1253
|
+
Args:
|
|
1254
|
+
expr: The SQL expression to check
|
|
1255
|
+
table_filter: Table name to filter by (case-insensitive partial match)
|
|
1256
|
+
|
|
1257
|
+
Returns:
|
|
1258
|
+
True if the query references the table, False otherwise
|
|
1259
|
+
"""
|
|
1260
|
+
# Temporarily swap self.expr to analyze this expression
|
|
1261
|
+
original_expr = self.expr
|
|
1262
|
+
self.expr = expr
|
|
1263
|
+
try:
|
|
1264
|
+
query_tables = self._get_query_tables()
|
|
1265
|
+
table_filter_lower = table_filter.lower()
|
|
1266
|
+
return any(table_filter_lower in table.lower() for table in query_tables)
|
|
1267
|
+
finally:
|
|
1268
|
+
self.expr = original_expr
|
|
1269
|
+
|
|
1270
|
+
def _iterate_queries(
|
|
1271
|
+
self, table_filter: Optional[str] = None
|
|
1272
|
+
) -> Iterator[Tuple[int, exp.Expression, str]]:
|
|
1273
|
+
"""
|
|
1274
|
+
Iterate over queries with filtering and preview generation.
|
|
1275
|
+
|
|
1276
|
+
Args:
|
|
1277
|
+
table_filter: Optional table name to filter queries by
|
|
1278
|
+
|
|
1279
|
+
Yields:
|
|
1280
|
+
Tuple of (query_index, expression, query_preview)
|
|
1281
|
+
"""
|
|
1282
|
+
for idx, expr in enumerate(self.expressions):
|
|
1283
|
+
# Apply table filter
|
|
1284
|
+
if table_filter and not self._filter_by_table(expr, table_filter):
|
|
1285
|
+
continue
|
|
1286
|
+
|
|
1287
|
+
# Generate preview
|
|
1288
|
+
preview = self._generate_query_preview(expr)
|
|
1289
|
+
|
|
1290
|
+
yield idx, expr, preview
|
|
1291
|
+
|
|
1292
|
+
# -------------------------------------------------------------------------
|
|
1293
|
+
# File-scoped schema context methods
|
|
1294
|
+
# -------------------------------------------------------------------------
|
|
1295
|
+
|
|
1296
|
+
def _extract_schema_from_statement(self, expr: exp.Expression) -> None:
|
|
1297
|
+
"""
|
|
1298
|
+
Extract column definitions from CREATE VIEW/TABLE AS SELECT statements.
|
|
1299
|
+
|
|
1300
|
+
This method builds up file-scoped schema context as statements are processed,
|
|
1301
|
+
enabling SQLGlot to correctly expand SELECT * and trace cross-statement references.
|
|
1302
|
+
|
|
1303
|
+
Args:
|
|
1304
|
+
expr: The SQL expression to extract schema from
|
|
1305
|
+
"""
|
|
1306
|
+
# Only handle CREATE VIEW or CREATE TABLE (AS SELECT)
|
|
1307
|
+
if not isinstance(expr, exp.Create):
|
|
1308
|
+
return
|
|
1309
|
+
if expr.kind not in ("VIEW", "TABLE"):
|
|
1310
|
+
return
|
|
1311
|
+
|
|
1312
|
+
# Get target table/view name
|
|
1313
|
+
target = expr.this
|
|
1314
|
+
if isinstance(target, exp.Schema):
|
|
1315
|
+
target = target.this
|
|
1316
|
+
if not isinstance(target, exp.Table):
|
|
1317
|
+
return
|
|
1318
|
+
|
|
1319
|
+
target_name = self._get_qualified_table_name(target)
|
|
1320
|
+
|
|
1321
|
+
# Get the SELECT node from the CREATE statement
|
|
1322
|
+
select_node = expr.expression
|
|
1323
|
+
if select_node is None:
|
|
1324
|
+
return
|
|
1325
|
+
|
|
1326
|
+
# Handle Subquery wrapper (e.g., CREATE VIEW AS (SELECT ...))
|
|
1327
|
+
if isinstance(select_node, exp.Subquery):
|
|
1328
|
+
select_node = select_node.this
|
|
1329
|
+
|
|
1330
|
+
if not isinstance(
|
|
1331
|
+
select_node, (exp.Select, exp.Union, exp.Intersect, exp.Except)
|
|
1332
|
+
):
|
|
1333
|
+
return
|
|
1334
|
+
|
|
1335
|
+
# Extract column names from the SELECT
|
|
1336
|
+
columns = self._extract_columns_from_select(select_node)
|
|
1337
|
+
|
|
1338
|
+
if columns:
|
|
1339
|
+
# Store with UNKNOWN type - SQLGlot only needs column names for expansion
|
|
1340
|
+
self._file_schema[target_name] = {col: "UNKNOWN" for col in columns}
|
|
1341
|
+
|
|
1342
|
+
def _extract_columns_from_select(
|
|
1343
|
+
self, select_node: Union[exp.Select, exp.Union, exp.Intersect, exp.Except]
|
|
1344
|
+
) -> List[str]:
|
|
1345
|
+
"""
|
|
1346
|
+
Extract column names from a SELECT statement.
|
|
1347
|
+
|
|
1348
|
+
Handles aliases, direct column references, and SELECT * by resolving
|
|
1349
|
+
against the known file schema.
|
|
1350
|
+
|
|
1351
|
+
Args:
|
|
1352
|
+
select_node: The SELECT or set operation expression
|
|
1353
|
+
|
|
1354
|
+
Returns:
|
|
1355
|
+
List of column names
|
|
1356
|
+
"""
|
|
1357
|
+
columns: List[str] = []
|
|
1358
|
+
|
|
1359
|
+
# Get projections (for UNION, use first branch)
|
|
1360
|
+
projections = self._get_select_projections(select_node)
|
|
1361
|
+
first_select = self._get_first_select(select_node)
|
|
1362
|
+
|
|
1363
|
+
for projection in projections:
|
|
1364
|
+
if isinstance(projection, exp.Alias):
|
|
1365
|
+
# Use the alias name as the column name
|
|
1366
|
+
columns.append(projection.alias)
|
|
1367
|
+
elif isinstance(projection, exp.Column):
|
|
1368
|
+
# Check if this is a table-qualified star (e.g., t.*)
|
|
1369
|
+
if isinstance(projection.this, exp.Star):
|
|
1370
|
+
# Resolve table-qualified star from known schema
|
|
1371
|
+
table_name = projection.table
|
|
1372
|
+
if table_name and first_select:
|
|
1373
|
+
qualified_star_cols = self._resolve_qualified_star(
|
|
1374
|
+
table_name, first_select
|
|
1375
|
+
)
|
|
1376
|
+
columns.extend(qualified_star_cols)
|
|
1377
|
+
else:
|
|
1378
|
+
# Use the column name
|
|
1379
|
+
columns.append(projection.name)
|
|
1380
|
+
elif isinstance(projection, exp.Star):
|
|
1381
|
+
# Resolve SELECT * from known schema
|
|
1382
|
+
if first_select:
|
|
1383
|
+
star_columns = self._resolve_star_columns(first_select)
|
|
1384
|
+
columns.extend(star_columns)
|
|
1385
|
+
else:
|
|
1386
|
+
# For expressions without alias, use SQL representation
|
|
1387
|
+
col_sql = projection.sql(dialect=self.dialect)
|
|
1388
|
+
columns.append(col_sql)
|
|
1389
|
+
|
|
1390
|
+
return columns
|
|
1391
|
+
|
|
1392
|
+
def _resolve_star_columns(self, select_node: exp.Select) -> List[str]:
|
|
1393
|
+
"""
|
|
1394
|
+
Resolve SELECT * to actual column names from known file schema or CTEs.
|
|
1395
|
+
|
|
1396
|
+
Args:
|
|
1397
|
+
select_node: The SELECT node containing the * reference
|
|
1398
|
+
|
|
1399
|
+
Returns:
|
|
1400
|
+
List of column names if source is known, empty list otherwise
|
|
1401
|
+
"""
|
|
1402
|
+
columns: List[str] = []
|
|
1403
|
+
|
|
1404
|
+
# Get the source table(s) from FROM clause
|
|
1405
|
+
from_clause = select_node.args.get("from")
|
|
1406
|
+
if not from_clause or not isinstance(from_clause, exp.From):
|
|
1407
|
+
return columns
|
|
1408
|
+
|
|
1409
|
+
source = from_clause.this
|
|
1410
|
+
|
|
1411
|
+
# Handle table reference from FROM clause
|
|
1412
|
+
columns.extend(self._resolve_source_columns(source, select_node))
|
|
1413
|
+
|
|
1414
|
+
# Handle JOIN clauses - collect columns from all joined tables
|
|
1415
|
+
# EXCEPT for SEMI and ANTI joins which only return left table columns
|
|
1416
|
+
joins = select_node.args.get("joins")
|
|
1417
|
+
if joins:
|
|
1418
|
+
for join in joins:
|
|
1419
|
+
if isinstance(join, exp.Join):
|
|
1420
|
+
# SEMI and ANTI joins don't include right table columns in SELECT *
|
|
1421
|
+
join_kind = join.kind
|
|
1422
|
+
if join_kind in ("SEMI", "ANTI"):
|
|
1423
|
+
# Skip right table columns for SEMI/ANTI joins
|
|
1424
|
+
continue
|
|
1425
|
+
join_source = join.this
|
|
1426
|
+
columns.extend(
|
|
1427
|
+
self._resolve_source_columns(join_source, select_node)
|
|
1428
|
+
)
|
|
1429
|
+
|
|
1430
|
+
# Handle LATERAL VIEW clauses - collect generated columns
|
|
1431
|
+
laterals = select_node.args.get("laterals")
|
|
1432
|
+
if laterals:
|
|
1433
|
+
for lateral in laterals:
|
|
1434
|
+
if isinstance(lateral, exp.Lateral):
|
|
1435
|
+
lateral_cols = self._resolve_lateral_columns(lateral)
|
|
1436
|
+
columns.extend(lateral_cols)
|
|
1437
|
+
|
|
1438
|
+
return columns
|
|
1439
|
+
|
|
1440
|
+
def _resolve_lateral_columns(self, lateral: exp.Lateral) -> List[str]:
|
|
1441
|
+
"""
|
|
1442
|
+
Extract generated column names from a LATERAL VIEW clause.
|
|
1443
|
+
|
|
1444
|
+
Args:
|
|
1445
|
+
lateral: The Lateral expression node
|
|
1446
|
+
|
|
1447
|
+
Returns:
|
|
1448
|
+
List of generated column names (e.g., ['elem'] for explode,
|
|
1449
|
+
['pos', 'elem'] for posexplode)
|
|
1450
|
+
"""
|
|
1451
|
+
# Use SQLGlot's built-in property to get alias column names
|
|
1452
|
+
return lateral.alias_column_names or []
|
|
1453
|
+
|
|
1454
|
+
def _resolve_source_columns(
|
|
1455
|
+
self, source: exp.Expression, select_node: exp.Select
|
|
1456
|
+
) -> List[str]:
|
|
1457
|
+
"""
|
|
1458
|
+
Resolve columns from a single source (table, subquery, etc.).
|
|
1459
|
+
|
|
1460
|
+
Args:
|
|
1461
|
+
source: The source expression (Table, Subquery, etc.)
|
|
1462
|
+
select_node: The containing SELECT node for CTE resolution
|
|
1463
|
+
|
|
1464
|
+
Returns:
|
|
1465
|
+
List of column names from the source
|
|
1466
|
+
"""
|
|
1467
|
+
columns: List[str] = []
|
|
1468
|
+
|
|
1469
|
+
# Handle table reference
|
|
1470
|
+
if isinstance(source, exp.Table):
|
|
1471
|
+
source_name = self._get_qualified_table_name(source)
|
|
1472
|
+
|
|
1473
|
+
# First check file schema (views/tables from previous statements)
|
|
1474
|
+
if source_name in self._file_schema:
|
|
1475
|
+
columns.extend(self._file_schema[source_name].keys())
|
|
1476
|
+
else:
|
|
1477
|
+
# Check if this is a CTE reference within the same statement
|
|
1478
|
+
cte_columns = self._resolve_cte_columns(source_name, select_node)
|
|
1479
|
+
columns.extend(cte_columns)
|
|
1480
|
+
|
|
1481
|
+
# Handle subquery with alias
|
|
1482
|
+
elif isinstance(source, exp.Subquery):
|
|
1483
|
+
# First check if this subquery alias is in file schema
|
|
1484
|
+
if source.alias and source.alias in self._file_schema:
|
|
1485
|
+
columns.extend(self._file_schema[source.alias].keys())
|
|
1486
|
+
else:
|
|
1487
|
+
# Extract columns from the subquery's SELECT
|
|
1488
|
+
inner_select = source.this
|
|
1489
|
+
if isinstance(inner_select, exp.Select):
|
|
1490
|
+
subquery_cols = self._extract_subquery_columns(inner_select)
|
|
1491
|
+
columns.extend(subquery_cols)
|
|
1492
|
+
|
|
1493
|
+
return columns
|
|
1494
|
+
|
|
1495
|
+
def _resolve_qualified_star(
|
|
1496
|
+
self, table_name: str, select_node: exp.Select
|
|
1497
|
+
) -> List[str]:
|
|
1498
|
+
"""
|
|
1499
|
+
Resolve a table-qualified star (e.g., t.*) to actual column names.
|
|
1500
|
+
|
|
1501
|
+
Args:
|
|
1502
|
+
table_name: The table/alias name qualifying the star
|
|
1503
|
+
select_node: The SELECT node for context
|
|
1504
|
+
|
|
1505
|
+
Returns:
|
|
1506
|
+
List of column names from the specified table
|
|
1507
|
+
"""
|
|
1508
|
+
# First check file schema
|
|
1509
|
+
if table_name in self._file_schema:
|
|
1510
|
+
return list(self._file_schema[table_name].keys())
|
|
1511
|
+
|
|
1512
|
+
# Check if it's a CTE reference
|
|
1513
|
+
cte_columns = self._resolve_cte_columns(table_name, select_node)
|
|
1514
|
+
if cte_columns:
|
|
1515
|
+
return cte_columns
|
|
1516
|
+
|
|
1517
|
+
# Check if the table name is an alias - need to resolve the actual table
|
|
1518
|
+
from_clause = select_node.args.get("from")
|
|
1519
|
+
if from_clause and isinstance(from_clause, exp.From):
|
|
1520
|
+
source = from_clause.this
|
|
1521
|
+
if isinstance(source, exp.Table) and source.alias == table_name:
|
|
1522
|
+
actual_name = self._get_qualified_table_name(source)
|
|
1523
|
+
if actual_name in self._file_schema:
|
|
1524
|
+
return list(self._file_schema[actual_name].keys())
|
|
1525
|
+
|
|
1526
|
+
# Check JOIN clauses for aliased tables
|
|
1527
|
+
joins = select_node.args.get("joins")
|
|
1528
|
+
if joins:
|
|
1529
|
+
for join in joins:
|
|
1530
|
+
if isinstance(join, exp.Join):
|
|
1531
|
+
join_source = join.this
|
|
1532
|
+
if (
|
|
1533
|
+
isinstance(join_source, exp.Table)
|
|
1534
|
+
and join_source.alias == table_name
|
|
1535
|
+
):
|
|
1536
|
+
actual_name = self._get_qualified_table_name(join_source)
|
|
1537
|
+
if actual_name in self._file_schema:
|
|
1538
|
+
return list(self._file_schema[actual_name].keys())
|
|
1539
|
+
|
|
1540
|
+
return []
|
|
1541
|
+
|
|
1542
|
+
def _extract_subquery_columns(self, subquery_select: exp.Select) -> List[str]:
|
|
1543
|
+
"""
|
|
1544
|
+
Extract column names from a subquery's SELECT statement.
|
|
1545
|
+
|
|
1546
|
+
Args:
|
|
1547
|
+
subquery_select: The SELECT expression within the subquery
|
|
1548
|
+
|
|
1549
|
+
Returns:
|
|
1550
|
+
List of column names
|
|
1551
|
+
"""
|
|
1552
|
+
columns: List[str] = []
|
|
1553
|
+
|
|
1554
|
+
for projection in subquery_select.expressions:
|
|
1555
|
+
if isinstance(projection, exp.Alias):
|
|
1556
|
+
columns.append(projection.alias)
|
|
1557
|
+
elif isinstance(projection, exp.Column):
|
|
1558
|
+
# Check for table-qualified star (t.*)
|
|
1559
|
+
if isinstance(projection.this, exp.Star):
|
|
1560
|
+
table_name = projection.table
|
|
1561
|
+
if table_name:
|
|
1562
|
+
qualified_cols = self._resolve_qualified_star(
|
|
1563
|
+
table_name, subquery_select
|
|
1564
|
+
)
|
|
1565
|
+
columns.extend(qualified_cols)
|
|
1566
|
+
else:
|
|
1567
|
+
columns.append(projection.name)
|
|
1568
|
+
elif isinstance(projection, exp.Star):
|
|
1569
|
+
# Resolve SELECT * in subquery
|
|
1570
|
+
star_columns = self._resolve_star_columns(subquery_select)
|
|
1571
|
+
columns.extend(star_columns)
|
|
1572
|
+
else:
|
|
1573
|
+
col_sql = projection.sql(dialect=self.dialect)
|
|
1574
|
+
columns.append(col_sql)
|
|
1575
|
+
|
|
1576
|
+
return columns
|
|
1577
|
+
|
|
1578
|
+
def _resolve_cte_columns(self, cte_name: str, select_node: exp.Select) -> List[str]:
|
|
1579
|
+
"""
|
|
1580
|
+
Resolve columns from a CTE definition within the same statement.
|
|
1581
|
+
|
|
1582
|
+
Args:
|
|
1583
|
+
cte_name: Name of the CTE to resolve
|
|
1584
|
+
select_node: The SELECT node that references the CTE
|
|
1585
|
+
|
|
1586
|
+
Returns:
|
|
1587
|
+
List of column names from the CTE, empty if CTE not found
|
|
1588
|
+
"""
|
|
1589
|
+
# Walk up the tree to find the WITH clause containing this CTE
|
|
1590
|
+
parent = select_node
|
|
1591
|
+
while parent:
|
|
1592
|
+
if hasattr(parent, "args") and parent.args.get("with"):
|
|
1593
|
+
with_clause = parent.args["with"]
|
|
1594
|
+
for cte in with_clause.expressions:
|
|
1595
|
+
if isinstance(cte, exp.CTE) and cte.alias == cte_name:
|
|
1596
|
+
# Found the CTE - extract its columns
|
|
1597
|
+
cte_select = cte.this
|
|
1598
|
+
if isinstance(cte_select, exp.Select):
|
|
1599
|
+
return self._extract_cte_select_columns(cte_select)
|
|
1600
|
+
parent = parent.parent if hasattr(parent, "parent") else None
|
|
1601
|
+
|
|
1602
|
+
return []
|
|
1603
|
+
|
|
1604
|
+
def _extract_cte_select_columns(self, cte_select: exp.Select) -> List[str]:
|
|
1605
|
+
"""
|
|
1606
|
+
Extract column names from a CTE's SELECT statement.
|
|
1607
|
+
|
|
1608
|
+
This handles SELECT * within the CTE by resolving against file schema.
|
|
1609
|
+
|
|
1610
|
+
Args:
|
|
1611
|
+
cte_select: The SELECT expression within the CTE
|
|
1612
|
+
|
|
1613
|
+
Returns:
|
|
1614
|
+
List of column names
|
|
1615
|
+
"""
|
|
1616
|
+
columns: List[str] = []
|
|
1617
|
+
|
|
1618
|
+
for projection in cte_select.expressions:
|
|
1619
|
+
if isinstance(projection, exp.Alias):
|
|
1620
|
+
columns.append(projection.alias)
|
|
1621
|
+
elif isinstance(projection, exp.Column):
|
|
1622
|
+
columns.append(projection.name)
|
|
1623
|
+
elif isinstance(projection, exp.Star):
|
|
1624
|
+
# Resolve SELECT * in CTE from file schema
|
|
1625
|
+
star_columns = self._resolve_star_columns(cte_select)
|
|
1626
|
+
columns.extend(star_columns)
|
|
1627
|
+
else:
|
|
1628
|
+
col_sql = projection.sql(dialect=self.dialect)
|
|
1629
|
+
columns.append(col_sql)
|
|
1630
|
+
|
|
1631
|
+
return columns
|