InfoTracker 0.1.0__py3-none-any.whl → 0.2.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
infotracker/engine.py CHANGED
@@ -5,7 +5,7 @@ import json
5
5
  import logging
6
6
  from dataclasses import dataclass
7
7
  from pathlib import Path
8
- from typing import Any, Dict, List, Optional
8
+ from typing import Any, Dict, List, Optional, Set
9
9
  from fnmatch import fnmatch
10
10
 
11
11
  import yaml
@@ -46,10 +46,10 @@ class ImpactRequest:
46
46
 
47
47
  @dataclass
48
48
  class DiffRequest:
49
+ base: str # git ref for base
50
+ head: str # git ref for head
49
51
  sql_dir: Path
50
52
  adapter: str
51
- base: Path
52
- head: Optional[Path] = None
53
53
  severity_threshold: str = "BREAKING" # NON_BREAKING | POTENTIALLY_BREAKING | BREAKING
54
54
 
55
55
 
@@ -76,7 +76,7 @@ class Engine:
76
76
  4) licz warnings na bazie outputs[0].facets (schema/columnLineage)
77
77
  5) zbuduj graf kolumn do późniejszego impact
78
78
  """
79
- adapter = get_adapter(req.adapter)
79
+ adapter = get_adapter(req.adapter, self.config)
80
80
  parser = adapter.parser
81
81
 
82
82
  warnings = 0
@@ -95,7 +95,7 @@ class Engine:
95
95
  cols: List[ColumnSchema] = [
96
96
  ColumnSchema(
97
97
  name=c["name"],
98
- type=c.get("type"),
98
+ data_type=c.get("type"),
99
99
  nullable=bool(c.get("nullable", True)),
100
100
  ordinal=int(c.get("ordinal", 0)),
101
101
  )
@@ -136,39 +136,72 @@ class Engine:
136
136
  if match_any(p, includes) and not match_any(p, excludes)
137
137
  ]
138
138
 
139
- # 3) Parsowanie i generacja OL
139
+ # 3) Parse all files first to build dependency graph
140
140
  out_dir = Path(req.out_dir)
141
141
  out_dir.mkdir(parents=True, exist_ok=True)
142
142
 
143
143
  outputs: List[List[str]] = []
144
144
  parsed_objects: List[ObjectInfo] = []
145
+ sql_file_map: Dict[str, Path] = {} # object_name -> file_path
145
146
 
146
147
  ignore_patterns: List[str] = list(getattr(self.config, "ignore", []) or [])
147
148
 
149
+ # Phase 1: Parse all SQL files and collect objects
148
150
  for sql_path in sql_files:
149
151
  try:
150
152
  sql_text = sql_path.read_text(encoding="utf-8")
151
-
152
- # Parse do ObjectInfo (na potrzeby ignorów i grafu)
153
153
  obj_info: ObjectInfo = parser.parse_sql_file(sql_text, object_hint=sql_path.stem)
154
- parsed_objects.append(obj_info)
155
-
156
- # ignore po nazwie obiektu (string), nie po ObjectInfo
154
+
155
+ # Store mapping for later processing
157
156
  obj_name = getattr(getattr(obj_info, "schema", None), "name", None) or getattr(obj_info, "name", None)
158
- if obj_name and ignore_patterns and any(fnmatch(obj_name, pat) for pat in ignore_patterns):
159
- continue
157
+ if obj_name:
158
+ sql_file_map[obj_name] = sql_path
159
+
160
+ # Skip ignored objects
161
+ if ignore_patterns and any(fnmatch(obj_name, pat) for pat in ignore_patterns):
162
+ continue
163
+
164
+ parsed_objects.append(obj_info)
165
+
166
+ except Exception as e:
167
+ warnings += 1
168
+ logger.warning("failed to parse %s: %s", sql_path, e)
160
169
 
161
- # Adapter payload (str lub dict) normalizacja do dict
170
+ # Phase 2: Build dependency graph and resolve schemas in topological order
171
+ dependency_graph = self._build_dependency_graph(parsed_objects)
172
+ processing_order = self._topological_sort(dependency_graph)
173
+
174
+ # Phase 3: Process objects in dependency order, building up schema registry
175
+ resolved_objects: List[ObjectInfo] = []
176
+ for obj_name in processing_order:
177
+ if obj_name not in sql_file_map:
178
+ continue
179
+
180
+ sql_path = sql_file_map[obj_name]
181
+ try:
182
+ sql_text = sql_path.read_text(encoding="utf-8")
183
+
184
+ # Parse with updated schema registry (now has dependencies resolved)
185
+ obj_info: ObjectInfo = parser.parse_sql_file(sql_text, object_hint=sql_path.stem)
186
+ resolved_objects.append(obj_info)
187
+
188
+ # Register this object's schema for future dependencies
189
+ if obj_info.schema:
190
+ parser.schema_registry.register(obj_info.schema)
191
+ # Also register in adapter's parser for lineage generation
192
+ adapter.parser.schema_registry.register(obj_info.schema)
193
+
194
+ # Generate OpenLineage with resolved schema context
162
195
  ol_raw = adapter.extract_lineage(sql_text, object_hint=sql_path.stem)
163
196
  ol_payload: Dict[str, Any] = json.loads(ol_raw) if isinstance(ol_raw, str) else ol_raw
164
197
 
165
- # Zapis do pliku (deterministyczny)
198
+ # Save to file
166
199
  target = out_dir / f"{sql_path.stem}.json"
167
200
  target.write_text(json.dumps(ol_payload, indent=2, ensure_ascii=False, sort_keys=True), encoding="utf-8")
168
201
 
169
202
  outputs.append([str(sql_path), str(target)])
170
203
 
171
- # Heurystyka warnings – patrzymy w outputs[0].facets
204
+ # Check for warnings
172
205
  out0 = (ol_payload.get("outputs") or [])
173
206
  out0 = out0[0] if out0 else {}
174
207
  facets = out0.get("facets", {})
@@ -182,19 +215,18 @@ class Engine:
182
215
  warnings += 1
183
216
  logger.warning("failed to process %s: %s", sql_path, e)
184
217
 
185
- # 5) Budowa grafu kolumn z wszystkich sparsowanych obiektów
186
- # 5) Budowa grafu kolumn z wszystkich sparsowanych obiektów
187
- if parsed_objects:
218
+ # 4) Build column graph from resolved objects (second pass)
219
+ if resolved_objects:
188
220
  try:
189
221
  graph = ColumnGraph()
190
- graph.build_from_object_lineage(parsed_objects) # użyj tej metody z models.py
222
+ graph.build_from_object_lineage(resolved_objects) # Use resolved objects with expanded schemas
191
223
  self._column_graph = graph
192
224
 
193
- # (opcjonalnie) zapisz graf na dysk, żeby impact mógł go wczytać w osobnym procesie
225
+ # Save graph to disk for impact analysis
194
226
  graph_path = Path(req.out_dir) / "column_graph.json"
195
227
  edges_dump = []
196
228
  seen = set()
197
- for edges_list in graph._downstream_edges.values(): # prosty eksport krawędzi
229
+ for edges_list in graph._downstream_edges.values():
198
230
  for e in edges_list:
199
231
  key = (str(e.from_column), str(e.to_column),
200
232
  getattr(e.transformation_type, "value", str(e.transformation_type)),
@@ -219,6 +251,51 @@ class Engine:
219
251
  "warnings": warnings,
220
252
  }
221
253
 
254
+ def _build_dependency_graph(self, objects: List[ObjectInfo]) -> Dict[str, Set[str]]:
255
+ """Build dependency graph: object_name -> set of dependencies."""
256
+ dependencies = {}
257
+
258
+ for obj in objects:
259
+ obj_name = obj.schema.name if obj.schema else obj.name
260
+
261
+ # Use ObjectInfo.dependencies first
262
+ if obj.dependencies:
263
+ dependencies[obj_name] = set(obj.dependencies)
264
+ else:
265
+ # Fallback to extracting dependencies from lineage.input_fields
266
+ dependencies[obj_name] = set()
267
+ for lineage in obj.lineage:
268
+ for input_field in lineage.input_fields:
269
+ dep_name = input_field.table_name
270
+ if dep_name != obj_name: # Don't depend on self
271
+ dependencies[obj_name].add(dep_name)
272
+
273
+ return dependencies
274
+
275
+ def _topological_sort(self, dependencies: Dict[str, Set[str]]) -> List[str]:
276
+ """Sort objects in dependency order (dependencies first)."""
277
+ result = []
278
+ remaining = dependencies.copy()
279
+
280
+ while remaining:
281
+ # Find nodes with no dependencies (or dependencies already processed)
282
+ ready = []
283
+ for node, deps in remaining.items():
284
+ if not deps or all(dep in result for dep in deps):
285
+ ready.append(node)
286
+
287
+ if not ready:
288
+ # Circular dependency or missing dependency - process remaining arbitrarily
289
+ ready = [next(iter(remaining.keys()))]
290
+ logger.warning("Circular or missing dependencies detected, processing: %s", ready[0])
291
+
292
+ # Process ready nodes
293
+ for node in ready:
294
+ result.append(node)
295
+ del remaining[node]
296
+
297
+ return result
298
+
222
299
  # ------------------ IMPACT (prosty wariant; zostaw swój jeśli masz bogatszy) ------------------
223
300
 
224
301
  def run_impact(self, req: ImpactRequest) -> Dict[str, Any]:
@@ -238,8 +315,8 @@ class Engine:
238
315
  data = json.loads(graph_path.read_text(encoding="utf-8"))
239
316
  graph = ColumnGraph()
240
317
  for edge in data.get("edges", []):
241
- from_ns, from_tbl, from_col = edge["from"].split(".", 2)
242
- to_ns, to_tbl, to_col = edge["to"].split(".", 2)
318
+ from_ns, from_tbl, from_col = edge["from"].rsplit(".", 2)
319
+ to_ns, to_tbl, to_col = edge["to"].rsplit(".", 2)
243
320
  graph.add_edge(ColumnEdge(
244
321
  from_column=ColumnNode(from_ns, from_tbl, from_col),
245
322
  to_column=ColumnNode(to_ns, to_tbl, to_col),
@@ -267,44 +344,84 @@ class Engine:
267
344
  direction_upstream = True
268
345
  sel = sel[1:-1] # remove both + symbols
269
346
  elif sel.startswith('+'):
270
- # +column → downstream only
271
- direction_downstream = True
347
+ # +column → upstream only
348
+ direction_upstream = True
272
349
  sel = sel[1:] # remove + from start
273
350
  elif sel.endswith('+'):
274
- # column+ → upstream only
275
- direction_upstream = True
351
+ # column+ → downstream only
352
+ direction_downstream = True
276
353
  sel = sel[:-1] # remove + from end
277
354
  else:
278
355
  # column → default (downstream)
279
356
  direction_downstream = True
280
357
 
281
358
  # Normalizacja selektora - obsługuj różne formaty:
282
- # 1. table.column -> dbo.table.column
283
- # 2. schema.table.column -> namespace/schema.table.column (jeśli nie ma protokołu)
284
- # 3. pełny URI -> użyj jak jest
359
+ # 1. table.column -> dbo.table.column (legacy)
360
+ # 2. schema.table.column -> schema.table.column (legacy)
361
+ # 3. database.schema.table.column -> namespace/database.schema.table.column
362
+ # 4. database.schema.table.* -> namespace/database.schema.table.* (table wildcard)
363
+ # 5. ..column -> ..column (column wildcard)
364
+ # 6. pełny URI -> użyj jak jest
285
365
  if "://" in sel:
286
366
  # pełny URI, użyj jak jest
287
367
  pass
368
+ elif sel.startswith('.') and not sel.startswith('..'):
369
+ # Alias: .column -> ..column (column wildcard in default namespace)
370
+ sel = f"mssql://localhost/InfoTrackerDW..{sel[1:]}"
371
+ elif sel.startswith('..'):
372
+ # Column wildcard pattern - leave as is, will be handled specially
373
+ sel = f"mssql://localhost/InfoTrackerDW{sel}"
374
+ elif sel.endswith('.*'):
375
+ # Table wildcard pattern
376
+ base_sel = sel[:-2] # Remove .*
377
+ parts = [p for p in base_sel.split(".") if p]
378
+ if len(parts) == 2:
379
+ # schema.table.* -> namespace/schema.table.*
380
+ sel = f"mssql://localhost/InfoTrackerDW.{base_sel}.*"
381
+ elif len(parts) == 3:
382
+ # database.schema.table.* -> namespace/database.schema.table.*
383
+ sel = f"mssql://localhost/InfoTrackerDW.{base_sel}.*"
384
+ else:
385
+ return {
386
+ "columns": ["message"],
387
+ "rows": [[f"Unsupported wildcard selector format: '{req.selector}'. Use 'schema.table.*' or 'database.schema.table.*'."]],
388
+ }
288
389
  else:
289
390
  parts = [p for p in sel.split(".") if p]
290
391
  if len(parts) == 2:
291
- # table.column -> dbo.table.column
292
- sel = f"dbo.{parts[0]}.{parts[1]}"
392
+ # table.column -> namespace/dbo.table.column
393
+ sel = f"mssql://localhost/InfoTrackerDW.dbo.{parts[0]}.{parts[1]}"
293
394
  elif len(parts) == 3:
294
- # schema.table.column -> namespace.schema.table.column
395
+ # schema.table.column -> namespace/schema.table.column
396
+ sel = f"mssql://localhost/InfoTrackerDW.{sel}"
397
+ elif len(parts) == 4:
398
+ # database.schema.table.column -> namespace/database.schema.table.column
295
399
  sel = f"mssql://localhost/InfoTrackerDW.{sel}"
296
400
  else:
297
401
  return {
298
402
  "columns": ["message"],
299
- "rows": [[f"Unsupported selector format: '{req.selector}'. Use 'table.column', 'schema.table.column', or full URI."]],
403
+ "rows": [[f"Unsupported selector format: '{req.selector}'. Use 'table.column', 'schema.table.column', 'database.schema.table.column', 'database.schema.table.*' (table wildcard), '..columnname' (column wildcard), '.columnname' (alias), or full URI."]],
300
404
  }
301
405
 
302
406
  target = self._column_graph.find_column(sel)
303
- if not target:
304
- return {
305
- "columns": ["message"],
306
- "rows": [[f"Column '{sel}' not found in graph."]],
307
- }
407
+ targets = []
408
+
409
+ # Check if this is a wildcard selector
410
+ if '*' in sel or '..' in sel or sel.endswith('.*'):
411
+ targets = self._column_graph.find_columns_wildcard(sel)
412
+ if not targets:
413
+ return {
414
+ "columns": ["message"],
415
+ "rows": [[f"No columns found matching pattern '{sel}'."]],
416
+ }
417
+ else:
418
+ # Single column selector
419
+ if not target:
420
+ return {
421
+ "columns": ["message"],
422
+ "rows": [[f"Column '{sel}' not found in graph."]],
423
+ }
424
+ targets = [target]
308
425
 
309
426
  rows: List[List[str]] = []
310
427
 
@@ -317,24 +434,122 @@ class Engine:
317
434
  e.transformation_description or "",
318
435
  ]
319
436
 
320
- if direction_upstream:
321
- for e in self._column_graph.get_upstream(target, req.max_depth):
322
- rows.append(edge_row("upstream", e))
323
- if direction_downstream:
324
- for e in self._column_graph.get_downstream(target, req.max_depth):
325
- rows.append(edge_row("downstream", e))
437
+ # Process all target columns
438
+ for target in targets:
439
+ if direction_upstream:
440
+ for e in self._column_graph.get_upstream(target, req.max_depth):
441
+ rows.append(edge_row("upstream", e))
442
+ if direction_downstream:
443
+ for e in self._column_graph.get_downstream(target, req.max_depth):
444
+ rows.append(edge_row("downstream", e))
445
+
446
+ # Remove duplicates while preserving order
447
+ seen = set()
448
+ unique_rows = []
449
+ for row in rows:
450
+ row_tuple = tuple(row)
451
+ if row_tuple not in seen:
452
+ seen.add(row_tuple)
453
+ unique_rows.append(row)
454
+
455
+ if not unique_rows:
456
+ # Show info about the matched columns
457
+ if len(targets) == 1:
458
+ unique_rows = [[str(targets[0]), str(targets[0]), "info", "", "No relationships found"]]
459
+ else:
460
+ unique_rows = [[f"Matched {len(targets)} columns", "", "info", "", f"Pattern: {req.selector}"]]
326
461
 
327
462
  return {
328
463
  "columns": ["from", "to", "direction", "transformation", "description"],
329
- "rows": rows or [[str(target), str(target), "info", "", "No relationships found"]],
464
+ "rows": unique_rows,
330
465
  }
331
466
 
332
467
 
333
- # ------------------ DIFF (stub – jeśli masz swoją wersję, zostaw ją) ------------------
468
+ # ------------------ DIFF (updated implementation) ------------------
334
469
 
335
- def run_diff(self, req: DiffRequest) -> Dict[str, Any]:
470
+ def run_diff(self, base_dir: Path, head_dir: Path, format: str, **kwargs) -> Dict[str, Any]:
336
471
  """
337
- Placeholder: jeśli masz pełną implementację porównywania, zostaw ją.
338
- Tu tylko zwracamy kod 0, żeby nie blokować CLI.
472
+ Compare base and head OpenLineage artifacts to detect breaking changes.
473
+
474
+ Args:
475
+ base_dir: Directory containing base OpenLineage JSON artifacts
476
+ head_dir: Directory containing head OpenLineage JSON artifacts
477
+ format: Output format (text|json)
478
+
479
+ Returns:
480
+ Dict with results including exit_code (1 if breaking changes, 0 otherwise)
339
481
  """
340
- return {"columns": ["message"], "rows": [["Diff not implemented in this stub"]], "exit_code": 0}
482
+ from .openlineage_utils import OpenLineageLoader, OLMapper
483
+ from .diff import BreakingChangeDetector, Severity
484
+
485
+ try:
486
+ # Load OpenLineage artifacts from both directories
487
+ base_artifacts = OpenLineageLoader.load_dir(base_dir)
488
+ head_artifacts = OpenLineageLoader.load_dir(head_dir)
489
+
490
+ # Convert to ObjectInfo instances
491
+ base_objects = OLMapper.to_object_infos(base_artifacts)
492
+ head_objects = OLMapper.to_object_infos(head_artifacts)
493
+
494
+ # Detect changes
495
+ detector = BreakingChangeDetector()
496
+ report = detector.compare(base_objects, head_objects)
497
+
498
+ # Filter changes based on severity threshold from config
499
+ threshold = self.config.severity_threshold.upper()
500
+ filtered_changes = []
501
+
502
+ if threshold == "BREAKING":
503
+ # Only show BREAKING changes
504
+ filtered_changes = [c for c in report.changes if c.severity == Severity.BREAKING]
505
+ elif threshold == "POTENTIALLY_BREAKING":
506
+ # Show BREAKING and POTENTIALLY_BREAKING changes
507
+ filtered_changes = [c for c in report.changes if c.severity in [Severity.BREAKING, Severity.POTENTIALLY_BREAKING]]
508
+ else: # NON_BREAKING
509
+ # Show all changes
510
+ filtered_changes = report.changes
511
+
512
+ # Determine exit code based on threshold
513
+ exit_code = 0
514
+ if threshold == "BREAKING":
515
+ exit_code = 1 if any(c.severity == Severity.BREAKING for c in report.changes) else 0
516
+ elif threshold == "POTENTIALLY_BREAKING":
517
+ exit_code = 1 if any(c.severity in [Severity.BREAKING, Severity.POTENTIALLY_BREAKING] for c in report.changes) else 0
518
+ else: # NON_BREAKING
519
+ exit_code = 1 if len(report.changes) > 0 else 0
520
+
521
+ # Build filtered report
522
+ if filtered_changes:
523
+ filtered_rows = []
524
+ for change in filtered_changes:
525
+ filtered_rows.append([
526
+ change.object_name,
527
+ change.column_name or "",
528
+ change.change_type.value,
529
+ change.severity.value,
530
+ change.description
531
+ ])
532
+ else:
533
+ filtered_rows = []
534
+
535
+ return {
536
+ "columns": ["object", "column", "change_type", "severity", "description"],
537
+ "rows": filtered_rows,
538
+ "exit_code": exit_code,
539
+ "summary": {
540
+ "total_changes": len(filtered_changes),
541
+ "breaking_changes": len([c for c in filtered_changes if c.severity.value == "BREAKING"]),
542
+ "potentially_breaking": len([c for c in filtered_changes if c.severity.value == "POTENTIALLY_BREAKING"]),
543
+ "non_breaking": len([c for c in filtered_changes if c.severity.value == "NON_BREAKING"])
544
+ }
545
+ }
546
+
547
+ except Exception as e:
548
+ logger.error(f"Error running diff: {e}")
549
+ return {
550
+ "error": str(e),
551
+ "columns": ["message"],
552
+ "rows": [["Error running diff: " + str(e)]],
553
+ "exit_code": 1
554
+ }
555
+
infotracker/lineage.py CHANGED
@@ -63,14 +63,17 @@ class OpenLineageGenerator:
63
63
 
64
64
  def _build_outputs(self, obj_info: ObjectInfo) -> List[Dict[str, Any]]:
65
65
  """Build outputs array with schema and lineage facets."""
66
+ # Use schema's namespace if available, otherwise default namespace
67
+ output_namespace = obj_info.schema.namespace if obj_info.schema.namespace else self.namespace
68
+
66
69
  output = {
67
- "namespace": self.namespace,
70
+ "namespace": output_namespace,
68
71
  "name": obj_info.schema.name,
69
72
  "facets": {}
70
73
  }
71
74
 
72
- # Add schema facet only for tables (not views)
73
- if obj_info.object_type == "table" and obj_info.schema.columns:
75
+ # Add schema facet for all objects with known columns (tables, views, functions, procedures)
76
+ if obj_info.schema and obj_info.schema.columns:
74
77
  output["facets"]["schema"] = self._build_schema_facet(obj_info)
75
78
 
76
79
  # Add column lineage facet only if we have lineage (views, not tables)
infotracker/models.py CHANGED
@@ -15,6 +15,8 @@ class TransformationType(Enum):
15
15
  CASE = "CASE"
16
16
  AGGREGATE = "AGGREGATE"
17
17
  AGGREGATION = "AGGREGATION"
18
+ ARITHMETIC_AGGREGATION = "ARITHMETIC_AGGREGATION"
19
+ COMPLEX_AGGREGATION = "COMPLEX_AGGREGATION"
18
20
  EXPRESSION = "EXPRESSION"
19
21
  CONCAT = "CONCAT"
20
22
  ARITHMETIC = "ARITHMETIC"
@@ -23,6 +25,10 @@ class TransformationType(Enum):
23
25
  STRING_PARSE = "STRING_PARSE"
24
26
  WINDOW_FUNCTION = "WINDOW_FUNCTION"
25
27
  WINDOW = "WINDOW"
28
+ DATE_FUNCTION = "DATE_FUNCTION"
29
+ DATE_FUNCTION_AGGREGATION = "DATE_FUNCTION_AGGREGATION"
30
+ CASE_AGGREGATION = "CASE_AGGREGATION"
31
+ EXEC = "EXEC"
26
32
 
27
33
 
28
34
  @dataclass
@@ -161,7 +167,7 @@ class ColumnNode:
161
167
  def __str__(self) -> str:
162
168
  return f"{self.namespace}.{self.table_name}.{self.column_name}"
163
169
 
164
- def __hash__(self) -> str:
170
+ def __hash__(self) -> int:
165
171
  return hash((self.namespace.lower(), self.table_name.lower(), self.column_name.lower()))
166
172
 
167
173
  def __eq__(self, other) -> bool:
@@ -184,10 +190,18 @@ class ColumnEdge:
184
190
  class ColumnGraph:
185
191
  """Bidirectional graph of column-level lineage relationships."""
186
192
 
187
- def __init__(self):
193
+ def __init__(self, max_upstream_depth: int = 10, max_downstream_depth: int = 10):
194
+ """Initialize the column graph with configurable depth limits.
195
+
196
+ Args:
197
+ max_upstream_depth: Maximum depth for upstream traversal (default: 10)
198
+ max_downstream_depth: Maximum depth for downstream traversal (default: 10)
199
+ """
188
200
  self._nodes: Dict[str, ColumnNode] = {}
189
201
  self._upstream_edges: Dict[str, List[ColumnEdge]] = {} # node -> edges coming into it
190
202
  self._downstream_edges: Dict[str, List[ColumnEdge]] = {} # node -> edges going out of it
203
+ self.max_upstream_depth = max_upstream_depth
204
+ self.max_downstream_depth = max_downstream_depth
191
205
 
192
206
  def add_node(self, column_node: ColumnNode) -> None:
193
207
  """Add a column node to the graph."""
@@ -212,16 +226,28 @@ class ColumnGraph:
212
226
  self._upstream_edges[to_key].append(edge)
213
227
 
214
228
  def get_upstream(self, column: ColumnNode, max_depth: Optional[int] = None) -> List[ColumnEdge]:
215
- """Get all upstream dependencies for a column."""
216
- return self._traverse_upstream(column, max_depth or 10, set())
229
+ """Get all upstream dependencies for a column.
230
+
231
+ Args:
232
+ column: The column to find upstream dependencies for
233
+ max_depth: Override the default max_upstream_depth for this query
234
+ """
235
+ effective_depth = max_depth if max_depth is not None else self.max_upstream_depth
236
+ return self._traverse_upstream(column, effective_depth, set())
217
237
 
218
238
  def get_downstream(self, column: ColumnNode, max_depth: Optional[int] = None) -> List[ColumnEdge]:
219
- """Get all downstream dependencies for a column."""
220
- return self._traverse_downstream(column, max_depth or 10, set())
239
+ """Get all downstream dependencies for a column.
240
+
241
+ Args:
242
+ column: The column to find downstream dependencies for
243
+ max_depth: Override the default max_downstream_depth for this query
244
+ """
245
+ effective_depth = max_depth if max_depth is not None else self.max_downstream_depth
246
+ return self._traverse_downstream(column, effective_depth, set())
221
247
 
222
- def _traverse_upstream(self, column: ColumnNode, max_depth: int, visited: Set[str]) -> List[ColumnEdge]:
248
+ def _traverse_upstream(self, column: ColumnNode, max_depth: int, visited: Set[str], current_depth: int = 0) -> List[ColumnEdge]:
223
249
  """Recursively traverse upstream dependencies."""
224
- if max_depth <= 0:
250
+ if max_depth <= 0 or current_depth >= max_depth:
225
251
  return []
226
252
 
227
253
  column_key = str(column).lower()
@@ -235,14 +261,14 @@ class ColumnGraph:
235
261
  for edge in self._upstream_edges.get(column_key, []):
236
262
  edges.append(edge)
237
263
  # Recursively get upstream of the source column
238
- upstream_edges = self._traverse_upstream(edge.from_column, max_depth - 1, visited.copy())
264
+ upstream_edges = self._traverse_upstream(edge.from_column, max_depth, visited.copy(), current_depth + 1)
239
265
  edges.extend(upstream_edges)
240
266
 
241
267
  return edges
242
268
 
243
- def _traverse_downstream(self, column: ColumnNode, max_depth: int, visited: Set[str]) -> List[ColumnEdge]:
269
+ def _traverse_downstream(self, column: ColumnNode, max_depth: int, visited: Set[str], current_depth: int = 0) -> List[ColumnEdge]:
244
270
  """Recursively traverse downstream dependencies."""
245
- if max_depth <= 0:
271
+ if max_depth <= 0 or current_depth >= max_depth:
246
272
  return []
247
273
 
248
274
  column_key = str(column).lower()
@@ -256,11 +282,30 @@ class ColumnGraph:
256
282
  for edge in self._downstream_edges.get(column_key, []):
257
283
  edges.append(edge)
258
284
  # Recursively get downstream of the target column
259
- downstream_edges = self._traverse_downstream(edge.to_column, max_depth - 1, visited.copy())
285
+ downstream_edges = self._traverse_downstream(edge.to_column, max_depth, visited.copy(), current_depth + 1)
260
286
  edges.extend(downstream_edges)
261
287
 
262
288
  return edges
263
289
 
290
+ def get_traversal_stats(self, column: ColumnNode) -> Dict[str, Any]:
291
+ """Get traversal statistics for a column including depth information.
292
+
293
+ Returns:
294
+ Dictionary with upstream/downstream counts and depth information
295
+ """
296
+ upstream_edges = self.get_upstream(column)
297
+ downstream_edges = self.get_downstream(column)
298
+
299
+ return {
300
+ "column": str(column),
301
+ "upstream_count": len(upstream_edges),
302
+ "downstream_count": len(downstream_edges),
303
+ "max_upstream_depth": self.max_upstream_depth,
304
+ "max_downstream_depth": self.max_downstream_depth,
305
+ "upstream_tables": len(set(str(edge.from_column).rsplit('.', 1)[0] for edge in upstream_edges)),
306
+ "downstream_tables": len(set(str(edge.to_column).rsplit('.', 1)[0] for edge in downstream_edges))
307
+ }
308
+
264
309
  def build_from_object_lineage(self, objects: List[ObjectInfo]) -> None:
265
310
  """Build column graph from object lineage information."""
266
311
  for obj in objects:
@@ -297,6 +342,52 @@ class ColumnGraph:
297
342
  selector_key = selector.lower()
298
343
  return self._nodes.get(selector_key)
299
344
 
300
- def get_all_nodes(self) -> List[ColumnNode]:
301
- """Get all column nodes in the graph."""
302
- return list(self._nodes.values())
345
+
346
+ def find_columns_wildcard(self, selector: str) -> List[ColumnNode]:
347
+ """
348
+ Find columns matching a wildcard pattern.
349
+
350
+ Supports:
351
+ - Table wildcard: <ns>.<schema>.<table>.* → all columns of that table
352
+ - Column wildcard: <optional_ns>..<pattern> → match by COLUMN NAME only:
353
+ * if pattern contains any of [*?[]] → fnmatch on the column name
354
+ * otherwise → default to case-insensitive "contains"
355
+ - Fallback: fnmatch on the full identifier "ns.schema.table.column"
356
+ """
357
+ import fnmatch as _fn
358
+
359
+ sel = (selector or "").strip().lower()
360
+
361
+ # 1) Table wildcard: "...schema.table.*"
362
+ if sel.endswith(".*"):
363
+ table_sel = sel[:-1] # remove trailing '*', keep final dot
364
+ # simple prefix match on full key
365
+ return [node for key, node in self._nodes.items() if key.startswith(table_sel)]
366
+
367
+ # 2) Column wildcard: "<optional_ns>..<pattern>"
368
+ if ".." in sel:
369
+ ns_part, col_pat = sel.split("..", 1)
370
+ ns_part = ns_part.strip(".")
371
+ col_pat = col_pat.strip()
372
+
373
+ # if no explicit wildcard meta, treat as "contains"
374
+ has_meta = any(ch in col_pat for ch in "*?[]")
375
+
376
+ def col_name_matches(name: str) -> bool:
377
+ name = (name or "").lower()
378
+ if has_meta:
379
+ return _fn.fnmatch(name, col_pat)
380
+ return col_pat in name # default: contains (case-insensitive)
381
+
382
+ if ns_part:
383
+ ns_prefix = ns_part + "."
384
+ return [
385
+ node
386
+ for key, node in self._nodes.items()
387
+ if key.startswith(ns_prefix) and col_name_matches(getattr(node, "column_name", ""))
388
+ ]
389
+ else:
390
+ return [node for node in self._nodes.values() if col_name_matches(getattr(node, "column_name", ""))]
391
+
392
+ # 3) Fallback: fnmatch on the full identifier
393
+ return [node for key, node in self._nodes.items() if _fn.fnmatch(key, sel)]