sqlprism 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sqlprism/__init__.py +1 -0
- sqlprism/cli.py +625 -0
- sqlprism/core/__init__.py +0 -0
- sqlprism/core/graph.py +1547 -0
- sqlprism/core/indexer.py +677 -0
- sqlprism/core/mcp_tools.py +982 -0
- sqlprism/languages/__init__.py +28 -0
- sqlprism/languages/dbt.py +199 -0
- sqlprism/languages/sql.py +1031 -0
- sqlprism/languages/sqlmesh.py +203 -0
- sqlprism/languages/utils.py +73 -0
- sqlprism/types.py +190 -0
- sqlprism-1.0.0.dist-info/METADATA +429 -0
- sqlprism-1.0.0.dist-info/RECORD +17 -0
- sqlprism-1.0.0.dist-info/WHEEL +4 -0
- sqlprism-1.0.0.dist-info/entry_points.txt +2 -0
- sqlprism-1.0.0.dist-info/licenses/LICENSE +190 -0
|
@@ -0,0 +1,1031 @@
|
|
|
1
|
+
"""SQL parser using sqlglot.
|
|
2
|
+
|
|
3
|
+
This is the richest parser in the system. sqlglot provides semantic analysis
|
|
4
|
+
beyond what tree-sitter can offer for SQL: CTE scope tracking, column-level
|
|
5
|
+
lineage via the Scope module, multi-dialect awareness, and proper resolution
|
|
6
|
+
of aliased references.
|
|
7
|
+
|
|
8
|
+
CTEs are tracked as first-class nodes, not flattened into the parent query.
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
from pathlib import Path
|
|
12
|
+
|
|
13
|
+
import sqlglot
|
|
14
|
+
from sqlglot import exp
|
|
15
|
+
from sqlglot.lineage import lineage as sqlglot_lineage
|
|
16
|
+
from sqlglot.optimizer.qualify_columns import qualify_columns
|
|
17
|
+
from sqlglot.optimizer.scope import build_scope
|
|
18
|
+
|
|
19
|
+
from sqlprism.types import (
|
|
20
|
+
ColumnLineageResult,
|
|
21
|
+
ColumnUsageResult,
|
|
22
|
+
EdgeResult,
|
|
23
|
+
LineageHop,
|
|
24
|
+
NodeResult,
|
|
25
|
+
ParseResult,
|
|
26
|
+
)
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class SqlParser:
|
|
30
|
+
"""Parses SQL files into nodes, edges, column usage, and column lineage using sqlglot.
|
|
31
|
+
|
|
32
|
+
Handles multi-statement files, CTE extraction, column-level scope analysis,
|
|
33
|
+
transform detection, and end-to-end column lineage tracing. Dialect-aware
|
|
34
|
+
identifier normalisation ensures consistent casing across Postgres, Snowflake,
|
|
35
|
+
DuckDB, and other engines.
|
|
36
|
+
"""
|
|
37
|
+
|
|
38
|
+
# Dialects that fold unquoted identifiers to lowercase
|
|
39
|
+
_LOWERCASE_DIALECTS = frozenset({"postgres", "postgresql", "redshift", "duckdb"})
|
|
40
|
+
# Dialects that fold unquoted identifiers to uppercase
|
|
41
|
+
_UPPERCASE_DIALECTS = frozenset({"snowflake", "oracle", "db2"})
|
|
42
|
+
|
|
43
|
+
def __init__(self, dialect: str | None = None):
|
|
44
|
+
"""Initialise with an optional SQL dialect.
|
|
45
|
+
|
|
46
|
+
Args:
|
|
47
|
+
dialect: sqlglot dialect string (e.g., 'postgres', 'mysql', 'duckdb').
|
|
48
|
+
None for auto-detection.
|
|
49
|
+
"""
|
|
50
|
+
self.dialect = dialect
|
|
51
|
+
|
|
52
|
+
def parse(self, file_path: str, file_content: str, schema: dict | None = None) -> ParseResult:
|
|
53
|
+
"""Parse a SQL file into nodes, edges, column usage, and column lineage.
|
|
54
|
+
|
|
55
|
+
Handles multiple statements per file. Each statement is parsed
|
|
56
|
+
independently. Errors in one statement don't prevent parsing others.
|
|
57
|
+
|
|
58
|
+
Args:
|
|
59
|
+
file_path: Path to the SQL file (used for naming nodes).
|
|
60
|
+
file_content: Raw SQL content.
|
|
61
|
+
schema: Optional schema catalog ``{table: {col: type}}`` for
|
|
62
|
+
expanding ``SELECT *`` in lineage tracing via
|
|
63
|
+
``qualify_columns``.
|
|
64
|
+
|
|
65
|
+
Returns:
|
|
66
|
+
A ``ParseResult`` containing all extracted nodes, edges,
|
|
67
|
+
column usage records, column lineage chains, and any
|
|
68
|
+
non-fatal parse errors.
|
|
69
|
+
"""
|
|
70
|
+
nodes: list[NodeResult] = []
|
|
71
|
+
edges: list[EdgeResult] = []
|
|
72
|
+
column_usage: list[ColumnUsageResult] = []
|
|
73
|
+
column_lineage: list[ColumnLineageResult] = []
|
|
74
|
+
errors: list[str] = []
|
|
75
|
+
|
|
76
|
+
file_stem = Path(file_path).stem
|
|
77
|
+
|
|
78
|
+
try:
|
|
79
|
+
statements = sqlglot.parse(file_content, dialect=self.dialect)
|
|
80
|
+
except (sqlglot.errors.ParseError, sqlglot.errors.TokenError) as e:
|
|
81
|
+
return ParseResult(language="sql", errors=[f"Parse error: {e}"])
|
|
82
|
+
|
|
83
|
+
# Persistent dedup sets across all statements in this file
|
|
84
|
+
seen_nodes: set[tuple[str, str, str | None]] = set()
|
|
85
|
+
seen_ctes: set[str] = set()
|
|
86
|
+
|
|
87
|
+
for stmt_idx, stmt in enumerate(statements):
|
|
88
|
+
if stmt is None:
|
|
89
|
+
continue
|
|
90
|
+
|
|
91
|
+
try:
|
|
92
|
+
self._process_statement(
|
|
93
|
+
stmt,
|
|
94
|
+
file_stem,
|
|
95
|
+
file_path,
|
|
96
|
+
nodes,
|
|
97
|
+
edges,
|
|
98
|
+
column_usage,
|
|
99
|
+
seen_nodes=seen_nodes,
|
|
100
|
+
seen_ctes=seen_ctes,
|
|
101
|
+
)
|
|
102
|
+
except Exception as e:
|
|
103
|
+
errors.append(f"Statement {stmt_idx}: {type(e).__name__}: {e}")
|
|
104
|
+
continue
|
|
105
|
+
|
|
106
|
+
# Column lineage via sqlglot.lineage — separate pass
|
|
107
|
+
try:
|
|
108
|
+
self._extract_column_lineage(stmt, file_stem, file_content, column_lineage, schema=schema)
|
|
109
|
+
except Exception as e:
|
|
110
|
+
errors.append(f"Lineage stmt {stmt_idx}: {type(e).__name__}: {e}")
|
|
111
|
+
continue
|
|
112
|
+
|
|
113
|
+
return ParseResult(
|
|
114
|
+
language="sql",
|
|
115
|
+
nodes=nodes,
|
|
116
|
+
edges=edges,
|
|
117
|
+
column_usage=column_usage,
|
|
118
|
+
column_lineage=column_lineage,
|
|
119
|
+
errors=errors,
|
|
120
|
+
)
|
|
121
|
+
|
|
122
|
+
def _process_statement(
|
|
123
|
+
self,
|
|
124
|
+
stmt: exp.Expression,
|
|
125
|
+
file_stem: str,
|
|
126
|
+
file_path: str,
|
|
127
|
+
nodes: list[NodeResult],
|
|
128
|
+
edges: list[EdgeResult],
|
|
129
|
+
column_usage: list[ColumnUsageResult],
|
|
130
|
+
seen_nodes: set[tuple[str, str, str | None]] | None = None,
|
|
131
|
+
seen_ctes: set[str] | None = None,
|
|
132
|
+
) -> None:
|
|
133
|
+
"""Process a single SQL statement."""
|
|
134
|
+
# Use persistent dedup sets across statements, or create fresh ones
|
|
135
|
+
if seen_nodes is None:
|
|
136
|
+
seen_nodes = {(n.name, n.kind, (n.metadata or {}).get("schema")) for n in nodes}
|
|
137
|
+
seen_edges: set[tuple[str, str, str]] = set()
|
|
138
|
+
|
|
139
|
+
# CREATE TABLE / CREATE VIEW
|
|
140
|
+
if isinstance(stmt, exp.Create):
|
|
141
|
+
self._process_create(stmt, file_stem, nodes, edges)
|
|
142
|
+
|
|
143
|
+
# Extract table references from any statement type
|
|
144
|
+
self._extract_table_references(stmt, file_stem, nodes, edges, seen_nodes, seen_edges)
|
|
145
|
+
|
|
146
|
+
# Extract CTEs as first-class nodes
|
|
147
|
+
self._extract_ctes(stmt, file_stem, nodes, edges, seen_ctes=seen_ctes)
|
|
148
|
+
|
|
149
|
+
# Column-level lineage via sqlglot's scope analysis
|
|
150
|
+
self._extract_column_usage(stmt, file_stem, nodes, column_usage)
|
|
151
|
+
|
|
152
|
+
# INSERT...SELECT column mapping
|
|
153
|
+
if isinstance(stmt, exp.Insert):
|
|
154
|
+
self._extract_insert_select_mapping(stmt, file_stem, column_usage)
|
|
155
|
+
|
|
156
|
+
def _process_create(
|
|
157
|
+
self,
|
|
158
|
+
stmt: exp.Create,
|
|
159
|
+
file_stem: str,
|
|
160
|
+
nodes: list[NodeResult],
|
|
161
|
+
edges: list[EdgeResult],
|
|
162
|
+
) -> None:
|
|
163
|
+
"""Handle CREATE TABLE / CREATE VIEW statements."""
|
|
164
|
+
kind_expr = stmt.args.get("kind")
|
|
165
|
+
if not kind_expr:
|
|
166
|
+
return
|
|
167
|
+
|
|
168
|
+
kind_str = kind_expr.upper() if isinstance(kind_expr, str) else str(kind_expr).upper()
|
|
169
|
+
|
|
170
|
+
table_expr = stmt.this
|
|
171
|
+
if not isinstance(table_expr, exp.Table):
|
|
172
|
+
# Could be a Schema wrapping a Table
|
|
173
|
+
if isinstance(table_expr, exp.Schema):
|
|
174
|
+
table_expr = table_expr.this
|
|
175
|
+
if not isinstance(table_expr, exp.Table):
|
|
176
|
+
return
|
|
177
|
+
|
|
178
|
+
name = self._normalize_identifier(table_expr.name, self._is_quoted_identifier(table_expr))
|
|
179
|
+
if not name:
|
|
180
|
+
return
|
|
181
|
+
|
|
182
|
+
node_kind = "view" if "VIEW" in kind_str else "table"
|
|
183
|
+
metadata = self._build_table_metadata(table_expr)
|
|
184
|
+
metadata["dialect"] = self.dialect
|
|
185
|
+
metadata["create_type"] = kind_str
|
|
186
|
+
|
|
187
|
+
nodes.append(
|
|
188
|
+
NodeResult(
|
|
189
|
+
kind=node_kind,
|
|
190
|
+
name=name,
|
|
191
|
+
line_start=None, # sqlglot doesn't track line numbers reliably
|
|
192
|
+
metadata=metadata,
|
|
193
|
+
)
|
|
194
|
+
)
|
|
195
|
+
edges.append(
|
|
196
|
+
EdgeResult(
|
|
197
|
+
source_name=file_stem,
|
|
198
|
+
source_kind="query",
|
|
199
|
+
target_name=name,
|
|
200
|
+
target_kind=node_kind,
|
|
201
|
+
relationship="defines",
|
|
202
|
+
context="CREATE statement",
|
|
203
|
+
)
|
|
204
|
+
)
|
|
205
|
+
|
|
206
|
+
def _extract_table_references(
|
|
207
|
+
self,
|
|
208
|
+
stmt: exp.Expression,
|
|
209
|
+
file_stem: str,
|
|
210
|
+
nodes: list[NodeResult],
|
|
211
|
+
edges: list[EdgeResult],
|
|
212
|
+
seen_nodes: set[tuple[str, str, str | None]] | None = None,
|
|
213
|
+
seen_edges: set[tuple[str, str, str]] | None = None,
|
|
214
|
+
) -> None:
|
|
215
|
+
"""Extract all table references from a statement."""
|
|
216
|
+
if seen_nodes is None:
|
|
217
|
+
seen_nodes = {(n.name, n.kind, (n.metadata or {}).get("schema")) for n in nodes}
|
|
218
|
+
if seen_edges is None:
|
|
219
|
+
seen_edges = set()
|
|
220
|
+
|
|
221
|
+
# Identify the CREATE target so we don't double-count it as a reference
|
|
222
|
+
create_target: str | None = None
|
|
223
|
+
if isinstance(stmt, exp.Create):
|
|
224
|
+
target_expr = stmt.this
|
|
225
|
+
if isinstance(target_expr, exp.Schema):
|
|
226
|
+
target_expr = target_expr.this
|
|
227
|
+
if isinstance(target_expr, exp.Table) and target_expr.name:
|
|
228
|
+
create_target = self._normalize_identifier(
|
|
229
|
+
target_expr.name,
|
|
230
|
+
self._is_quoted_identifier(target_expr),
|
|
231
|
+
)
|
|
232
|
+
|
|
233
|
+
for table in stmt.find_all(exp.Table):
|
|
234
|
+
name = self._normalize_identifier(table.name, self._is_quoted_identifier(table))
|
|
235
|
+
if not name:
|
|
236
|
+
continue
|
|
237
|
+
|
|
238
|
+
# Skip the CREATE target — it's already handled by _process_create
|
|
239
|
+
if name == create_target:
|
|
240
|
+
# Check if this is the actual CREATE target (direct child of Create/Schema)
|
|
241
|
+
parent = table.parent
|
|
242
|
+
if isinstance(parent, (exp.Create, exp.Schema)):
|
|
243
|
+
continue
|
|
244
|
+
|
|
245
|
+
# Avoid duplicating nodes for the same table+schema within one file (O(1) check)
|
|
246
|
+
metadata = self._build_table_metadata(table)
|
|
247
|
+
table_schema = metadata.get("schema")
|
|
248
|
+
node_key = (name, "table", table_schema)
|
|
249
|
+
if node_key not in seen_nodes:
|
|
250
|
+
seen_nodes.add(node_key)
|
|
251
|
+
nodes.append(NodeResult(kind="table", name=name, metadata=metadata or None))
|
|
252
|
+
|
|
253
|
+
# Determine context from parent expression
|
|
254
|
+
context = self._get_table_context(table)
|
|
255
|
+
|
|
256
|
+
relationship = "inserts_into" if isinstance(stmt, exp.Insert) else "references"
|
|
257
|
+
|
|
258
|
+
# Skip duplicate edges with the same (source, target, context)
|
|
259
|
+
edge_key = (file_stem, name, context)
|
|
260
|
+
if edge_key in seen_edges:
|
|
261
|
+
continue
|
|
262
|
+
seen_edges.add(edge_key)
|
|
263
|
+
|
|
264
|
+
edges.append(
|
|
265
|
+
EdgeResult(
|
|
266
|
+
source_name=file_stem,
|
|
267
|
+
source_kind="query",
|
|
268
|
+
target_name=name,
|
|
269
|
+
target_kind="table",
|
|
270
|
+
relationship=relationship,
|
|
271
|
+
context=context,
|
|
272
|
+
)
|
|
273
|
+
)
|
|
274
|
+
|
|
275
|
+
def _extract_ctes(
|
|
276
|
+
self,
|
|
277
|
+
stmt: exp.Expression,
|
|
278
|
+
file_stem: str,
|
|
279
|
+
nodes: list[NodeResult],
|
|
280
|
+
edges: list[EdgeResult],
|
|
281
|
+
seen_ctes: set[str] | None = None,
|
|
282
|
+
) -> None:
|
|
283
|
+
"""Extract CTEs as first-class nodes with their own edges.
|
|
284
|
+
|
|
285
|
+
When a CTE references another CTE from the same statement, the edge
|
|
286
|
+
uses target_kind='cte' so trace queries follow CTE chains correctly.
|
|
287
|
+
|
|
288
|
+
Args:
|
|
289
|
+
seen_ctes: Set of CTE names already added across statements in this file.
|
|
290
|
+
Used to deduplicate CTEs with the same name across statements.
|
|
291
|
+
"""
|
|
292
|
+
if seen_ctes is None:
|
|
293
|
+
seen_ctes = set()
|
|
294
|
+
|
|
295
|
+
# Collect all CTE names in this statement first
|
|
296
|
+
cte_names: set[str] = set()
|
|
297
|
+
for cte in stmt.find_all(exp.CTE):
|
|
298
|
+
if cte.alias:
|
|
299
|
+
alias_node = cte.args.get("alias")
|
|
300
|
+
quoted = self._is_quoted_identifier(alias_node) if alias_node else False
|
|
301
|
+
cte_names.add(self._normalize_identifier(cte.alias, quoted))
|
|
302
|
+
|
|
303
|
+
for cte in stmt.find_all(exp.CTE):
|
|
304
|
+
alias_node = cte.args.get("alias")
|
|
305
|
+
cte_quoted = self._is_quoted_identifier(alias_node) if alias_node else False
|
|
306
|
+
cte_name = self._normalize_identifier(cte.alias, cte_quoted) if cte.alias else None
|
|
307
|
+
if not cte_name:
|
|
308
|
+
continue
|
|
309
|
+
|
|
310
|
+
# Deduplicate CTEs across statements in the same file
|
|
311
|
+
if cte_name in seen_ctes:
|
|
312
|
+
continue
|
|
313
|
+
seen_ctes.add(cte_name)
|
|
314
|
+
|
|
315
|
+
nodes.append(
|
|
316
|
+
NodeResult(
|
|
317
|
+
kind="cte",
|
|
318
|
+
name=cte_name,
|
|
319
|
+
metadata={"parent_query": file_stem},
|
|
320
|
+
)
|
|
321
|
+
)
|
|
322
|
+
|
|
323
|
+
# Find tables referenced within this CTE
|
|
324
|
+
for table in cte.find_all(exp.Table):
|
|
325
|
+
table_name = self._normalize_identifier(
|
|
326
|
+
table.name,
|
|
327
|
+
self._is_quoted_identifier(table),
|
|
328
|
+
)
|
|
329
|
+
if not table_name or table_name == cte_name:
|
|
330
|
+
continue
|
|
331
|
+
|
|
332
|
+
# If the reference is to another CTE, use target_kind='cte'
|
|
333
|
+
target_kind = "cte" if table_name in cte_names else "table"
|
|
334
|
+
|
|
335
|
+
edges.append(
|
|
336
|
+
EdgeResult(
|
|
337
|
+
source_name=cte_name,
|
|
338
|
+
source_kind="cte",
|
|
339
|
+
target_name=table_name,
|
|
340
|
+
target_kind=target_kind,
|
|
341
|
+
relationship="cte_references",
|
|
342
|
+
context=self._get_table_context(table),
|
|
343
|
+
)
|
|
344
|
+
)
|
|
345
|
+
|
|
346
|
+
def _extract_column_usage(
|
|
347
|
+
self,
|
|
348
|
+
stmt: exp.Expression,
|
|
349
|
+
file_stem: str,
|
|
350
|
+
nodes: list[NodeResult],
|
|
351
|
+
column_usage: list[ColumnUsageResult],
|
|
352
|
+
) -> None:
|
|
353
|
+
"""Extract column-level usage via sqlglot's scope analysis.
|
|
354
|
+
|
|
355
|
+
This is where sqlglot's investment pays off — scope-aware column
|
|
356
|
+
resolution that understands aliases, CTEs, and subqueries.
|
|
357
|
+
Also captures wrapping transforms (CAST, COALESCE, etc.) and
|
|
358
|
+
extracts WHERE clause filters as node metadata.
|
|
359
|
+
"""
|
|
360
|
+
# Only works on SELECT-like statements
|
|
361
|
+
select = stmt
|
|
362
|
+
if not isinstance(stmt, (exp.Select, exp.Union)):
|
|
363
|
+
select = stmt.find(exp.Select)
|
|
364
|
+
if select is None:
|
|
365
|
+
return
|
|
366
|
+
|
|
367
|
+
try:
|
|
368
|
+
root_scope = build_scope(select)
|
|
369
|
+
except Exception:
|
|
370
|
+
return
|
|
371
|
+
|
|
372
|
+
if root_scope is None:
|
|
373
|
+
return
|
|
374
|
+
|
|
375
|
+
seen_scopes = set()
|
|
376
|
+
for scope in [root_scope] + list(root_scope.traverse()):
|
|
377
|
+
scope_id = id(scope)
|
|
378
|
+
if scope_id in seen_scopes:
|
|
379
|
+
continue
|
|
380
|
+
seen_scopes.add(scope_id)
|
|
381
|
+
|
|
382
|
+
# Determine scope name
|
|
383
|
+
scope_name = file_stem
|
|
384
|
+
scope_kind = "query"
|
|
385
|
+
parent_expr = scope.expression.parent
|
|
386
|
+
if scope.is_cte:
|
|
387
|
+
# Extract CTE name from the expression's parent
|
|
388
|
+
if isinstance(parent_expr, exp.CTE) and parent_expr.alias:
|
|
389
|
+
alias_node = parent_expr.args.get("alias")
|
|
390
|
+
quoted = self._is_quoted_identifier(alias_node) if alias_node else False
|
|
391
|
+
scope_name = self._normalize_identifier(parent_expr.alias, quoted)
|
|
392
|
+
scope_kind = "cte"
|
|
393
|
+
elif isinstance(parent_expr, exp.Subquery) and parent_expr.alias:
|
|
394
|
+
# Derived table (subquery in FROM/JOIN)
|
|
395
|
+
alias_node = parent_expr.args.get("alias")
|
|
396
|
+
quoted = self._is_quoted_identifier(alias_node) if alias_node else False
|
|
397
|
+
scope_name = self._normalize_identifier(parent_expr.alias, quoted)
|
|
398
|
+
scope_kind = "subquery"
|
|
399
|
+
# Create a node for the subquery alias so column_usage can resolve
|
|
400
|
+
nodes.append(
|
|
401
|
+
NodeResult(
|
|
402
|
+
kind="subquery",
|
|
403
|
+
name=scope_name,
|
|
404
|
+
metadata={"parent_query": file_stem},
|
|
405
|
+
)
|
|
406
|
+
)
|
|
407
|
+
elif isinstance(parent_expr, exp.Create):
|
|
408
|
+
# Root scope inside CREATE TABLE/VIEW — use the table name
|
|
409
|
+
table_expr = parent_expr.this
|
|
410
|
+
if isinstance(table_expr, exp.Schema):
|
|
411
|
+
table_expr = table_expr.this
|
|
412
|
+
if isinstance(table_expr, exp.Table) and table_expr.name:
|
|
413
|
+
scope_name = self._normalize_identifier(
|
|
414
|
+
table_expr.name,
|
|
415
|
+
self._is_quoted_identifier(table_expr),
|
|
416
|
+
)
|
|
417
|
+
elif scope_kind == "query" and (scope_name, "query", None) not in {
|
|
418
|
+
(n.name, n.kind, (n.metadata or {}).get("schema")) for n in nodes
|
|
419
|
+
}:
|
|
420
|
+
# Bare SELECT root scope — create a query node so column_usage resolves
|
|
421
|
+
nodes.append(NodeResult(kind="query", name=scope_name, metadata={"bare_query": True}))
|
|
422
|
+
|
|
423
|
+
# Build alias → real table name mapping
|
|
424
|
+
alias_map: dict[str, str] = {}
|
|
425
|
+
for source_name, source in scope.sources.items():
|
|
426
|
+
if isinstance(source, exp.Table):
|
|
427
|
+
alias_map[source_name] = self._normalize_identifier(
|
|
428
|
+
source.name,
|
|
429
|
+
self._is_quoted_identifier(source),
|
|
430
|
+
)
|
|
431
|
+
|
|
432
|
+
# When there's exactly one source and no table qualifier, infer the table
|
|
433
|
+
single_table = ""
|
|
434
|
+
if len(alias_map) == 1:
|
|
435
|
+
single_table = next(iter(alias_map.values()))
|
|
436
|
+
|
|
437
|
+
for col in scope.columns:
|
|
438
|
+
if not isinstance(col, exp.Column):
|
|
439
|
+
continue
|
|
440
|
+
col_name = self._normalize_identifier(col.name, self._is_quoted_identifier(col))
|
|
441
|
+
if not col_name:
|
|
442
|
+
continue
|
|
443
|
+
|
|
444
|
+
# Resolve alias to real table name
|
|
445
|
+
table_alias = col.table or ""
|
|
446
|
+
table_name = alias_map.get(table_alias, table_alias)
|
|
447
|
+
if not table_name and single_table:
|
|
448
|
+
table_name = single_table
|
|
449
|
+
|
|
450
|
+
usage_type = self._classify_column_context(col)
|
|
451
|
+
transform = self._extract_transform(col)
|
|
452
|
+
alias = self._extract_alias(col)
|
|
453
|
+
|
|
454
|
+
column_usage.append(
|
|
455
|
+
ColumnUsageResult(
|
|
456
|
+
node_name=scope_name,
|
|
457
|
+
node_kind=scope_kind,
|
|
458
|
+
table_name=table_name,
|
|
459
|
+
column_name=col_name,
|
|
460
|
+
usage_type=usage_type,
|
|
461
|
+
alias=alias,
|
|
462
|
+
transform=transform,
|
|
463
|
+
)
|
|
464
|
+
)
|
|
465
|
+
|
|
466
|
+
# Handle SELECT * — emit usage for each source table
|
|
467
|
+
select_expr = scope.expression
|
|
468
|
+
if isinstance(select_expr, exp.Select):
|
|
469
|
+
for expr in select_expr.expressions:
|
|
470
|
+
if isinstance(expr, exp.Star):
|
|
471
|
+
# Unqualified * — emit for each source table
|
|
472
|
+
for source_name, source in scope.sources.items():
|
|
473
|
+
table_name = source.name if isinstance(source, exp.Table) else source_name
|
|
474
|
+
column_usage.append(
|
|
475
|
+
ColumnUsageResult(
|
|
476
|
+
node_name=scope_name,
|
|
477
|
+
node_kind=scope_kind,
|
|
478
|
+
table_name=table_name,
|
|
479
|
+
column_name="*",
|
|
480
|
+
usage_type="select",
|
|
481
|
+
)
|
|
482
|
+
)
|
|
483
|
+
elif isinstance(expr, exp.Column) and isinstance(expr.this, exp.Star):
|
|
484
|
+
# Qualified table.* — emit for that specific table
|
|
485
|
+
table_alias = expr.table or ""
|
|
486
|
+
table_name = alias_map.get(table_alias, table_alias)
|
|
487
|
+
column_usage.append(
|
|
488
|
+
ColumnUsageResult(
|
|
489
|
+
node_name=scope_name,
|
|
490
|
+
node_kind=scope_kind,
|
|
491
|
+
table_name=table_name,
|
|
492
|
+
column_name="*",
|
|
493
|
+
usage_type="select",
|
|
494
|
+
)
|
|
495
|
+
)
|
|
496
|
+
|
|
497
|
+
# Extract WHERE filters as metadata on the scope's node
|
|
498
|
+
self._extract_where_filters(scope, scope_name, scope_kind, nodes)
|
|
499
|
+
|
|
500
|
+
def _classify_column_context(self, col: exp.Column) -> str:
|
|
501
|
+
"""Determine how a column is used based on its AST position.
|
|
502
|
+
|
|
503
|
+
Distinguishes window function sub-clauses (PARTITION BY, window ORDER BY)
|
|
504
|
+
from regular usage types.
|
|
505
|
+
"""
|
|
506
|
+
parent = col.parent
|
|
507
|
+
|
|
508
|
+
while parent:
|
|
509
|
+
# Window function sub-clauses — check before general Order
|
|
510
|
+
if isinstance(parent, exp.Window):
|
|
511
|
+
# Determine if column is in PARTITION BY or ORDER BY within window
|
|
512
|
+
return self._classify_window_position(col, parent)
|
|
513
|
+
if isinstance(parent, exp.Where):
|
|
514
|
+
return "where"
|
|
515
|
+
if isinstance(parent, exp.Join):
|
|
516
|
+
return "join_on"
|
|
517
|
+
if isinstance(parent, exp.Group):
|
|
518
|
+
return "group_by"
|
|
519
|
+
if isinstance(parent, exp.Order):
|
|
520
|
+
# Check if this Order is inside a Window (window ORDER BY)
|
|
521
|
+
order_parent = parent.parent
|
|
522
|
+
if isinstance(order_parent, exp.Window):
|
|
523
|
+
return "window_order"
|
|
524
|
+
return "order_by"
|
|
525
|
+
if isinstance(parent, exp.Having):
|
|
526
|
+
return "having"
|
|
527
|
+
if isinstance(parent, exp.Qualify):
|
|
528
|
+
return "qualify"
|
|
529
|
+
if isinstance(parent, exp.Select):
|
|
530
|
+
return "select"
|
|
531
|
+
parent = parent.parent
|
|
532
|
+
|
|
533
|
+
return "unknown"
|
|
534
|
+
|
|
535
|
+
def _classify_window_position(self, col: exp.Column, window: exp.Window) -> str:
|
|
536
|
+
"""Classify a column's position within a window function."""
|
|
537
|
+
# Walk from column up to the window, checking if we pass through
|
|
538
|
+
# partition_by or order clause
|
|
539
|
+
parent = col.parent
|
|
540
|
+
while parent and parent is not window:
|
|
541
|
+
if isinstance(parent, exp.Order):
|
|
542
|
+
return "window_order"
|
|
543
|
+
parent = parent.parent
|
|
544
|
+
|
|
545
|
+
# Check if column is in the partition_by list
|
|
546
|
+
partition_by = window.args.get("partition_by")
|
|
547
|
+
if partition_by:
|
|
548
|
+
for partition_col in partition_by:
|
|
549
|
+
if col in partition_col.walk():
|
|
550
|
+
return "partition_by"
|
|
551
|
+
|
|
552
|
+
return "select" # fallback — column is in the aggregate part of the window
|
|
553
|
+
|
|
554
|
+
def _extract_transform(self, col: exp.Column) -> str | None:
|
|
555
|
+
"""Extract the wrapping transform expression around a column.
|
|
556
|
+
|
|
557
|
+
Walks up from the Column node to find wrapping functions like
|
|
558
|
+
CAST, COALESCE, IF, CASE, arithmetic, etc. Returns the SQL string
|
|
559
|
+
of the outermost meaningful wrapper, or None if the column is bare.
|
|
560
|
+
"""
|
|
561
|
+
# Wrapping expression types that constitute a "transform"
|
|
562
|
+
transform_types = (
|
|
563
|
+
exp.Cast,
|
|
564
|
+
exp.TryCast,
|
|
565
|
+
exp.Coalesce,
|
|
566
|
+
exp.If,
|
|
567
|
+
exp.Case,
|
|
568
|
+
exp.Anonymous, # function calls like NVL, IFNULL, etc.
|
|
569
|
+
exp.Func, # base class for all functions (UPPER, LOWER, etc.)
|
|
570
|
+
exp.Add,
|
|
571
|
+
exp.Sub,
|
|
572
|
+
exp.Mul,
|
|
573
|
+
exp.Div,
|
|
574
|
+
exp.Mod,
|
|
575
|
+
exp.Concat,
|
|
576
|
+
exp.DPipe, # || concat operator
|
|
577
|
+
exp.Substring,
|
|
578
|
+
exp.Trim,
|
|
579
|
+
exp.Extract, # EXTRACT(YEAR FROM ...)
|
|
580
|
+
exp.DateAdd,
|
|
581
|
+
exp.DateSub,
|
|
582
|
+
exp.DateDiff,
|
|
583
|
+
exp.Between,
|
|
584
|
+
exp.In,
|
|
585
|
+
exp.Like,
|
|
586
|
+
exp.Neg, # unary minus
|
|
587
|
+
)
|
|
588
|
+
|
|
589
|
+
# Comparison types — include as transforms but don't traverse past
|
|
590
|
+
comparison_types = (
|
|
591
|
+
exp.EQ,
|
|
592
|
+
exp.NEQ,
|
|
593
|
+
exp.GT,
|
|
594
|
+
exp.GTE,
|
|
595
|
+
exp.LT,
|
|
596
|
+
exp.LTE,
|
|
597
|
+
exp.Is,
|
|
598
|
+
exp.Not,
|
|
599
|
+
)
|
|
600
|
+
|
|
601
|
+
parent = col.parent
|
|
602
|
+
outermost = None
|
|
603
|
+
|
|
604
|
+
while parent:
|
|
605
|
+
if isinstance(parent, transform_types):
|
|
606
|
+
outermost = parent
|
|
607
|
+
elif isinstance(parent, comparison_types):
|
|
608
|
+
outermost = parent
|
|
609
|
+
break # comparisons are the natural boundary for WHERE/JOIN
|
|
610
|
+
elif isinstance(parent, (exp.And, exp.Or)):
|
|
611
|
+
break # don't capture the full AND/OR chain
|
|
612
|
+
elif isinstance(
|
|
613
|
+
parent,
|
|
614
|
+
(
|
|
615
|
+
exp.Select,
|
|
616
|
+
exp.Where,
|
|
617
|
+
exp.Group,
|
|
618
|
+
exp.Order,
|
|
619
|
+
exp.Having,
|
|
620
|
+
exp.Join,
|
|
621
|
+
exp.From,
|
|
622
|
+
exp.Subquery,
|
|
623
|
+
exp.CTE,
|
|
624
|
+
),
|
|
625
|
+
):
|
|
626
|
+
# Stop at clause boundaries
|
|
627
|
+
break
|
|
628
|
+
parent = parent.parent
|
|
629
|
+
|
|
630
|
+
if outermost is None:
|
|
631
|
+
return None
|
|
632
|
+
|
|
633
|
+
try:
|
|
634
|
+
sql = outermost.sql(dialect=self.dialect)
|
|
635
|
+
# Skip if the transform is just the column itself
|
|
636
|
+
col_sql = col.sql(dialect=self.dialect)
|
|
637
|
+
if sql == col_sql:
|
|
638
|
+
return None
|
|
639
|
+
return sql
|
|
640
|
+
except Exception:
|
|
641
|
+
return None
|
|
642
|
+
|
|
643
|
+
def _extract_alias(self, col: exp.Column) -> str | None:
|
|
644
|
+
"""Extract the output alias for a column (AS name)."""
|
|
645
|
+
parent = col.parent
|
|
646
|
+
while parent:
|
|
647
|
+
if isinstance(parent, exp.Alias):
|
|
648
|
+
return parent.alias
|
|
649
|
+
if isinstance(parent, (exp.Select, exp.Where, exp.Group, exp.Order, exp.Having)):
|
|
650
|
+
break
|
|
651
|
+
parent = parent.parent
|
|
652
|
+
return None
|
|
653
|
+
|
|
654
|
+
def _extract_where_filters(
|
|
655
|
+
self,
|
|
656
|
+
scope,
|
|
657
|
+
scope_name: str,
|
|
658
|
+
scope_kind: str,
|
|
659
|
+
nodes: list[NodeResult],
|
|
660
|
+
) -> None:
|
|
661
|
+
"""Extract WHERE clause conditions and attach as metadata to the scope's node.
|
|
662
|
+
|
|
663
|
+
Finds the WHERE clause in the scope expression and extracts each
|
|
664
|
+
top-level condition as a string. These are stored as node metadata
|
|
665
|
+
so they're searchable in the graph.
|
|
666
|
+
"""
|
|
667
|
+
try:
|
|
668
|
+
# Use .args["where"] to get only the direct WHERE, not from subqueries
|
|
669
|
+
where = scope.expression.args.get("where")
|
|
670
|
+
except Exception:
|
|
671
|
+
return
|
|
672
|
+
|
|
673
|
+
if not where:
|
|
674
|
+
return
|
|
675
|
+
|
|
676
|
+
filters = []
|
|
677
|
+
# Split AND conditions into individual filters
|
|
678
|
+
conditions = self._split_conditions(where.this)
|
|
679
|
+
for cond in conditions:
|
|
680
|
+
try:
|
|
681
|
+
sql = cond.sql(dialect=self.dialect)
|
|
682
|
+
if sql and len(sql) < 500: # skip absurdly long conditions
|
|
683
|
+
filters.append(sql)
|
|
684
|
+
except Exception:
|
|
685
|
+
continue
|
|
686
|
+
|
|
687
|
+
if not filters:
|
|
688
|
+
return
|
|
689
|
+
|
|
690
|
+
# Find the matching node and update its metadata
|
|
691
|
+
# Try exact match first, then match by name only (handles query→table/view mapping)
|
|
692
|
+
# Use enumerate to avoid O(N) nodes.index() and wrong-match-on-duplicates bug
|
|
693
|
+
for idx, node in enumerate(nodes):
|
|
694
|
+
if node.name == scope_name and (node.kind == scope_kind or node.kind in ("table", "view", "cte")):
|
|
695
|
+
existing_meta = dict(node.metadata) if node.metadata else {}
|
|
696
|
+
existing_meta["filters"] = filters
|
|
697
|
+
# NodeResult is frozen, so we need to replace it
|
|
698
|
+
nodes[idx] = NodeResult(
|
|
699
|
+
kind=node.kind,
|
|
700
|
+
name=node.name,
|
|
701
|
+
line_start=node.line_start,
|
|
702
|
+
line_end=node.line_end,
|
|
703
|
+
metadata=existing_meta,
|
|
704
|
+
)
|
|
705
|
+
return
|
|
706
|
+
|
|
707
|
+
def _split_conditions(self, expr: exp.Expression) -> list[exp.Expression]:
|
|
708
|
+
"""Split an AND chain into individual conditions."""
|
|
709
|
+
if isinstance(expr, exp.And):
|
|
710
|
+
return self._split_conditions(expr.left) + self._split_conditions(expr.right)
|
|
711
|
+
return [expr]
|
|
712
|
+
|
|
713
|
+
def _extract_column_lineage(
|
|
714
|
+
self,
|
|
715
|
+
stmt: exp.Expression,
|
|
716
|
+
file_stem: str,
|
|
717
|
+
file_content: str,
|
|
718
|
+
column_lineage: list[ColumnLineageResult],
|
|
719
|
+
schema: dict | None = None,
|
|
720
|
+
) -> None:
|
|
721
|
+
"""Extract end-to-end column lineage using sqlglot.lineage.lineage().
|
|
722
|
+
|
|
723
|
+
Traces each output column through CTEs and subqueries back to source tables.
|
|
724
|
+
If a schema catalog is provided, it's passed to sqlglot_lineage to help
|
|
725
|
+
resolve SELECT * and improve lineage accuracy.
|
|
726
|
+
"""
|
|
727
|
+
# Find the output SELECT to get column names
|
|
728
|
+
select = stmt
|
|
729
|
+
output_name = file_stem
|
|
730
|
+
|
|
731
|
+
if isinstance(stmt, exp.Create):
|
|
732
|
+
# Get the CREATE target name
|
|
733
|
+
table_expr = stmt.this
|
|
734
|
+
if isinstance(table_expr, exp.Schema):
|
|
735
|
+
table_expr = table_expr.this
|
|
736
|
+
if isinstance(table_expr, exp.Table) and table_expr.name:
|
|
737
|
+
output_name = table_expr.name
|
|
738
|
+
select = stmt.find(exp.Select)
|
|
739
|
+
elif not isinstance(stmt, (exp.Select, exp.Union)):
|
|
740
|
+
select = stmt.find(exp.Select)
|
|
741
|
+
|
|
742
|
+
if select is None:
|
|
743
|
+
return
|
|
744
|
+
|
|
745
|
+
# If schema available, try qualify_columns to expand SELECT *
|
|
746
|
+
qualified_stmt = stmt
|
|
747
|
+
if schema:
|
|
748
|
+
try:
|
|
749
|
+
qualified_stmt = qualify_columns(stmt.copy(), schema=schema, dialect=self.dialect)
|
|
750
|
+
# Re-find the select from the qualified version
|
|
751
|
+
if isinstance(qualified_stmt, exp.Create):
|
|
752
|
+
select = qualified_stmt.find(exp.Select)
|
|
753
|
+
elif isinstance(qualified_stmt, (exp.Select, exp.Union)):
|
|
754
|
+
select = qualified_stmt
|
|
755
|
+
else:
|
|
756
|
+
select = qualified_stmt.find(exp.Select)
|
|
757
|
+
if select is None:
|
|
758
|
+
return
|
|
759
|
+
except Exception:
|
|
760
|
+
pass # fall back to unqualified
|
|
761
|
+
|
|
762
|
+
# Get output column names from the SELECT
|
|
763
|
+
# For UNION, enumerate output columns from ALL branches
|
|
764
|
+
if isinstance(select, exp.Union):
|
|
765
|
+
output_cols = []
|
|
766
|
+
seen_cols: set[str] = set()
|
|
767
|
+
for branch_select in select.find_all(exp.Select):
|
|
768
|
+
for expr in branch_select.expressions:
|
|
769
|
+
col_name = None
|
|
770
|
+
if isinstance(expr, exp.Alias):
|
|
771
|
+
col_name = expr.alias
|
|
772
|
+
elif isinstance(expr, exp.Column):
|
|
773
|
+
col_name = expr.name
|
|
774
|
+
elif isinstance(expr, exp.Star):
|
|
775
|
+
col_name = "*"
|
|
776
|
+
if col_name and col_name not in seen_cols:
|
|
777
|
+
seen_cols.add(col_name)
|
|
778
|
+
output_cols.append(col_name)
|
|
779
|
+
elif isinstance(select, exp.Select):
|
|
780
|
+
output_cols = []
|
|
781
|
+
for expr in select.expressions:
|
|
782
|
+
if isinstance(expr, exp.Alias):
|
|
783
|
+
output_cols.append(expr.alias)
|
|
784
|
+
elif isinstance(expr, exp.Column):
|
|
785
|
+
output_cols.append(expr.name)
|
|
786
|
+
elif isinstance(expr, exp.Star):
|
|
787
|
+
# SELECT * — can't trace individual columns without schema
|
|
788
|
+
output_cols.append("*")
|
|
789
|
+
else:
|
|
790
|
+
# Complex expression without alias — skip
|
|
791
|
+
continue
|
|
792
|
+
else:
|
|
793
|
+
return
|
|
794
|
+
|
|
795
|
+
# Trace each output column — pass AST directly to avoid re-serializing
|
|
796
|
+
for col_name in output_cols:
|
|
797
|
+
if col_name == "*":
|
|
798
|
+
# Can't trace SELECT * without schema catalog
|
|
799
|
+
continue
|
|
800
|
+
try:
|
|
801
|
+
root = sqlglot_lineage(
|
|
802
|
+
col_name,
|
|
803
|
+
qualified_stmt,
|
|
804
|
+
dialect=self.dialect,
|
|
805
|
+
schema=schema,
|
|
806
|
+
)
|
|
807
|
+
except Exception:
|
|
808
|
+
continue
|
|
809
|
+
|
|
810
|
+
# Walk the lineage tree to build hop chains
|
|
811
|
+
chains = self._walk_lineage_tree(root, [])
|
|
812
|
+
for chain in chains:
|
|
813
|
+
if chain: # skip empty chains
|
|
814
|
+
column_lineage.append(
|
|
815
|
+
ColumnLineageResult(
|
|
816
|
+
output_column=col_name,
|
|
817
|
+
output_node=output_name,
|
|
818
|
+
chain=chain,
|
|
819
|
+
)
|
|
820
|
+
)
|
|
821
|
+
|
|
822
|
+
def _walk_lineage_tree(
|
|
823
|
+
self,
|
|
824
|
+
node,
|
|
825
|
+
current_chain: list[LineageHop],
|
|
826
|
+
max_depth: int = 50,
|
|
827
|
+
max_chains: int = 1000,
|
|
828
|
+
_chain_count: list | None = None,
|
|
829
|
+
) -> list[list[LineageHop]]:
|
|
830
|
+
"""Recursively walk a sqlglot lineage node tree into flat chains.
|
|
831
|
+
|
|
832
|
+
Each leaf produces one complete chain from output to source.
|
|
833
|
+
|
|
834
|
+
Args:
|
|
835
|
+
node: Current lineage node.
|
|
836
|
+
current_chain: Chain built so far.
|
|
837
|
+
max_depth: Maximum recursion depth before treating node as leaf.
|
|
838
|
+
max_chains: Maximum total chains to collect before stopping early.
|
|
839
|
+
_chain_count: Mutable counter shared across recursion to track total chains.
|
|
840
|
+
"""
|
|
841
|
+
if _chain_count is None:
|
|
842
|
+
_chain_count = [0]
|
|
843
|
+
|
|
844
|
+
# Stop if depth or chain limit exceeded — treat current node as leaf
|
|
845
|
+
if len(current_chain) >= max_depth or _chain_count[0] >= max_chains:
|
|
846
|
+
return [current_chain] if current_chain else []
|
|
847
|
+
|
|
848
|
+
# Extract info from this node
|
|
849
|
+
name = node.name if hasattr(node, "name") else ""
|
|
850
|
+
source = node.source.sql() if hasattr(node, "source") and node.source else ""
|
|
851
|
+
expr_str = (
|
|
852
|
+
node.expression.sql(dialect=self.dialect) if hasattr(node, "expression") and node.expression else None
|
|
853
|
+
)
|
|
854
|
+
|
|
855
|
+
# Parse column and table from the node name (format: "table.column" or just "column")
|
|
856
|
+
parts = name.split(".") if name else []
|
|
857
|
+
hop_column = parts[-1] if parts else name
|
|
858
|
+
hop_table = parts[-2] if len(parts) >= 2 else ""
|
|
859
|
+
|
|
860
|
+
# If no table from name, try to extract from source
|
|
861
|
+
if not hop_table and source:
|
|
862
|
+
# Source often looks like "table AS alias" or just "table"
|
|
863
|
+
source_parts = source.strip().split()
|
|
864
|
+
if source_parts:
|
|
865
|
+
hop_table = source_parts[0].strip('"').strip("'")
|
|
866
|
+
|
|
867
|
+
hop = LineageHop(
|
|
868
|
+
column=hop_column,
|
|
869
|
+
table=hop_table,
|
|
870
|
+
expression=expr_str if expr_str and expr_str != hop_column else None,
|
|
871
|
+
)
|
|
872
|
+
|
|
873
|
+
new_chain = current_chain + [hop]
|
|
874
|
+
|
|
875
|
+
downstream = node.downstream if hasattr(node, "downstream") else []
|
|
876
|
+
if not downstream:
|
|
877
|
+
# Leaf node — return the completed chain
|
|
878
|
+
_chain_count[0] += 1
|
|
879
|
+
return [new_chain]
|
|
880
|
+
|
|
881
|
+
# Recurse into downstream nodes
|
|
882
|
+
all_chains = []
|
|
883
|
+
for child in downstream:
|
|
884
|
+
if _chain_count[0] >= max_chains:
|
|
885
|
+
break
|
|
886
|
+
all_chains.extend(self._walk_lineage_tree(child, new_chain, max_depth, max_chains, _chain_count))
|
|
887
|
+
return all_chains
|
|
888
|
+
|
|
889
|
+
def _extract_insert_select_mapping(
|
|
890
|
+
self,
|
|
891
|
+
stmt: exp.Insert,
|
|
892
|
+
file_stem: str,
|
|
893
|
+
column_usage: list[ColumnUsageResult],
|
|
894
|
+
) -> None:
|
|
895
|
+
"""Extract positional column mapping from INSERT...SELECT.
|
|
896
|
+
|
|
897
|
+
When INSERT INTO target (a, b) SELECT x, y FROM source,
|
|
898
|
+
maps source column x -> target column a, y -> b by position.
|
|
899
|
+
"""
|
|
900
|
+
# Get the target table name
|
|
901
|
+
target_table = stmt.this
|
|
902
|
+
if isinstance(target_table, exp.Schema):
|
|
903
|
+
# INSERT INTO table (col1, col2) — columns are Identifier nodes
|
|
904
|
+
target_cols = [col.name for col in target_table.expressions if hasattr(col, "name")]
|
|
905
|
+
target_table = target_table.this
|
|
906
|
+
else:
|
|
907
|
+
target_cols = []
|
|
908
|
+
|
|
909
|
+
if not isinstance(target_table, exp.Table) or not target_table.name:
|
|
910
|
+
return
|
|
911
|
+
|
|
912
|
+
target_name = target_table.name
|
|
913
|
+
|
|
914
|
+
# Get the SELECT statement
|
|
915
|
+
select = stmt.expression
|
|
916
|
+
if not isinstance(select, exp.Select):
|
|
917
|
+
return
|
|
918
|
+
|
|
919
|
+
# Get SELECT expressions (output columns)
|
|
920
|
+
select_exprs = select.expressions
|
|
921
|
+
if not select_exprs:
|
|
922
|
+
return
|
|
923
|
+
|
|
924
|
+
# Build alias → real table name mapping from the SELECT's FROM/JOIN sources
|
|
925
|
+
alias_map: dict[str, str] = {}
|
|
926
|
+
for table_ref in select.find_all(exp.Table):
|
|
927
|
+
tbl_name = self._normalize_identifier(
|
|
928
|
+
table_ref.name,
|
|
929
|
+
self._is_quoted_identifier(table_ref),
|
|
930
|
+
)
|
|
931
|
+
if tbl_name:
|
|
932
|
+
alias_map[tbl_name] = tbl_name
|
|
933
|
+
if table_ref.alias:
|
|
934
|
+
alias_map[table_ref.alias] = tbl_name
|
|
935
|
+
|
|
936
|
+
# Map each SELECT expression to its target column by position
|
|
937
|
+
for i, select_expr in enumerate(select_exprs):
|
|
938
|
+
target_col = target_cols[i] if i < len(target_cols) else None
|
|
939
|
+
|
|
940
|
+
# Find the source column in this expression
|
|
941
|
+
source_cols = list(select_expr.find_all(exp.Column))
|
|
942
|
+
for src_col in source_cols:
|
|
943
|
+
if not src_col.name:
|
|
944
|
+
continue
|
|
945
|
+
|
|
946
|
+
# Resolve table alias to real table name
|
|
947
|
+
table_alias = src_col.table or ""
|
|
948
|
+
source_table = alias_map.get(table_alias, table_alias)
|
|
949
|
+
|
|
950
|
+
column_usage.append(
|
|
951
|
+
ColumnUsageResult(
|
|
952
|
+
node_name=file_stem,
|
|
953
|
+
node_kind="query",
|
|
954
|
+
table_name=source_table or target_name,
|
|
955
|
+
column_name=src_col.name,
|
|
956
|
+
usage_type="insert",
|
|
957
|
+
transform=self._extract_transform(src_col),
|
|
958
|
+
alias=target_col,
|
|
959
|
+
)
|
|
960
|
+
)
|
|
961
|
+
|
|
962
|
+
def _normalize_identifier(self, name: str, quoted: bool = False) -> str:
|
|
963
|
+
"""Normalize an identifier based on the SQL dialect's case folding rules.
|
|
964
|
+
|
|
965
|
+
Unquoted identifiers are folded: lowercase for Postgres/Redshift/DuckDB,
|
|
966
|
+
uppercase for Snowflake/Oracle. Other dialects preserve case.
|
|
967
|
+
|
|
968
|
+
Quoted identifiers are never folded — they preserve the exact case
|
|
969
|
+
the user wrote.
|
|
970
|
+
"""
|
|
971
|
+
if not name or not self.dialect or quoted:
|
|
972
|
+
return name
|
|
973
|
+
d = self.dialect.lower()
|
|
974
|
+
if d in self._LOWERCASE_DIALECTS:
|
|
975
|
+
return name.lower()
|
|
976
|
+
if d in self._UPPERCASE_DIALECTS:
|
|
977
|
+
return name.upper()
|
|
978
|
+
return name
|
|
979
|
+
|
|
980
|
+
@staticmethod
|
|
981
|
+
def _is_quoted_identifier(node: exp.Expression) -> bool:
|
|
982
|
+
"""Check whether a sqlglot expression's name identifier is quoted.
|
|
983
|
+
|
|
984
|
+
Works for Table (node.this is Identifier), Column (node.this is Identifier),
|
|
985
|
+
CTE/Subquery aliases (via TableAlias wrapping an Identifier), etc.
|
|
986
|
+
"""
|
|
987
|
+
ident = node.this if hasattr(node, "this") else None
|
|
988
|
+
if isinstance(ident, exp.Identifier):
|
|
989
|
+
return bool(ident.quoted)
|
|
990
|
+
return False
|
|
991
|
+
|
|
992
|
+
def _build_table_metadata(self, table: exp.Table) -> dict:
|
|
993
|
+
"""Build metadata dict with catalog/schema from a qualified table reference.
|
|
994
|
+
|
|
995
|
+
Catalog and schema values are normalized using the same dialect-aware
|
|
996
|
+
case folding as table/column names. Quoted identifiers keep their
|
|
997
|
+
original case.
|
|
998
|
+
"""
|
|
999
|
+
meta: dict = {}
|
|
1000
|
+
if table.catalog:
|
|
1001
|
+
catalog_node = table.args.get("catalog")
|
|
1002
|
+
quoted = isinstance(catalog_node, exp.Identifier) and bool(catalog_node.quoted)
|
|
1003
|
+
meta["catalog"] = self._normalize_identifier(table.catalog, quoted)
|
|
1004
|
+
if table.db:
|
|
1005
|
+
db_node = table.args.get("db")
|
|
1006
|
+
quoted = isinstance(db_node, exp.Identifier) and bool(db_node.quoted)
|
|
1007
|
+
meta["schema"] = self._normalize_identifier(table.db, quoted)
|
|
1008
|
+
return meta
|
|
1009
|
+
|
|
1010
|
+
def _get_table_context(self, table: exp.Table) -> str:
|
|
1011
|
+
"""Determine context of a table reference from its AST position."""
|
|
1012
|
+
parent = table.parent
|
|
1013
|
+
|
|
1014
|
+
while parent:
|
|
1015
|
+
if isinstance(parent, exp.Join):
|
|
1016
|
+
return "JOIN clause"
|
|
1017
|
+
if isinstance(parent, exp.From):
|
|
1018
|
+
return "FROM clause"
|
|
1019
|
+
if isinstance(parent, exp.Subquery):
|
|
1020
|
+
return "subquery"
|
|
1021
|
+
if isinstance(parent, exp.Insert):
|
|
1022
|
+
return "INSERT INTO"
|
|
1023
|
+
if isinstance(parent, exp.Merge):
|
|
1024
|
+
return "MERGE target"
|
|
1025
|
+
if isinstance(parent, exp.Update):
|
|
1026
|
+
return "UPDATE target"
|
|
1027
|
+
if isinstance(parent, exp.Lateral):
|
|
1028
|
+
return "LATERAL subquery"
|
|
1029
|
+
parent = parent.parent
|
|
1030
|
+
|
|
1031
|
+
return "FROM clause"
|