iflow-mcp_niclasolofsson-dbt-core-mcp 1.7.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (38) hide show
  1. dbt_core_mcp/__init__.py +18 -0
  2. dbt_core_mcp/__main__.py +436 -0
  3. dbt_core_mcp/context.py +459 -0
  4. dbt_core_mcp/cte_generator.py +601 -0
  5. dbt_core_mcp/dbt/__init__.py +1 -0
  6. dbt_core_mcp/dbt/bridge_runner.py +1361 -0
  7. dbt_core_mcp/dbt/manifest.py +781 -0
  8. dbt_core_mcp/dbt/runner.py +67 -0
  9. dbt_core_mcp/dependencies.py +50 -0
  10. dbt_core_mcp/server.py +381 -0
  11. dbt_core_mcp/tools/__init__.py +77 -0
  12. dbt_core_mcp/tools/analyze_impact.py +78 -0
  13. dbt_core_mcp/tools/build_models.py +190 -0
  14. dbt_core_mcp/tools/demo/__init__.py +1 -0
  15. dbt_core_mcp/tools/demo/hello.html +267 -0
  16. dbt_core_mcp/tools/demo/ui_demo.py +41 -0
  17. dbt_core_mcp/tools/get_column_lineage.py +1988 -0
  18. dbt_core_mcp/tools/get_lineage.py +89 -0
  19. dbt_core_mcp/tools/get_project_info.py +96 -0
  20. dbt_core_mcp/tools/get_resource_info.py +134 -0
  21. dbt_core_mcp/tools/install_deps.py +102 -0
  22. dbt_core_mcp/tools/list_resources.py +84 -0
  23. dbt_core_mcp/tools/load_seeds.py +179 -0
  24. dbt_core_mcp/tools/query_database.py +459 -0
  25. dbt_core_mcp/tools/run_models.py +234 -0
  26. dbt_core_mcp/tools/snapshot_models.py +120 -0
  27. dbt_core_mcp/tools/test_models.py +238 -0
  28. dbt_core_mcp/utils/__init__.py +1 -0
  29. dbt_core_mcp/utils/env_detector.py +186 -0
  30. dbt_core_mcp/utils/process_check.py +130 -0
  31. dbt_core_mcp/utils/tool_utils.py +411 -0
  32. dbt_core_mcp/utils/warehouse_adapter.py +82 -0
  33. dbt_core_mcp/utils/warehouse_databricks.py +297 -0
  34. iflow_mcp_niclasolofsson_dbt_core_mcp-1.7.0.dist-info/METADATA +784 -0
  35. iflow_mcp_niclasolofsson_dbt_core_mcp-1.7.0.dist-info/RECORD +38 -0
  36. iflow_mcp_niclasolofsson_dbt_core_mcp-1.7.0.dist-info/WHEEL +4 -0
  37. iflow_mcp_niclasolofsson_dbt_core_mcp-1.7.0.dist-info/entry_points.txt +2 -0
  38. iflow_mcp_niclasolofsson_dbt_core_mcp-1.7.0.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,1988 @@
1
+ """Get column-level lineage through SQL transformations.
2
+
3
+ This module implements the get_column_lineage tool for dbt Core MCP.
4
+ Uses sqlglot to parse SQL and trace column dependencies through CTEs and transformations.
5
+
6
+ Architecture:
7
+ - Unified SQL parsing approach for both upstream and downstream directions
8
+ - Two-stage process: (1) resolve output columns, (2) analyze lineage per column
9
+ - Direction-agnostic helpers: _prepare_model_analysis, _analyze_column_lineage
10
+ - Consistent fallback order: SQL-derived → warehouse → manifest → none
11
+ """
12
+
13
+ import logging
14
+ from typing import Any
15
+
16
+ from fastmcp.dependencies import Depends # type: ignore[reportAttributeAccessIssue]
17
+ from fastmcp.server.context import Context
18
+ from sqlglot import exp, parse_one
19
+ from sqlglot.errors import SqlglotError
20
+ from sqlglot.lineage import lineage
21
+ from sqlglot.optimizer.scope import build_scope
22
+
23
+ from ..context import DbtCoreServerContext
24
+ from ..dbt.manifest import ManifestLoader
25
+ from ..dependencies import get_state
26
+ from . import dbtTool
27
+
28
+ logger = logging.getLogger(__name__)
29
+
30
+
31
+ def _wrap_final_select(sql: str, column_name: str, dialect: str = "databricks") -> exp.Expression:
32
+ """Wrap the final SELECT in a CTE to enable lineage tracing through SELECT *.
33
+
34
+ Transforms:
35
+ SELECT * FROM final
36
+ Into:
37
+ WITH __lineage_final__ AS (SELECT * FROM final)
38
+ SELECT column_name FROM __lineage_final__
39
+
40
+ This allows lineage() to trace specific columns through schemaless queries.
41
+
42
+ Args:
43
+ sql: SQL query to wrap
44
+ column_name: Column to trace
45
+ dialect: SQL dialect
46
+
47
+ Returns:
48
+ Modified AST with wrapped final SELECT
49
+ """
50
+ ast = parse_one(sql, dialect=dialect)
51
+
52
+ # Get the root SELECT
53
+ root_select = ast if isinstance(ast, exp.Select) else ast.find(exp.Select)
54
+
55
+ if not root_select:
56
+ return ast
57
+
58
+ # Save the WITH clause before clearing
59
+ with_clause = root_select.args.get("with_")
60
+
61
+ # Copy entire root SELECT and strip its WITH
62
+ wrapper_select = root_select.copy()
63
+ wrapper_select.set("with_", None)
64
+
65
+ wrapper_cte = exp.CTE(this=wrapper_select, alias=exp.TableAlias(this=exp.Identifier(this="__lineage_final__")))
66
+
67
+ # Clear ALL args from root
68
+ for key in list(root_select.args.keys()):
69
+ root_select.set(key, None)
70
+
71
+ # Add wrapper to with_clause or create new one
72
+ if with_clause:
73
+ with_clause.expressions.append(wrapper_cte)
74
+ else:
75
+ with_clause = exp.With(expressions=[wrapper_cte])
76
+
77
+ # Rebuild root SELECT with only what we need
78
+ root_select.set("with_", with_clause)
79
+ root_select.set("expressions", [exp.Column(this=exp.Identifier(this=column_name))])
80
+ root_select.set("from_", exp.From(this=exp.Table(this=exp.Identifier(this="__lineage_final__"))))
81
+
82
+ return ast
83
+
84
+
85
+ # Export for testing
86
+ __all__ = [
87
+ "implementation",
88
+ "get_column_lineage",
89
+ "_prepare_model_analysis",
90
+ "_analyze_column_lineage",
91
+ "_build_schema_mapping",
92
+ "_resolve_output_columns",
93
+ "_resolve_wildcard_column_in_table",
94
+ "_resolve_unresolved_table_reference",
95
+ "_extract_transformations_with_sources",
96
+ "_extract_dependencies_from_lineage",
97
+ "_format_lineage_response",
98
+ "_map_dbt_adapter_to_sqlglot_dialect",
99
+ ]
100
+
101
+
102
+ def _map_dbt_adapter_to_sqlglot_dialect(adapter_type: str) -> str:
103
+ """Map dbt adapter type to sqlglot dialect.
104
+
105
+ Args:
106
+ adapter_type: The dbt adapter type from manifest metadata
107
+
108
+ Returns:
109
+ sqlglot dialect name
110
+ """
111
+ # Direct matches between dbt adapter and sqlglot dialect
112
+ ADAPTER_TO_DIALECT = {
113
+ "athena": "athena",
114
+ "bigquery": "bigquery",
115
+ "clickhouse": "clickhouse",
116
+ "databricks": "databricks",
117
+ "doris": "doris",
118
+ "dremio": "dremio",
119
+ "duckdb": "duckdb",
120
+ "fabric": "fabric",
121
+ "hive": "hive",
122
+ "materialize": "materialize",
123
+ "mysql": "mysql",
124
+ "oracle": "oracle",
125
+ "postgres": "postgres",
126
+ "postgresql": "postgres", # Some adapters use postgresql
127
+ "redshift": "redshift",
128
+ "risingwave": "risingwave",
129
+ "singlestore": "singlestore",
130
+ "snowflake": "snowflake",
131
+ "spark": "spark",
132
+ "sqlite": "sqlite",
133
+ "starrocks": "starrocks",
134
+ "teradata": "teradata",
135
+ "trino": "trino",
136
+ # Adapters that need dialect mapping
137
+ "synapse": "tsql", # Azure Synapse uses T-SQL
138
+ "sqlserver": "tsql", # SQL Server uses T-SQL
139
+ "glue": "spark", # AWS Glue uses Spark
140
+ "fabricspark": "spark", # Fabric Lakehouse uses Spark
141
+ }
142
+
143
+ adapter_lower = adapter_type.lower()
144
+ dialect = ADAPTER_TO_DIALECT.get(adapter_lower)
145
+
146
+ if dialect:
147
+ logger.debug(f"Mapped dbt adapter '{adapter_type}' to sqlglot dialect '{dialect}'")
148
+ return dialect
149
+
150
+ # Fallback: use adapter type as-is (might work for some cases)
151
+ logger.warning(f"No explicit mapping for dbt adapter '{adapter_type}', using as-is for sqlglot")
152
+ return adapter_lower
153
+
154
+
155
+ # ========== Unified Helpers (Direction-Agnostic) ==========
156
+ # These helpers are used by both upstream and downstream lineage analysis,
157
+ # ensuring consistent behavior regardless of direction:
158
+ #
159
+ # 1. _prepare_model_analysis(): Gets resource_info + compiled_sql + schema_mapping
160
+ # 2. _analyze_column_lineage(): Runs sqlglot lineage with consistent error handling
161
+ # 3. _resolve_output_columns(): Resolves output columns with fallback order
162
+ #
163
+ # This unified approach ensures:
164
+ # - Same SQL parsing logic for both directions
165
+ # - Consistent error handling and logging
166
+ # - No code duplication
167
+ # - Easier to maintain and test
168
+
169
+
170
+ def _build_schema_mapping(manifest: ManifestLoader, upstream_lineage: dict[str, Any]) -> dict[str, Any]:
171
+ """Build schema mapping from upstream models for sqlglot.
172
+
173
+ Args:
174
+ manifest: ManifestLoader instance
175
+ upstream_lineage: Upstream lineage dict from manifest.get_lineage()
176
+
177
+ Returns:
178
+ Schema mapping in format: {database: {schema: {table: {column: type}}}}
179
+ """
180
+ schema_mapping: dict[str, Any] = {}
181
+
182
+ if "upstream" not in upstream_lineage:
183
+ return schema_mapping
184
+
185
+ for upstream_node in upstream_lineage["upstream"]:
186
+ try:
187
+ # Use node name, not unique_id (get_resource_info expects name)
188
+ node_name = upstream_node.get("name")
189
+ unique_id = upstream_node.get("unique_id")
190
+ if not node_name or not unique_id:
191
+ continue
192
+
193
+ # Extract resource type from unique_id (e.g. "seed.project.name" -> "seed")
194
+ resource_type = unique_id.split(".", 1)[0] if "." in unique_id else "model"
195
+
196
+ node_info = manifest.get_resource_info(node_name, resource_type=resource_type, include_database_schema=True, include_compiled_sql=False)
197
+
198
+ database = node_info.get("database", "").lower()
199
+ schema = node_info.get("schema", "").lower()
200
+ # For sources, use identifier; for models/seeds, use alias or name
201
+ table = node_info.get("identifier") or node_info.get("alias") or node_info.get("name", "").lower()
202
+
203
+ if database and schema and table:
204
+ # Add columns with their types
205
+ columns = node_info.get("database_columns", [])
206
+
207
+ if not columns:
208
+ manifest_columns = node_info.get("columns", {})
209
+ if manifest_columns:
210
+ columns = {col_name: {"type": (col_info.get("data_type") or col_info.get("type") or "string")} for col_name, col_info in manifest_columns.items()}
211
+
212
+ if not columns:
213
+ column_map = {col_info.get("col_name", "").lower(): col_info.get("type", "string").lower() for col_info in columns}
214
+ else:
215
+ # Dict format: {"customer_id": {"type": "INTEGER"}}
216
+ column_map = {col_name.lower(): col_info.get("type", "string").lower() for col_name, col_info in columns.items()}
217
+
218
+ if not column_map:
219
+ continue
220
+
221
+ if database not in schema_mapping:
222
+ schema_mapping[database] = {}
223
+ if schema not in schema_mapping[database]:
224
+ schema_mapping[database][schema] = {}
225
+
226
+ schema_mapping[database][schema][table] = column_map
227
+ except Exception as e:
228
+ logger.warning(f"Could not load schema for upstream node {upstream_node.get('unique_id')}: {e}")
229
+ continue
230
+
231
+ return schema_mapping
232
+
233
+
234
+ def _find_table_columns(schema_mapping: dict[str, Any], table_name: str) -> list[str]:
235
+ """Find column names for a table in the schema mapping.
236
+
237
+ Args:
238
+ schema_mapping: Schema mapping in format {db: {schema: {table: {column: type}}}}
239
+ table_name: Table name to look up
240
+
241
+ Returns:
242
+ List of column names for the table
243
+ """
244
+ table_lower = table_name.lower()
245
+ for database_mapping in schema_mapping.values():
246
+ for schema_mapping_for_db in database_mapping.values():
247
+ if table_lower in schema_mapping_for_db:
248
+ return list(schema_mapping_for_db[table_lower].keys())
249
+ return []
250
+
251
+
252
+ def _normalize_relation_name(value: str) -> str:
253
+ """Normalize relation names for matching."""
254
+ return value.replace('"', "").replace("`", "").replace("[", "").replace("]", "").strip().lower()
255
+
256
+
257
+ def _add_relation_keys(lookup: dict[str, str], database: str | None, schema: str | None, identifier: str | None, unique_id: str) -> None:
258
+ if not identifier:
259
+ return
260
+
261
+ if database and schema:
262
+ lookup[_normalize_relation_name(f"{database}.{schema}.{identifier}")] = unique_id
263
+
264
+ if schema:
265
+ lookup[_normalize_relation_name(f"{schema}.{identifier}")] = unique_id
266
+
267
+ lookup[_normalize_relation_name(identifier)] = unique_id
268
+
269
+
270
+ def _build_relation_lookup(manifest: ManifestLoader) -> dict[str, str]:
271
+ """Build relation name -> unique_id lookup from the manifest."""
272
+ lookup: dict[str, str] = {}
273
+ manifest_dict = manifest.get_manifest_dict()
274
+
275
+ for node in manifest_dict.get("nodes", {}).values():
276
+ if not isinstance(node, dict):
277
+ continue
278
+
279
+ unique_id = node.get("unique_id")
280
+ if not unique_id:
281
+ continue
282
+
283
+ relation_name = node.get("relation_name")
284
+ if relation_name:
285
+ lookup[_normalize_relation_name(relation_name)] = unique_id
286
+
287
+ database = node.get("database")
288
+ schema = node.get("schema")
289
+ identifier = node.get("alias") or node.get("name")
290
+ _add_relation_keys(lookup, database, schema, identifier, unique_id)
291
+
292
+ for source in manifest_dict.get("sources", {}).values():
293
+ if not isinstance(source, dict):
294
+ continue
295
+
296
+ unique_id = source.get("unique_id")
297
+ if not unique_id:
298
+ continue
299
+
300
+ relation_name = source.get("relation_name")
301
+ if relation_name:
302
+ lookup[_normalize_relation_name(relation_name)] = unique_id
303
+
304
+ database = source.get("database")
305
+ schema = source.get("schema")
306
+ identifier = source.get("identifier") or source.get("name")
307
+ _add_relation_keys(lookup, database, schema, identifier, unique_id)
308
+
309
+ return lookup
310
+
311
+
312
+ def _get_output_columns_from_sql(compiled_sql: str, schema_mapping: dict[str, Any], dialect: str = "databricks") -> list[str]:
313
+ """Extract output column names from compiled SQL.
314
+
315
+ Handles SELECT * from a single CTE or table by expanding from schema mapping
316
+ or the CTE's projections.
317
+
318
+ Args:
319
+ compiled_sql: Compiled SQL string
320
+ schema_mapping: Schema mapping for table expansion
321
+ dialect: SQL dialect for parsing (default: databricks)
322
+
323
+ Returns:
324
+ List of output column names
325
+ """
326
+ try:
327
+ ast = parse_one(compiled_sql, dialect=dialect)
328
+ except Exception:
329
+ return []
330
+
331
+ root_scope = build_scope(ast)
332
+ if not root_scope:
333
+ return []
334
+
335
+ select = root_scope.expression if isinstance(root_scope.expression, exp.Select) else root_scope.expression.find(exp.Select)
336
+ if not select:
337
+ return []
338
+
339
+ projections = list(select.expressions)
340
+ if projections and all(isinstance(p, exp.Star) for p in projections):
341
+ if len(root_scope.selected_sources) == 1:
342
+ _, (_, source) = next(iter(root_scope.selected_sources.items()))
343
+ if isinstance(source, exp.Table):
344
+ return _find_table_columns(schema_mapping, source.name)
345
+
346
+ # CTE or subquery scope
347
+ if hasattr(source, "expression"):
348
+ cte_select = source.expression if isinstance(source.expression, exp.Select) else source.expression.find(exp.Select)
349
+ if cte_select:
350
+ return [proj.alias_or_name for proj in cte_select.expressions if proj.alias_or_name]
351
+ return []
352
+
353
+ return [proj.alias_or_name for proj in projections if proj.alias_or_name]
354
+
355
+
356
+ def _resolve_output_columns(
357
+ compiled_sql: str,
358
+ schema_mapping: dict[str, Any],
359
+ resource_info: dict[str, Any],
360
+ dialect: str = "databricks",
361
+ ) -> tuple[dict[str, Any], str]:
362
+ """Resolve output columns for a model using resource-specific strategies.
363
+
364
+ For Sources & Seeds (Database Tables):
365
+ 1) Warehouse columns (database_columns) - database is truth
366
+ 2) Manifest columns (schema.yml) - fallback
367
+ 3) Wildcard "*" - table not built yet, but table reference known
368
+
369
+ For Models (SQL Transformations):
370
+ 1) SQL-derived projections ONLY - compiled SQL is truth
371
+ (No fallbacks - if SQL parsing fails, no columns available)
372
+
373
+ Args:
374
+ compiled_sql: Compiled SQL string
375
+ schema_mapping: Schema mapping for table expansion
376
+ resource_info: Resource information dict
377
+ dialect: SQL dialect for parsing (default: databricks)
378
+
379
+ Returns:
380
+ (output_columns_dict, source_label)
381
+ """
382
+ output_columns_dict: dict[str, Any] = {}
383
+
384
+ # For sources AND seeds: database-first with wildcard fallback
385
+ if resource_info.get("resource_type") in ["source", "seed"]:
386
+ database_columns = resource_info.get("database_columns", [])
387
+ if isinstance(database_columns, list) and database_columns:
388
+ output_columns_dict = {col.get("col_name", ""): {} for col in database_columns if col.get("col_name")}
389
+ if output_columns_dict:
390
+ return output_columns_dict, "warehouse"
391
+ elif isinstance(database_columns, dict) and database_columns:
392
+ return {col_name: {} for col_name in database_columns.keys()}, "warehouse"
393
+
394
+ # Fallback to manifest columns
395
+ manifest_columns = resource_info.get("columns", {})
396
+ output_columns_dict = {col_name: {} for col_name in manifest_columns.keys()}
397
+ if output_columns_dict:
398
+ return output_columns_dict, "manifest"
399
+
400
+ # Final fallback: table not built yet, return wildcard
401
+ return {"*": {}}, "wildcard"
402
+
403
+ # For models/seeds: SQL parsing ONLY (compiled SQL is truth)
404
+ output_columns = _get_output_columns_from_sql(compiled_sql, schema_mapping, dialect)
405
+ if output_columns:
406
+ return {col: {} for col in output_columns}, "sql"
407
+
408
+ return {}, "none"
409
+
410
+
411
+ def _resolve_wildcard_column_in_table(table_name: str, schema_mapping: dict[str, Any]) -> str | None:
412
+ """Resolve wildcard (*) column to specific column name within a known table.
413
+
414
+ When we have a resolved table but wildcard column (e.g., "column": "*", "table": "stg_customers"),
415
+ this function attempts to find the specific column being traced in that table.
416
+
417
+ Args:
418
+ table_name: Name of the table (without quotes)
419
+ schema_mapping: Schema mapping with table and column information
420
+
421
+ Returns:
422
+ Resolved column name if found, None otherwise
423
+ """
424
+ if table_name not in schema_mapping:
425
+ return None
426
+
427
+ table_info = schema_mapping[table_name]
428
+ resource_type = table_info.get("resource_type", "")
429
+
430
+ # For sources and seeds: use database-first approach
431
+ if resource_type in ["source", "seed"]:
432
+ # Check database columns first
433
+ database_columns = table_info.get("database_columns", [])
434
+ if isinstance(database_columns, list) and database_columns:
435
+ # Return first database column as representative
436
+ first_col = database_columns[0]
437
+ if isinstance(first_col, dict):
438
+ return first_col.get("col_name")
439
+ elif isinstance(database_columns, dict) and database_columns:
440
+ # Return first key from database columns
441
+ return next(iter(database_columns.keys()))
442
+
443
+ # Fallback to manifest columns
444
+ manifest_columns = table_info.get("columns", {})
445
+ if manifest_columns:
446
+ return next(iter(manifest_columns.keys()))
447
+
448
+ # For models: would use SQL parsing (handled elsewhere)
449
+ return None
450
+
451
+
452
+ def _resolve_unresolved_table_reference(col_name: str, schema_mapping: dict[str, Any]) -> dict[str, Any] | None:
453
+ """Resolve unresolved table references using schema mapping and database-first approach.
454
+
455
+ When sqlglot can't resolve a table reference (returns None), this function
456
+ attempts to find which table in the schema mapping contains the requested column,
457
+ using database-first resolution for sources and seeds.
458
+
459
+ Args:
460
+ col_name: Column name to search for
461
+ schema_mapping: Schema mapping with table and column information
462
+
463
+ Returns:
464
+ Dictionary with resolved table info if found, None otherwise:
465
+ {
466
+ "table_name": str,
467
+ "database": str | None,
468
+ "schema": str | None,
469
+ "resolved_column": str | None # If resolved from wildcard
470
+ }
471
+ """
472
+ for table_name, table_info in schema_mapping.items():
473
+ resource_type = table_info.get("resource_type", "")
474
+
475
+ # For sources and seeds: use database-first approach
476
+ if resource_type in ["source", "seed"]:
477
+ # Check database columns first
478
+ database_columns = table_info.get("database_columns", [])
479
+ if isinstance(database_columns, list):
480
+ for db_col in database_columns:
481
+ if isinstance(db_col, dict) and db_col.get("col_name", "").lower() == col_name.lower():
482
+ return {"table_name": f'"{table_name}"', "database": table_info.get("database"), "schema": table_info.get("schema"), "resolved_column": db_col.get("col_name")}
483
+ elif isinstance(database_columns, dict):
484
+ if col_name.lower() in [k.lower() for k in database_columns.keys()]:
485
+ # Find the actual key with proper case
486
+ actual_col = next(k for k in database_columns.keys() if k.lower() == col_name.lower())
487
+ return {"table_name": f'"{table_name}"', "database": table_info.get("database"), "schema": table_info.get("schema"), "resolved_column": actual_col}
488
+
489
+ # Fallback to manifest columns for sources/seeds
490
+ manifest_columns = table_info.get("columns", {})
491
+ if col_name.lower() in [k.lower() for k in manifest_columns.keys()]:
492
+ actual_col = next(k for k in manifest_columns.keys() if k.lower() == col_name.lower())
493
+ return {"table_name": f'"{table_name}"', "database": table_info.get("database"), "schema": table_info.get("schema"), "resolved_column": actual_col}
494
+
495
+ return None
496
+
497
+
498
+ def _extract_transformations_with_sources(lineage_node: Any) -> list[dict[str, Any]]:
499
+ """Extract transformations with namespaced IDs, types, and source references.
500
+
501
+ Uses a two-pass approach:
502
+ 1. First pass: Create all transformations and build lookup map
503
+ 2. Second pass: Extract sources by parsing expressions for CTE/table references
504
+
505
+ Args:
506
+ lineage_node: Result from sqlglot.lineage()
507
+
508
+ Returns:
509
+ List of transformation dicts with structure:
510
+ {
511
+ "id": "cte:name" or "table:name", # Namespaced identifier
512
+ "type": "cte" or "table", # Type of transformation
513
+ "column": "column_name", # Column being transformed
514
+ "expression": "...", # Optional: how it's computed (if not pass-through)
515
+ "sources": ["cte:other", ...] # Upstream dependencies
516
+ }
517
+ """
518
+ transform_map: dict[str, dict[str, Any]] = {} # id -> transform dict
519
+ cte_to_id: dict[str, str] = {} # cte_name -> transform id (for lookup)
520
+ nodes_with_data: list[tuple[str, Any, str]] = [] # (transform_id, node, expression)
521
+ outer_query_sources: set[str] = set() # Track what the outer query references
522
+ union_branches: dict[str, list[dict[str, Any]]] = {} # reference_node_name -> list of branch info
523
+
524
+ # FIRST PASS: Create all transformations and build lookup map
525
+ for node in lineage_node.walk():
526
+ if not hasattr(node, "name") or not node.name:
527
+ continue
528
+
529
+ # Handle numeric nodes (UNION branches)
530
+ if "." not in node.name:
531
+ # Check if this is a UNION branch node (has reference_node_name attribute)
532
+ if hasattr(node, "reference_node_name") and node.reference_node_name:
533
+ ref_cte = node.reference_node_name
534
+
535
+ # Extract expression and sources from this branch
536
+ branch_info: dict[str, Any] = {}
537
+ if hasattr(node, "expression") and node.expression:
538
+ expr_str = str(node.expression)
539
+ branch_info["expression"] = expr_str if len(expr_str) <= 200 else expr_str[:197] + "..."
540
+
541
+ # Extract column name (will be used for the union transformation)
542
+ # Expression is like "all_segments_union.customer_id AS customer_id"
543
+ if " AS " in expr_str:
544
+ branch_info["column"] = expr_str.split(" AS ")[-1].strip()
545
+
546
+ # Store full expression for sources extraction later
547
+ branch_info["_full_expr"] = expr_str
548
+
549
+ # Group branches by their parent CTE
550
+ union_branches.setdefault(ref_cte, []).append(branch_info)
551
+ continue
552
+
553
+ parts = node.name.split(".", 1)
554
+ cte_or_table = parts[0]
555
+ col_name = parts[1] if len(parts) > 1 else node.name
556
+
557
+ # Special handling for wrapper CTE - extract what the outer query references
558
+ if cte_or_table == "__lineage_final__":
559
+ # Extract what __lineage_final__ references (could be CTE or table)
560
+ if hasattr(node, "expression") and node.expression:
561
+ expr_str = str(node.expression)
562
+ # Expression will be like "final.customer_id" (CTE) or "si.gold_itemkey" (table alias)
563
+ # Look for references in the expression
564
+ for potential_ref in expr_str.split():
565
+ if "." in potential_ref:
566
+ # Could be "final.customer_id" or "si.gold_itemkey"
567
+ ref_name = potential_ref.split(".")[0].strip("(),")
568
+ if ref_name and not ref_name.isdigit():
569
+ # Try as CTE first, then as table
570
+ # We'll resolve the actual type in the second pass
571
+ outer_query_sources.add(ref_name)
572
+ continue
573
+
574
+ source_type = type(node.source).__name__ if hasattr(node, "source") else None
575
+
576
+ # Get the actual CTE/table name (not the alias)
577
+ # node.name might be "alias.column" but node.source.this has the real name
578
+ actual_cte_or_table = cte_or_table # default to parsed name
579
+ if hasattr(node.source, "this") and node.source.this:
580
+ actual_cte_or_table = str(node.source.this)
581
+
582
+ # Create namespaced ID using the actual CTE/table name
583
+ if source_type == "Table":
584
+ transform_id = f"table:{actual_cte_or_table}"
585
+ transform_type = "table"
586
+ else:
587
+ transform_id = f"cte:{actual_cte_or_table}"
588
+ transform_type = "cte"
589
+
590
+ # Track CTE name -> ID for lookup in second pass (use actual name, not alias)
591
+ cte_to_id[actual_cte_or_table] = transform_id
592
+
593
+ # Skip if we've already processed this transformation
594
+ if transform_id in transform_map:
595
+ # But save node data for sources extraction in second pass
596
+ if hasattr(node, "expression") and node.expression:
597
+ nodes_with_data.append((transform_id, node, str(node.expression)))
598
+ continue
599
+
600
+ # Build transformation
601
+ transform: dict[str, Any] = {
602
+ "id": transform_id,
603
+ "type": transform_type,
604
+ "column": col_name,
605
+ }
606
+
607
+ # Add expression if present and meaningful
608
+ if hasattr(node, "expression") and node.expression:
609
+ expr_sql = str(node.expression)
610
+ if expr_sql and expr_sql.strip() != col_name and len(expr_sql) <= 200:
611
+ transform["expression"] = expr_sql
612
+ elif len(expr_sql) > 200:
613
+ transform["expression"] = expr_sql[:197] + "..."
614
+
615
+ transform_map[transform_id] = transform
616
+
617
+ # Save for second pass
618
+ if hasattr(node, "expression") and node.expression:
619
+ expr_sql_str = str(node.expression)
620
+ nodes_with_data.append((transform_id, node, expr_sql_str if len(expr_sql_str) <= 200 else expr_sql_str[:200]))
621
+
622
+ # Process UNION branches in two passes
623
+ # FIRST: Register all union CTEs in cte_to_id
624
+ for ref_cte, branches in union_branches.items():
625
+ if not branches:
626
+ continue
627
+ transform_id = f"cte:{ref_cte}"
628
+ cte_to_id[ref_cte] = transform_id
629
+
630
+ # SECOND: Create union transformations with sources
631
+ for ref_cte, branches in union_branches.items():
632
+ if not branches:
633
+ continue
634
+
635
+ # Create union transformation
636
+ transform_id = f"cte:{ref_cte}"
637
+ column = branches[0].get("column", "") # All branches have same column
638
+
639
+ # Build branches array with sources
640
+ formatted_branches: list[dict[str, Any]] = []
641
+ for branch_info in branches:
642
+ branch_entry: dict[str, Any] = {}
643
+ if "expression" in branch_info:
644
+ branch_entry["expression"] = branch_info["expression"]
645
+
646
+ # Extract sources for this branch (now all unions are in cte_to_id)
647
+ source_ids: set[str] = set()
648
+ full_expr = branch_info.get("_full_expr", "")
649
+ if full_expr:
650
+ for cte_name in cte_to_id.keys():
651
+ if f"{cte_name}." in full_expr:
652
+ source_ids.add(cte_to_id[cte_name])
653
+
654
+ branch_entry["sources"] = sorted(list(source_ids))
655
+ formatted_branches.append(branch_entry)
656
+
657
+ # Create the union transformation
658
+ union_transform: dict[str, Any] = {
659
+ "id": transform_id,
660
+ "type": "union",
661
+ "column": column,
662
+ "branches": formatted_branches,
663
+ }
664
+
665
+ transform_map[transform_id] = union_transform
666
+
667
+ # SECOND PASS: Extract sources by looking for CTE references in expressions
668
+ sources_map: dict[str, set[str]] = {} # transform_id -> set of source ids
669
+
670
+ for transform_id, node, expr_str in nodes_with_data:
671
+ source_ids: set[str] = set()
672
+
673
+ # If it's a wildcard (SELECT *), get the table from the FROM clause
674
+ if expr_str.strip() == "*" and hasattr(node.source, "find"):
675
+ table_node = node.source.find(exp.Table)
676
+ if table_node and hasattr(table_node, "this"):
677
+ table_name = str(table_node.this)
678
+ source_ids.add(f"table:{table_name}")
679
+
680
+ # Look for CTE/table references in the expression (e.g., "orders.order_id")
681
+ for cte_name, cte_id in cte_to_id.items():
682
+ # Don't mark self-reference as a source
683
+ if cte_id != transform_id and f"{cte_name}." in expr_str:
684
+ source_ids.add(cte_id)
685
+
686
+ # Also check if expression uses an alias - resolve to actual CTE/table
687
+ # e.g., expression "seg.customer_id" where seg is alias for all_segments
688
+ if "." in expr_str and hasattr(node, "source"):
689
+ # Extract the prefix (potential alias)
690
+ alias_candidate = expr_str.split(".")[0].strip()
691
+
692
+ # Try to resolve it from the source
693
+ if hasattr(node.source, "find"):
694
+ table_node = node.source.find(exp.Table)
695
+ if table_node:
696
+ actual_name = None
697
+ # Get the actual table/CTE name
698
+ if hasattr(table_node, "this") and table_node.this:
699
+ actual_name = str(table_node.this)
700
+
701
+ # Get the alias
702
+ if hasattr(table_node, "alias_or_name"):
703
+ alias = str(table_node.alias_or_name)
704
+
705
+ # If expression uses this alias, resolve to actual name
706
+ # Include alias in output as cte:name[alias] or table:name[alias]
707
+ if alias == alias_candidate and actual_name and actual_name in cte_to_id:
708
+ base_id = cte_to_id[actual_name]
709
+ # Add alias notation if alias differs from actual name
710
+ if alias != actual_name:
711
+ source_ids.add(f"{base_id}[{alias}]")
712
+ else:
713
+ source_ids.add(base_id)
714
+
715
+ # Merge sources for this transformation (handle duplicates from multiple nodes)
716
+ sources_map.setdefault(transform_id, set()).update(source_ids)
717
+
718
+ # Build final list with sources
719
+ transformations: list[dict[str, Any]] = []
720
+ for trans in transform_map.values():
721
+ # Union types already have sources in branches, don't add top-level sources
722
+ if trans.get("type") != "union":
723
+ trans["sources"] = sorted(list(sources_map.get(trans["id"], set())))
724
+ transformations.append(trans)
725
+
726
+ # Add outer query transformation if we detected it
727
+ if outer_query_sources:
728
+ # Extract column name from any transformation (they all have the same column)
729
+ column_for_query = next((t.get("column") for t in transformations), "")
730
+
731
+ # Resolve outer query source references to namespaced IDs
732
+ resolved_sources: list[str] = []
733
+ for ref_name in outer_query_sources:
734
+ # Check if it's a known CTE or table
735
+ if ref_name in cte_to_id:
736
+ resolved_sources.append(cte_to_id[ref_name])
737
+ else:
738
+ # Not in cte_to_id, so it must be a table reference
739
+ # Use the same logic as in FIRST PASS to determine if it's a table
740
+ resolved_sources.append(f"table:{ref_name}")
741
+
742
+ outer_query: dict[str, Any] = {
743
+ "id": "query",
744
+ "type": "outer_query",
745
+ "column": column_for_query,
746
+ "sources": sorted(resolved_sources),
747
+ }
748
+ # Insert at the beginning (outer query is first in the chain)
749
+ transformations.insert(0, outer_query)
750
+
751
+ return transformations
752
+
753
+
754
+ def _extract_dependencies_from_lineage(
755
+ lineage_node: Any,
756
+ manifest: ManifestLoader | None,
757
+ schema_mapping: dict[str, Any],
758
+ depth: int | None,
759
+ column_name: str = "",
760
+ ) -> list[dict[str, Any]]:
761
+ """Extract column dependencies from sqlglot lineage node.
762
+
763
+ Args:
764
+ lineage_node: Result from sqlglot.lineage()
765
+ manifest: Optional ManifestLoader for dbt resource lookup
766
+ schema_mapping: Schema mapping for table and column resolution
767
+ depth: Maximum depth to traverse
768
+ column_name: Original column being traced (used to replace wildcards)
769
+
770
+ Returns:
771
+ List of dependency dicts with column, table, optional dbt_resource,
772
+ and internal CTE transformation path (via_ctes, transformations)
773
+ """
774
+ dependencies: list[dict[str, Any]] = []
775
+ last_cte_column: dict[str, str] = {} # Track last CTE column before each table
776
+
777
+ def walk_dependencies(node: Any, depth_current: int = 0) -> None:
778
+ """Recursively walk lineage tree."""
779
+ if depth is not None and depth_current >= depth:
780
+ return
781
+
782
+ for dep in node.walk():
783
+ # Check if this is a table reference (external dependency)
784
+ if hasattr(dep, "source") and hasattr(dep.source, "this"):
785
+ table_name = str(dep.source.this)
786
+ col_name = dep.name
787
+
788
+ # Skip our artificial wrapper CTE
789
+ if table_name == "__lineage_final__":
790
+ continue
791
+
792
+ # Track CTE columns (non-Table sources)
793
+ source_type = type(dep.source).__name__
794
+ if source_type != "Table":
795
+ # This is a CTE/Select - track the column name for the next Table dependency
796
+ # Extract column name from CTE-qualified references like "orders.order_id"
797
+ if "." in col_name and col_name != "*":
798
+ parts = col_name.split(".")
799
+ cte_col = parts[-1]
800
+ # Store the most recent non-wildcard column for the next table we'll encounter
801
+ last_cte_column["__next_table__"] = cte_col
802
+
803
+ # Extract from_table if the Select has a FROM clause with a table
804
+ if hasattr(dep.source, "find") and callable(dep.source.find):
805
+ from_table_node = dep.source.find(exp.Table)
806
+ if from_table_node and hasattr(from_table_node, "this") and from_table_node.this: # type: ignore[reportAttributeAccessIssue]
807
+ from_table = str(from_table_node.this) # type: ignore[reportAttributeAccessIssue]
808
+ last_cte_column["__next_table_source__"] = from_table
809
+
810
+ continue # Don't add to dependencies yet
811
+
812
+ dependency_info: dict[str, Any] = {
813
+ "column": col_name,
814
+ "table": table_name,
815
+ }
816
+
817
+ # Always extract db/schema metadata if available
818
+ db = getattr(dep.source, "catalog", None)
819
+ schema_name = getattr(dep.source, "db", None)
820
+
821
+ if db:
822
+ dependency_info["database"] = str(db)
823
+ if schema_name:
824
+ dependency_info["schema"] = str(schema_name)
825
+
826
+ # This is a Table reference - add to dependencies
827
+ # Extract just the column name from potentially qualified references
828
+ # e.g., "customers.first_name" -> "first_name", "id" -> "id"
829
+ final_col_name = col_name.split(".")[-1] if col_name else col_name
830
+
831
+ # Replace wildcard with the actual column being traced
832
+ if final_col_name == "*" and column_name:
833
+ final_col_name = column_name
834
+
835
+ # Check if we tracked a CTE column that feeds this table reference
836
+ if "__next_table__" in last_cte_column:
837
+ source_column = last_cte_column["__next_table__"]
838
+ # Use the source column as the actual column in this dependency
839
+ final_col_name = source_column
840
+ # Clear it after use
841
+ del last_cte_column["__next_table__"]
842
+
843
+ # Update dependency with clean column name
844
+ dependency_info["column"] = final_col_name
845
+
846
+ # Try to find the corresponding dbt resource
847
+ if manifest:
848
+ try:
849
+ matching_node = manifest.get_resource_node(table_name)
850
+ if not matching_node.get("multiple_matches"):
851
+ dependency_info["dbt_resource"] = matching_node.get("unique_id", "")
852
+ except Exception:
853
+ pass # Resource lookup failed, continue with table name
854
+
855
+ dependencies.append(dependency_info)
856
+
857
+ walk_dependencies(lineage_node)
858
+
859
+ # Deduplicate dependencies while preserving order
860
+ seen = set()
861
+ deduplicated = []
862
+ for dep in dependencies:
863
+ # Create a key for deduplication
864
+ key = (dep.get("column"), dep.get("table"), dep.get("database"), dep.get("schema"))
865
+ if key not in seen:
866
+ seen.add(key)
867
+ deduplicated.append(dep)
868
+
869
+ return deduplicated
870
+
871
+
872
+ def _extract_cte_path(lineage_node: Any) -> dict[str, Any]:
873
+ """Extract the CTE transformation path from a lineage node.
874
+
875
+ Sqlglot lineage nodes have names like:
876
+ - "final.tier" (CTE reference)
877
+ - "aggregated.order_count" (CTE reference)
878
+ - "customer_id" (direct column, no CTE)
879
+ - "base.customer_id" (CTE reference)
880
+
881
+ The part before the dot is the CTE name. We walk up the lineage to collect
882
+ all CTEs in the transformation chain.
883
+
884
+ Args:
885
+ lineage_node: A single dependency node from sqlglot lineage walk
886
+
887
+ Returns:
888
+ Dictionary with:
889
+ - via_ctes: List of CTE names in transformation order (root to leaf)
890
+ - transformations: List of transformation details per step
891
+ """
892
+ via_ctes: list[str] = []
893
+ transformations: list[dict[str, str]] = []
894
+ visited: set[int] = set() # Track visited nodes by id to prevent cycles
895
+
896
+ # Walk up the lineage chain (using downstream references)
897
+ current = lineage_node
898
+ max_iterations = 100 # Hard limit to prevent infinite loops
899
+ iteration = 0
900
+
901
+ while current is not None and iteration < max_iterations:
902
+ iteration += 1
903
+
904
+ # Prevent cycles by tracking visited nodes
905
+ node_id = id(current)
906
+ if node_id in visited:
907
+ break
908
+ visited.add(node_id)
909
+
910
+ # Extract CTE name from node name (format: "cte_name.column_name" or "column_name")
911
+ if hasattr(current, "name") and current.name:
912
+ node_name = current.name
913
+
914
+ # Check if this is a CTE reference (has a dot separator)
915
+ if "." in node_name:
916
+ parts = node_name.split(".", 1)
917
+ cte_name = parts[0]
918
+ column_name = parts[1] if len(parts) > 1 else node_name
919
+
920
+ # Only add if it's not already in the list (dedup) and looks like a CTE
921
+ # (CTEs typically don't have schema qualifiers like "database.schema.table")
922
+ if cte_name and cte_name not in via_ctes:
923
+ # Skip if it looks like a database or schema qualifier
924
+ # (these typically have catalog/db attributes on source)
925
+ is_table_ref = False
926
+ if hasattr(current, "source"):
927
+ source = current.source
928
+ is_table_ref = hasattr(source, "catalog") or getattr(source, "db", None) is not None
929
+
930
+ if not is_table_ref:
931
+ via_ctes.append(cte_name)
932
+
933
+ # Capture transformation info
934
+ transform_info: dict[str, str] = {
935
+ "cte": cte_name,
936
+ "column": column_name,
937
+ }
938
+
939
+ # Try to get expression if available
940
+ if hasattr(current, "expression") and current.expression is not None:
941
+ expr_sql = str(current.expression)
942
+ # Only include if it's not just a simple column reference
943
+ if expr_sql and expr_sql.strip() != column_name:
944
+ # Limit expression length to avoid huge outputs
945
+ if len(expr_sql) > 200:
946
+ expr_sql = expr_sql[:197] + "..."
947
+ transform_info["expression"] = expr_sql
948
+
949
+ transformations.append(transform_info)
950
+
951
+ # Move up the lineage chain (downstream attribute points to parent)
952
+ current = getattr(current, "downstream", None) if hasattr(current, "downstream") else None
953
+
954
+ return {
955
+ "via_ctes": via_ctes,
956
+ "transformations": transformations,
957
+ }
958
+
959
+
960
+ def _resolve_dependency_resource(
961
+ dependency: dict[str, Any],
962
+ relation_lookup: dict[str, str],
963
+ ) -> str | None:
964
+ table_name = dependency.get("table")
965
+ if not table_name:
966
+ return None
967
+
968
+ database = dependency.get("database")
969
+ schema = dependency.get("schema")
970
+
971
+ if database and schema:
972
+ key = _normalize_relation_name(f"{database}.{schema}.{table_name}")
973
+ if key in relation_lookup:
974
+ return relation_lookup[key]
975
+
976
+ if schema:
977
+ key = _normalize_relation_name(f"{schema}.{table_name}")
978
+ if key in relation_lookup:
979
+ return relation_lookup[key]
980
+
981
+ key = _normalize_relation_name(table_name)
982
+ return relation_lookup.get(key)
983
+
984
+
985
+ def _check_column_in_lineage(lineage_node: Any, source_model: str, source_column: str) -> bool:
986
+ """Check if a source column appears in the lineage tree.
987
+
988
+ Args:
989
+ lineage_node: Result from sqlglot.lineage()
990
+ source_model: Model name to look for
991
+ source_column: Column name to look for
992
+
993
+ Returns:
994
+ True if the column appears in the lineage
995
+ """
996
+ logger.debug(f"[LINEAGE_CHECK] Looking for {source_model}.{source_column}")
997
+
998
+ for dep in lineage_node.walk():
999
+ if hasattr(dep, "name"):
1000
+ # dep.name can be either "column_name" or "table.column_name"
1001
+ name_parts = dep.name.lower().split(".")
1002
+ col_name = name_parts[-1] # Get the last part (column name)
1003
+ table_name = name_parts[0] if len(name_parts) > 1 else None
1004
+
1005
+ logger.debug(f"[LINEAGE_CHECK] Checking dep.name='{dep.name}' -> col='{col_name}', table='{table_name}'")
1006
+
1007
+ # Check if column matches
1008
+ if col_name == source_column.lower():
1009
+ logger.debug("[LINEAGE_CHECK] Column matches! Checking table...")
1010
+ # If table name in dep.name, check it matches source_model
1011
+ if table_name and source_model.lower() in table_name:
1012
+ logger.debug(f"[LINEAGE_CHECK] ✓ Table in dep.name matches: {table_name} contains {source_model}")
1013
+ return True
1014
+ # If no table name in dep.name, check source attribute
1015
+ elif hasattr(dep, "source") and hasattr(dep.source, "this"):
1016
+ source_table = str(dep.source.this).strip('"').lower()
1017
+ logger.debug(f"[LINEAGE_CHECK] Checking dep.source.this='{source_table}'")
1018
+ if source_model.lower() in source_table:
1019
+ logger.debug(f"[LINEAGE_CHECK] ✓ Source attribute matches: {source_table} contains {source_model}")
1020
+ return True
1021
+ else:
1022
+ logger.debug("[LINEAGE_CHECK] ✗ Column matches but no table info found")
1023
+
1024
+ logger.debug(f"[LINEAGE_CHECK] ✗ No match found for {source_model}.{source_column}")
1025
+ return False
1026
+
1027
+
1028
+ async def _trace_downstream_column(
1029
+ manifest: ManifestLoader,
1030
+ model_name: str,
1031
+ column_name: str,
1032
+ depth: int | None,
1033
+ dialect: str = "databricks",
1034
+ current_depth: int = 0,
1035
+ ) -> list[dict[str, Any]]:
1036
+ """Recursively trace where a column is used downstream.
1037
+
1038
+ Args:
1039
+ manifest: ManifestLoader instance
1040
+ model_name: Model name to start from
1041
+ column_name: Column name to trace
1042
+ depth: Maximum depth to traverse (None for unlimited)
1043
+ dialect: SQL dialect for parsing (default: databricks)
1044
+ current_depth: Current recursion depth
1045
+
1046
+ Returns:
1047
+ List of downstream usage dictionaries
1048
+ """
1049
+ logger.info(f"[DOWNSTREAM] Tracing {model_name}.{column_name} at depth {current_depth}")
1050
+
1051
+ if depth is not None and current_depth >= depth:
1052
+ return []
1053
+
1054
+ results: list[dict[str, Any]] = []
1055
+
1056
+ # Get downstream models (distance 1)
1057
+ try:
1058
+ lineage_data = manifest.get_lineage(model_name, resource_type="model", direction="downstream", depth=1)
1059
+ except Exception as e:
1060
+ logger.warning(f"Could not get downstream lineage for {model_name}: {e}")
1061
+ return []
1062
+
1063
+ downstream_models = lineage_data.get("downstream", [])
1064
+ logger.info(f"[DOWNSTREAM] Found {len(downstream_models)} downstream models: {[m.get('name') for m in downstream_models]}")
1065
+
1066
+ for downstream_model in downstream_models:
1067
+ # Only process models (skip tests, snapshots, etc.)
1068
+ if not downstream_model.get("unique_id", "").startswith("model."):
1069
+ continue
1070
+
1071
+ # Skip CTE unit test helper models (auto-generated by CTE test generator)
1072
+ # Pattern: ends with __<6-char-hash> like customers_enriched__customer_agg__259035
1073
+ model_name_check = downstream_model.get("name", "")
1074
+ if model_name_check and len(model_name_check) > 8:
1075
+ # Check if ends with __<6 hex chars>
1076
+ suffix = model_name_check[-8:]
1077
+ if suffix[:2] == "__" and all(c in "0123456789abcdef" for c in suffix[2:]):
1078
+ logger.debug(f"Skipping CTE test model: {model_name_check}")
1079
+ continue
1080
+
1081
+ try:
1082
+ # Get downstream model info (schema + SQL) using unified helper
1083
+ model_name_downstream = downstream_model["name"]
1084
+ downstream_info, compiled_sql, schema_mapping = _prepare_model_analysis(
1085
+ manifest,
1086
+ model_name_downstream,
1087
+ )
1088
+
1089
+ logger.info(f"[DOWNSTREAM] {model_name_downstream}: has_sql={compiled_sql is not None}")
1090
+
1091
+ # Resolve output columns using centralized logic
1092
+ output_columns_dict, output_source = _resolve_output_columns(compiled_sql, schema_mapping, downstream_info, dialect)
1093
+ if output_columns_dict:
1094
+ logger.info(f"[DOWNSTREAM] {model_name_downstream}: using {len(output_columns_dict)} columns from {output_source}")
1095
+ else:
1096
+ # Fallback: use string search if no column metadata available
1097
+ logger.debug(f"[DOWNSTREAM] {model_name_downstream}: No column metadata, using string search")
1098
+ source_column_ref_patterns = [
1099
+ f"{model_name}.{column_name}",
1100
+ f".{column_name}",
1101
+ f" {column_name} ",
1102
+ f" {column_name},",
1103
+ f"({column_name}",
1104
+ ]
1105
+ sql_lower = compiled_sql.lower()
1106
+ found_reference = any(pattern.lower() in sql_lower for pattern in source_column_ref_patterns)
1107
+
1108
+ if found_reference:
1109
+ logger.info(f"[DOWNSTREAM] ✓ {model_name_downstream} references {model_name}.{column_name} (string search)")
1110
+ results.append({"model": model_name_downstream, "column": column_name, "distance": current_depth + 1})
1111
+ further_downstream = await _trace_downstream_column(manifest, model_name_downstream, column_name, depth, dialect, current_depth + 1)
1112
+ results.extend(further_downstream)
1113
+ continue
1114
+
1115
+ # Check each output column to see if it uses our source column
1116
+ logger.debug(f"[DOWNSTREAM] {model_name_downstream}: checking {len(output_columns_dict)} output columns")
1117
+ sql_lower = compiled_sql.lower()
1118
+ source_column_ref_patterns = [
1119
+ f"{model_name}.{column_name}",
1120
+ f".{column_name}",
1121
+ f" {column_name} ",
1122
+ f" {column_name},",
1123
+ f"({column_name}",
1124
+ ]
1125
+
1126
+ for output_col_name in output_columns_dict.keys():
1127
+ try:
1128
+ # Trace this output column's lineage using unified helper
1129
+ column_lineage_result = _analyze_column_lineage(
1130
+ output_col_name,
1131
+ compiled_sql,
1132
+ schema_mapping,
1133
+ model_name_downstream,
1134
+ dialect,
1135
+ )
1136
+
1137
+ logger.debug(f"[DOWNSTREAM] Check if {output_col_name} uses {model_name}.{column_name}")
1138
+
1139
+ # Check if our source column appears in the dependencies
1140
+ if _check_column_in_lineage(column_lineage_result, model_name, column_name):
1141
+ logger.info(f"[DOWNSTREAM] ✓ {model_name_downstream}.{output_col_name} USES {model_name}.{column_name}")
1142
+ # This output column uses our source column!
1143
+ results.append({"model": model_name_downstream, "column": output_col_name, "distance": current_depth + 1})
1144
+
1145
+ # Recurse: trace this column further downstream
1146
+ further_downstream = await _trace_downstream_column(manifest, model_name_downstream, output_col_name, depth, dialect, current_depth + 1)
1147
+ results.extend(further_downstream)
1148
+ else:
1149
+ # Heuristic fallback when lineage doesn't resolve dependencies
1150
+ if output_col_name.lower() == column_name.lower() and any(pattern.lower() in sql_lower for pattern in source_column_ref_patterns):
1151
+ logger.info(f"[DOWNSTREAM] ✓ {model_name_downstream}.{output_col_name} USES {model_name}.{column_name} (heuristic)")
1152
+ results.append({"model": model_name_downstream, "column": output_col_name, "distance": current_depth + 1})
1153
+
1154
+ further_downstream = await _trace_downstream_column(manifest, model_name_downstream, output_col_name, depth, dialect, current_depth + 1)
1155
+ results.extend(further_downstream)
1156
+ else:
1157
+ logger.debug(f"[DOWNSTREAM] ✗ {model_name_downstream}.{output_col_name} does NOT use {model_name}.{column_name}")
1158
+
1159
+ except SqlglotError as e:
1160
+ # Heuristic fallback when sqlglot fails
1161
+ if output_col_name.lower() == column_name.lower() and any(pattern.lower() in sql_lower for pattern in source_column_ref_patterns):
1162
+ logger.info(f"[DOWNSTREAM] ✓ {model_name_downstream}.{output_col_name} USES {model_name}.{column_name} (heuristic)")
1163
+ results.append({"model": model_name_downstream, "column": output_col_name, "distance": current_depth + 1})
1164
+
1165
+ further_downstream = await _trace_downstream_column(manifest, model_name_downstream, output_col_name, depth, dialect, current_depth + 1)
1166
+ results.extend(further_downstream)
1167
+ else:
1168
+ logger.warning(f"Could not trace {model_name_downstream}.{output_col_name}: {e}")
1169
+ continue
1170
+
1171
+ except Exception as e:
1172
+ # Use get() to avoid UnboundLocalError if exception occurs before model_name_downstream is set
1173
+ model_name_for_log = downstream_model.get("name", "unknown")
1174
+ logger.warning(f"Error analyzing downstream model {model_name_for_log}: {e}")
1175
+ continue
1176
+
1177
+ return results
1178
+
1179
+
1180
+ def _prepare_model_analysis(
1181
+ manifest: ManifestLoader,
1182
+ model_name: str,
1183
+ ) -> tuple[dict[str, Any], str, dict[str, Any]]:
1184
+ """Prepare model for column lineage analysis (unified helper).
1185
+
1186
+ Gets all necessary data in one call:
1187
+ - Resource info (metadata, columns)
1188
+ - Compiled SQL
1189
+ - Schema mapping for sqlglot context
1190
+
1191
+ Args:
1192
+ manifest: ManifestLoader instance
1193
+ model_name: Model name to analyze
1194
+
1195
+ Returns:
1196
+ Tuple of (resource_info, compiled_sql, schema_mapping)
1197
+
1198
+ Raises:
1199
+ ValueError: If model not found or has no compiled SQL
1200
+ """
1201
+ # Get resource info with compiled SQL and database schema
1202
+ resource_info = manifest.get_resource_info(
1203
+ model_name,
1204
+ resource_type="model",
1205
+ include_compiled_sql=True,
1206
+ include_database_schema=True,
1207
+ )
1208
+
1209
+ compiled_sql = resource_info.get("compiled_sql")
1210
+ if not compiled_sql:
1211
+ raise ValueError(f"No compiled SQL found for model '{model_name}'")
1212
+
1213
+ # Build schema mapping from upstream models
1214
+ try:
1215
+ upstream_lineage = manifest.get_lineage(
1216
+ model_name,
1217
+ resource_type="model",
1218
+ direction="upstream",
1219
+ depth=1,
1220
+ )
1221
+ schema_mapping = _build_schema_mapping(manifest, upstream_lineage)
1222
+ except (ValueError, KeyError, AttributeError) as e:
1223
+ logger.warning(f"Could not build schema mapping for {model_name}: {e}")
1224
+ schema_mapping = {}
1225
+
1226
+ return resource_info, compiled_sql, schema_mapping
1227
+
1228
+
1229
+ def _clean_static_union_branches(ast: exp.Expression, column_name: str, dialect: str) -> exp.Expression:
1230
+ """Remove UNION branches that contain only static literal values.
1231
+
1232
+ When a UNION fails during lineage tracing due to index errors, it's often
1233
+ because one branch has all static values (e.g., '-1' as gold_itemkey) while
1234
+ another has dynamic column references. This function removes all-static branches
1235
+ to allow lineage tracing to proceed.
1236
+
1237
+ Args:
1238
+ ast: The SQL AST to clean
1239
+ column_name: Column being traced (for logging)
1240
+ dialect: SQL dialect
1241
+
1242
+ Returns:
1243
+ Cleaned AST with static UNION branches removed
1244
+ """
1245
+ # Find all UNION nodes (they appear as Union expressions)
1246
+ unions_found = list(ast.find_all(exp.Union))
1247
+
1248
+ if not unions_found:
1249
+ return ast # No UNIONs to clean
1250
+
1251
+ logger.debug(f"Found {len(unions_found)} UNION nodes to analyze")
1252
+
1253
+ for union_node in unions_found:
1254
+ # UNION in sqlglot is binary: has 'this' (left) and 'expression' (right)
1255
+ left_branch = union_node.this
1256
+ right_branch = union_node.expression
1257
+
1258
+ # Check if each branch is all-static (contains only literals)
1259
+ left_is_static = _is_all_static_branch(left_branch)
1260
+ right_is_static = _is_all_static_branch(right_branch)
1261
+
1262
+ logger.debug(f"UNION branch analysis: left_static={left_is_static}, right_static={right_is_static}")
1263
+
1264
+ if left_is_static and not right_is_static:
1265
+ # Keep only right branch
1266
+ logger.info(f"Removing static left UNION branch for column {column_name}")
1267
+ union_node.replace(right_branch)
1268
+ elif right_is_static and not left_is_static:
1269
+ # Keep only left branch
1270
+ logger.info(f"Removing static right UNION branch for column {column_name}")
1271
+ union_node.replace(left_branch)
1272
+ elif left_is_static and right_is_static:
1273
+ # Both static - this is unusual but keep left by convention
1274
+ logger.warning(f"Both UNION branches are static for column {column_name}, keeping left")
1275
+ union_node.replace(left_branch)
1276
+ # else: both dynamic, keep the UNION as-is
1277
+
1278
+ return ast
1279
+
1280
+
1281
+ def _is_all_static_branch(branch: exp.Expression) -> bool:
1282
+ """Check if a SELECT branch contains only static literal expressions.
1283
+
1284
+ Args:
1285
+ branch: A SELECT expression to analyze
1286
+
1287
+ Returns:
1288
+ True if all SELECT expressions are literals, False otherwise
1289
+ """
1290
+ # Find the SELECT node in this branch
1291
+ if isinstance(branch, exp.Select):
1292
+ select_node = branch
1293
+ else:
1294
+ select_node = branch.find(exp.Select)
1295
+
1296
+ if not select_node:
1297
+ return False
1298
+
1299
+ # Check all expressions in the SELECT clause
1300
+ for select_expr in select_node.expressions:
1301
+ # Get the actual expression (might be wrapped in Alias)
1302
+ if isinstance(select_expr, exp.Alias):
1303
+ actual_expr = select_expr.this
1304
+ else:
1305
+ actual_expr = select_expr
1306
+
1307
+ # If we find any non-literal, this branch is dynamic
1308
+ if not isinstance(actual_expr, exp.Literal):
1309
+ # Also check for NULL which is represented differently
1310
+ if not (isinstance(actual_expr, exp.Null)):
1311
+ return False
1312
+
1313
+ # All expressions are literals
1314
+ return True
1315
+
1316
+
1317
+ def _analyze_column_lineage(
1318
+ column_name: str,
1319
+ compiled_sql: str,
1320
+ schema_mapping: dict[str, Any],
1321
+ model_name: str,
1322
+ dialect: str = "databricks",
1323
+ ) -> Any:
1324
+ """Run sqlglot column lineage analysis with automatic UNION cleanup.
1325
+
1326
+ Tries lineage analysis, and if it fails due to UNION structure issues (IndexError),
1327
+ automatically cleans static UNION branches and retries. This handles the common
1328
+ data warehouse pattern of UNION ALL with hardcoded default rows.
1329
+
1330
+ Args:
1331
+ column_name: Column to analyze
1332
+ compiled_sql: Compiled SQL
1333
+ schema_mapping: Schema context for sqlglot
1334
+ model_name: Model name (for error messages)
1335
+ dialect: SQL dialect for parsing (default: databricks)
1336
+
1337
+ Returns:
1338
+ sqlglot lineage node
1339
+
1340
+ Raises:
1341
+ ValueError: If SQL parsing fails or UNION cleanup doesn't resolve issues
1342
+ """
1343
+ try:
1344
+ # Wrap the SQL to enable lineage tracing through SELECT *
1345
+ wrapped_ast = _wrap_final_select(compiled_sql, column_name, dialect)
1346
+
1347
+ max_attempts = 3
1348
+ for attempt in range(max_attempts):
1349
+ try:
1350
+ result = lineage(
1351
+ column=column_name,
1352
+ sql=wrapped_ast,
1353
+ schema=schema_mapping,
1354
+ dialect=dialect,
1355
+ )
1356
+
1357
+ # Test the result by walking the tree to trigger any IndexError
1358
+ for _ in result.walk():
1359
+ pass
1360
+
1361
+ if attempt > 0:
1362
+ logger.info(f"Lineage succeeded after {attempt} UNION cleanup attempt(s) for {model_name}.{column_name}")
1363
+
1364
+ return result # Success!
1365
+
1366
+ except IndexError as e:
1367
+ if "index out of range" in str(e).lower():
1368
+ if attempt == max_attempts - 1:
1369
+ raise ValueError(f"UNION structure issue in {model_name}.{column_name} after {max_attempts} cleanup attempts: {e}")
1370
+
1371
+ logger.info(f"UNION index error detected for {model_name}.{column_name}, cleaning static branches (attempt {attempt + 1})")
1372
+ # Clean the AST and retry
1373
+ wrapped_ast = _clean_static_union_branches(wrapped_ast, column_name, dialect)
1374
+ else:
1375
+ # Different kind of IndexError, re-raise
1376
+ raise
1377
+
1378
+ except SqlglotError as e:
1379
+ raise ValueError(f"Failed to parse SQL for column lineage: {e}\nModel: {model_name}, Column: {column_name}")
1380
+ except Exception as e:
1381
+ logger.exception("Unexpected error in column lineage analysis")
1382
+ raise ValueError(f"Column lineage analysis failed: {e}")
1383
+
1384
+
1385
+ def _trace_upstream_recursive(
1386
+ manifest: ManifestLoader,
1387
+ model_name: str,
1388
+ column_name: str,
1389
+ depth: int | None,
1390
+ dialect: str = "databricks",
1391
+ current_depth: int = 0,
1392
+ relation_lookup: dict[str, str] | None = None,
1393
+ visited: set[str] | None = None,
1394
+ ) -> list[dict[str, Any]]:
1395
+ """Recursively trace upstream dependencies for a column.
1396
+
1397
+ Uses unified SQL parsing approach (same as downstream):
1398
+ 1. Prepare model (get resource_info + compiled_sql + schema_mapping)
1399
+ 2. Resolve output columns (SQL → warehouse → manifest)
1400
+ 3. Run lineage per resolved column
1401
+ 4. Recurse on dependencies
1402
+
1403
+ Args:
1404
+ manifest: ManifestLoader instance
1405
+ model_name: Model name to analyze
1406
+ column_name: Column name to trace
1407
+ depth: Maximum depth (None for unlimited)
1408
+ dialect: SQL dialect for parsing (default: databricks)
1409
+ current_depth: Current recursion depth
1410
+ relation_lookup: FQN → unique_id mapping (built once, reused)
1411
+ visited: Set of visited unique_ids (prevent cycles)
1412
+
1413
+ Returns:
1414
+ List of upstream dependencies with dbt resource mapping
1415
+ """
1416
+ # Skip wildcards
1417
+ if column_name.strip() == "*":
1418
+ return []
1419
+
1420
+ # Check depth limit
1421
+ if depth is not None and current_depth >= depth:
1422
+ return []
1423
+
1424
+ # Initialize tracking on first call
1425
+ if relation_lookup is None:
1426
+ relation_lookup = _build_relation_lookup(manifest)
1427
+
1428
+ if visited is None:
1429
+ visited = set()
1430
+
1431
+ # Prepare model for analysis (unified helper - Phase 1)
1432
+ try:
1433
+ _, compiled_sql, schema_mapping = _prepare_model_analysis(
1434
+ manifest,
1435
+ model_name,
1436
+ )
1437
+ except ValueError as e:
1438
+ logger.warning(f"Could not prepare model {model_name}: {e}")
1439
+ return []
1440
+
1441
+ # Run lineage analysis (unified helper - Phase 1)
1442
+ try:
1443
+ column_lineage_result = _analyze_column_lineage(
1444
+ column_name,
1445
+ compiled_sql,
1446
+ schema_mapping,
1447
+ model_name,
1448
+ dialect,
1449
+ )
1450
+ except ValueError as e:
1451
+ logger.warning(f"Could not analyze column {model_name}.{column_name}: {e}")
1452
+ return []
1453
+
1454
+ # Extract root model transformations using new format (id, type, sources)
1455
+ root_transformations = _extract_transformations_with_sources(column_lineage_result)
1456
+
1457
+ # Extract dependencies and resolve to dbt resources
1458
+ dependencies = _extract_dependencies_from_lineage(
1459
+ column_lineage_result,
1460
+ manifest,
1461
+ schema_mapping,
1462
+ depth,
1463
+ column_name,
1464
+ )
1465
+
1466
+ # Resolve dependency FQNs to dbt resources
1467
+ for dependency in dependencies:
1468
+ if dependency.get("dbt_resource"):
1469
+ continue
1470
+
1471
+ resolved = _resolve_dependency_resource(dependency, relation_lookup)
1472
+ if resolved:
1473
+ dependency["dbt_resource"] = resolved
1474
+
1475
+ # Enrich model dependencies with their internal transformations
1476
+ for dependency in dependencies:
1477
+ dbt_resource = dependency.get("dbt_resource")
1478
+ if not dbt_resource:
1479
+ continue
1480
+
1481
+ node = manifest.get_node_by_unique_id(dbt_resource)
1482
+ if not node or node.get("resource_type") != "model":
1483
+ continue
1484
+
1485
+ # Get the dependency model's internal transformations
1486
+ dep_model_name = node.get("name")
1487
+ dep_column = dependency.get("column", "")
1488
+
1489
+ if dep_model_name and dep_column:
1490
+ try:
1491
+ # Analyze the dependency model to get its internal CTEs
1492
+ _, dep_compiled_sql, dep_schema_mapping = _prepare_model_analysis(manifest, dep_model_name)
1493
+ dep_lineage = _analyze_column_lineage(dep_column, dep_compiled_sql, dep_schema_mapping, dep_model_name, dialect)
1494
+
1495
+ # Extract internal transformations using new format (id, type, sources)
1496
+ internal_transformations = _extract_transformations_with_sources(dep_lineage)
1497
+
1498
+ # Replace the accumulated transformations with internal ones
1499
+ if internal_transformations:
1500
+ dependency["transformations"] = internal_transformations
1501
+ # Update via_ctes to match (extract CTE ids)
1502
+ via_ctes = []
1503
+ for t in internal_transformations:
1504
+ if t.get("type") == "cte":
1505
+ # Remove namespace prefix for via_ctes backward compatibility
1506
+ cte_id = t.get("id", "")
1507
+ if cte_id.startswith("cte:"):
1508
+ cte_name = cte_id[4:] # Strip "cte:" prefix
1509
+ if cte_name not in via_ctes:
1510
+ via_ctes.append(cte_name)
1511
+ dependency["via_ctes"] = via_ctes
1512
+
1513
+ except Exception as e:
1514
+ logger.warning(f"Could not extract internal transformations for {dep_model_name}.{dep_column}: {e}")
1515
+ # Keep existing transformations if extraction fails
1516
+
1517
+ # Store results with root model transformations
1518
+ results = list(dependencies)
1519
+
1520
+ # Add root transformations metadata to each result (will be extracted by formatter)
1521
+ if root_transformations and results:
1522
+ for result in results:
1523
+ result["__root_transformations__"] = root_transformations
1524
+ elif root_transformations:
1525
+ # No dependencies found, but we have root transformations - return them anyway
1526
+ results = [{"__root_transformations__": root_transformations}]
1527
+
1528
+ # Recurse on each dependency
1529
+ for i, dependency in enumerate(dependencies):
1530
+ dbt_resource = dependency.get("dbt_resource")
1531
+ if not dbt_resource or dbt_resource in visited:
1532
+ continue
1533
+
1534
+ node = manifest.get_node_by_unique_id(dbt_resource)
1535
+ if not node:
1536
+ continue
1537
+
1538
+ if node.get("resource_type") != "model":
1539
+ # This is a source/seed - already in results, don't recurse
1540
+ continue
1541
+
1542
+ next_model = node.get("name")
1543
+ if not next_model:
1544
+ continue
1545
+
1546
+ visited.add(dbt_resource)
1547
+
1548
+ # Extract column name from dependency
1549
+ dep_column = dependency.get("column", "")
1550
+
1551
+ # Handle wildcards: resolve upstream model's output columns
1552
+ if dep_column.endswith(".*") or dep_column.strip() == "*":
1553
+ # Use unified helper to prepare upstream model
1554
+ try:
1555
+ upstream_info, upstream_sql, upstream_schema_mapping = _prepare_model_analysis(
1556
+ manifest,
1557
+ next_model,
1558
+ )
1559
+ except ValueError:
1560
+ continue
1561
+
1562
+ # Resolve output columns using unified logic
1563
+ output_columns_dict, _ = _resolve_output_columns(
1564
+ upstream_sql,
1565
+ upstream_schema_mapping,
1566
+ upstream_info,
1567
+ dialect,
1568
+ )
1569
+
1570
+ # Determine which column to trace through the wildcard
1571
+ # Check previous dependency for the actual column name
1572
+ if i > 0:
1573
+ prev_dep = dependencies[i - 1]
1574
+ # Try transformations first, then fall back to column field
1575
+ if prev_dep.get("transformations"):
1576
+ column_to_trace = prev_dep["transformations"][0].get("column", column_name)
1577
+ else:
1578
+ # Strip CTE prefix if present (e.g., "customers.customer_id" -> "customer_id")
1579
+ prev_column = prev_dep.get("column", "")
1580
+ column_to_trace = prev_column.split(".")[-1] if prev_column else column_name
1581
+ else:
1582
+ # No previous dependency - use original target column
1583
+ column_to_trace = column_name
1584
+
1585
+ # Only trace the specific column we're looking for, not all columns
1586
+ if column_to_trace in output_columns_dict:
1587
+ results.extend(
1588
+ _trace_upstream_recursive(
1589
+ manifest,
1590
+ next_model,
1591
+ column_to_trace,
1592
+ depth,
1593
+ dialect,
1594
+ current_depth + 1,
1595
+ relation_lookup,
1596
+ visited,
1597
+ )
1598
+ )
1599
+ continue
1600
+
1601
+ # Regular column: extract name and recurse
1602
+ next_column = dep_column.split(".")[-1] if dep_column else column_name
1603
+
1604
+ results.extend(
1605
+ _trace_upstream_recursive(
1606
+ manifest,
1607
+ next_model,
1608
+ next_column,
1609
+ depth,
1610
+ dialect,
1611
+ current_depth + 1,
1612
+ relation_lookup,
1613
+ visited,
1614
+ )
1615
+ )
1616
+
1617
+ return results
1618
+
1619
+ return results
1620
+
1621
+
1622
+ async def _build_downstream_usages(
1623
+ manifest: ManifestLoader,
1624
+ model_name: str,
1625
+ column_name: str,
1626
+ depth: int | None,
1627
+ dialect: str,
1628
+ current_depth: int = 0,
1629
+ ) -> list[dict[str, Any]]:
1630
+ """Build downstream usages with transformations (reversed order).
1631
+
1632
+ For each downstream model using this column, trace how they transform it
1633
+ from input to output. Transformations are in reversed order (bottom-to-top:
1634
+ table input → CTEs → outer query).
1635
+
1636
+ Args:
1637
+ manifest: ManifestLoader instance
1638
+ model_name: Source model name
1639
+ column_name: Source column name
1640
+ depth: Maximum depth to traverse
1641
+ dialect: SQL dialect for parsing
1642
+ current_depth: Current recursion depth
1643
+
1644
+ Returns:
1645
+ List of usage dicts with transformations
1646
+ """
1647
+ if depth is not None and current_depth >= depth:
1648
+ return []
1649
+
1650
+ try:
1651
+ lineage_data = manifest.get_lineage(
1652
+ model_name,
1653
+ resource_type="model",
1654
+ direction="downstream",
1655
+ depth=1,
1656
+ )
1657
+ except (ValueError, KeyError, AttributeError) as e:
1658
+ logger.warning(f"Could not get downstream lineage for {model_name}: {e}")
1659
+ return []
1660
+
1661
+ downstream_models = lineage_data.get("downstream", [])
1662
+ usages = []
1663
+
1664
+ for downstream_model in downstream_models:
1665
+ if not downstream_model.get("unique_id", "").startswith("model."):
1666
+ continue
1667
+
1668
+ # Skip CTE unit test helper models
1669
+ model_name_downstream = downstream_model.get("name", "")
1670
+ if model_name_downstream and len(model_name_downstream) > 8:
1671
+ suffix = model_name_downstream[-8:]
1672
+ if suffix[:2] == "__" and all(c in "0123456789abcdef" for c in suffix[2:]):
1673
+ continue
1674
+
1675
+ try:
1676
+ # Get downstream model's SQL and trace transformations
1677
+ downstream_info, compiled_sql, schema_mapping = _prepare_model_analysis(
1678
+ manifest,
1679
+ model_name_downstream,
1680
+ )
1681
+
1682
+ # Resolve output columns to find which column(s) use our source column
1683
+ output_columns_dict, _ = _resolve_output_columns(compiled_sql, schema_mapping, downstream_info, dialect)
1684
+
1685
+ # Find output columns that reference our source column
1686
+ for output_col_name in output_columns_dict:
1687
+ try:
1688
+ # Trace transformations for this output column
1689
+ column_lineage_result = _analyze_column_lineage(
1690
+ column_name=output_col_name,
1691
+ compiled_sql=compiled_sql,
1692
+ schema_mapping=schema_mapping,
1693
+ model_name=model_name_downstream,
1694
+ dialect=dialect,
1695
+ )
1696
+
1697
+ # Extract transformations (shows how downstream model builds the column)
1698
+ transformations = _extract_transformations_with_sources(column_lineage_result)
1699
+
1700
+ # Reverse transformations for downstream (bottom-to-top: input → output)
1701
+ # Upstream shows top-to-bottom (query → CTEs → table)
1702
+ # Downstream shows bottom-to-top (table → CTEs → query)
1703
+ transformations = list(reversed(transformations))
1704
+
1705
+ # Extract dependencies (to find if they reference our source column)
1706
+ deps = _extract_dependencies_from_lineage(
1707
+ column_lineage_result,
1708
+ manifest,
1709
+ schema_mapping,
1710
+ depth=1, # Only immediate deps
1711
+ column_name=output_col_name,
1712
+ )
1713
+
1714
+ # Check if any dependency references our source model+column
1715
+ uses_source_column = False
1716
+ for dep in deps:
1717
+ dep_table = dep.get("table", "")
1718
+ dep_column = dep.get("column", "")
1719
+ # Match table name (handle schema-qualified names)
1720
+ if model_name.lower() in dep_table.lower() and dep_column.lower() == column_name.lower():
1721
+ uses_source_column = True
1722
+ break
1723
+
1724
+ if uses_source_column:
1725
+ # Already extracted transformations above
1726
+ # Recursively get usages of this downstream column
1727
+ nested_usages = await _build_downstream_usages(
1728
+ manifest,
1729
+ model_name_downstream,
1730
+ output_col_name,
1731
+ depth,
1732
+ dialect,
1733
+ current_depth + 1,
1734
+ )
1735
+
1736
+ usage = {
1737
+ "model": model_name_downstream,
1738
+ "column": output_col_name,
1739
+ "distance": current_depth + 1,
1740
+ "transformations": transformations,
1741
+ }
1742
+
1743
+ if nested_usages:
1744
+ usage["usages"] = nested_usages
1745
+
1746
+ usages.append(usage)
1747
+
1748
+ except Exception as e:
1749
+ logger.debug(f"Could not trace {model_name_downstream}.{output_col_name}: {e}")
1750
+ continue
1751
+
1752
+ except Exception as e:
1753
+ logger.warning(f"Error analyzing downstream model {model_name_downstream}: {e}")
1754
+ continue
1755
+
1756
+ return usages
1757
+
1758
+
1759
+ def _format_lineage_response(model_name: str, column_name: str, direction: str, dependencies: list[dict[str, Any]], downstream_usage: list[dict[str, Any]] | None = None) -> dict[str, Any]:
1760
+ """Format the final lineage response.
1761
+
1762
+ Args:
1763
+ model_name: Model name
1764
+ column_name: Column name
1765
+ direction: Direction of lineage
1766
+ dependencies: Upstream dependencies
1767
+ downstream_usage: Optional downstream usage info (legacy format or new usages format)
1768
+
1769
+ Returns:
1770
+ Formatted response dict
1771
+ """
1772
+ result: dict[str, Any] = {
1773
+ "model": model_name,
1774
+ "column": column_name,
1775
+ "direction": direction,
1776
+ }
1777
+
1778
+ # Handle downstream-only format (new unified structure with transformations)
1779
+ if direction == "downstream" and downstream_usage is not None:
1780
+ # New format: usages with transformations (no root transformations, we're the source)
1781
+ result["usages"] = downstream_usage
1782
+ return result
1783
+
1784
+ # Extract root transformations from dependencies (if present)
1785
+ root_transformations = None
1786
+ if dependencies and "__root_transformations__" in dependencies[0]:
1787
+ root_transformations = dependencies[0]["__root_transformations__"]
1788
+ # Remove from all dependencies (cleanup)
1789
+ for dep in dependencies:
1790
+ dep.pop("__root_transformations__", None)
1791
+
1792
+ # Add root transformations if present (upstream or both)
1793
+ if root_transformations:
1794
+ result["transformations"] = root_transformations
1795
+ # Extract via_ctes from transformations (new format with id/type)
1796
+ via_ctes = []
1797
+ for t in root_transformations:
1798
+ if t.get("type") == "cte":
1799
+ # Extract CTE name from namespaced id (e.g., "cte:orders" -> "orders")
1800
+ cte_id = t.get("id", "")
1801
+ if cte_id.startswith("cte:"):
1802
+ cte_name = cte_id[4:] # Strip "cte:" prefix
1803
+ if cte_name not in via_ctes:
1804
+ via_ctes.append(cte_name)
1805
+ if via_ctes:
1806
+ result["via_ctes"] = via_ctes
1807
+
1808
+ result["dependencies"] = dependencies
1809
+ result["dependency_count"] = len(dependencies)
1810
+
1811
+ if downstream_usage is not None and direction != "downstream":
1812
+ # For "both" direction: use new usages format
1813
+ result["usages"] = downstream_usage
1814
+
1815
+ return result
1816
+
1817
+
1818
+ async def implementation(
1819
+ ctx: Context | None,
1820
+ model_name: str,
1821
+ column_name: str,
1822
+ direction: str,
1823
+ depth: int | None,
1824
+ state: DbtCoreServerContext,
1825
+ force_parse: bool = True,
1826
+ ) -> dict[str, Any]:
1827
+ """Implementation function for get_column_lineage tool.
1828
+
1829
+ Separated for testing purposes - tests call this directly with explicit state.
1830
+ The @tool() decorated get_column_lineage() function calls this with injected dependencies.
1831
+ """
1832
+ # Initialize state if needed
1833
+ await state.ensure_initialized(ctx, force_parse)
1834
+
1835
+ # Verify manifest is available
1836
+ if state.manifest is None:
1837
+ raise RuntimeError("Manifest not initialized")
1838
+
1839
+ # FIRST: Check if we have compiled code, compile ALL if needed
1840
+ # Test source model to see if compilation is needed
1841
+ resouce_info = state.manifest.get_resource_info(model_name, resource_type="model", include_compiled_sql=True)
1842
+ compiled_sql = resouce_info.get("compiled_sql")
1843
+ if not compiled_sql:
1844
+ logger.info("No compiled SQL found - compiling entire project")
1845
+ runner = await state.get_runner()
1846
+ # Compile ALL models (dbt compile with no selector)
1847
+ compile_result = await runner.invoke(["compile"])
1848
+
1849
+ if compile_result.success:
1850
+ # Reload manifest to get compiled code
1851
+ await state.manifest.load()
1852
+ # Re-fetch the resource to get updated compiled_code
1853
+ resouce_info = state.manifest.get_resource_info(
1854
+ model_name,
1855
+ resource_type="model",
1856
+ include_database_schema=False,
1857
+ include_compiled_sql=True,
1858
+ )
1859
+ compiled_sql = resouce_info.get("compiled_sql")
1860
+ else:
1861
+ raise RuntimeError(f"Failed to compile project: {compile_result}")
1862
+
1863
+ logger.info("Project compiled successfully")
1864
+
1865
+ # We always need compiled SQL from this point.
1866
+ if not compiled_sql:
1867
+ raise ValueError(f"No compiled SQL found for model '{model_name}'. Model may not contain SQL code.")
1868
+
1869
+ # Handle multiple matches (check AFTER compilation like get_resource_info does)
1870
+ if resouce_info.get("multiple_matches"):
1871
+ raise ValueError(f"Multiple models found matching '{model_name}'. Please use unique_id: {[m['unique_id'] for m in resouce_info['matches']]}")
1872
+
1873
+ # Get SQL dialect from manifest metadata
1874
+ project_info = state.manifest.get_project_info()
1875
+ adapter_type = project_info.get("adapter_type", "databricks")
1876
+ dialect = _map_dbt_adapter_to_sqlglot_dialect(adapter_type)
1877
+ logger.info(f"Using SQL dialect '{dialect}' for adapter type '{adapter_type}'")
1878
+
1879
+ # Validate that the requested column exists in the model's output
1880
+ # Build schema mapping for column resolution
1881
+ upstream_lineage = state.manifest.get_lineage(
1882
+ model_name,
1883
+ resource_type="model",
1884
+ direction="upstream",
1885
+ depth=1,
1886
+ )
1887
+ schema_mapping = _build_schema_mapping(state.manifest, upstream_lineage)
1888
+
1889
+ # Resolve output columns
1890
+ output_columns, _ = _resolve_output_columns(
1891
+ compiled_sql=compiled_sql,
1892
+ schema_mapping=schema_mapping,
1893
+ resource_info=resouce_info,
1894
+ dialect=dialect,
1895
+ )
1896
+
1897
+ # Check if requested column exists
1898
+ if column_name not in output_columns:
1899
+ available_columns = ", ".join(f"'{col}'" for col in sorted(output_columns.keys()))
1900
+ raise ValueError(f"Column '{column_name}' not found in output of model '{model_name}'. Available columns: {available_columns}")
1901
+
1902
+ if direction == "downstream":
1903
+ # Downstream only - new format with transformations
1904
+ usages = await _build_downstream_usages(state.manifest, model_name, column_name, depth, dialect)
1905
+ return _format_lineage_response(model_name, column_name, direction, [], usages)
1906
+
1907
+ if direction == "upstream":
1908
+ # Upstream only
1909
+ dependencies = _trace_upstream_recursive(state.manifest, model_name, column_name, depth, dialect)
1910
+ return _format_lineage_response(model_name, column_name, direction, dependencies, None)
1911
+
1912
+ else: # direction == "both"
1913
+ # Both upstream and downstream
1914
+ dependencies = _trace_upstream_recursive(state.manifest, model_name, column_name, depth, dialect)
1915
+ usages = await _build_downstream_usages(state.manifest, model_name, column_name, depth, dialect)
1916
+ return _format_lineage_response(model_name, column_name, direction, dependencies, usages)
1917
+
1918
+
1919
+ @dbtTool()
1920
+ async def get_column_lineage(
1921
+ ctx: Context,
1922
+ model_name: str,
1923
+ column_name: str,
1924
+ direction: str = "upstream",
1925
+ depth: int | None = None,
1926
+ state: DbtCoreServerContext = Depends(get_state),
1927
+ ) -> dict[str, Any]:
1928
+ """Trace column-level lineage through SQL transformations.
1929
+
1930
+ Uses sqlglot to parse compiled SQL and track how columns flow through:
1931
+ - CTEs and subqueries
1932
+ - JOINs and aggregations
1933
+ - Transformations (calculations, CASE statements, etc.)
1934
+ - Window functions
1935
+
1936
+ This provides detailed column-to-column dependencies that model-level
1937
+ lineage cannot capture.
1938
+
1939
+ Args:
1940
+ model_name: Name or unique_id of the dbt model to analyze
1941
+ column_name: Name of the column to trace
1942
+ direction: Direction to trace lineage:
1943
+ - "upstream": Which source columns feed into this column
1944
+ - "downstream": Which downstream columns use this column
1945
+ - "both": Full bidirectional column lineage
1946
+ depth: Maximum levels to traverse (None for unlimited)
1947
+ - depth=1: Immediate column dependencies only
1948
+ - depth=2: Dependencies + their dependencies
1949
+ - None: Full dependency tree
1950
+
1951
+ Returns:
1952
+ Column lineage information including:
1953
+ - Source columns this column depends on (upstream)
1954
+ - Downstream columns that depend on this column
1955
+ - Transformations and derivations
1956
+ - CTE transformation paths (via_ctes, transformations)
1957
+ - dbt resource mapping where available
1958
+
1959
+ Each dependency includes:
1960
+ - column: Column name
1961
+ - table: Source table name
1962
+ - schema: Source schema (if available)
1963
+ - database: Source database (if available)
1964
+ - via_ctes: List of CTE names in transformation order
1965
+ - transformations: Transformation details per CTE step
1966
+ - cte: CTE name
1967
+ - column: Column name at this step
1968
+ - expression: SQL expression (truncated to 200 chars)
1969
+
1970
+ Raises:
1971
+ ValueError: If model not found, column not found, or SQL parse fails
1972
+ RuntimeError: If sqlglot is not installed
1973
+
1974
+ Examples:
1975
+ # Find which source columns feed into revenue
1976
+ get_column_lineage("fct_sales", "revenue", "upstream")
1977
+
1978
+ # See what downstream models use customer_id
1979
+ get_column_lineage("dim_customers", "customer_id", "downstream")
1980
+
1981
+ # Full bidirectional lineage for a column
1982
+ get_column_lineage("fct_orders", "order_total", "both")
1983
+
1984
+ Note:
1985
+ Requires sqlglot package. Install with: pip install sqlglot
1986
+ The model must be compiled (run 'dbt compile' first).
1987
+ """
1988
+ return await implementation(ctx, model_name, column_name, direction, depth, state)