iflow-mcp_niclasolofsson-dbt-core-mcp 1.7.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dbt_core_mcp/__init__.py +18 -0
- dbt_core_mcp/__main__.py +436 -0
- dbt_core_mcp/context.py +459 -0
- dbt_core_mcp/cte_generator.py +601 -0
- dbt_core_mcp/dbt/__init__.py +1 -0
- dbt_core_mcp/dbt/bridge_runner.py +1361 -0
- dbt_core_mcp/dbt/manifest.py +781 -0
- dbt_core_mcp/dbt/runner.py +67 -0
- dbt_core_mcp/dependencies.py +50 -0
- dbt_core_mcp/server.py +381 -0
- dbt_core_mcp/tools/__init__.py +77 -0
- dbt_core_mcp/tools/analyze_impact.py +78 -0
- dbt_core_mcp/tools/build_models.py +190 -0
- dbt_core_mcp/tools/demo/__init__.py +1 -0
- dbt_core_mcp/tools/demo/hello.html +267 -0
- dbt_core_mcp/tools/demo/ui_demo.py +41 -0
- dbt_core_mcp/tools/get_column_lineage.py +1988 -0
- dbt_core_mcp/tools/get_lineage.py +89 -0
- dbt_core_mcp/tools/get_project_info.py +96 -0
- dbt_core_mcp/tools/get_resource_info.py +134 -0
- dbt_core_mcp/tools/install_deps.py +102 -0
- dbt_core_mcp/tools/list_resources.py +84 -0
- dbt_core_mcp/tools/load_seeds.py +179 -0
- dbt_core_mcp/tools/query_database.py +459 -0
- dbt_core_mcp/tools/run_models.py +234 -0
- dbt_core_mcp/tools/snapshot_models.py +120 -0
- dbt_core_mcp/tools/test_models.py +238 -0
- dbt_core_mcp/utils/__init__.py +1 -0
- dbt_core_mcp/utils/env_detector.py +186 -0
- dbt_core_mcp/utils/process_check.py +130 -0
- dbt_core_mcp/utils/tool_utils.py +411 -0
- dbt_core_mcp/utils/warehouse_adapter.py +82 -0
- dbt_core_mcp/utils/warehouse_databricks.py +297 -0
- iflow_mcp_niclasolofsson_dbt_core_mcp-1.7.0.dist-info/METADATA +784 -0
- iflow_mcp_niclasolofsson_dbt_core_mcp-1.7.0.dist-info/RECORD +38 -0
- iflow_mcp_niclasolofsson_dbt_core_mcp-1.7.0.dist-info/WHEEL +4 -0
- iflow_mcp_niclasolofsson_dbt_core_mcp-1.7.0.dist-info/entry_points.txt +2 -0
- iflow_mcp_niclasolofsson_dbt_core_mcp-1.7.0.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,1988 @@
|
|
|
1
|
+
"""Get column-level lineage through SQL transformations.
|
|
2
|
+
|
|
3
|
+
This module implements the get_column_lineage tool for dbt Core MCP.
|
|
4
|
+
Uses sqlglot to parse SQL and trace column dependencies through CTEs and transformations.
|
|
5
|
+
|
|
6
|
+
Architecture:
|
|
7
|
+
- Unified SQL parsing approach for both upstream and downstream directions
|
|
8
|
+
- Two-stage process: (1) resolve output columns, (2) analyze lineage per column
|
|
9
|
+
- Direction-agnostic helpers: _prepare_model_analysis, _analyze_column_lineage
|
|
10
|
+
- Consistent fallback order: SQL-derived → warehouse → manifest → none
|
|
11
|
+
"""
|
|
12
|
+
|
|
13
|
+
import logging
|
|
14
|
+
from typing import Any
|
|
15
|
+
|
|
16
|
+
from fastmcp.dependencies import Depends # type: ignore[reportAttributeAccessIssue]
|
|
17
|
+
from fastmcp.server.context import Context
|
|
18
|
+
from sqlglot import exp, parse_one
|
|
19
|
+
from sqlglot.errors import SqlglotError
|
|
20
|
+
from sqlglot.lineage import lineage
|
|
21
|
+
from sqlglot.optimizer.scope import build_scope
|
|
22
|
+
|
|
23
|
+
from ..context import DbtCoreServerContext
|
|
24
|
+
from ..dbt.manifest import ManifestLoader
|
|
25
|
+
from ..dependencies import get_state
|
|
26
|
+
from . import dbtTool
|
|
27
|
+
|
|
28
|
+
logger = logging.getLogger(__name__)
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def _wrap_final_select(sql: str, column_name: str, dialect: str = "databricks") -> exp.Expression:
|
|
32
|
+
"""Wrap the final SELECT in a CTE to enable lineage tracing through SELECT *.
|
|
33
|
+
|
|
34
|
+
Transforms:
|
|
35
|
+
SELECT * FROM final
|
|
36
|
+
Into:
|
|
37
|
+
WITH __lineage_final__ AS (SELECT * FROM final)
|
|
38
|
+
SELECT column_name FROM __lineage_final__
|
|
39
|
+
|
|
40
|
+
This allows lineage() to trace specific columns through schemaless queries.
|
|
41
|
+
|
|
42
|
+
Args:
|
|
43
|
+
sql: SQL query to wrap
|
|
44
|
+
column_name: Column to trace
|
|
45
|
+
dialect: SQL dialect
|
|
46
|
+
|
|
47
|
+
Returns:
|
|
48
|
+
Modified AST with wrapped final SELECT
|
|
49
|
+
"""
|
|
50
|
+
ast = parse_one(sql, dialect=dialect)
|
|
51
|
+
|
|
52
|
+
# Get the root SELECT
|
|
53
|
+
root_select = ast if isinstance(ast, exp.Select) else ast.find(exp.Select)
|
|
54
|
+
|
|
55
|
+
if not root_select:
|
|
56
|
+
return ast
|
|
57
|
+
|
|
58
|
+
# Save the WITH clause before clearing
|
|
59
|
+
with_clause = root_select.args.get("with_")
|
|
60
|
+
|
|
61
|
+
# Copy entire root SELECT and strip its WITH
|
|
62
|
+
wrapper_select = root_select.copy()
|
|
63
|
+
wrapper_select.set("with_", None)
|
|
64
|
+
|
|
65
|
+
wrapper_cte = exp.CTE(this=wrapper_select, alias=exp.TableAlias(this=exp.Identifier(this="__lineage_final__")))
|
|
66
|
+
|
|
67
|
+
# Clear ALL args from root
|
|
68
|
+
for key in list(root_select.args.keys()):
|
|
69
|
+
root_select.set(key, None)
|
|
70
|
+
|
|
71
|
+
# Add wrapper to with_clause or create new one
|
|
72
|
+
if with_clause:
|
|
73
|
+
with_clause.expressions.append(wrapper_cte)
|
|
74
|
+
else:
|
|
75
|
+
with_clause = exp.With(expressions=[wrapper_cte])
|
|
76
|
+
|
|
77
|
+
# Rebuild root SELECT with only what we need
|
|
78
|
+
root_select.set("with_", with_clause)
|
|
79
|
+
root_select.set("expressions", [exp.Column(this=exp.Identifier(this=column_name))])
|
|
80
|
+
root_select.set("from_", exp.From(this=exp.Table(this=exp.Identifier(this="__lineage_final__"))))
|
|
81
|
+
|
|
82
|
+
return ast
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
# Export for testing
|
|
86
|
+
__all__ = [
|
|
87
|
+
"implementation",
|
|
88
|
+
"get_column_lineage",
|
|
89
|
+
"_prepare_model_analysis",
|
|
90
|
+
"_analyze_column_lineage",
|
|
91
|
+
"_build_schema_mapping",
|
|
92
|
+
"_resolve_output_columns",
|
|
93
|
+
"_resolve_wildcard_column_in_table",
|
|
94
|
+
"_resolve_unresolved_table_reference",
|
|
95
|
+
"_extract_transformations_with_sources",
|
|
96
|
+
"_extract_dependencies_from_lineage",
|
|
97
|
+
"_format_lineage_response",
|
|
98
|
+
"_map_dbt_adapter_to_sqlglot_dialect",
|
|
99
|
+
]
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
def _map_dbt_adapter_to_sqlglot_dialect(adapter_type: str) -> str:
|
|
103
|
+
"""Map dbt adapter type to sqlglot dialect.
|
|
104
|
+
|
|
105
|
+
Args:
|
|
106
|
+
adapter_type: The dbt adapter type from manifest metadata
|
|
107
|
+
|
|
108
|
+
Returns:
|
|
109
|
+
sqlglot dialect name
|
|
110
|
+
"""
|
|
111
|
+
# Direct matches between dbt adapter and sqlglot dialect
|
|
112
|
+
ADAPTER_TO_DIALECT = {
|
|
113
|
+
"athena": "athena",
|
|
114
|
+
"bigquery": "bigquery",
|
|
115
|
+
"clickhouse": "clickhouse",
|
|
116
|
+
"databricks": "databricks",
|
|
117
|
+
"doris": "doris",
|
|
118
|
+
"dremio": "dremio",
|
|
119
|
+
"duckdb": "duckdb",
|
|
120
|
+
"fabric": "fabric",
|
|
121
|
+
"hive": "hive",
|
|
122
|
+
"materialize": "materialize",
|
|
123
|
+
"mysql": "mysql",
|
|
124
|
+
"oracle": "oracle",
|
|
125
|
+
"postgres": "postgres",
|
|
126
|
+
"postgresql": "postgres", # Some adapters use postgresql
|
|
127
|
+
"redshift": "redshift",
|
|
128
|
+
"risingwave": "risingwave",
|
|
129
|
+
"singlestore": "singlestore",
|
|
130
|
+
"snowflake": "snowflake",
|
|
131
|
+
"spark": "spark",
|
|
132
|
+
"sqlite": "sqlite",
|
|
133
|
+
"starrocks": "starrocks",
|
|
134
|
+
"teradata": "teradata",
|
|
135
|
+
"trino": "trino",
|
|
136
|
+
# Adapters that need dialect mapping
|
|
137
|
+
"synapse": "tsql", # Azure Synapse uses T-SQL
|
|
138
|
+
"sqlserver": "tsql", # SQL Server uses T-SQL
|
|
139
|
+
"glue": "spark", # AWS Glue uses Spark
|
|
140
|
+
"fabricspark": "spark", # Fabric Lakehouse uses Spark
|
|
141
|
+
}
|
|
142
|
+
|
|
143
|
+
adapter_lower = adapter_type.lower()
|
|
144
|
+
dialect = ADAPTER_TO_DIALECT.get(adapter_lower)
|
|
145
|
+
|
|
146
|
+
if dialect:
|
|
147
|
+
logger.debug(f"Mapped dbt adapter '{adapter_type}' to sqlglot dialect '{dialect}'")
|
|
148
|
+
return dialect
|
|
149
|
+
|
|
150
|
+
# Fallback: use adapter type as-is (might work for some cases)
|
|
151
|
+
logger.warning(f"No explicit mapping for dbt adapter '{adapter_type}', using as-is for sqlglot")
|
|
152
|
+
return adapter_lower
|
|
153
|
+
|
|
154
|
+
|
|
155
|
+
# ========== Unified Helpers (Direction-Agnostic) ==========
|
|
156
|
+
# These helpers are used by both upstream and downstream lineage analysis,
|
|
157
|
+
# ensuring consistent behavior regardless of direction:
|
|
158
|
+
#
|
|
159
|
+
# 1. _prepare_model_analysis(): Gets resource_info + compiled_sql + schema_mapping
|
|
160
|
+
# 2. _analyze_column_lineage(): Runs sqlglot lineage with consistent error handling
|
|
161
|
+
# 3. _resolve_output_columns(): Resolves output columns with fallback order
|
|
162
|
+
#
|
|
163
|
+
# This unified approach ensures:
|
|
164
|
+
# - Same SQL parsing logic for both directions
|
|
165
|
+
# - Consistent error handling and logging
|
|
166
|
+
# - No code duplication
|
|
167
|
+
# - Easier to maintain and test
|
|
168
|
+
|
|
169
|
+
|
|
170
|
+
def _build_schema_mapping(manifest: ManifestLoader, upstream_lineage: dict[str, Any]) -> dict[str, Any]:
|
|
171
|
+
"""Build schema mapping from upstream models for sqlglot.
|
|
172
|
+
|
|
173
|
+
Args:
|
|
174
|
+
manifest: ManifestLoader instance
|
|
175
|
+
upstream_lineage: Upstream lineage dict from manifest.get_lineage()
|
|
176
|
+
|
|
177
|
+
Returns:
|
|
178
|
+
Schema mapping in format: {database: {schema: {table: {column: type}}}}
|
|
179
|
+
"""
|
|
180
|
+
schema_mapping: dict[str, Any] = {}
|
|
181
|
+
|
|
182
|
+
if "upstream" not in upstream_lineage:
|
|
183
|
+
return schema_mapping
|
|
184
|
+
|
|
185
|
+
for upstream_node in upstream_lineage["upstream"]:
|
|
186
|
+
try:
|
|
187
|
+
# Use node name, not unique_id (get_resource_info expects name)
|
|
188
|
+
node_name = upstream_node.get("name")
|
|
189
|
+
unique_id = upstream_node.get("unique_id")
|
|
190
|
+
if not node_name or not unique_id:
|
|
191
|
+
continue
|
|
192
|
+
|
|
193
|
+
# Extract resource type from unique_id (e.g. "seed.project.name" -> "seed")
|
|
194
|
+
resource_type = unique_id.split(".", 1)[0] if "." in unique_id else "model"
|
|
195
|
+
|
|
196
|
+
node_info = manifest.get_resource_info(node_name, resource_type=resource_type, include_database_schema=True, include_compiled_sql=False)
|
|
197
|
+
|
|
198
|
+
database = node_info.get("database", "").lower()
|
|
199
|
+
schema = node_info.get("schema", "").lower()
|
|
200
|
+
# For sources, use identifier; for models/seeds, use alias or name
|
|
201
|
+
table = node_info.get("identifier") or node_info.get("alias") or node_info.get("name", "").lower()
|
|
202
|
+
|
|
203
|
+
if database and schema and table:
|
|
204
|
+
# Add columns with their types
|
|
205
|
+
columns = node_info.get("database_columns", [])
|
|
206
|
+
|
|
207
|
+
if not columns:
|
|
208
|
+
manifest_columns = node_info.get("columns", {})
|
|
209
|
+
if manifest_columns:
|
|
210
|
+
columns = {col_name: {"type": (col_info.get("data_type") or col_info.get("type") or "string")} for col_name, col_info in manifest_columns.items()}
|
|
211
|
+
|
|
212
|
+
if not columns:
|
|
213
|
+
column_map = {col_info.get("col_name", "").lower(): col_info.get("type", "string").lower() for col_info in columns}
|
|
214
|
+
else:
|
|
215
|
+
# Dict format: {"customer_id": {"type": "INTEGER"}}
|
|
216
|
+
column_map = {col_name.lower(): col_info.get("type", "string").lower() for col_name, col_info in columns.items()}
|
|
217
|
+
|
|
218
|
+
if not column_map:
|
|
219
|
+
continue
|
|
220
|
+
|
|
221
|
+
if database not in schema_mapping:
|
|
222
|
+
schema_mapping[database] = {}
|
|
223
|
+
if schema not in schema_mapping[database]:
|
|
224
|
+
schema_mapping[database][schema] = {}
|
|
225
|
+
|
|
226
|
+
schema_mapping[database][schema][table] = column_map
|
|
227
|
+
except Exception as e:
|
|
228
|
+
logger.warning(f"Could not load schema for upstream node {upstream_node.get('unique_id')}: {e}")
|
|
229
|
+
continue
|
|
230
|
+
|
|
231
|
+
return schema_mapping
|
|
232
|
+
|
|
233
|
+
|
|
234
|
+
def _find_table_columns(schema_mapping: dict[str, Any], table_name: str) -> list[str]:
|
|
235
|
+
"""Find column names for a table in the schema mapping.
|
|
236
|
+
|
|
237
|
+
Args:
|
|
238
|
+
schema_mapping: Schema mapping in format {db: {schema: {table: {column: type}}}}
|
|
239
|
+
table_name: Table name to look up
|
|
240
|
+
|
|
241
|
+
Returns:
|
|
242
|
+
List of column names for the table
|
|
243
|
+
"""
|
|
244
|
+
table_lower = table_name.lower()
|
|
245
|
+
for database_mapping in schema_mapping.values():
|
|
246
|
+
for schema_mapping_for_db in database_mapping.values():
|
|
247
|
+
if table_lower in schema_mapping_for_db:
|
|
248
|
+
return list(schema_mapping_for_db[table_lower].keys())
|
|
249
|
+
return []
|
|
250
|
+
|
|
251
|
+
|
|
252
|
+
def _normalize_relation_name(value: str) -> str:
|
|
253
|
+
"""Normalize relation names for matching."""
|
|
254
|
+
return value.replace('"', "").replace("`", "").replace("[", "").replace("]", "").strip().lower()
|
|
255
|
+
|
|
256
|
+
|
|
257
|
+
def _add_relation_keys(lookup: dict[str, str], database: str | None, schema: str | None, identifier: str | None, unique_id: str) -> None:
|
|
258
|
+
if not identifier:
|
|
259
|
+
return
|
|
260
|
+
|
|
261
|
+
if database and schema:
|
|
262
|
+
lookup[_normalize_relation_name(f"{database}.{schema}.{identifier}")] = unique_id
|
|
263
|
+
|
|
264
|
+
if schema:
|
|
265
|
+
lookup[_normalize_relation_name(f"{schema}.{identifier}")] = unique_id
|
|
266
|
+
|
|
267
|
+
lookup[_normalize_relation_name(identifier)] = unique_id
|
|
268
|
+
|
|
269
|
+
|
|
270
|
+
def _build_relation_lookup(manifest: ManifestLoader) -> dict[str, str]:
|
|
271
|
+
"""Build relation name -> unique_id lookup from the manifest."""
|
|
272
|
+
lookup: dict[str, str] = {}
|
|
273
|
+
manifest_dict = manifest.get_manifest_dict()
|
|
274
|
+
|
|
275
|
+
for node in manifest_dict.get("nodes", {}).values():
|
|
276
|
+
if not isinstance(node, dict):
|
|
277
|
+
continue
|
|
278
|
+
|
|
279
|
+
unique_id = node.get("unique_id")
|
|
280
|
+
if not unique_id:
|
|
281
|
+
continue
|
|
282
|
+
|
|
283
|
+
relation_name = node.get("relation_name")
|
|
284
|
+
if relation_name:
|
|
285
|
+
lookup[_normalize_relation_name(relation_name)] = unique_id
|
|
286
|
+
|
|
287
|
+
database = node.get("database")
|
|
288
|
+
schema = node.get("schema")
|
|
289
|
+
identifier = node.get("alias") or node.get("name")
|
|
290
|
+
_add_relation_keys(lookup, database, schema, identifier, unique_id)
|
|
291
|
+
|
|
292
|
+
for source in manifest_dict.get("sources", {}).values():
|
|
293
|
+
if not isinstance(source, dict):
|
|
294
|
+
continue
|
|
295
|
+
|
|
296
|
+
unique_id = source.get("unique_id")
|
|
297
|
+
if not unique_id:
|
|
298
|
+
continue
|
|
299
|
+
|
|
300
|
+
relation_name = source.get("relation_name")
|
|
301
|
+
if relation_name:
|
|
302
|
+
lookup[_normalize_relation_name(relation_name)] = unique_id
|
|
303
|
+
|
|
304
|
+
database = source.get("database")
|
|
305
|
+
schema = source.get("schema")
|
|
306
|
+
identifier = source.get("identifier") or source.get("name")
|
|
307
|
+
_add_relation_keys(lookup, database, schema, identifier, unique_id)
|
|
308
|
+
|
|
309
|
+
return lookup
|
|
310
|
+
|
|
311
|
+
|
|
312
|
+
def _get_output_columns_from_sql(compiled_sql: str, schema_mapping: dict[str, Any], dialect: str = "databricks") -> list[str]:
|
|
313
|
+
"""Extract output column names from compiled SQL.
|
|
314
|
+
|
|
315
|
+
Handles SELECT * from a single CTE or table by expanding from schema mapping
|
|
316
|
+
or the CTE's projections.
|
|
317
|
+
|
|
318
|
+
Args:
|
|
319
|
+
compiled_sql: Compiled SQL string
|
|
320
|
+
schema_mapping: Schema mapping for table expansion
|
|
321
|
+
dialect: SQL dialect for parsing (default: databricks)
|
|
322
|
+
|
|
323
|
+
Returns:
|
|
324
|
+
List of output column names
|
|
325
|
+
"""
|
|
326
|
+
try:
|
|
327
|
+
ast = parse_one(compiled_sql, dialect=dialect)
|
|
328
|
+
except Exception:
|
|
329
|
+
return []
|
|
330
|
+
|
|
331
|
+
root_scope = build_scope(ast)
|
|
332
|
+
if not root_scope:
|
|
333
|
+
return []
|
|
334
|
+
|
|
335
|
+
select = root_scope.expression if isinstance(root_scope.expression, exp.Select) else root_scope.expression.find(exp.Select)
|
|
336
|
+
if not select:
|
|
337
|
+
return []
|
|
338
|
+
|
|
339
|
+
projections = list(select.expressions)
|
|
340
|
+
if projections and all(isinstance(p, exp.Star) for p in projections):
|
|
341
|
+
if len(root_scope.selected_sources) == 1:
|
|
342
|
+
_, (_, source) = next(iter(root_scope.selected_sources.items()))
|
|
343
|
+
if isinstance(source, exp.Table):
|
|
344
|
+
return _find_table_columns(schema_mapping, source.name)
|
|
345
|
+
|
|
346
|
+
# CTE or subquery scope
|
|
347
|
+
if hasattr(source, "expression"):
|
|
348
|
+
cte_select = source.expression if isinstance(source.expression, exp.Select) else source.expression.find(exp.Select)
|
|
349
|
+
if cte_select:
|
|
350
|
+
return [proj.alias_or_name for proj in cte_select.expressions if proj.alias_or_name]
|
|
351
|
+
return []
|
|
352
|
+
|
|
353
|
+
return [proj.alias_or_name for proj in projections if proj.alias_or_name]
|
|
354
|
+
|
|
355
|
+
|
|
356
|
+
def _resolve_output_columns(
|
|
357
|
+
compiled_sql: str,
|
|
358
|
+
schema_mapping: dict[str, Any],
|
|
359
|
+
resource_info: dict[str, Any],
|
|
360
|
+
dialect: str = "databricks",
|
|
361
|
+
) -> tuple[dict[str, Any], str]:
|
|
362
|
+
"""Resolve output columns for a model using resource-specific strategies.
|
|
363
|
+
|
|
364
|
+
For Sources & Seeds (Database Tables):
|
|
365
|
+
1) Warehouse columns (database_columns) - database is truth
|
|
366
|
+
2) Manifest columns (schema.yml) - fallback
|
|
367
|
+
3) Wildcard "*" - table not built yet, but table reference known
|
|
368
|
+
|
|
369
|
+
For Models (SQL Transformations):
|
|
370
|
+
1) SQL-derived projections ONLY - compiled SQL is truth
|
|
371
|
+
(No fallbacks - if SQL parsing fails, no columns available)
|
|
372
|
+
|
|
373
|
+
Args:
|
|
374
|
+
compiled_sql: Compiled SQL string
|
|
375
|
+
schema_mapping: Schema mapping for table expansion
|
|
376
|
+
resource_info: Resource information dict
|
|
377
|
+
dialect: SQL dialect for parsing (default: databricks)
|
|
378
|
+
|
|
379
|
+
Returns:
|
|
380
|
+
(output_columns_dict, source_label)
|
|
381
|
+
"""
|
|
382
|
+
output_columns_dict: dict[str, Any] = {}
|
|
383
|
+
|
|
384
|
+
# For sources AND seeds: database-first with wildcard fallback
|
|
385
|
+
if resource_info.get("resource_type") in ["source", "seed"]:
|
|
386
|
+
database_columns = resource_info.get("database_columns", [])
|
|
387
|
+
if isinstance(database_columns, list) and database_columns:
|
|
388
|
+
output_columns_dict = {col.get("col_name", ""): {} for col in database_columns if col.get("col_name")}
|
|
389
|
+
if output_columns_dict:
|
|
390
|
+
return output_columns_dict, "warehouse"
|
|
391
|
+
elif isinstance(database_columns, dict) and database_columns:
|
|
392
|
+
return {col_name: {} for col_name in database_columns.keys()}, "warehouse"
|
|
393
|
+
|
|
394
|
+
# Fallback to manifest columns
|
|
395
|
+
manifest_columns = resource_info.get("columns", {})
|
|
396
|
+
output_columns_dict = {col_name: {} for col_name in manifest_columns.keys()}
|
|
397
|
+
if output_columns_dict:
|
|
398
|
+
return output_columns_dict, "manifest"
|
|
399
|
+
|
|
400
|
+
# Final fallback: table not built yet, return wildcard
|
|
401
|
+
return {"*": {}}, "wildcard"
|
|
402
|
+
|
|
403
|
+
# For models/seeds: SQL parsing ONLY (compiled SQL is truth)
|
|
404
|
+
output_columns = _get_output_columns_from_sql(compiled_sql, schema_mapping, dialect)
|
|
405
|
+
if output_columns:
|
|
406
|
+
return {col: {} for col in output_columns}, "sql"
|
|
407
|
+
|
|
408
|
+
return {}, "none"
|
|
409
|
+
|
|
410
|
+
|
|
411
|
+
def _resolve_wildcard_column_in_table(table_name: str, schema_mapping: dict[str, Any]) -> str | None:
|
|
412
|
+
"""Resolve wildcard (*) column to specific column name within a known table.
|
|
413
|
+
|
|
414
|
+
When we have a resolved table but wildcard column (e.g., "column": "*", "table": "stg_customers"),
|
|
415
|
+
this function attempts to find the specific column being traced in that table.
|
|
416
|
+
|
|
417
|
+
Args:
|
|
418
|
+
table_name: Name of the table (without quotes)
|
|
419
|
+
schema_mapping: Schema mapping with table and column information
|
|
420
|
+
|
|
421
|
+
Returns:
|
|
422
|
+
Resolved column name if found, None otherwise
|
|
423
|
+
"""
|
|
424
|
+
if table_name not in schema_mapping:
|
|
425
|
+
return None
|
|
426
|
+
|
|
427
|
+
table_info = schema_mapping[table_name]
|
|
428
|
+
resource_type = table_info.get("resource_type", "")
|
|
429
|
+
|
|
430
|
+
# For sources and seeds: use database-first approach
|
|
431
|
+
if resource_type in ["source", "seed"]:
|
|
432
|
+
# Check database columns first
|
|
433
|
+
database_columns = table_info.get("database_columns", [])
|
|
434
|
+
if isinstance(database_columns, list) and database_columns:
|
|
435
|
+
# Return first database column as representative
|
|
436
|
+
first_col = database_columns[0]
|
|
437
|
+
if isinstance(first_col, dict):
|
|
438
|
+
return first_col.get("col_name")
|
|
439
|
+
elif isinstance(database_columns, dict) and database_columns:
|
|
440
|
+
# Return first key from database columns
|
|
441
|
+
return next(iter(database_columns.keys()))
|
|
442
|
+
|
|
443
|
+
# Fallback to manifest columns
|
|
444
|
+
manifest_columns = table_info.get("columns", {})
|
|
445
|
+
if manifest_columns:
|
|
446
|
+
return next(iter(manifest_columns.keys()))
|
|
447
|
+
|
|
448
|
+
# For models: would use SQL parsing (handled elsewhere)
|
|
449
|
+
return None
|
|
450
|
+
|
|
451
|
+
|
|
452
|
+
def _resolve_unresolved_table_reference(col_name: str, schema_mapping: dict[str, Any]) -> dict[str, Any] | None:
|
|
453
|
+
"""Resolve unresolved table references using schema mapping and database-first approach.
|
|
454
|
+
|
|
455
|
+
When sqlglot can't resolve a table reference (returns None), this function
|
|
456
|
+
attempts to find which table in the schema mapping contains the requested column,
|
|
457
|
+
using database-first resolution for sources and seeds.
|
|
458
|
+
|
|
459
|
+
Args:
|
|
460
|
+
col_name: Column name to search for
|
|
461
|
+
schema_mapping: Schema mapping with table and column information
|
|
462
|
+
|
|
463
|
+
Returns:
|
|
464
|
+
Dictionary with resolved table info if found, None otherwise:
|
|
465
|
+
{
|
|
466
|
+
"table_name": str,
|
|
467
|
+
"database": str | None,
|
|
468
|
+
"schema": str | None,
|
|
469
|
+
"resolved_column": str | None # If resolved from wildcard
|
|
470
|
+
}
|
|
471
|
+
"""
|
|
472
|
+
for table_name, table_info in schema_mapping.items():
|
|
473
|
+
resource_type = table_info.get("resource_type", "")
|
|
474
|
+
|
|
475
|
+
# For sources and seeds: use database-first approach
|
|
476
|
+
if resource_type in ["source", "seed"]:
|
|
477
|
+
# Check database columns first
|
|
478
|
+
database_columns = table_info.get("database_columns", [])
|
|
479
|
+
if isinstance(database_columns, list):
|
|
480
|
+
for db_col in database_columns:
|
|
481
|
+
if isinstance(db_col, dict) and db_col.get("col_name", "").lower() == col_name.lower():
|
|
482
|
+
return {"table_name": f'"{table_name}"', "database": table_info.get("database"), "schema": table_info.get("schema"), "resolved_column": db_col.get("col_name")}
|
|
483
|
+
elif isinstance(database_columns, dict):
|
|
484
|
+
if col_name.lower() in [k.lower() for k in database_columns.keys()]:
|
|
485
|
+
# Find the actual key with proper case
|
|
486
|
+
actual_col = next(k for k in database_columns.keys() if k.lower() == col_name.lower())
|
|
487
|
+
return {"table_name": f'"{table_name}"', "database": table_info.get("database"), "schema": table_info.get("schema"), "resolved_column": actual_col}
|
|
488
|
+
|
|
489
|
+
# Fallback to manifest columns for sources/seeds
|
|
490
|
+
manifest_columns = table_info.get("columns", {})
|
|
491
|
+
if col_name.lower() in [k.lower() for k in manifest_columns.keys()]:
|
|
492
|
+
actual_col = next(k for k in manifest_columns.keys() if k.lower() == col_name.lower())
|
|
493
|
+
return {"table_name": f'"{table_name}"', "database": table_info.get("database"), "schema": table_info.get("schema"), "resolved_column": actual_col}
|
|
494
|
+
|
|
495
|
+
return None
|
|
496
|
+
|
|
497
|
+
|
|
498
|
+
def _extract_transformations_with_sources(lineage_node: Any) -> list[dict[str, Any]]:
|
|
499
|
+
"""Extract transformations with namespaced IDs, types, and source references.
|
|
500
|
+
|
|
501
|
+
Uses a two-pass approach:
|
|
502
|
+
1. First pass: Create all transformations and build lookup map
|
|
503
|
+
2. Second pass: Extract sources by parsing expressions for CTE/table references
|
|
504
|
+
|
|
505
|
+
Args:
|
|
506
|
+
lineage_node: Result from sqlglot.lineage()
|
|
507
|
+
|
|
508
|
+
Returns:
|
|
509
|
+
List of transformation dicts with structure:
|
|
510
|
+
{
|
|
511
|
+
"id": "cte:name" or "table:name", # Namespaced identifier
|
|
512
|
+
"type": "cte" or "table", # Type of transformation
|
|
513
|
+
"column": "column_name", # Column being transformed
|
|
514
|
+
"expression": "...", # Optional: how it's computed (if not pass-through)
|
|
515
|
+
"sources": ["cte:other", ...] # Upstream dependencies
|
|
516
|
+
}
|
|
517
|
+
"""
|
|
518
|
+
transform_map: dict[str, dict[str, Any]] = {} # id -> transform dict
|
|
519
|
+
cte_to_id: dict[str, str] = {} # cte_name -> transform id (for lookup)
|
|
520
|
+
nodes_with_data: list[tuple[str, Any, str]] = [] # (transform_id, node, expression)
|
|
521
|
+
outer_query_sources: set[str] = set() # Track what the outer query references
|
|
522
|
+
union_branches: dict[str, list[dict[str, Any]]] = {} # reference_node_name -> list of branch info
|
|
523
|
+
|
|
524
|
+
# FIRST PASS: Create all transformations and build lookup map
|
|
525
|
+
for node in lineage_node.walk():
|
|
526
|
+
if not hasattr(node, "name") or not node.name:
|
|
527
|
+
continue
|
|
528
|
+
|
|
529
|
+
# Handle numeric nodes (UNION branches)
|
|
530
|
+
if "." not in node.name:
|
|
531
|
+
# Check if this is a UNION branch node (has reference_node_name attribute)
|
|
532
|
+
if hasattr(node, "reference_node_name") and node.reference_node_name:
|
|
533
|
+
ref_cte = node.reference_node_name
|
|
534
|
+
|
|
535
|
+
# Extract expression and sources from this branch
|
|
536
|
+
branch_info: dict[str, Any] = {}
|
|
537
|
+
if hasattr(node, "expression") and node.expression:
|
|
538
|
+
expr_str = str(node.expression)
|
|
539
|
+
branch_info["expression"] = expr_str if len(expr_str) <= 200 else expr_str[:197] + "..."
|
|
540
|
+
|
|
541
|
+
# Extract column name (will be used for the union transformation)
|
|
542
|
+
# Expression is like "all_segments_union.customer_id AS customer_id"
|
|
543
|
+
if " AS " in expr_str:
|
|
544
|
+
branch_info["column"] = expr_str.split(" AS ")[-1].strip()
|
|
545
|
+
|
|
546
|
+
# Store full expression for sources extraction later
|
|
547
|
+
branch_info["_full_expr"] = expr_str
|
|
548
|
+
|
|
549
|
+
# Group branches by their parent CTE
|
|
550
|
+
union_branches.setdefault(ref_cte, []).append(branch_info)
|
|
551
|
+
continue
|
|
552
|
+
|
|
553
|
+
parts = node.name.split(".", 1)
|
|
554
|
+
cte_or_table = parts[0]
|
|
555
|
+
col_name = parts[1] if len(parts) > 1 else node.name
|
|
556
|
+
|
|
557
|
+
# Special handling for wrapper CTE - extract what the outer query references
|
|
558
|
+
if cte_or_table == "__lineage_final__":
|
|
559
|
+
# Extract what __lineage_final__ references (could be CTE or table)
|
|
560
|
+
if hasattr(node, "expression") and node.expression:
|
|
561
|
+
expr_str = str(node.expression)
|
|
562
|
+
# Expression will be like "final.customer_id" (CTE) or "si.gold_itemkey" (table alias)
|
|
563
|
+
# Look for references in the expression
|
|
564
|
+
for potential_ref in expr_str.split():
|
|
565
|
+
if "." in potential_ref:
|
|
566
|
+
# Could be "final.customer_id" or "si.gold_itemkey"
|
|
567
|
+
ref_name = potential_ref.split(".")[0].strip("(),")
|
|
568
|
+
if ref_name and not ref_name.isdigit():
|
|
569
|
+
# Try as CTE first, then as table
|
|
570
|
+
# We'll resolve the actual type in the second pass
|
|
571
|
+
outer_query_sources.add(ref_name)
|
|
572
|
+
continue
|
|
573
|
+
|
|
574
|
+
source_type = type(node.source).__name__ if hasattr(node, "source") else None
|
|
575
|
+
|
|
576
|
+
# Get the actual CTE/table name (not the alias)
|
|
577
|
+
# node.name might be "alias.column" but node.source.this has the real name
|
|
578
|
+
actual_cte_or_table = cte_or_table # default to parsed name
|
|
579
|
+
if hasattr(node.source, "this") and node.source.this:
|
|
580
|
+
actual_cte_or_table = str(node.source.this)
|
|
581
|
+
|
|
582
|
+
# Create namespaced ID using the actual CTE/table name
|
|
583
|
+
if source_type == "Table":
|
|
584
|
+
transform_id = f"table:{actual_cte_or_table}"
|
|
585
|
+
transform_type = "table"
|
|
586
|
+
else:
|
|
587
|
+
transform_id = f"cte:{actual_cte_or_table}"
|
|
588
|
+
transform_type = "cte"
|
|
589
|
+
|
|
590
|
+
# Track CTE name -> ID for lookup in second pass (use actual name, not alias)
|
|
591
|
+
cte_to_id[actual_cte_or_table] = transform_id
|
|
592
|
+
|
|
593
|
+
# Skip if we've already processed this transformation
|
|
594
|
+
if transform_id in transform_map:
|
|
595
|
+
# But save node data for sources extraction in second pass
|
|
596
|
+
if hasattr(node, "expression") and node.expression:
|
|
597
|
+
nodes_with_data.append((transform_id, node, str(node.expression)))
|
|
598
|
+
continue
|
|
599
|
+
|
|
600
|
+
# Build transformation
|
|
601
|
+
transform: dict[str, Any] = {
|
|
602
|
+
"id": transform_id,
|
|
603
|
+
"type": transform_type,
|
|
604
|
+
"column": col_name,
|
|
605
|
+
}
|
|
606
|
+
|
|
607
|
+
# Add expression if present and meaningful
|
|
608
|
+
if hasattr(node, "expression") and node.expression:
|
|
609
|
+
expr_sql = str(node.expression)
|
|
610
|
+
if expr_sql and expr_sql.strip() != col_name and len(expr_sql) <= 200:
|
|
611
|
+
transform["expression"] = expr_sql
|
|
612
|
+
elif len(expr_sql) > 200:
|
|
613
|
+
transform["expression"] = expr_sql[:197] + "..."
|
|
614
|
+
|
|
615
|
+
transform_map[transform_id] = transform
|
|
616
|
+
|
|
617
|
+
# Save for second pass
|
|
618
|
+
if hasattr(node, "expression") and node.expression:
|
|
619
|
+
expr_sql_str = str(node.expression)
|
|
620
|
+
nodes_with_data.append((transform_id, node, expr_sql_str if len(expr_sql_str) <= 200 else expr_sql_str[:200]))
|
|
621
|
+
|
|
622
|
+
# Process UNION branches in two passes
|
|
623
|
+
# FIRST: Register all union CTEs in cte_to_id
|
|
624
|
+
for ref_cte, branches in union_branches.items():
|
|
625
|
+
if not branches:
|
|
626
|
+
continue
|
|
627
|
+
transform_id = f"cte:{ref_cte}"
|
|
628
|
+
cte_to_id[ref_cte] = transform_id
|
|
629
|
+
|
|
630
|
+
# SECOND: Create union transformations with sources
|
|
631
|
+
for ref_cte, branches in union_branches.items():
|
|
632
|
+
if not branches:
|
|
633
|
+
continue
|
|
634
|
+
|
|
635
|
+
# Create union transformation
|
|
636
|
+
transform_id = f"cte:{ref_cte}"
|
|
637
|
+
column = branches[0].get("column", "") # All branches have same column
|
|
638
|
+
|
|
639
|
+
# Build branches array with sources
|
|
640
|
+
formatted_branches: list[dict[str, Any]] = []
|
|
641
|
+
for branch_info in branches:
|
|
642
|
+
branch_entry: dict[str, Any] = {}
|
|
643
|
+
if "expression" in branch_info:
|
|
644
|
+
branch_entry["expression"] = branch_info["expression"]
|
|
645
|
+
|
|
646
|
+
# Extract sources for this branch (now all unions are in cte_to_id)
|
|
647
|
+
source_ids: set[str] = set()
|
|
648
|
+
full_expr = branch_info.get("_full_expr", "")
|
|
649
|
+
if full_expr:
|
|
650
|
+
for cte_name in cte_to_id.keys():
|
|
651
|
+
if f"{cte_name}." in full_expr:
|
|
652
|
+
source_ids.add(cte_to_id[cte_name])
|
|
653
|
+
|
|
654
|
+
branch_entry["sources"] = sorted(list(source_ids))
|
|
655
|
+
formatted_branches.append(branch_entry)
|
|
656
|
+
|
|
657
|
+
# Create the union transformation
|
|
658
|
+
union_transform: dict[str, Any] = {
|
|
659
|
+
"id": transform_id,
|
|
660
|
+
"type": "union",
|
|
661
|
+
"column": column,
|
|
662
|
+
"branches": formatted_branches,
|
|
663
|
+
}
|
|
664
|
+
|
|
665
|
+
transform_map[transform_id] = union_transform
|
|
666
|
+
|
|
667
|
+
# SECOND PASS: Extract sources by looking for CTE references in expressions
|
|
668
|
+
sources_map: dict[str, set[str]] = {} # transform_id -> set of source ids
|
|
669
|
+
|
|
670
|
+
for transform_id, node, expr_str in nodes_with_data:
|
|
671
|
+
source_ids: set[str] = set()
|
|
672
|
+
|
|
673
|
+
# If it's a wildcard (SELECT *), get the table from the FROM clause
|
|
674
|
+
if expr_str.strip() == "*" and hasattr(node.source, "find"):
|
|
675
|
+
table_node = node.source.find(exp.Table)
|
|
676
|
+
if table_node and hasattr(table_node, "this"):
|
|
677
|
+
table_name = str(table_node.this)
|
|
678
|
+
source_ids.add(f"table:{table_name}")
|
|
679
|
+
|
|
680
|
+
# Look for CTE/table references in the expression (e.g., "orders.order_id")
|
|
681
|
+
for cte_name, cte_id in cte_to_id.items():
|
|
682
|
+
# Don't mark self-reference as a source
|
|
683
|
+
if cte_id != transform_id and f"{cte_name}." in expr_str:
|
|
684
|
+
source_ids.add(cte_id)
|
|
685
|
+
|
|
686
|
+
# Also check if expression uses an alias - resolve to actual CTE/table
|
|
687
|
+
# e.g., expression "seg.customer_id" where seg is alias for all_segments
|
|
688
|
+
if "." in expr_str and hasattr(node, "source"):
|
|
689
|
+
# Extract the prefix (potential alias)
|
|
690
|
+
alias_candidate = expr_str.split(".")[0].strip()
|
|
691
|
+
|
|
692
|
+
# Try to resolve it from the source
|
|
693
|
+
if hasattr(node.source, "find"):
|
|
694
|
+
table_node = node.source.find(exp.Table)
|
|
695
|
+
if table_node:
|
|
696
|
+
actual_name = None
|
|
697
|
+
# Get the actual table/CTE name
|
|
698
|
+
if hasattr(table_node, "this") and table_node.this:
|
|
699
|
+
actual_name = str(table_node.this)
|
|
700
|
+
|
|
701
|
+
# Get the alias
|
|
702
|
+
if hasattr(table_node, "alias_or_name"):
|
|
703
|
+
alias = str(table_node.alias_or_name)
|
|
704
|
+
|
|
705
|
+
# If expression uses this alias, resolve to actual name
|
|
706
|
+
# Include alias in output as cte:name[alias] or table:name[alias]
|
|
707
|
+
if alias == alias_candidate and actual_name and actual_name in cte_to_id:
|
|
708
|
+
base_id = cte_to_id[actual_name]
|
|
709
|
+
# Add alias notation if alias differs from actual name
|
|
710
|
+
if alias != actual_name:
|
|
711
|
+
source_ids.add(f"{base_id}[{alias}]")
|
|
712
|
+
else:
|
|
713
|
+
source_ids.add(base_id)
|
|
714
|
+
|
|
715
|
+
# Merge sources for this transformation (handle duplicates from multiple nodes)
|
|
716
|
+
sources_map.setdefault(transform_id, set()).update(source_ids)
|
|
717
|
+
|
|
718
|
+
# Build final list with sources
|
|
719
|
+
transformations: list[dict[str, Any]] = []
|
|
720
|
+
for trans in transform_map.values():
|
|
721
|
+
# Union types already have sources in branches, don't add top-level sources
|
|
722
|
+
if trans.get("type") != "union":
|
|
723
|
+
trans["sources"] = sorted(list(sources_map.get(trans["id"], set())))
|
|
724
|
+
transformations.append(trans)
|
|
725
|
+
|
|
726
|
+
# Add outer query transformation if we detected it
|
|
727
|
+
if outer_query_sources:
|
|
728
|
+
# Extract column name from any transformation (they all have the same column)
|
|
729
|
+
column_for_query = next((t.get("column") for t in transformations), "")
|
|
730
|
+
|
|
731
|
+
# Resolve outer query source references to namespaced IDs
|
|
732
|
+
resolved_sources: list[str] = []
|
|
733
|
+
for ref_name in outer_query_sources:
|
|
734
|
+
# Check if it's a known CTE or table
|
|
735
|
+
if ref_name in cte_to_id:
|
|
736
|
+
resolved_sources.append(cte_to_id[ref_name])
|
|
737
|
+
else:
|
|
738
|
+
# Not in cte_to_id, so it must be a table reference
|
|
739
|
+
# Use the same logic as in FIRST PASS to determine if it's a table
|
|
740
|
+
resolved_sources.append(f"table:{ref_name}")
|
|
741
|
+
|
|
742
|
+
outer_query: dict[str, Any] = {
|
|
743
|
+
"id": "query",
|
|
744
|
+
"type": "outer_query",
|
|
745
|
+
"column": column_for_query,
|
|
746
|
+
"sources": sorted(resolved_sources),
|
|
747
|
+
}
|
|
748
|
+
# Insert at the beginning (outer query is first in the chain)
|
|
749
|
+
transformations.insert(0, outer_query)
|
|
750
|
+
|
|
751
|
+
return transformations
|
|
752
|
+
|
|
753
|
+
|
|
754
|
+
def _extract_dependencies_from_lineage(
|
|
755
|
+
lineage_node: Any,
|
|
756
|
+
manifest: ManifestLoader | None,
|
|
757
|
+
schema_mapping: dict[str, Any],
|
|
758
|
+
depth: int | None,
|
|
759
|
+
column_name: str = "",
|
|
760
|
+
) -> list[dict[str, Any]]:
|
|
761
|
+
"""Extract column dependencies from sqlglot lineage node.
|
|
762
|
+
|
|
763
|
+
Args:
|
|
764
|
+
lineage_node: Result from sqlglot.lineage()
|
|
765
|
+
manifest: Optional ManifestLoader for dbt resource lookup
|
|
766
|
+
schema_mapping: Schema mapping for table and column resolution
|
|
767
|
+
depth: Maximum depth to traverse
|
|
768
|
+
column_name: Original column being traced (used to replace wildcards)
|
|
769
|
+
|
|
770
|
+
Returns:
|
|
771
|
+
List of dependency dicts with column, table, optional dbt_resource,
|
|
772
|
+
and internal CTE transformation path (via_ctes, transformations)
|
|
773
|
+
"""
|
|
774
|
+
dependencies: list[dict[str, Any]] = []
|
|
775
|
+
last_cte_column: dict[str, str] = {} # Track last CTE column before each table
|
|
776
|
+
|
|
777
|
+
def walk_dependencies(node: Any, depth_current: int = 0) -> None:
|
|
778
|
+
"""Recursively walk lineage tree."""
|
|
779
|
+
if depth is not None and depth_current >= depth:
|
|
780
|
+
return
|
|
781
|
+
|
|
782
|
+
for dep in node.walk():
|
|
783
|
+
# Check if this is a table reference (external dependency)
|
|
784
|
+
if hasattr(dep, "source") and hasattr(dep.source, "this"):
|
|
785
|
+
table_name = str(dep.source.this)
|
|
786
|
+
col_name = dep.name
|
|
787
|
+
|
|
788
|
+
# Skip our artificial wrapper CTE
|
|
789
|
+
if table_name == "__lineage_final__":
|
|
790
|
+
continue
|
|
791
|
+
|
|
792
|
+
# Track CTE columns (non-Table sources)
|
|
793
|
+
source_type = type(dep.source).__name__
|
|
794
|
+
if source_type != "Table":
|
|
795
|
+
# This is a CTE/Select - track the column name for the next Table dependency
|
|
796
|
+
# Extract column name from CTE-qualified references like "orders.order_id"
|
|
797
|
+
if "." in col_name and col_name != "*":
|
|
798
|
+
parts = col_name.split(".")
|
|
799
|
+
cte_col = parts[-1]
|
|
800
|
+
# Store the most recent non-wildcard column for the next table we'll encounter
|
|
801
|
+
last_cte_column["__next_table__"] = cte_col
|
|
802
|
+
|
|
803
|
+
# Extract from_table if the Select has a FROM clause with a table
|
|
804
|
+
if hasattr(dep.source, "find") and callable(dep.source.find):
|
|
805
|
+
from_table_node = dep.source.find(exp.Table)
|
|
806
|
+
if from_table_node and hasattr(from_table_node, "this") and from_table_node.this: # type: ignore[reportAttributeAccessIssue]
|
|
807
|
+
from_table = str(from_table_node.this) # type: ignore[reportAttributeAccessIssue]
|
|
808
|
+
last_cte_column["__next_table_source__"] = from_table
|
|
809
|
+
|
|
810
|
+
continue # Don't add to dependencies yet
|
|
811
|
+
|
|
812
|
+
dependency_info: dict[str, Any] = {
|
|
813
|
+
"column": col_name,
|
|
814
|
+
"table": table_name,
|
|
815
|
+
}
|
|
816
|
+
|
|
817
|
+
# Always extract db/schema metadata if available
|
|
818
|
+
db = getattr(dep.source, "catalog", None)
|
|
819
|
+
schema_name = getattr(dep.source, "db", None)
|
|
820
|
+
|
|
821
|
+
if db:
|
|
822
|
+
dependency_info["database"] = str(db)
|
|
823
|
+
if schema_name:
|
|
824
|
+
dependency_info["schema"] = str(schema_name)
|
|
825
|
+
|
|
826
|
+
# This is a Table reference - add to dependencies
|
|
827
|
+
# Extract just the column name from potentially qualified references
|
|
828
|
+
# e.g., "customers.first_name" -> "first_name", "id" -> "id"
|
|
829
|
+
final_col_name = col_name.split(".")[-1] if col_name else col_name
|
|
830
|
+
|
|
831
|
+
# Replace wildcard with the actual column being traced
|
|
832
|
+
if final_col_name == "*" and column_name:
|
|
833
|
+
final_col_name = column_name
|
|
834
|
+
|
|
835
|
+
# Check if we tracked a CTE column that feeds this table reference
|
|
836
|
+
if "__next_table__" in last_cte_column:
|
|
837
|
+
source_column = last_cte_column["__next_table__"]
|
|
838
|
+
# Use the source column as the actual column in this dependency
|
|
839
|
+
final_col_name = source_column
|
|
840
|
+
# Clear it after use
|
|
841
|
+
del last_cte_column["__next_table__"]
|
|
842
|
+
|
|
843
|
+
# Update dependency with clean column name
|
|
844
|
+
dependency_info["column"] = final_col_name
|
|
845
|
+
|
|
846
|
+
# Try to find the corresponding dbt resource
|
|
847
|
+
if manifest:
|
|
848
|
+
try:
|
|
849
|
+
matching_node = manifest.get_resource_node(table_name)
|
|
850
|
+
if not matching_node.get("multiple_matches"):
|
|
851
|
+
dependency_info["dbt_resource"] = matching_node.get("unique_id", "")
|
|
852
|
+
except Exception:
|
|
853
|
+
pass # Resource lookup failed, continue with table name
|
|
854
|
+
|
|
855
|
+
dependencies.append(dependency_info)
|
|
856
|
+
|
|
857
|
+
walk_dependencies(lineage_node)
|
|
858
|
+
|
|
859
|
+
# Deduplicate dependencies while preserving order
|
|
860
|
+
seen = set()
|
|
861
|
+
deduplicated = []
|
|
862
|
+
for dep in dependencies:
|
|
863
|
+
# Create a key for deduplication
|
|
864
|
+
key = (dep.get("column"), dep.get("table"), dep.get("database"), dep.get("schema"))
|
|
865
|
+
if key not in seen:
|
|
866
|
+
seen.add(key)
|
|
867
|
+
deduplicated.append(dep)
|
|
868
|
+
|
|
869
|
+
return deduplicated
|
|
870
|
+
|
|
871
|
+
|
|
872
|
+
def _extract_cte_path(lineage_node: Any) -> dict[str, Any]:
|
|
873
|
+
"""Extract the CTE transformation path from a lineage node.
|
|
874
|
+
|
|
875
|
+
Sqlglot lineage nodes have names like:
|
|
876
|
+
- "final.tier" (CTE reference)
|
|
877
|
+
- "aggregated.order_count" (CTE reference)
|
|
878
|
+
- "customer_id" (direct column, no CTE)
|
|
879
|
+
- "base.customer_id" (CTE reference)
|
|
880
|
+
|
|
881
|
+
The part before the dot is the CTE name. We walk up the lineage to collect
|
|
882
|
+
all CTEs in the transformation chain.
|
|
883
|
+
|
|
884
|
+
Args:
|
|
885
|
+
lineage_node: A single dependency node from sqlglot lineage walk
|
|
886
|
+
|
|
887
|
+
Returns:
|
|
888
|
+
Dictionary with:
|
|
889
|
+
- via_ctes: List of CTE names in transformation order (root to leaf)
|
|
890
|
+
- transformations: List of transformation details per step
|
|
891
|
+
"""
|
|
892
|
+
via_ctes: list[str] = []
|
|
893
|
+
transformations: list[dict[str, str]] = []
|
|
894
|
+
visited: set[int] = set() # Track visited nodes by id to prevent cycles
|
|
895
|
+
|
|
896
|
+
# Walk up the lineage chain (using downstream references)
|
|
897
|
+
current = lineage_node
|
|
898
|
+
max_iterations = 100 # Hard limit to prevent infinite loops
|
|
899
|
+
iteration = 0
|
|
900
|
+
|
|
901
|
+
while current is not None and iteration < max_iterations:
|
|
902
|
+
iteration += 1
|
|
903
|
+
|
|
904
|
+
# Prevent cycles by tracking visited nodes
|
|
905
|
+
node_id = id(current)
|
|
906
|
+
if node_id in visited:
|
|
907
|
+
break
|
|
908
|
+
visited.add(node_id)
|
|
909
|
+
|
|
910
|
+
# Extract CTE name from node name (format: "cte_name.column_name" or "column_name")
|
|
911
|
+
if hasattr(current, "name") and current.name:
|
|
912
|
+
node_name = current.name
|
|
913
|
+
|
|
914
|
+
# Check if this is a CTE reference (has a dot separator)
|
|
915
|
+
if "." in node_name:
|
|
916
|
+
parts = node_name.split(".", 1)
|
|
917
|
+
cte_name = parts[0]
|
|
918
|
+
column_name = parts[1] if len(parts) > 1 else node_name
|
|
919
|
+
|
|
920
|
+
# Only add if it's not already in the list (dedup) and looks like a CTE
|
|
921
|
+
# (CTEs typically don't have schema qualifiers like "database.schema.table")
|
|
922
|
+
if cte_name and cte_name not in via_ctes:
|
|
923
|
+
# Skip if it looks like a database or schema qualifier
|
|
924
|
+
# (these typically have catalog/db attributes on source)
|
|
925
|
+
is_table_ref = False
|
|
926
|
+
if hasattr(current, "source"):
|
|
927
|
+
source = current.source
|
|
928
|
+
is_table_ref = hasattr(source, "catalog") or getattr(source, "db", None) is not None
|
|
929
|
+
|
|
930
|
+
if not is_table_ref:
|
|
931
|
+
via_ctes.append(cte_name)
|
|
932
|
+
|
|
933
|
+
# Capture transformation info
|
|
934
|
+
transform_info: dict[str, str] = {
|
|
935
|
+
"cte": cte_name,
|
|
936
|
+
"column": column_name,
|
|
937
|
+
}
|
|
938
|
+
|
|
939
|
+
# Try to get expression if available
|
|
940
|
+
if hasattr(current, "expression") and current.expression is not None:
|
|
941
|
+
expr_sql = str(current.expression)
|
|
942
|
+
# Only include if it's not just a simple column reference
|
|
943
|
+
if expr_sql and expr_sql.strip() != column_name:
|
|
944
|
+
# Limit expression length to avoid huge outputs
|
|
945
|
+
if len(expr_sql) > 200:
|
|
946
|
+
expr_sql = expr_sql[:197] + "..."
|
|
947
|
+
transform_info["expression"] = expr_sql
|
|
948
|
+
|
|
949
|
+
transformations.append(transform_info)
|
|
950
|
+
|
|
951
|
+
# Move up the lineage chain (downstream attribute points to parent)
|
|
952
|
+
current = getattr(current, "downstream", None) if hasattr(current, "downstream") else None
|
|
953
|
+
|
|
954
|
+
return {
|
|
955
|
+
"via_ctes": via_ctes,
|
|
956
|
+
"transformations": transformations,
|
|
957
|
+
}
|
|
958
|
+
|
|
959
|
+
|
|
960
|
+
def _resolve_dependency_resource(
|
|
961
|
+
dependency: dict[str, Any],
|
|
962
|
+
relation_lookup: dict[str, str],
|
|
963
|
+
) -> str | None:
|
|
964
|
+
table_name = dependency.get("table")
|
|
965
|
+
if not table_name:
|
|
966
|
+
return None
|
|
967
|
+
|
|
968
|
+
database = dependency.get("database")
|
|
969
|
+
schema = dependency.get("schema")
|
|
970
|
+
|
|
971
|
+
if database and schema:
|
|
972
|
+
key = _normalize_relation_name(f"{database}.{schema}.{table_name}")
|
|
973
|
+
if key in relation_lookup:
|
|
974
|
+
return relation_lookup[key]
|
|
975
|
+
|
|
976
|
+
if schema:
|
|
977
|
+
key = _normalize_relation_name(f"{schema}.{table_name}")
|
|
978
|
+
if key in relation_lookup:
|
|
979
|
+
return relation_lookup[key]
|
|
980
|
+
|
|
981
|
+
key = _normalize_relation_name(table_name)
|
|
982
|
+
return relation_lookup.get(key)
|
|
983
|
+
|
|
984
|
+
|
|
985
|
+
def _check_column_in_lineage(lineage_node: Any, source_model: str, source_column: str) -> bool:
|
|
986
|
+
"""Check if a source column appears in the lineage tree.
|
|
987
|
+
|
|
988
|
+
Args:
|
|
989
|
+
lineage_node: Result from sqlglot.lineage()
|
|
990
|
+
source_model: Model name to look for
|
|
991
|
+
source_column: Column name to look for
|
|
992
|
+
|
|
993
|
+
Returns:
|
|
994
|
+
True if the column appears in the lineage
|
|
995
|
+
"""
|
|
996
|
+
logger.debug(f"[LINEAGE_CHECK] Looking for {source_model}.{source_column}")
|
|
997
|
+
|
|
998
|
+
for dep in lineage_node.walk():
|
|
999
|
+
if hasattr(dep, "name"):
|
|
1000
|
+
# dep.name can be either "column_name" or "table.column_name"
|
|
1001
|
+
name_parts = dep.name.lower().split(".")
|
|
1002
|
+
col_name = name_parts[-1] # Get the last part (column name)
|
|
1003
|
+
table_name = name_parts[0] if len(name_parts) > 1 else None
|
|
1004
|
+
|
|
1005
|
+
logger.debug(f"[LINEAGE_CHECK] Checking dep.name='{dep.name}' -> col='{col_name}', table='{table_name}'")
|
|
1006
|
+
|
|
1007
|
+
# Check if column matches
|
|
1008
|
+
if col_name == source_column.lower():
|
|
1009
|
+
logger.debug("[LINEAGE_CHECK] Column matches! Checking table...")
|
|
1010
|
+
# If table name in dep.name, check it matches source_model
|
|
1011
|
+
if table_name and source_model.lower() in table_name:
|
|
1012
|
+
logger.debug(f"[LINEAGE_CHECK] ✓ Table in dep.name matches: {table_name} contains {source_model}")
|
|
1013
|
+
return True
|
|
1014
|
+
# If no table name in dep.name, check source attribute
|
|
1015
|
+
elif hasattr(dep, "source") and hasattr(dep.source, "this"):
|
|
1016
|
+
source_table = str(dep.source.this).strip('"').lower()
|
|
1017
|
+
logger.debug(f"[LINEAGE_CHECK] Checking dep.source.this='{source_table}'")
|
|
1018
|
+
if source_model.lower() in source_table:
|
|
1019
|
+
logger.debug(f"[LINEAGE_CHECK] ✓ Source attribute matches: {source_table} contains {source_model}")
|
|
1020
|
+
return True
|
|
1021
|
+
else:
|
|
1022
|
+
logger.debug("[LINEAGE_CHECK] ✗ Column matches but no table info found")
|
|
1023
|
+
|
|
1024
|
+
logger.debug(f"[LINEAGE_CHECK] ✗ No match found for {source_model}.{source_column}")
|
|
1025
|
+
return False
|
|
1026
|
+
|
|
1027
|
+
|
|
1028
|
+
async def _trace_downstream_column(
|
|
1029
|
+
manifest: ManifestLoader,
|
|
1030
|
+
model_name: str,
|
|
1031
|
+
column_name: str,
|
|
1032
|
+
depth: int | None,
|
|
1033
|
+
dialect: str = "databricks",
|
|
1034
|
+
current_depth: int = 0,
|
|
1035
|
+
) -> list[dict[str, Any]]:
|
|
1036
|
+
"""Recursively trace where a column is used downstream.
|
|
1037
|
+
|
|
1038
|
+
Args:
|
|
1039
|
+
manifest: ManifestLoader instance
|
|
1040
|
+
model_name: Model name to start from
|
|
1041
|
+
column_name: Column name to trace
|
|
1042
|
+
depth: Maximum depth to traverse (None for unlimited)
|
|
1043
|
+
dialect: SQL dialect for parsing (default: databricks)
|
|
1044
|
+
current_depth: Current recursion depth
|
|
1045
|
+
|
|
1046
|
+
Returns:
|
|
1047
|
+
List of downstream usage dictionaries
|
|
1048
|
+
"""
|
|
1049
|
+
logger.info(f"[DOWNSTREAM] Tracing {model_name}.{column_name} at depth {current_depth}")
|
|
1050
|
+
|
|
1051
|
+
if depth is not None and current_depth >= depth:
|
|
1052
|
+
return []
|
|
1053
|
+
|
|
1054
|
+
results: list[dict[str, Any]] = []
|
|
1055
|
+
|
|
1056
|
+
# Get downstream models (distance 1)
|
|
1057
|
+
try:
|
|
1058
|
+
lineage_data = manifest.get_lineage(model_name, resource_type="model", direction="downstream", depth=1)
|
|
1059
|
+
except Exception as e:
|
|
1060
|
+
logger.warning(f"Could not get downstream lineage for {model_name}: {e}")
|
|
1061
|
+
return []
|
|
1062
|
+
|
|
1063
|
+
downstream_models = lineage_data.get("downstream", [])
|
|
1064
|
+
logger.info(f"[DOWNSTREAM] Found {len(downstream_models)} downstream models: {[m.get('name') for m in downstream_models]}")
|
|
1065
|
+
|
|
1066
|
+
for downstream_model in downstream_models:
|
|
1067
|
+
# Only process models (skip tests, snapshots, etc.)
|
|
1068
|
+
if not downstream_model.get("unique_id", "").startswith("model."):
|
|
1069
|
+
continue
|
|
1070
|
+
|
|
1071
|
+
# Skip CTE unit test helper models (auto-generated by CTE test generator)
|
|
1072
|
+
# Pattern: ends with __<6-char-hash> like customers_enriched__customer_agg__259035
|
|
1073
|
+
model_name_check = downstream_model.get("name", "")
|
|
1074
|
+
if model_name_check and len(model_name_check) > 8:
|
|
1075
|
+
# Check if ends with __<6 hex chars>
|
|
1076
|
+
suffix = model_name_check[-8:]
|
|
1077
|
+
if suffix[:2] == "__" and all(c in "0123456789abcdef" for c in suffix[2:]):
|
|
1078
|
+
logger.debug(f"Skipping CTE test model: {model_name_check}")
|
|
1079
|
+
continue
|
|
1080
|
+
|
|
1081
|
+
try:
|
|
1082
|
+
# Get downstream model info (schema + SQL) using unified helper
|
|
1083
|
+
model_name_downstream = downstream_model["name"]
|
|
1084
|
+
downstream_info, compiled_sql, schema_mapping = _prepare_model_analysis(
|
|
1085
|
+
manifest,
|
|
1086
|
+
model_name_downstream,
|
|
1087
|
+
)
|
|
1088
|
+
|
|
1089
|
+
logger.info(f"[DOWNSTREAM] {model_name_downstream}: has_sql={compiled_sql is not None}")
|
|
1090
|
+
|
|
1091
|
+
# Resolve output columns using centralized logic
|
|
1092
|
+
output_columns_dict, output_source = _resolve_output_columns(compiled_sql, schema_mapping, downstream_info, dialect)
|
|
1093
|
+
if output_columns_dict:
|
|
1094
|
+
logger.info(f"[DOWNSTREAM] {model_name_downstream}: using {len(output_columns_dict)} columns from {output_source}")
|
|
1095
|
+
else:
|
|
1096
|
+
# Fallback: use string search if no column metadata available
|
|
1097
|
+
logger.debug(f"[DOWNSTREAM] {model_name_downstream}: No column metadata, using string search")
|
|
1098
|
+
source_column_ref_patterns = [
|
|
1099
|
+
f"{model_name}.{column_name}",
|
|
1100
|
+
f".{column_name}",
|
|
1101
|
+
f" {column_name} ",
|
|
1102
|
+
f" {column_name},",
|
|
1103
|
+
f"({column_name}",
|
|
1104
|
+
]
|
|
1105
|
+
sql_lower = compiled_sql.lower()
|
|
1106
|
+
found_reference = any(pattern.lower() in sql_lower for pattern in source_column_ref_patterns)
|
|
1107
|
+
|
|
1108
|
+
if found_reference:
|
|
1109
|
+
logger.info(f"[DOWNSTREAM] ✓ {model_name_downstream} references {model_name}.{column_name} (string search)")
|
|
1110
|
+
results.append({"model": model_name_downstream, "column": column_name, "distance": current_depth + 1})
|
|
1111
|
+
further_downstream = await _trace_downstream_column(manifest, model_name_downstream, column_name, depth, dialect, current_depth + 1)
|
|
1112
|
+
results.extend(further_downstream)
|
|
1113
|
+
continue
|
|
1114
|
+
|
|
1115
|
+
# Check each output column to see if it uses our source column
|
|
1116
|
+
logger.debug(f"[DOWNSTREAM] {model_name_downstream}: checking {len(output_columns_dict)} output columns")
|
|
1117
|
+
sql_lower = compiled_sql.lower()
|
|
1118
|
+
source_column_ref_patterns = [
|
|
1119
|
+
f"{model_name}.{column_name}",
|
|
1120
|
+
f".{column_name}",
|
|
1121
|
+
f" {column_name} ",
|
|
1122
|
+
f" {column_name},",
|
|
1123
|
+
f"({column_name}",
|
|
1124
|
+
]
|
|
1125
|
+
|
|
1126
|
+
for output_col_name in output_columns_dict.keys():
|
|
1127
|
+
try:
|
|
1128
|
+
# Trace this output column's lineage using unified helper
|
|
1129
|
+
column_lineage_result = _analyze_column_lineage(
|
|
1130
|
+
output_col_name,
|
|
1131
|
+
compiled_sql,
|
|
1132
|
+
schema_mapping,
|
|
1133
|
+
model_name_downstream,
|
|
1134
|
+
dialect,
|
|
1135
|
+
)
|
|
1136
|
+
|
|
1137
|
+
logger.debug(f"[DOWNSTREAM] Check if {output_col_name} uses {model_name}.{column_name}")
|
|
1138
|
+
|
|
1139
|
+
# Check if our source column appears in the dependencies
|
|
1140
|
+
if _check_column_in_lineage(column_lineage_result, model_name, column_name):
|
|
1141
|
+
logger.info(f"[DOWNSTREAM] ✓ {model_name_downstream}.{output_col_name} USES {model_name}.{column_name}")
|
|
1142
|
+
# This output column uses our source column!
|
|
1143
|
+
results.append({"model": model_name_downstream, "column": output_col_name, "distance": current_depth + 1})
|
|
1144
|
+
|
|
1145
|
+
# Recurse: trace this column further downstream
|
|
1146
|
+
further_downstream = await _trace_downstream_column(manifest, model_name_downstream, output_col_name, depth, dialect, current_depth + 1)
|
|
1147
|
+
results.extend(further_downstream)
|
|
1148
|
+
else:
|
|
1149
|
+
# Heuristic fallback when lineage doesn't resolve dependencies
|
|
1150
|
+
if output_col_name.lower() == column_name.lower() and any(pattern.lower() in sql_lower for pattern in source_column_ref_patterns):
|
|
1151
|
+
logger.info(f"[DOWNSTREAM] ✓ {model_name_downstream}.{output_col_name} USES {model_name}.{column_name} (heuristic)")
|
|
1152
|
+
results.append({"model": model_name_downstream, "column": output_col_name, "distance": current_depth + 1})
|
|
1153
|
+
|
|
1154
|
+
further_downstream = await _trace_downstream_column(manifest, model_name_downstream, output_col_name, depth, dialect, current_depth + 1)
|
|
1155
|
+
results.extend(further_downstream)
|
|
1156
|
+
else:
|
|
1157
|
+
logger.debug(f"[DOWNSTREAM] ✗ {model_name_downstream}.{output_col_name} does NOT use {model_name}.{column_name}")
|
|
1158
|
+
|
|
1159
|
+
except SqlglotError as e:
|
|
1160
|
+
# Heuristic fallback when sqlglot fails
|
|
1161
|
+
if output_col_name.lower() == column_name.lower() and any(pattern.lower() in sql_lower for pattern in source_column_ref_patterns):
|
|
1162
|
+
logger.info(f"[DOWNSTREAM] ✓ {model_name_downstream}.{output_col_name} USES {model_name}.{column_name} (heuristic)")
|
|
1163
|
+
results.append({"model": model_name_downstream, "column": output_col_name, "distance": current_depth + 1})
|
|
1164
|
+
|
|
1165
|
+
further_downstream = await _trace_downstream_column(manifest, model_name_downstream, output_col_name, depth, dialect, current_depth + 1)
|
|
1166
|
+
results.extend(further_downstream)
|
|
1167
|
+
else:
|
|
1168
|
+
logger.warning(f"Could not trace {model_name_downstream}.{output_col_name}: {e}")
|
|
1169
|
+
continue
|
|
1170
|
+
|
|
1171
|
+
except Exception as e:
|
|
1172
|
+
# Use get() to avoid UnboundLocalError if exception occurs before model_name_downstream is set
|
|
1173
|
+
model_name_for_log = downstream_model.get("name", "unknown")
|
|
1174
|
+
logger.warning(f"Error analyzing downstream model {model_name_for_log}: {e}")
|
|
1175
|
+
continue
|
|
1176
|
+
|
|
1177
|
+
return results
|
|
1178
|
+
|
|
1179
|
+
|
|
1180
|
+
def _prepare_model_analysis(
|
|
1181
|
+
manifest: ManifestLoader,
|
|
1182
|
+
model_name: str,
|
|
1183
|
+
) -> tuple[dict[str, Any], str, dict[str, Any]]:
|
|
1184
|
+
"""Prepare model for column lineage analysis (unified helper).
|
|
1185
|
+
|
|
1186
|
+
Gets all necessary data in one call:
|
|
1187
|
+
- Resource info (metadata, columns)
|
|
1188
|
+
- Compiled SQL
|
|
1189
|
+
- Schema mapping for sqlglot context
|
|
1190
|
+
|
|
1191
|
+
Args:
|
|
1192
|
+
manifest: ManifestLoader instance
|
|
1193
|
+
model_name: Model name to analyze
|
|
1194
|
+
|
|
1195
|
+
Returns:
|
|
1196
|
+
Tuple of (resource_info, compiled_sql, schema_mapping)
|
|
1197
|
+
|
|
1198
|
+
Raises:
|
|
1199
|
+
ValueError: If model not found or has no compiled SQL
|
|
1200
|
+
"""
|
|
1201
|
+
# Get resource info with compiled SQL and database schema
|
|
1202
|
+
resource_info = manifest.get_resource_info(
|
|
1203
|
+
model_name,
|
|
1204
|
+
resource_type="model",
|
|
1205
|
+
include_compiled_sql=True,
|
|
1206
|
+
include_database_schema=True,
|
|
1207
|
+
)
|
|
1208
|
+
|
|
1209
|
+
compiled_sql = resource_info.get("compiled_sql")
|
|
1210
|
+
if not compiled_sql:
|
|
1211
|
+
raise ValueError(f"No compiled SQL found for model '{model_name}'")
|
|
1212
|
+
|
|
1213
|
+
# Build schema mapping from upstream models
|
|
1214
|
+
try:
|
|
1215
|
+
upstream_lineage = manifest.get_lineage(
|
|
1216
|
+
model_name,
|
|
1217
|
+
resource_type="model",
|
|
1218
|
+
direction="upstream",
|
|
1219
|
+
depth=1,
|
|
1220
|
+
)
|
|
1221
|
+
schema_mapping = _build_schema_mapping(manifest, upstream_lineage)
|
|
1222
|
+
except (ValueError, KeyError, AttributeError) as e:
|
|
1223
|
+
logger.warning(f"Could not build schema mapping for {model_name}: {e}")
|
|
1224
|
+
schema_mapping = {}
|
|
1225
|
+
|
|
1226
|
+
return resource_info, compiled_sql, schema_mapping
|
|
1227
|
+
|
|
1228
|
+
|
|
1229
|
+
def _clean_static_union_branches(ast: exp.Expression, column_name: str, dialect: str) -> exp.Expression:
|
|
1230
|
+
"""Remove UNION branches that contain only static literal values.
|
|
1231
|
+
|
|
1232
|
+
When a UNION fails during lineage tracing due to index errors, it's often
|
|
1233
|
+
because one branch has all static values (e.g., '-1' as gold_itemkey) while
|
|
1234
|
+
another has dynamic column references. This function removes all-static branches
|
|
1235
|
+
to allow lineage tracing to proceed.
|
|
1236
|
+
|
|
1237
|
+
Args:
|
|
1238
|
+
ast: The SQL AST to clean
|
|
1239
|
+
column_name: Column being traced (for logging)
|
|
1240
|
+
dialect: SQL dialect
|
|
1241
|
+
|
|
1242
|
+
Returns:
|
|
1243
|
+
Cleaned AST with static UNION branches removed
|
|
1244
|
+
"""
|
|
1245
|
+
# Find all UNION nodes (they appear as Union expressions)
|
|
1246
|
+
unions_found = list(ast.find_all(exp.Union))
|
|
1247
|
+
|
|
1248
|
+
if not unions_found:
|
|
1249
|
+
return ast # No UNIONs to clean
|
|
1250
|
+
|
|
1251
|
+
logger.debug(f"Found {len(unions_found)} UNION nodes to analyze")
|
|
1252
|
+
|
|
1253
|
+
for union_node in unions_found:
|
|
1254
|
+
# UNION in sqlglot is binary: has 'this' (left) and 'expression' (right)
|
|
1255
|
+
left_branch = union_node.this
|
|
1256
|
+
right_branch = union_node.expression
|
|
1257
|
+
|
|
1258
|
+
# Check if each branch is all-static (contains only literals)
|
|
1259
|
+
left_is_static = _is_all_static_branch(left_branch)
|
|
1260
|
+
right_is_static = _is_all_static_branch(right_branch)
|
|
1261
|
+
|
|
1262
|
+
logger.debug(f"UNION branch analysis: left_static={left_is_static}, right_static={right_is_static}")
|
|
1263
|
+
|
|
1264
|
+
if left_is_static and not right_is_static:
|
|
1265
|
+
# Keep only right branch
|
|
1266
|
+
logger.info(f"Removing static left UNION branch for column {column_name}")
|
|
1267
|
+
union_node.replace(right_branch)
|
|
1268
|
+
elif right_is_static and not left_is_static:
|
|
1269
|
+
# Keep only left branch
|
|
1270
|
+
logger.info(f"Removing static right UNION branch for column {column_name}")
|
|
1271
|
+
union_node.replace(left_branch)
|
|
1272
|
+
elif left_is_static and right_is_static:
|
|
1273
|
+
# Both static - this is unusual but keep left by convention
|
|
1274
|
+
logger.warning(f"Both UNION branches are static for column {column_name}, keeping left")
|
|
1275
|
+
union_node.replace(left_branch)
|
|
1276
|
+
# else: both dynamic, keep the UNION as-is
|
|
1277
|
+
|
|
1278
|
+
return ast
|
|
1279
|
+
|
|
1280
|
+
|
|
1281
|
+
def _is_all_static_branch(branch: exp.Expression) -> bool:
|
|
1282
|
+
"""Check if a SELECT branch contains only static literal expressions.
|
|
1283
|
+
|
|
1284
|
+
Args:
|
|
1285
|
+
branch: A SELECT expression to analyze
|
|
1286
|
+
|
|
1287
|
+
Returns:
|
|
1288
|
+
True if all SELECT expressions are literals, False otherwise
|
|
1289
|
+
"""
|
|
1290
|
+
# Find the SELECT node in this branch
|
|
1291
|
+
if isinstance(branch, exp.Select):
|
|
1292
|
+
select_node = branch
|
|
1293
|
+
else:
|
|
1294
|
+
select_node = branch.find(exp.Select)
|
|
1295
|
+
|
|
1296
|
+
if not select_node:
|
|
1297
|
+
return False
|
|
1298
|
+
|
|
1299
|
+
# Check all expressions in the SELECT clause
|
|
1300
|
+
for select_expr in select_node.expressions:
|
|
1301
|
+
# Get the actual expression (might be wrapped in Alias)
|
|
1302
|
+
if isinstance(select_expr, exp.Alias):
|
|
1303
|
+
actual_expr = select_expr.this
|
|
1304
|
+
else:
|
|
1305
|
+
actual_expr = select_expr
|
|
1306
|
+
|
|
1307
|
+
# If we find any non-literal, this branch is dynamic
|
|
1308
|
+
if not isinstance(actual_expr, exp.Literal):
|
|
1309
|
+
# Also check for NULL which is represented differently
|
|
1310
|
+
if not (isinstance(actual_expr, exp.Null)):
|
|
1311
|
+
return False
|
|
1312
|
+
|
|
1313
|
+
# All expressions are literals
|
|
1314
|
+
return True
|
|
1315
|
+
|
|
1316
|
+
|
|
1317
|
+
def _analyze_column_lineage(
|
|
1318
|
+
column_name: str,
|
|
1319
|
+
compiled_sql: str,
|
|
1320
|
+
schema_mapping: dict[str, Any],
|
|
1321
|
+
model_name: str,
|
|
1322
|
+
dialect: str = "databricks",
|
|
1323
|
+
) -> Any:
|
|
1324
|
+
"""Run sqlglot column lineage analysis with automatic UNION cleanup.
|
|
1325
|
+
|
|
1326
|
+
Tries lineage analysis, and if it fails due to UNION structure issues (IndexError),
|
|
1327
|
+
automatically cleans static UNION branches and retries. This handles the common
|
|
1328
|
+
data warehouse pattern of UNION ALL with hardcoded default rows.
|
|
1329
|
+
|
|
1330
|
+
Args:
|
|
1331
|
+
column_name: Column to analyze
|
|
1332
|
+
compiled_sql: Compiled SQL
|
|
1333
|
+
schema_mapping: Schema context for sqlglot
|
|
1334
|
+
model_name: Model name (for error messages)
|
|
1335
|
+
dialect: SQL dialect for parsing (default: databricks)
|
|
1336
|
+
|
|
1337
|
+
Returns:
|
|
1338
|
+
sqlglot lineage node
|
|
1339
|
+
|
|
1340
|
+
Raises:
|
|
1341
|
+
ValueError: If SQL parsing fails or UNION cleanup doesn't resolve issues
|
|
1342
|
+
"""
|
|
1343
|
+
try:
|
|
1344
|
+
# Wrap the SQL to enable lineage tracing through SELECT *
|
|
1345
|
+
wrapped_ast = _wrap_final_select(compiled_sql, column_name, dialect)
|
|
1346
|
+
|
|
1347
|
+
max_attempts = 3
|
|
1348
|
+
for attempt in range(max_attempts):
|
|
1349
|
+
try:
|
|
1350
|
+
result = lineage(
|
|
1351
|
+
column=column_name,
|
|
1352
|
+
sql=wrapped_ast,
|
|
1353
|
+
schema=schema_mapping,
|
|
1354
|
+
dialect=dialect,
|
|
1355
|
+
)
|
|
1356
|
+
|
|
1357
|
+
# Test the result by walking the tree to trigger any IndexError
|
|
1358
|
+
for _ in result.walk():
|
|
1359
|
+
pass
|
|
1360
|
+
|
|
1361
|
+
if attempt > 0:
|
|
1362
|
+
logger.info(f"Lineage succeeded after {attempt} UNION cleanup attempt(s) for {model_name}.{column_name}")
|
|
1363
|
+
|
|
1364
|
+
return result # Success!
|
|
1365
|
+
|
|
1366
|
+
except IndexError as e:
|
|
1367
|
+
if "index out of range" in str(e).lower():
|
|
1368
|
+
if attempt == max_attempts - 1:
|
|
1369
|
+
raise ValueError(f"UNION structure issue in {model_name}.{column_name} after {max_attempts} cleanup attempts: {e}")
|
|
1370
|
+
|
|
1371
|
+
logger.info(f"UNION index error detected for {model_name}.{column_name}, cleaning static branches (attempt {attempt + 1})")
|
|
1372
|
+
# Clean the AST and retry
|
|
1373
|
+
wrapped_ast = _clean_static_union_branches(wrapped_ast, column_name, dialect)
|
|
1374
|
+
else:
|
|
1375
|
+
# Different kind of IndexError, re-raise
|
|
1376
|
+
raise
|
|
1377
|
+
|
|
1378
|
+
except SqlglotError as e:
|
|
1379
|
+
raise ValueError(f"Failed to parse SQL for column lineage: {e}\nModel: {model_name}, Column: {column_name}")
|
|
1380
|
+
except Exception as e:
|
|
1381
|
+
logger.exception("Unexpected error in column lineage analysis")
|
|
1382
|
+
raise ValueError(f"Column lineage analysis failed: {e}")
|
|
1383
|
+
|
|
1384
|
+
|
|
1385
|
+
def _trace_upstream_recursive(
|
|
1386
|
+
manifest: ManifestLoader,
|
|
1387
|
+
model_name: str,
|
|
1388
|
+
column_name: str,
|
|
1389
|
+
depth: int | None,
|
|
1390
|
+
dialect: str = "databricks",
|
|
1391
|
+
current_depth: int = 0,
|
|
1392
|
+
relation_lookup: dict[str, str] | None = None,
|
|
1393
|
+
visited: set[str] | None = None,
|
|
1394
|
+
) -> list[dict[str, Any]]:
|
|
1395
|
+
"""Recursively trace upstream dependencies for a column.
|
|
1396
|
+
|
|
1397
|
+
Uses unified SQL parsing approach (same as downstream):
|
|
1398
|
+
1. Prepare model (get resource_info + compiled_sql + schema_mapping)
|
|
1399
|
+
2. Resolve output columns (SQL → warehouse → manifest)
|
|
1400
|
+
3. Run lineage per resolved column
|
|
1401
|
+
4. Recurse on dependencies
|
|
1402
|
+
|
|
1403
|
+
Args:
|
|
1404
|
+
manifest: ManifestLoader instance
|
|
1405
|
+
model_name: Model name to analyze
|
|
1406
|
+
column_name: Column name to trace
|
|
1407
|
+
depth: Maximum depth (None for unlimited)
|
|
1408
|
+
dialect: SQL dialect for parsing (default: databricks)
|
|
1409
|
+
current_depth: Current recursion depth
|
|
1410
|
+
relation_lookup: FQN → unique_id mapping (built once, reused)
|
|
1411
|
+
visited: Set of visited unique_ids (prevent cycles)
|
|
1412
|
+
|
|
1413
|
+
Returns:
|
|
1414
|
+
List of upstream dependencies with dbt resource mapping
|
|
1415
|
+
"""
|
|
1416
|
+
# Skip wildcards
|
|
1417
|
+
if column_name.strip() == "*":
|
|
1418
|
+
return []
|
|
1419
|
+
|
|
1420
|
+
# Check depth limit
|
|
1421
|
+
if depth is not None and current_depth >= depth:
|
|
1422
|
+
return []
|
|
1423
|
+
|
|
1424
|
+
# Initialize tracking on first call
|
|
1425
|
+
if relation_lookup is None:
|
|
1426
|
+
relation_lookup = _build_relation_lookup(manifest)
|
|
1427
|
+
|
|
1428
|
+
if visited is None:
|
|
1429
|
+
visited = set()
|
|
1430
|
+
|
|
1431
|
+
# Prepare model for analysis (unified helper - Phase 1)
|
|
1432
|
+
try:
|
|
1433
|
+
_, compiled_sql, schema_mapping = _prepare_model_analysis(
|
|
1434
|
+
manifest,
|
|
1435
|
+
model_name,
|
|
1436
|
+
)
|
|
1437
|
+
except ValueError as e:
|
|
1438
|
+
logger.warning(f"Could not prepare model {model_name}: {e}")
|
|
1439
|
+
return []
|
|
1440
|
+
|
|
1441
|
+
# Run lineage analysis (unified helper - Phase 1)
|
|
1442
|
+
try:
|
|
1443
|
+
column_lineage_result = _analyze_column_lineage(
|
|
1444
|
+
column_name,
|
|
1445
|
+
compiled_sql,
|
|
1446
|
+
schema_mapping,
|
|
1447
|
+
model_name,
|
|
1448
|
+
dialect,
|
|
1449
|
+
)
|
|
1450
|
+
except ValueError as e:
|
|
1451
|
+
logger.warning(f"Could not analyze column {model_name}.{column_name}: {e}")
|
|
1452
|
+
return []
|
|
1453
|
+
|
|
1454
|
+
# Extract root model transformations using new format (id, type, sources)
|
|
1455
|
+
root_transformations = _extract_transformations_with_sources(column_lineage_result)
|
|
1456
|
+
|
|
1457
|
+
# Extract dependencies and resolve to dbt resources
|
|
1458
|
+
dependencies = _extract_dependencies_from_lineage(
|
|
1459
|
+
column_lineage_result,
|
|
1460
|
+
manifest,
|
|
1461
|
+
schema_mapping,
|
|
1462
|
+
depth,
|
|
1463
|
+
column_name,
|
|
1464
|
+
)
|
|
1465
|
+
|
|
1466
|
+
# Resolve dependency FQNs to dbt resources
|
|
1467
|
+
for dependency in dependencies:
|
|
1468
|
+
if dependency.get("dbt_resource"):
|
|
1469
|
+
continue
|
|
1470
|
+
|
|
1471
|
+
resolved = _resolve_dependency_resource(dependency, relation_lookup)
|
|
1472
|
+
if resolved:
|
|
1473
|
+
dependency["dbt_resource"] = resolved
|
|
1474
|
+
|
|
1475
|
+
# Enrich model dependencies with their internal transformations
|
|
1476
|
+
for dependency in dependencies:
|
|
1477
|
+
dbt_resource = dependency.get("dbt_resource")
|
|
1478
|
+
if not dbt_resource:
|
|
1479
|
+
continue
|
|
1480
|
+
|
|
1481
|
+
node = manifest.get_node_by_unique_id(dbt_resource)
|
|
1482
|
+
if not node or node.get("resource_type") != "model":
|
|
1483
|
+
continue
|
|
1484
|
+
|
|
1485
|
+
# Get the dependency model's internal transformations
|
|
1486
|
+
dep_model_name = node.get("name")
|
|
1487
|
+
dep_column = dependency.get("column", "")
|
|
1488
|
+
|
|
1489
|
+
if dep_model_name and dep_column:
|
|
1490
|
+
try:
|
|
1491
|
+
# Analyze the dependency model to get its internal CTEs
|
|
1492
|
+
_, dep_compiled_sql, dep_schema_mapping = _prepare_model_analysis(manifest, dep_model_name)
|
|
1493
|
+
dep_lineage = _analyze_column_lineage(dep_column, dep_compiled_sql, dep_schema_mapping, dep_model_name, dialect)
|
|
1494
|
+
|
|
1495
|
+
# Extract internal transformations using new format (id, type, sources)
|
|
1496
|
+
internal_transformations = _extract_transformations_with_sources(dep_lineage)
|
|
1497
|
+
|
|
1498
|
+
# Replace the accumulated transformations with internal ones
|
|
1499
|
+
if internal_transformations:
|
|
1500
|
+
dependency["transformations"] = internal_transformations
|
|
1501
|
+
# Update via_ctes to match (extract CTE ids)
|
|
1502
|
+
via_ctes = []
|
|
1503
|
+
for t in internal_transformations:
|
|
1504
|
+
if t.get("type") == "cte":
|
|
1505
|
+
# Remove namespace prefix for via_ctes backward compatibility
|
|
1506
|
+
cte_id = t.get("id", "")
|
|
1507
|
+
if cte_id.startswith("cte:"):
|
|
1508
|
+
cte_name = cte_id[4:] # Strip "cte:" prefix
|
|
1509
|
+
if cte_name not in via_ctes:
|
|
1510
|
+
via_ctes.append(cte_name)
|
|
1511
|
+
dependency["via_ctes"] = via_ctes
|
|
1512
|
+
|
|
1513
|
+
except Exception as e:
|
|
1514
|
+
logger.warning(f"Could not extract internal transformations for {dep_model_name}.{dep_column}: {e}")
|
|
1515
|
+
# Keep existing transformations if extraction fails
|
|
1516
|
+
|
|
1517
|
+
# Store results with root model transformations
|
|
1518
|
+
results = list(dependencies)
|
|
1519
|
+
|
|
1520
|
+
# Add root transformations metadata to each result (will be extracted by formatter)
|
|
1521
|
+
if root_transformations and results:
|
|
1522
|
+
for result in results:
|
|
1523
|
+
result["__root_transformations__"] = root_transformations
|
|
1524
|
+
elif root_transformations:
|
|
1525
|
+
# No dependencies found, but we have root transformations - return them anyway
|
|
1526
|
+
results = [{"__root_transformations__": root_transformations}]
|
|
1527
|
+
|
|
1528
|
+
# Recurse on each dependency
|
|
1529
|
+
for i, dependency in enumerate(dependencies):
|
|
1530
|
+
dbt_resource = dependency.get("dbt_resource")
|
|
1531
|
+
if not dbt_resource or dbt_resource in visited:
|
|
1532
|
+
continue
|
|
1533
|
+
|
|
1534
|
+
node = manifest.get_node_by_unique_id(dbt_resource)
|
|
1535
|
+
if not node:
|
|
1536
|
+
continue
|
|
1537
|
+
|
|
1538
|
+
if node.get("resource_type") != "model":
|
|
1539
|
+
# This is a source/seed - already in results, don't recurse
|
|
1540
|
+
continue
|
|
1541
|
+
|
|
1542
|
+
next_model = node.get("name")
|
|
1543
|
+
if not next_model:
|
|
1544
|
+
continue
|
|
1545
|
+
|
|
1546
|
+
visited.add(dbt_resource)
|
|
1547
|
+
|
|
1548
|
+
# Extract column name from dependency
|
|
1549
|
+
dep_column = dependency.get("column", "")
|
|
1550
|
+
|
|
1551
|
+
# Handle wildcards: resolve upstream model's output columns
|
|
1552
|
+
if dep_column.endswith(".*") or dep_column.strip() == "*":
|
|
1553
|
+
# Use unified helper to prepare upstream model
|
|
1554
|
+
try:
|
|
1555
|
+
upstream_info, upstream_sql, upstream_schema_mapping = _prepare_model_analysis(
|
|
1556
|
+
manifest,
|
|
1557
|
+
next_model,
|
|
1558
|
+
)
|
|
1559
|
+
except ValueError:
|
|
1560
|
+
continue
|
|
1561
|
+
|
|
1562
|
+
# Resolve output columns using unified logic
|
|
1563
|
+
output_columns_dict, _ = _resolve_output_columns(
|
|
1564
|
+
upstream_sql,
|
|
1565
|
+
upstream_schema_mapping,
|
|
1566
|
+
upstream_info,
|
|
1567
|
+
dialect,
|
|
1568
|
+
)
|
|
1569
|
+
|
|
1570
|
+
# Determine which column to trace through the wildcard
|
|
1571
|
+
# Check previous dependency for the actual column name
|
|
1572
|
+
if i > 0:
|
|
1573
|
+
prev_dep = dependencies[i - 1]
|
|
1574
|
+
# Try transformations first, then fall back to column field
|
|
1575
|
+
if prev_dep.get("transformations"):
|
|
1576
|
+
column_to_trace = prev_dep["transformations"][0].get("column", column_name)
|
|
1577
|
+
else:
|
|
1578
|
+
# Strip CTE prefix if present (e.g., "customers.customer_id" -> "customer_id")
|
|
1579
|
+
prev_column = prev_dep.get("column", "")
|
|
1580
|
+
column_to_trace = prev_column.split(".")[-1] if prev_column else column_name
|
|
1581
|
+
else:
|
|
1582
|
+
# No previous dependency - use original target column
|
|
1583
|
+
column_to_trace = column_name
|
|
1584
|
+
|
|
1585
|
+
# Only trace the specific column we're looking for, not all columns
|
|
1586
|
+
if column_to_trace in output_columns_dict:
|
|
1587
|
+
results.extend(
|
|
1588
|
+
_trace_upstream_recursive(
|
|
1589
|
+
manifest,
|
|
1590
|
+
next_model,
|
|
1591
|
+
column_to_trace,
|
|
1592
|
+
depth,
|
|
1593
|
+
dialect,
|
|
1594
|
+
current_depth + 1,
|
|
1595
|
+
relation_lookup,
|
|
1596
|
+
visited,
|
|
1597
|
+
)
|
|
1598
|
+
)
|
|
1599
|
+
continue
|
|
1600
|
+
|
|
1601
|
+
# Regular column: extract name and recurse
|
|
1602
|
+
next_column = dep_column.split(".")[-1] if dep_column else column_name
|
|
1603
|
+
|
|
1604
|
+
results.extend(
|
|
1605
|
+
_trace_upstream_recursive(
|
|
1606
|
+
manifest,
|
|
1607
|
+
next_model,
|
|
1608
|
+
next_column,
|
|
1609
|
+
depth,
|
|
1610
|
+
dialect,
|
|
1611
|
+
current_depth + 1,
|
|
1612
|
+
relation_lookup,
|
|
1613
|
+
visited,
|
|
1614
|
+
)
|
|
1615
|
+
)
|
|
1616
|
+
|
|
1617
|
+
return results
|
|
1618
|
+
|
|
1619
|
+
return results
|
|
1620
|
+
|
|
1621
|
+
|
|
1622
|
+
async def _build_downstream_usages(
|
|
1623
|
+
manifest: ManifestLoader,
|
|
1624
|
+
model_name: str,
|
|
1625
|
+
column_name: str,
|
|
1626
|
+
depth: int | None,
|
|
1627
|
+
dialect: str,
|
|
1628
|
+
current_depth: int = 0,
|
|
1629
|
+
) -> list[dict[str, Any]]:
|
|
1630
|
+
"""Build downstream usages with transformations (reversed order).
|
|
1631
|
+
|
|
1632
|
+
For each downstream model using this column, trace how they transform it
|
|
1633
|
+
from input to output. Transformations are in reversed order (bottom-to-top:
|
|
1634
|
+
table input → CTEs → outer query).
|
|
1635
|
+
|
|
1636
|
+
Args:
|
|
1637
|
+
manifest: ManifestLoader instance
|
|
1638
|
+
model_name: Source model name
|
|
1639
|
+
column_name: Source column name
|
|
1640
|
+
depth: Maximum depth to traverse
|
|
1641
|
+
dialect: SQL dialect for parsing
|
|
1642
|
+
current_depth: Current recursion depth
|
|
1643
|
+
|
|
1644
|
+
Returns:
|
|
1645
|
+
List of usage dicts with transformations
|
|
1646
|
+
"""
|
|
1647
|
+
if depth is not None and current_depth >= depth:
|
|
1648
|
+
return []
|
|
1649
|
+
|
|
1650
|
+
try:
|
|
1651
|
+
lineage_data = manifest.get_lineage(
|
|
1652
|
+
model_name,
|
|
1653
|
+
resource_type="model",
|
|
1654
|
+
direction="downstream",
|
|
1655
|
+
depth=1,
|
|
1656
|
+
)
|
|
1657
|
+
except (ValueError, KeyError, AttributeError) as e:
|
|
1658
|
+
logger.warning(f"Could not get downstream lineage for {model_name}: {e}")
|
|
1659
|
+
return []
|
|
1660
|
+
|
|
1661
|
+
downstream_models = lineage_data.get("downstream", [])
|
|
1662
|
+
usages = []
|
|
1663
|
+
|
|
1664
|
+
for downstream_model in downstream_models:
|
|
1665
|
+
if not downstream_model.get("unique_id", "").startswith("model."):
|
|
1666
|
+
continue
|
|
1667
|
+
|
|
1668
|
+
# Skip CTE unit test helper models
|
|
1669
|
+
model_name_downstream = downstream_model.get("name", "")
|
|
1670
|
+
if model_name_downstream and len(model_name_downstream) > 8:
|
|
1671
|
+
suffix = model_name_downstream[-8:]
|
|
1672
|
+
if suffix[:2] == "__" and all(c in "0123456789abcdef" for c in suffix[2:]):
|
|
1673
|
+
continue
|
|
1674
|
+
|
|
1675
|
+
try:
|
|
1676
|
+
# Get downstream model's SQL and trace transformations
|
|
1677
|
+
downstream_info, compiled_sql, schema_mapping = _prepare_model_analysis(
|
|
1678
|
+
manifest,
|
|
1679
|
+
model_name_downstream,
|
|
1680
|
+
)
|
|
1681
|
+
|
|
1682
|
+
# Resolve output columns to find which column(s) use our source column
|
|
1683
|
+
output_columns_dict, _ = _resolve_output_columns(compiled_sql, schema_mapping, downstream_info, dialect)
|
|
1684
|
+
|
|
1685
|
+
# Find output columns that reference our source column
|
|
1686
|
+
for output_col_name in output_columns_dict:
|
|
1687
|
+
try:
|
|
1688
|
+
# Trace transformations for this output column
|
|
1689
|
+
column_lineage_result = _analyze_column_lineage(
|
|
1690
|
+
column_name=output_col_name,
|
|
1691
|
+
compiled_sql=compiled_sql,
|
|
1692
|
+
schema_mapping=schema_mapping,
|
|
1693
|
+
model_name=model_name_downstream,
|
|
1694
|
+
dialect=dialect,
|
|
1695
|
+
)
|
|
1696
|
+
|
|
1697
|
+
# Extract transformations (shows how downstream model builds the column)
|
|
1698
|
+
transformations = _extract_transformations_with_sources(column_lineage_result)
|
|
1699
|
+
|
|
1700
|
+
# Reverse transformations for downstream (bottom-to-top: input → output)
|
|
1701
|
+
# Upstream shows top-to-bottom (query → CTEs → table)
|
|
1702
|
+
# Downstream shows bottom-to-top (table → CTEs → query)
|
|
1703
|
+
transformations = list(reversed(transformations))
|
|
1704
|
+
|
|
1705
|
+
# Extract dependencies (to find if they reference our source column)
|
|
1706
|
+
deps = _extract_dependencies_from_lineage(
|
|
1707
|
+
column_lineage_result,
|
|
1708
|
+
manifest,
|
|
1709
|
+
schema_mapping,
|
|
1710
|
+
depth=1, # Only immediate deps
|
|
1711
|
+
column_name=output_col_name,
|
|
1712
|
+
)
|
|
1713
|
+
|
|
1714
|
+
# Check if any dependency references our source model+column
|
|
1715
|
+
uses_source_column = False
|
|
1716
|
+
for dep in deps:
|
|
1717
|
+
dep_table = dep.get("table", "")
|
|
1718
|
+
dep_column = dep.get("column", "")
|
|
1719
|
+
# Match table name (handle schema-qualified names)
|
|
1720
|
+
if model_name.lower() in dep_table.lower() and dep_column.lower() == column_name.lower():
|
|
1721
|
+
uses_source_column = True
|
|
1722
|
+
break
|
|
1723
|
+
|
|
1724
|
+
if uses_source_column:
|
|
1725
|
+
# Already extracted transformations above
|
|
1726
|
+
# Recursively get usages of this downstream column
|
|
1727
|
+
nested_usages = await _build_downstream_usages(
|
|
1728
|
+
manifest,
|
|
1729
|
+
model_name_downstream,
|
|
1730
|
+
output_col_name,
|
|
1731
|
+
depth,
|
|
1732
|
+
dialect,
|
|
1733
|
+
current_depth + 1,
|
|
1734
|
+
)
|
|
1735
|
+
|
|
1736
|
+
usage = {
|
|
1737
|
+
"model": model_name_downstream,
|
|
1738
|
+
"column": output_col_name,
|
|
1739
|
+
"distance": current_depth + 1,
|
|
1740
|
+
"transformations": transformations,
|
|
1741
|
+
}
|
|
1742
|
+
|
|
1743
|
+
if nested_usages:
|
|
1744
|
+
usage["usages"] = nested_usages
|
|
1745
|
+
|
|
1746
|
+
usages.append(usage)
|
|
1747
|
+
|
|
1748
|
+
except Exception as e:
|
|
1749
|
+
logger.debug(f"Could not trace {model_name_downstream}.{output_col_name}: {e}")
|
|
1750
|
+
continue
|
|
1751
|
+
|
|
1752
|
+
except Exception as e:
|
|
1753
|
+
logger.warning(f"Error analyzing downstream model {model_name_downstream}: {e}")
|
|
1754
|
+
continue
|
|
1755
|
+
|
|
1756
|
+
return usages
|
|
1757
|
+
|
|
1758
|
+
|
|
1759
|
+
def _format_lineage_response(model_name: str, column_name: str, direction: str, dependencies: list[dict[str, Any]], downstream_usage: list[dict[str, Any]] | None = None) -> dict[str, Any]:
|
|
1760
|
+
"""Format the final lineage response.
|
|
1761
|
+
|
|
1762
|
+
Args:
|
|
1763
|
+
model_name: Model name
|
|
1764
|
+
column_name: Column name
|
|
1765
|
+
direction: Direction of lineage
|
|
1766
|
+
dependencies: Upstream dependencies
|
|
1767
|
+
downstream_usage: Optional downstream usage info (legacy format or new usages format)
|
|
1768
|
+
|
|
1769
|
+
Returns:
|
|
1770
|
+
Formatted response dict
|
|
1771
|
+
"""
|
|
1772
|
+
result: dict[str, Any] = {
|
|
1773
|
+
"model": model_name,
|
|
1774
|
+
"column": column_name,
|
|
1775
|
+
"direction": direction,
|
|
1776
|
+
}
|
|
1777
|
+
|
|
1778
|
+
# Handle downstream-only format (new unified structure with transformations)
|
|
1779
|
+
if direction == "downstream" and downstream_usage is not None:
|
|
1780
|
+
# New format: usages with transformations (no root transformations, we're the source)
|
|
1781
|
+
result["usages"] = downstream_usage
|
|
1782
|
+
return result
|
|
1783
|
+
|
|
1784
|
+
# Extract root transformations from dependencies (if present)
|
|
1785
|
+
root_transformations = None
|
|
1786
|
+
if dependencies and "__root_transformations__" in dependencies[0]:
|
|
1787
|
+
root_transformations = dependencies[0]["__root_transformations__"]
|
|
1788
|
+
# Remove from all dependencies (cleanup)
|
|
1789
|
+
for dep in dependencies:
|
|
1790
|
+
dep.pop("__root_transformations__", None)
|
|
1791
|
+
|
|
1792
|
+
# Add root transformations if present (upstream or both)
|
|
1793
|
+
if root_transformations:
|
|
1794
|
+
result["transformations"] = root_transformations
|
|
1795
|
+
# Extract via_ctes from transformations (new format with id/type)
|
|
1796
|
+
via_ctes = []
|
|
1797
|
+
for t in root_transformations:
|
|
1798
|
+
if t.get("type") == "cte":
|
|
1799
|
+
# Extract CTE name from namespaced id (e.g., "cte:orders" -> "orders")
|
|
1800
|
+
cte_id = t.get("id", "")
|
|
1801
|
+
if cte_id.startswith("cte:"):
|
|
1802
|
+
cte_name = cte_id[4:] # Strip "cte:" prefix
|
|
1803
|
+
if cte_name not in via_ctes:
|
|
1804
|
+
via_ctes.append(cte_name)
|
|
1805
|
+
if via_ctes:
|
|
1806
|
+
result["via_ctes"] = via_ctes
|
|
1807
|
+
|
|
1808
|
+
result["dependencies"] = dependencies
|
|
1809
|
+
result["dependency_count"] = len(dependencies)
|
|
1810
|
+
|
|
1811
|
+
if downstream_usage is not None and direction != "downstream":
|
|
1812
|
+
# For "both" direction: use new usages format
|
|
1813
|
+
result["usages"] = downstream_usage
|
|
1814
|
+
|
|
1815
|
+
return result
|
|
1816
|
+
|
|
1817
|
+
|
|
1818
|
+
async def implementation(
|
|
1819
|
+
ctx: Context | None,
|
|
1820
|
+
model_name: str,
|
|
1821
|
+
column_name: str,
|
|
1822
|
+
direction: str,
|
|
1823
|
+
depth: int | None,
|
|
1824
|
+
state: DbtCoreServerContext,
|
|
1825
|
+
force_parse: bool = True,
|
|
1826
|
+
) -> dict[str, Any]:
|
|
1827
|
+
"""Implementation function for get_column_lineage tool.
|
|
1828
|
+
|
|
1829
|
+
Separated for testing purposes - tests call this directly with explicit state.
|
|
1830
|
+
The @tool() decorated get_column_lineage() function calls this with injected dependencies.
|
|
1831
|
+
"""
|
|
1832
|
+
# Initialize state if needed
|
|
1833
|
+
await state.ensure_initialized(ctx, force_parse)
|
|
1834
|
+
|
|
1835
|
+
# Verify manifest is available
|
|
1836
|
+
if state.manifest is None:
|
|
1837
|
+
raise RuntimeError("Manifest not initialized")
|
|
1838
|
+
|
|
1839
|
+
# FIRST: Check if we have compiled code, compile ALL if needed
|
|
1840
|
+
# Test source model to see if compilation is needed
|
|
1841
|
+
resouce_info = state.manifest.get_resource_info(model_name, resource_type="model", include_compiled_sql=True)
|
|
1842
|
+
compiled_sql = resouce_info.get("compiled_sql")
|
|
1843
|
+
if not compiled_sql:
|
|
1844
|
+
logger.info("No compiled SQL found - compiling entire project")
|
|
1845
|
+
runner = await state.get_runner()
|
|
1846
|
+
# Compile ALL models (dbt compile with no selector)
|
|
1847
|
+
compile_result = await runner.invoke(["compile"])
|
|
1848
|
+
|
|
1849
|
+
if compile_result.success:
|
|
1850
|
+
# Reload manifest to get compiled code
|
|
1851
|
+
await state.manifest.load()
|
|
1852
|
+
# Re-fetch the resource to get updated compiled_code
|
|
1853
|
+
resouce_info = state.manifest.get_resource_info(
|
|
1854
|
+
model_name,
|
|
1855
|
+
resource_type="model",
|
|
1856
|
+
include_database_schema=False,
|
|
1857
|
+
include_compiled_sql=True,
|
|
1858
|
+
)
|
|
1859
|
+
compiled_sql = resouce_info.get("compiled_sql")
|
|
1860
|
+
else:
|
|
1861
|
+
raise RuntimeError(f"Failed to compile project: {compile_result}")
|
|
1862
|
+
|
|
1863
|
+
logger.info("Project compiled successfully")
|
|
1864
|
+
|
|
1865
|
+
# We always need compiled SQL from this point.
|
|
1866
|
+
if not compiled_sql:
|
|
1867
|
+
raise ValueError(f"No compiled SQL found for model '{model_name}'. Model may not contain SQL code.")
|
|
1868
|
+
|
|
1869
|
+
# Handle multiple matches (check AFTER compilation like get_resource_info does)
|
|
1870
|
+
if resouce_info.get("multiple_matches"):
|
|
1871
|
+
raise ValueError(f"Multiple models found matching '{model_name}'. Please use unique_id: {[m['unique_id'] for m in resouce_info['matches']]}")
|
|
1872
|
+
|
|
1873
|
+
# Get SQL dialect from manifest metadata
|
|
1874
|
+
project_info = state.manifest.get_project_info()
|
|
1875
|
+
adapter_type = project_info.get("adapter_type", "databricks")
|
|
1876
|
+
dialect = _map_dbt_adapter_to_sqlglot_dialect(adapter_type)
|
|
1877
|
+
logger.info(f"Using SQL dialect '{dialect}' for adapter type '{adapter_type}'")
|
|
1878
|
+
|
|
1879
|
+
# Validate that the requested column exists in the model's output
|
|
1880
|
+
# Build schema mapping for column resolution
|
|
1881
|
+
upstream_lineage = state.manifest.get_lineage(
|
|
1882
|
+
model_name,
|
|
1883
|
+
resource_type="model",
|
|
1884
|
+
direction="upstream",
|
|
1885
|
+
depth=1,
|
|
1886
|
+
)
|
|
1887
|
+
schema_mapping = _build_schema_mapping(state.manifest, upstream_lineage)
|
|
1888
|
+
|
|
1889
|
+
# Resolve output columns
|
|
1890
|
+
output_columns, _ = _resolve_output_columns(
|
|
1891
|
+
compiled_sql=compiled_sql,
|
|
1892
|
+
schema_mapping=schema_mapping,
|
|
1893
|
+
resource_info=resouce_info,
|
|
1894
|
+
dialect=dialect,
|
|
1895
|
+
)
|
|
1896
|
+
|
|
1897
|
+
# Check if requested column exists
|
|
1898
|
+
if column_name not in output_columns:
|
|
1899
|
+
available_columns = ", ".join(f"'{col}'" for col in sorted(output_columns.keys()))
|
|
1900
|
+
raise ValueError(f"Column '{column_name}' not found in output of model '{model_name}'. Available columns: {available_columns}")
|
|
1901
|
+
|
|
1902
|
+
if direction == "downstream":
|
|
1903
|
+
# Downstream only - new format with transformations
|
|
1904
|
+
usages = await _build_downstream_usages(state.manifest, model_name, column_name, depth, dialect)
|
|
1905
|
+
return _format_lineage_response(model_name, column_name, direction, [], usages)
|
|
1906
|
+
|
|
1907
|
+
if direction == "upstream":
|
|
1908
|
+
# Upstream only
|
|
1909
|
+
dependencies = _trace_upstream_recursive(state.manifest, model_name, column_name, depth, dialect)
|
|
1910
|
+
return _format_lineage_response(model_name, column_name, direction, dependencies, None)
|
|
1911
|
+
|
|
1912
|
+
else: # direction == "both"
|
|
1913
|
+
# Both upstream and downstream
|
|
1914
|
+
dependencies = _trace_upstream_recursive(state.manifest, model_name, column_name, depth, dialect)
|
|
1915
|
+
usages = await _build_downstream_usages(state.manifest, model_name, column_name, depth, dialect)
|
|
1916
|
+
return _format_lineage_response(model_name, column_name, direction, dependencies, usages)
|
|
1917
|
+
|
|
1918
|
+
|
|
1919
|
+
@dbtTool()
|
|
1920
|
+
async def get_column_lineage(
|
|
1921
|
+
ctx: Context,
|
|
1922
|
+
model_name: str,
|
|
1923
|
+
column_name: str,
|
|
1924
|
+
direction: str = "upstream",
|
|
1925
|
+
depth: int | None = None,
|
|
1926
|
+
state: DbtCoreServerContext = Depends(get_state),
|
|
1927
|
+
) -> dict[str, Any]:
|
|
1928
|
+
"""Trace column-level lineage through SQL transformations.
|
|
1929
|
+
|
|
1930
|
+
Uses sqlglot to parse compiled SQL and track how columns flow through:
|
|
1931
|
+
- CTEs and subqueries
|
|
1932
|
+
- JOINs and aggregations
|
|
1933
|
+
- Transformations (calculations, CASE statements, etc.)
|
|
1934
|
+
- Window functions
|
|
1935
|
+
|
|
1936
|
+
This provides detailed column-to-column dependencies that model-level
|
|
1937
|
+
lineage cannot capture.
|
|
1938
|
+
|
|
1939
|
+
Args:
|
|
1940
|
+
model_name: Name or unique_id of the dbt model to analyze
|
|
1941
|
+
column_name: Name of the column to trace
|
|
1942
|
+
direction: Direction to trace lineage:
|
|
1943
|
+
- "upstream": Which source columns feed into this column
|
|
1944
|
+
- "downstream": Which downstream columns use this column
|
|
1945
|
+
- "both": Full bidirectional column lineage
|
|
1946
|
+
depth: Maximum levels to traverse (None for unlimited)
|
|
1947
|
+
- depth=1: Immediate column dependencies only
|
|
1948
|
+
- depth=2: Dependencies + their dependencies
|
|
1949
|
+
- None: Full dependency tree
|
|
1950
|
+
|
|
1951
|
+
Returns:
|
|
1952
|
+
Column lineage information including:
|
|
1953
|
+
- Source columns this column depends on (upstream)
|
|
1954
|
+
- Downstream columns that depend on this column
|
|
1955
|
+
- Transformations and derivations
|
|
1956
|
+
- CTE transformation paths (via_ctes, transformations)
|
|
1957
|
+
- dbt resource mapping where available
|
|
1958
|
+
|
|
1959
|
+
Each dependency includes:
|
|
1960
|
+
- column: Column name
|
|
1961
|
+
- table: Source table name
|
|
1962
|
+
- schema: Source schema (if available)
|
|
1963
|
+
- database: Source database (if available)
|
|
1964
|
+
- via_ctes: List of CTE names in transformation order
|
|
1965
|
+
- transformations: Transformation details per CTE step
|
|
1966
|
+
- cte: CTE name
|
|
1967
|
+
- column: Column name at this step
|
|
1968
|
+
- expression: SQL expression (truncated to 200 chars)
|
|
1969
|
+
|
|
1970
|
+
Raises:
|
|
1971
|
+
ValueError: If model not found, column not found, or SQL parse fails
|
|
1972
|
+
RuntimeError: If sqlglot is not installed
|
|
1973
|
+
|
|
1974
|
+
Examples:
|
|
1975
|
+
# Find which source columns feed into revenue
|
|
1976
|
+
get_column_lineage("fct_sales", "revenue", "upstream")
|
|
1977
|
+
|
|
1978
|
+
# See what downstream models use customer_id
|
|
1979
|
+
get_column_lineage("dim_customers", "customer_id", "downstream")
|
|
1980
|
+
|
|
1981
|
+
# Full bidirectional lineage for a column
|
|
1982
|
+
get_column_lineage("fct_orders", "order_total", "both")
|
|
1983
|
+
|
|
1984
|
+
Note:
|
|
1985
|
+
Requires sqlglot package. Install with: pip install sqlglot
|
|
1986
|
+
The model must be compiled (run 'dbt compile' first).
|
|
1987
|
+
"""
|
|
1988
|
+
return await implementation(ctx, model_name, column_name, direction, depth, state)
|