InfoTracker 0.3.1__py3-none-any.whl → 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
infotracker/parser.py CHANGED
@@ -24,7 +24,10 @@ class SqlParser:
24
24
  def __init__(self, dialect: str = "tsql"):
25
25
  self.dialect = dialect
26
26
  self.schema_registry = SchemaRegistry()
27
+ self.cte_registry: Dict[str, List[str]] = {} # CTE name -> column list
28
+ self.temp_registry: Dict[str, List[str]] = {} # Temp table name -> column list
27
29
  self.default_database: Optional[str] = None # Will be set from config
30
+ self.current_database: Optional[str] = None # Track current database context
28
31
 
29
32
  def _clean_proc_name(self, s: str) -> str:
30
33
  """Clean procedure name by removing semicolons and parameters."""
@@ -32,18 +35,296 @@ class SqlParser:
32
35
 
33
36
  def _normalize_table_ident(self, s: str) -> str:
34
37
  """Remove brackets and normalize table identifier."""
35
- import re
36
- return re.sub(r'[\[\]]', '', s)
38
+ # Remove brackets, trailing semicolons and whitespace
39
+ normalized = re.sub(r'[\[\]]', '', s)
40
+ return normalized.strip().rstrip(';')
41
+
42
+ def _normalize_tsql(self, text: str) -> str:
43
+ """Normalize T-SQL to improve sqlglot parsing compatibility."""
44
+ t = text.replace("\r\n", "\n")
45
+
46
+ # Strip technical banners and settings
47
+ t = re.sub(r"^\s*SET\s+(ANSI_NULLS|QUOTED_IDENTIFIER)\s+(ON|OFF)\s*;?\s*$", "", t, flags=re.I|re.M)
48
+ t = re.sub(r"^\s*GO\s*;?\s*$", "", t, flags=re.I|re.M)
49
+
50
+ # Remove column-level COLLATE clauses (keeps DDL simple)
51
+ t = re.sub(r"\s+COLLATE\s+[A-Za-z0-9_]+", "", t, flags=re.I)
52
+
53
+ # Normalize T-SQL specific functions to standard equivalents
54
+ t = re.sub(r"\bISNULL\s*\(", "COALESCE(", t, flags=re.I)
55
+
56
+ # Convert IIF to CASE WHEN (basic conversion for simple cases)
57
+ t = re.sub(r"\bIIF\s*\(", "CASE WHEN ", t, flags=re.I)
58
+
59
+ return t
60
+
61
+ def _rewrite_ast(self, root: exp.Expression) -> exp.Expression:
62
+ """Rewrite AST nodes for better T-SQL compatibility."""
63
+ for node in list(root.walk()):
64
+ # Convert CONVERT(T, x [, style]) to CAST(x AS T)
65
+ if isinstance(node, exp.Convert):
66
+ target_type = node.args.get("to")
67
+ source_expr = node.args.get("expression")
68
+ if target_type and source_expr:
69
+ cast_node = exp.Cast(this=source_expr, to=target_type)
70
+ node.replace(cast_node)
71
+
72
+ # Mark HASHBYTES(...) nodes for special handling
73
+ if isinstance(node, exp.Anonymous) and (node.name or "").upper() == "HASHBYTES":
74
+ node.set("is_hashbytes", True)
75
+
76
+ return root
77
+
78
+ def _split_fqn(self, fqn: str):
79
+ """Split fully qualified name into database, schema, table components."""
80
+ parts = (fqn or "").split(".")
81
+ if len(parts) >= 3:
82
+ return parts[0], parts[1], ".".join(parts[2:])
83
+ if len(parts) == 2:
84
+ return (self.current_database or self.default_database), parts[0], parts[1]
85
+ return (self.current_database or self.default_database), "dbo", (parts[0] if parts else None)
86
+
87
+ def _qualify_table(self, tbl: exp.Table) -> str:
88
+ """Get fully qualified table name from Table expression."""
89
+ name = tbl.name
90
+ sch = getattr(tbl, "db", None) or "dbo"
91
+ db = getattr(tbl, "catalog", None) or self.current_database or self.default_database
92
+ return ".".join([p for p in [db, sch, name] if p])
93
+
94
+ def _build_alias_maps(self, select_exp: exp.Select):
95
+ """Build maps for table aliases and derived table columns."""
96
+ alias_map = {} # alias_lower -> DB.sch.tbl
97
+ derived_cols = {} # (alias_lower, out_col_lower) -> list[exp.Column] (base cols of subquery projection)
98
+
99
+ # Plain tables
100
+ for t in select_exp.find_all(exp.Table):
101
+ a = getattr(t, "alias", None) or t.args.get("alias")
102
+ alias = None
103
+ if a:
104
+ # Handle both string aliases and alias objects
105
+ if hasattr(a, "name"):
106
+ alias = a.name.lower()
107
+ else:
108
+ alias = str(a).lower()
109
+ fqn = self._qualify_table(t)
110
+ if alias:
111
+ alias_map[alias] = fqn
112
+ alias_map[t.name.lower()] = fqn
113
+
114
+ # Derived tables (subqueries with alias)
115
+ for sq in select_exp.find_all(exp.Subquery):
116
+ a = getattr(sq, "alias", None) or sq.args.get("alias")
117
+ if not a:
118
+ continue
119
+ # Handle both string aliases and alias objects
120
+ if hasattr(a, "name"):
121
+ alias = a.name.lower()
122
+ else:
123
+ alias = str(a).lower()
124
+ inner = sq.this if isinstance(sq.this, exp.Select) else None
125
+ if not inner:
126
+ continue
127
+ idx = 0
128
+ for proj in (inner.expressions or []):
129
+ if isinstance(proj, exp.Alias):
130
+ out_name = (proj.alias or proj.alias_or_name)
131
+ target = proj.this
132
+ else:
133
+ out_name = f"col_{idx+1}"
134
+ target = proj
135
+ key = (alias, (out_name or "").lower())
136
+ derived_cols[key] = list(target.find_all(exp.Column))
137
+ idx += 1
138
+
139
+ return alias_map, derived_cols
140
+
141
+ def _append_column_ref(self, out_list, col_exp: exp.Column, alias_map: dict):
142
+ """Append a column reference to the output list after resolving aliases."""
143
+ qual = (col_exp.table or "").lower()
144
+ table_fqn = alias_map.get(qual)
145
+ if not table_fqn:
146
+ return
147
+ db, sch, tbl = self._split_fqn(table_fqn)
148
+ out_list.append(ColumnReference(
149
+ namespace=f"mssql://localhost/{db}" if db else "mssql://localhost",
150
+ table_name=table_fqn, # Use full qualified name for consistency
151
+ column_name=col_exp.name
152
+ ))
153
+
154
+ def _collect_inputs_for_expr(self, expr: exp.Expression, alias_map: dict, derived_cols: dict):
155
+ """Collect input column references for an expression, resolving derived table aliases."""
156
+ inputs = []
157
+ for col in expr.find_all(exp.Column):
158
+ qual = (col.table or "").lower()
159
+ key = (qual, col.name.lower())
160
+ base_cols = derived_cols.get(key)
161
+ if base_cols:
162
+ # This column comes from a derived table - use its base columns
163
+ for b in base_cols:
164
+ self._append_column_ref(inputs, b, alias_map)
165
+ continue
166
+ # Regular table column
167
+ self._append_column_ref(inputs, col, alias_map)
168
+ return inputs
169
+
170
+ def _get_schema(self, db: str, sch: str, tbl: str):
171
+ """Get schema information for a table."""
172
+ ns = f"mssql://localhost/{db}" if db else None
173
+ key = f"{sch}.{tbl}"
174
+ if hasattr(self.schema_registry, "get"):
175
+ return self.schema_registry.get(ns, key)
176
+ # Fallback for different registry implementations
177
+ return self.schema_registry.get((ns, key))
178
+
179
+ def _type_of_column(self, col_exp, alias_map):
180
+ """Get the data type of a column from schema registry."""
181
+ qual = (getattr(col_exp, "table", None) or "").lower()
182
+ fqn = alias_map.get(qual)
183
+ if not fqn:
184
+ return None
185
+ db, sch, tbl = self._split_fqn(fqn)
186
+ schema = self._get_schema(db, sch, tbl)
187
+ if not schema:
188
+ return None
189
+ c = schema.get_column(col_exp.name)
190
+ return c.data_type if c else None
191
+
192
+ def _infer_type(self, expr, alias_map) -> str:
193
+ """Infer data type for an expression."""
194
+ if isinstance(expr, exp.Cast):
195
+ t = expr.args.get("to")
196
+ return str(t) if t else "unknown"
197
+ if isinstance(expr, exp.Convert):
198
+ t = expr.args.get("to")
199
+ return str(t) if t else "unknown"
200
+ if isinstance(expr, (exp.Trim, exp.Upper, exp.Lower)):
201
+ base = expr.find(exp.Column)
202
+ return self._type_of_column(base, alias_map) or "nvarchar"
203
+ if isinstance(expr, exp.Coalesce):
204
+ types = []
205
+ for a in (expr.args.get("expressions") or []):
206
+ if isinstance(a, exp.Column):
207
+ types.append(self._type_of_column(a, alias_map))
208
+ elif isinstance(a, exp.Literal):
209
+ types.append("nvarchar" if a.is_string else "numeric")
210
+ tset = [t for t in types if t]
211
+ if any(t and "nvarchar" in t.lower() for t in tset):
212
+ return "nvarchar"
213
+ if any(t and "varchar" in t.lower() for t in tset):
214
+ return "varchar"
215
+ return tset[0] if tset else "unknown"
216
+ s = str(expr).upper()
217
+ if "HASHBYTES(" in s or "MD5(" in s:
218
+ return "binary(16)"
219
+ if isinstance(expr, exp.Column):
220
+ return self._type_of_column(expr, alias_map) or "unknown"
221
+ return "unknown"
222
+
223
+ def _short_desc(self, expr) -> str:
224
+ """Generate a short transformation description."""
225
+ return " ".join(str(expr).split())[:250]
226
+
227
+ def _extract_view_header_cols(self, create_exp) -> list[str]:
228
+ """Extract column names from CREATE VIEW (col1, col2, ...) AS pattern."""
229
+ exprs = getattr(create_exp, "expressions", None) or create_exp.args.get("expressions") or []
230
+ cols = []
231
+ for e in exprs:
232
+ n = getattr(e, "name", None)
233
+ if n:
234
+ cols.append(str(n).strip("[]"))
235
+ else:
236
+ cols.append(str(e).strip().strip("[]"))
237
+ return cols
238
+
239
+ def _apply_view_header_names(self, create_exp, select_exp, obj: ObjectInfo):
240
+ """Apply header column names to view schema and lineage by position."""
241
+ header = self._extract_view_header_cols(create_exp)
242
+ if not header:
243
+ return
244
+ projs = list(select_exp.expressions or [])
245
+ for i, _ in enumerate(projs):
246
+ out_name = header[i] if i < len(header) else f"col_{i+1}"
247
+ # Update schema
248
+ if i < len(obj.schema.columns):
249
+ obj.schema.columns[i].name = out_name
250
+ obj.schema.columns[i].ordinal = i
251
+ else:
252
+ obj.schema.columns.append(ColumnSchema(
253
+ name=out_name,
254
+ data_type="unknown",
255
+ nullable=True,
256
+ ordinal=i
257
+ ))
258
+ # Update lineage
259
+ if i < len(obj.lineage):
260
+ obj.lineage[i].output_column = out_name
261
+ else:
262
+ obj.lineage.append(ColumnLineage(
263
+ output_column=out_name,
264
+ input_fields=[],
265
+ transformation_type=TransformationType.EXPRESSION,
266
+ transformation_description=""
267
+ ))
37
268
 
38
269
  def set_default_database(self, default_database: Optional[str]):
39
270
  """Set the default database for qualification."""
40
271
  self.default_database = default_database
41
272
 
273
+ def _extract_database_from_use_statement(self, content: str) -> Optional[str]:
274
+ """Extract database name from USE statement at the beginning of file."""
275
+ lines = content.strip().split('\n')
276
+ for line in lines[:10]: # Check first 10 lines
277
+ line = line.strip()
278
+ if not line or line.startswith('--'):
279
+ continue
280
+
281
+ # Match USE :DBNAME: or USE [database] or USE database
282
+ use_match = re.match(r'USE\s+(?::([^:]+):|(?:\[([^\]]+)\]|(\w+)))', line, re.IGNORECASE)
283
+ if use_match:
284
+ db_name = use_match.group(1) or use_match.group(2) or use_match.group(3)
285
+ logger.debug(f"Found USE statement, setting database to: {db_name}")
286
+ return db_name
287
+
288
+ # If we hit a non-comment, non-USE statement, stop looking
289
+ if not line.startswith(('USE', 'DECLARE', 'SET', 'PRINT')):
290
+ break
291
+
292
+ return None
293
+
294
+ def _get_full_table_name(self, table_name: str) -> str:
295
+ """Get full table name with database prefix using current or default database."""
296
+ # Use current database from USE statement or fall back to default
297
+ db_to_use = self.current_database or self.default_database or "InfoTrackerDW"
298
+
299
+ if '.' not in table_name:
300
+ # Just table name - use database and default schema
301
+ return f"{db_to_use}.dbo.{table_name}"
302
+
303
+ parts = table_name.split('.')
304
+ if len(parts) == 2:
305
+ # schema.table - add database
306
+ return f"{db_to_use}.{table_name}"
307
+ elif len(parts) == 3:
308
+ # database.schema.table - use as is
309
+ return table_name
310
+ else:
311
+ # Fallback
312
+ return f"{db_to_use}.dbo.{table_name}"
313
+
42
314
  def _preprocess_sql(self, sql: str) -> str:
43
315
  """
44
316
  Preprocess SQL to remove control lines and join INSERT INTO #temp EXEC patterns.
317
+ Also extracts database context from USE statements.
45
318
  """
46
- import re
319
+
320
+
321
+ # Extract database from USE statement first
322
+ db_from_use = self._extract_database_from_use_statement(sql)
323
+ if db_from_use:
324
+ self.current_database = db_from_use
325
+ else:
326
+ # Ensure current_database is set to default if no USE statement found
327
+ self.current_database = self.default_database
47
328
 
48
329
  lines = sql.split('\n')
49
330
  processed_lines = []
@@ -56,14 +337,20 @@ class SqlParser:
56
337
  continue
57
338
 
58
339
  # Skip IF OBJECT_ID('tempdb..#...') patterns and DROP TABLE #temp patterns
340
+ # Also skip complete IF OBJECT_ID ... DROP TABLE sequences
59
341
  if (re.match(r"(?i)^IF\s+OBJECT_ID\('tempdb\.\.#", stripped_line) or
60
- re.match(r'(?i)^DROP\s+TABLE\s+#\w+', stripped_line)):
342
+ re.match(r'(?i)^DROP\s+TABLE\s+#\w+', stripped_line) or
343
+ re.match(r'(?i)^IF\s+OBJECT_ID.*IS\s+NOT\s+NULL\s+DROP\s+TABLE', stripped_line)):
61
344
  continue
62
345
 
63
346
  # Skip GO statements (SQL Server batch separator)
64
347
  if re.match(r'(?im)^\s*GO\s*$', stripped_line):
65
348
  continue
66
349
 
350
+ # Skip USE <db> lines (we already extracted DB context)
351
+ if re.match(r'(?i)^\s*USE\b', stripped_line):
352
+ continue
353
+
67
354
  processed_lines.append(line)
68
355
 
69
356
  # Join the lines back together
@@ -76,53 +363,129 @@ class SqlParser:
76
363
  processed_sql
77
364
  )
78
365
 
366
+ # Cut to first significant statement
367
+ processed_sql = self._cut_to_first_statement(processed_sql)
368
+
79
369
  return processed_sql
80
370
 
371
+ def _cut_to_first_statement(self, sql: str) -> str:
372
+ """
373
+ Cut SQL content to start from the first significant statement.
374
+ Looks for: CREATE [OR ALTER] VIEW|TABLE|FUNCTION|PROCEDURE, ALTER, SELECT...INTO, INSERT...EXEC
375
+ """
376
+
377
+
378
+ pattern = re.compile(
379
+ r'(?is)' # DOTALL + IGNORECASE
380
+ r'(?:'
381
+ r'CREATE\s+(?:OR\s+ALTER\s+)?(?:VIEW|TABLE|FUNCTION|PROCEDURE)\b'
382
+ r'|ALTER\s+(?:VIEW|TABLE|FUNCTION|PROCEDURE)\b'
383
+ r'|SELECT\b.*?\bINTO\b' # SELECT ... INTO (może być w wielu liniach)
384
+ r'|INSERT\s+INTO\b.*?\bEXEC\b'
385
+ r')'
386
+ )
387
+ m = pattern.search(sql)
388
+ return sql[m.start():] if m else sql
389
+
81
390
  def _try_insert_exec_fallback(self, sql_content: str, object_hint: Optional[str] = None) -> Optional[ObjectInfo]:
82
391
  """
83
- Fallback parser for INSERT INTO ... EXEC pattern when SQLGlot fails.
84
- Handles both temp tables and regular tables.
392
+ Enhanced fallback parser for complex SQL files when SQLGlot fails.
393
+ Handles INSERT INTO ... EXEC pattern plus additional dependency extraction.
394
+ Also handles INSERT INTO persistent tables.
85
395
  """
86
- import re
396
+ from .openlineage_utils import sanitize_name
87
397
 
88
398
  # Get preprocessed SQL
89
399
  sql_pre = self._preprocess_sql(sql_content)
90
400
 
91
401
  # Look for INSERT INTO ... EXEC pattern (both temp and regular tables)
92
- pattern = r'(?is)INSERT\s+INTO\s+([#\[\]\w.]+)\s+EXEC\s+([^\s(;]+)'
93
- match = re.search(pattern, sql_pre)
402
+ insert_exec_pattern = r'(?is)INSERT\s+INTO\s+([#\[\]\w.]+)\s+EXEC\s+([^\s(;]+)'
403
+ exec_match = re.search(insert_exec_pattern, sql_pre)
94
404
 
95
- if not match:
96
- return None
97
-
98
- raw_table = match.group(1)
99
- raw_proc = match.group(2)
100
-
101
- # Clean and normalize names
102
- table_name = self._normalize_table_ident(raw_table)
103
- proc_name = self._clean_proc_name(raw_proc)
405
+ # Look for INSERT INTO persistent tables (not temp tables)
406
+ insert_table_pattern = r'(?is)INSERT\s+INTO\s+([^\s#][#\[\]\w.]+)\s*\(([^)]+)\)\s+SELECT'
407
+ table_match = re.search(insert_table_pattern, sql_pre)
104
408
 
105
- # Determine if it's a temp table
106
- is_temp = table_name.startswith('#')
107
- namespace = "tempdb" if is_temp else "mssql://localhost/InfoTrackerDW"
108
- object_type = "temp_table" if is_temp else "table"
409
+ # Always extract all dependencies from the file
410
+ all_dependencies = self._extract_basic_dependencies(sql_pre)
109
411
 
110
- # Create placeholder columns
412
+ # Default placeholder columns
111
413
  placeholder_columns = [
112
414
  ColumnSchema(
113
415
  name="output_col_1",
114
416
  data_type="unknown",
115
417
  nullable=True,
116
418
  ordinal=0
117
- ),
118
- ColumnSchema(
119
- name="output_col_2",
120
- data_type="unknown",
121
- nullable=True,
122
- ordinal=1
123
419
  )
124
420
  ]
125
421
 
422
+ # Prioritize persistent table INSERT over INSERT EXEC
423
+ if table_match and not table_match.group(1).startswith('#'):
424
+ # Found INSERT INTO persistent table with explicit column list
425
+ raw_table = table_match.group(1)
426
+ raw_columns = table_match.group(2)
427
+
428
+ table_name = self._normalize_table_ident(raw_table)
429
+ # For output tables, use simple schema.table format without database prefix
430
+ if '.' not in table_name:
431
+ table_name = f"dbo.{table_name}"
432
+ elif len(table_name.split('.')) == 3:
433
+ # Remove database prefix for output tables
434
+ parts = table_name.split('.')
435
+ table_name = f"{parts[1]}.{parts[2]}"
436
+ namespace = "mssql://localhost/InfoTrackerDW"
437
+ object_type = "table"
438
+
439
+ # Parse column list from INSERT INTO
440
+ column_names = [col.strip() for col in raw_columns.split(',')]
441
+ placeholder_columns = []
442
+ for i, col_name in enumerate(column_names):
443
+ placeholder_columns.append(ColumnSchema(
444
+ name=col_name,
445
+ data_type="unknown",
446
+ nullable=True,
447
+ ordinal=i
448
+ ))
449
+
450
+ elif exec_match:
451
+ # Found INSERT INTO ... EXEC - use that as pattern
452
+ raw_table = exec_match.group(1)
453
+ raw_proc = exec_match.group(2)
454
+
455
+ # Clean and normalize names
456
+ table_name = self._normalize_table_ident(raw_table)
457
+ proc_name = self._clean_proc_name(raw_proc)
458
+
459
+ # Apply consistent temp table namespace handling
460
+ if table_name.startswith('#'):
461
+ # Temp table - use consistent naming and namespace
462
+ temp_name = table_name.lstrip('#')
463
+ table_name = f"tempdb..#{temp_name}"
464
+ namespace = "mssql://localhost/tempdb"
465
+ object_type = "temp_table"
466
+ else:
467
+ # Regular table - use full qualified name with database context
468
+ table_name = self._get_full_table_name(table_name)
469
+ namespace = "mssql://localhost/InfoTrackerDW"
470
+ object_type = "table"
471
+
472
+ # Get full procedure name for dependencies and lineage
473
+ proc_full_name = self._get_full_table_name(proc_name)
474
+ proc_full_name = sanitize_name(proc_full_name)
475
+
476
+ # Add the procedure to dependencies
477
+ all_dependencies.add(proc_full_name)
478
+
479
+ else:
480
+ # No INSERT pattern found - create a generic script object
481
+ if all_dependencies:
482
+ table_name = sanitize_name(object_hint or "script_output")
483
+ namespace = "mssql://localhost/InfoTrackerDW"
484
+ object_type = "script"
485
+ else:
486
+ # No dependencies found at all
487
+ return None
488
+
126
489
  # Create schema
127
490
  schema = TableSchema(
128
491
  namespace=namespace,
@@ -130,43 +493,88 @@ class SqlParser:
130
493
  columns=placeholder_columns
131
494
  )
132
495
 
133
- # Create lineage for each placeholder column
496
+ # Create lineage using all dependencies
134
497
  lineage = []
135
- for col in placeholder_columns:
136
- lineage.append(ColumnLineage(
137
- output_column=col.name,
138
- input_fields=[
139
- ColumnReference(
140
- namespace="mssql://localhost/InfoTrackerDW",
141
- table_name=proc_name, # Clean procedure name without semicolons
142
- column_name="*"
143
- )
144
- ],
145
- transformation_type=TransformationType.EXEC,
146
- transformation_description=f"INSERT INTO {table_name} EXEC {proc_name}"
147
- ))
148
-
149
- # Set dependencies to the clean procedure name
150
- dependencies = {proc_name}
498
+ if table_match and not table_match.group(1).startswith('#') and placeholder_columns:
499
+ # For INSERT INTO table with columns, create intelligent lineage mapping
500
+ # Look for EXEC pattern in the same file to map columns to procedure output
501
+ proc_pattern = r'(?is)INSERT\s+INTO\s+#\w+\s+EXEC\s+([^\s(;]+)'
502
+ proc_match = re.search(proc_pattern, sql_pre)
503
+
504
+ if proc_match:
505
+ proc_name = self._clean_proc_name(proc_match.group(1))
506
+ proc_full_name = self._get_full_table_name(proc_name)
507
+ proc_full_name = sanitize_name(proc_full_name)
508
+
509
+ for i, col in enumerate(placeholder_columns):
510
+ if col.name.lower() in ['archivedate', 'createdate', 'insertdate'] and 'getdate' in sql_pre.lower():
511
+ # CONSTANT for date columns that use GETDATE()
512
+ lineage.append(ColumnLineage(
513
+ output_column=col.name,
514
+ input_fields=[],
515
+ transformation_type=TransformationType.CONSTANT,
516
+ transformation_description=f"GETDATE() constant value for archiving"
517
+ ))
518
+ else:
519
+ # IDENTITY mapping from procedure output
520
+ lineage.append(ColumnLineage(
521
+ output_column=col.name,
522
+ input_fields=[
523
+ ColumnReference(
524
+ namespace="mssql://localhost/InfoTrackerDW",
525
+ table_name=proc_full_name,
526
+ column_name=col.name
527
+ )
528
+ ],
529
+ transformation_type=TransformationType.IDENTITY,
530
+ transformation_description=f"{col.name} from procedure output via temp table"
531
+ ))
532
+ else:
533
+ # Fallback to generic mapping
534
+ for col in placeholder_columns:
535
+ lineage.append(ColumnLineage(
536
+ output_column=col.name,
537
+ input_fields=[],
538
+ transformation_type=TransformationType.UNKNOWN,
539
+ transformation_description=f"Column {col.name} from complex transformation"
540
+ ))
541
+ elif exec_match:
542
+ # For INSERT EXEC, create specific lineage
543
+ proc_full_name = self._get_full_table_name(self._clean_proc_name(exec_match.group(2)))
544
+ proc_full_name = sanitize_name(proc_full_name)
545
+ for col in placeholder_columns:
546
+ lineage.append(ColumnLineage(
547
+ output_column=col.name,
548
+ input_fields=[
549
+ ColumnReference(
550
+ namespace="mssql://localhost/InfoTrackerDW",
551
+ table_name=proc_full_name,
552
+ column_name="*"
553
+ )
554
+ ],
555
+ transformation_type=TransformationType.EXEC,
556
+ transformation_description=f"INSERT INTO {table_name} EXEC {proc_full_name}"
557
+ ))
151
558
 
152
559
  # Register schema in registry
153
560
  self.schema_registry.register(schema)
154
561
 
155
- # Create and return ObjectInfo with table_name as name (not object_hint)
562
+ # Create and return ObjectInfo with enhanced dependencies
156
563
  return ObjectInfo(
157
564
  name=table_name,
158
565
  object_type=object_type,
159
566
  schema=schema,
160
567
  lineage=lineage,
161
- dependencies=dependencies
568
+ dependencies=all_dependencies, # Use all extracted dependencies
569
+ is_fallback=True
162
570
  )
163
571
 
164
572
  def _find_last_select_string(self, sql_content: str, dialect: str = "tsql") -> str | None:
165
573
  """Find the last SELECT statement in SQL content using SQLGlot AST."""
166
- import sqlglot
167
- from sqlglot import exp
168
574
  try:
169
- parsed = sqlglot.parse(sql_content, read=dialect)
575
+ normalized = self._normalize_tsql(sql_content)
576
+ preprocessed = self._preprocess_sql(normalized)
577
+ parsed = sqlglot.parse(preprocessed, read=self.dialect)
170
578
  selects = []
171
579
  for stmt in parsed:
172
580
  selects.extend(list(stmt.find_all(exp.Select)))
@@ -174,37 +582,241 @@ class SqlParser:
174
582
  return None
175
583
  return str(selects[-1])
176
584
  except Exception:
177
- return None
585
+ # Fallback to string-based SELECT extraction for procedures
586
+ return self._find_last_select_string_fallback(sql_content)
587
+
588
+ def _find_last_select_string_fallback(self, sql_content: str) -> str | None:
589
+ """Fallback method to find last SELECT using string parsing."""
590
+ try:
591
+ # For procedures, find the last SELECT statement that goes to the end of the procedure
592
+ # Look for the last occurrence of SELECT and take everything until END
593
+
594
+ # First, find all SELECT positions
595
+ select_positions = []
596
+ for match in re.finditer(r'\bSELECT\b', sql_content, re.IGNORECASE):
597
+ select_positions.append(match.start())
598
+
599
+ if not select_positions:
600
+ return None
601
+
602
+ # Take the last SELECT position
603
+ last_select_pos = select_positions[-1]
604
+
605
+ # Get everything from the last SELECT to the end, but stop at END
606
+ remaining_content = sql_content[last_select_pos:]
607
+
608
+ # Find the procedure END (but not CASE END)
609
+ # Look for END at the start of a line or END followed by semicolon
610
+ end_pattern = r'(?i)(?:^|\n)\s*END\s*(?:;|\s*$)'
611
+ end_match = re.search(end_pattern, remaining_content)
612
+
613
+ if end_match:
614
+ last_select = remaining_content[:end_match.start()].strip()
615
+ else:
616
+ last_select = remaining_content.strip()
617
+
618
+ # Clean up any trailing semicolons
619
+ last_select = re.sub(r';\s*$', '', last_select)
620
+
621
+ return last_select
622
+
623
+ except Exception as e:
624
+ logger.debug(f"Fallback SELECT extraction failed: {e}")
625
+
626
+ return None
178
627
 
179
628
  def parse_sql_file(self, sql_content: str, object_hint: Optional[str] = None) -> ObjectInfo:
180
629
  """Parse a SQL file and extract object information."""
630
+ from .openlineage_utils import sanitize_name
631
+
632
+ # Reset current database to default for each file
633
+ self.current_database = self.default_database
634
+
635
+ # Reset registries for each file to avoid contamination
636
+ self.cte_registry.clear()
637
+ self.temp_registry.clear()
638
+
181
639
  try:
182
- # First check if this is a function or procedure using string matching
640
+ # Check if this file contains multiple objects and handle accordingly
183
641
  sql_upper = sql_content.upper()
184
- if "CREATE FUNCTION" in sql_upper or "CREATE OR ALTER FUNCTION" in sql_upper:
642
+
643
+ # Count how many CREATE statements we have
644
+ create_function_count = sql_upper.count('CREATE FUNCTION') + sql_upper.count('CREATE OR ALTER FUNCTION')
645
+ create_procedure_count = sql_upper.count('CREATE PROCEDURE') + sql_upper.count('CREATE OR ALTER PROCEDURE')
646
+ create_table_count = sql_upper.count('CREATE TABLE') + sql_upper.count('CREATE OR ALTER TABLE')
647
+
648
+ if create_table_count == 1 and all(x == 0 for x in [create_function_count, create_procedure_count]):
649
+ # spróbuj najpierw AST; jeśli SQLGlot zwróci Command albo None — fallback stringowy
650
+ try:
651
+ normalized_sql = self._normalize_tsql(sql_content)
652
+ statements = sqlglot.parse(self._preprocess_sql(normalized_sql), read=self.dialect) or []
653
+ st = statements[0] if statements else None
654
+ if st and isinstance(st, exp.Create) and (getattr(st, "kind", "") or "").upper() == "TABLE":
655
+ return self._parse_create_table(st, object_hint)
656
+ except Exception:
657
+ pass
658
+ return self._parse_create_table_string(sql_content, object_hint)
659
+ # If it's a single function or procedure, use string-based approach
660
+ if create_function_count == 1 and create_procedure_count == 0:
185
661
  return self._parse_function_string(sql_content, object_hint)
186
- elif "CREATE PROCEDURE" in sql_upper or "CREATE OR ALTER PROCEDURE" in sql_upper:
662
+ elif create_procedure_count == 1 and create_function_count == 0:
187
663
  return self._parse_procedure_string(sql_content, object_hint)
188
664
 
665
+ # If it's multiple functions but no procedures, process the first function as primary
666
+ # This handles files like 94_fn_customer_orders_tvf.sql with multiple function variants
667
+ elif create_function_count > 1 and create_procedure_count == 0:
668
+ # Extract and process the first function only for detailed lineage
669
+ first_function_sql = self._extract_first_create_statement(sql_content, 'FUNCTION')
670
+ if first_function_sql:
671
+ return self._parse_function_string(first_function_sql, object_hint)
672
+
673
+ # If multiple objects or mixed content, use multi-statement processing
674
+ # This handles demo scripts with multiple functions/procedures/statements
675
+
189
676
  # Preprocess the SQL content to handle demo script patterns
190
- preprocessed_sql = self._preprocess_sql(sql_content)
677
+ # This will also extract and set current_database from USE statements
678
+ normalized_sql = self._normalize_tsql(sql_content)
679
+ preprocessed_sql = self._preprocess_sql(normalized_sql)
191
680
 
192
- # Parse the SQL statement with SQLGlot
681
+ # For files with complex IF/ELSE blocks, also try string-based extraction
682
+ # This is needed for demo scripts like 96_demo_usage_tvf_and_proc.sql
683
+ string_deps = set()
684
+ # Parse all SQL statements with SQLGlot
193
685
  statements = sqlglot.parse(preprocessed_sql, read=self.dialect)
686
+
687
+ # Apply AST rewrites to improve parsing
688
+ if statements:
689
+ statements = [self._rewrite_ast(s) for s in statements]
194
690
  if not statements:
195
- raise ValueError("No valid SQL statements found")
691
+ # If SQLGlot parsing fails completely, try to extract dependencies with string parsing
692
+ dependencies = self._extract_basic_dependencies(preprocessed_sql)
693
+ return ObjectInfo(
694
+ name=object_hint or self._get_fallback_name(sql_content),
695
+ object_type="script",
696
+ schema=[],
697
+ dependencies=dependencies,
698
+ lineage=[]
699
+ )
700
+
701
+ # Process the entire script - aggregate across all statements
702
+ all_inputs = set()
703
+ all_outputs = []
704
+ main_object = None
705
+ last_persistent_output = None
706
+
707
+ # Process all statements in order
708
+ for statement in statements:
709
+ if isinstance(statement, exp.Create):
710
+ # This is the main object being created
711
+ obj = self._parse_create_statement(statement, object_hint)
712
+ if obj.object_type in ["table", "view", "function", "procedure"]:
713
+ last_persistent_output = obj
714
+ # Add inputs from DDL statements
715
+ all_inputs.update(obj.dependencies)
716
+
717
+ elif isinstance(statement, exp.Select) and self._is_select_into(statement):
718
+ # SELECT ... INTO creates a table/temp table
719
+ obj = self._parse_select_into(statement, object_hint)
720
+ all_outputs.append(obj)
721
+ # Check if it's persistent (not temp)
722
+ if not obj.name.startswith("#") and "tempdb" not in obj.name:
723
+ last_persistent_output = obj
724
+ all_inputs.update(obj.dependencies)
725
+
726
+ elif isinstance(statement, exp.Select):
727
+ # Loose SELECT statement - extract dependencies but no output
728
+ self._process_ctes(statement)
729
+ stmt_deps = self._extract_dependencies(statement)
730
+
731
+ # Expand CTEs and temp tables to base tables
732
+ for dep in stmt_deps:
733
+ expanded_deps = self._expand_dependency_to_base_tables(dep, statement)
734
+ all_inputs.update(expanded_deps)
735
+
736
+ elif isinstance(statement, exp.Insert):
737
+ if self._is_insert_exec(statement):
738
+ # INSERT INTO ... EXEC
739
+ obj = self._parse_insert_exec(statement, object_hint)
740
+ all_outputs.append(obj)
741
+ if not obj.name.startswith("#") and "tempdb" not in obj.name:
742
+ last_persistent_output = obj
743
+ all_inputs.update(obj.dependencies)
744
+ else:
745
+ # INSERT INTO ... SELECT - this handles persistent tables
746
+ obj = self._parse_insert_select(statement, object_hint)
747
+ if obj:
748
+ all_outputs.append(obj)
749
+ # Check if this is a persistent table (main output)
750
+ if not obj.name.startswith("#") and "tempdb" not in obj.name.lower():
751
+ last_persistent_output = obj
752
+ all_inputs.update(obj.dependencies)
753
+
754
+ # Extra: guard for INSERT variants parsed oddly by SQLGlot (Command inside expression)
755
+ elif hasattr(statement, "this") and isinstance(statement, exp.Table) and "INSERT" in str(statement).upper():
756
+ # Best-effort: try _parse_insert_select fallback if AST is quirky
757
+ try:
758
+ obj = self._parse_insert_select(statement, object_hint)
759
+ if obj:
760
+ all_outputs.append(obj)
761
+ if not obj.name.startswith("#") and "tempdb" not in obj.name.lower():
762
+ last_persistent_output = obj
763
+ all_inputs.update(obj.dependencies)
764
+ except Exception:
765
+ pass
766
+
767
+ elif isinstance(statement, exp.With):
768
+ # Process WITH statements (CTEs)
769
+ if hasattr(statement, 'this') and isinstance(statement.this, exp.Select):
770
+ self._process_ctes(statement.this)
771
+ stmt_deps = self._extract_dependencies(statement.this)
772
+ for dep in stmt_deps:
773
+ expanded_deps = self._expand_dependency_to_base_tables(dep, statement.this)
774
+ all_inputs.update(expanded_deps)
775
+
776
+ # Remove CTE references from final inputs
777
+ all_inputs = {dep for dep in all_inputs if not self._is_cte_reference(dep)}
196
778
 
197
- # For now, handle single statement per file
198
- statement = statements[0]
779
+ # Sanitize all input names
780
+ all_inputs = {sanitize_name(dep) for dep in all_inputs if dep}
781
+ def _strip_db(name: str) -> str:
782
+ parts = (name or "").split(".")
783
+ return ".".join(parts[-2:]) if len(parts) >= 2 else (name or "")
784
+
785
+ out_key = _strip_db(sanitize_name(last_persistent_output.schema.name or last_persistent_output.name))
786
+ all_inputs = {d for d in all_inputs if _strip_db(sanitize_name(d)) != out_key}
787
+ # Determine the main object
788
+ if last_persistent_output:
789
+ # Use the last persistent output as the main object
790
+ main_object = last_persistent_output
791
+ # Update its dependencies with all collected inputs
792
+ main_object.dependencies = all_inputs
793
+ elif all_outputs:
794
+ # Use the last output if no persistent one found
795
+ main_object = all_outputs[-1]
796
+ main_object.dependencies = all_inputs
797
+ elif all_inputs:
798
+ # Create a file-level object with aggregated inputs (for demo scripts)
799
+ main_object = ObjectInfo(
800
+ name=sanitize_name(object_hint or "loose_statements"),
801
+ object_type="script",
802
+ schema=TableSchema(
803
+ namespace="mssql://localhost/InfoTrackerDW",
804
+ name=sanitize_name(object_hint or "loose_statements"),
805
+ columns=[]
806
+ ),
807
+ lineage=[],
808
+ dependencies=all_inputs
809
+ )
810
+ # Add no-output reason for diagnostics
811
+ if not self.current_database and not self.default_database:
812
+ main_object.no_output_reason = "UNKNOWN_DB_CONTEXT"
813
+ else:
814
+ main_object.no_output_reason = "NO_PERSISTENT_OUTPUT_DETECTED"
199
815
 
200
- if isinstance(statement, exp.Create):
201
- return self._parse_create_statement(statement, object_hint)
202
- elif isinstance(statement, exp.Select) and self._is_select_into(statement):
203
- return self._parse_select_into(statement, object_hint)
204
- elif isinstance(statement, exp.Insert) and self._is_insert_exec(statement):
205
- return self._parse_insert_exec(statement, object_hint)
816
+ if main_object:
817
+ return main_object
206
818
  else:
207
- raise ValueError(f"Unsupported statement type: {type(statement)}")
819
+ raise ValueError("No valid statements found to process")
208
820
 
209
821
  except Exception as e:
210
822
  # Try fallback for INSERT INTO #temp EXEC pattern
@@ -215,11 +827,11 @@ class SqlParser:
215
827
  logger.warning("parse failed: %s", e)
216
828
  # Return an object with error information
217
829
  return ObjectInfo(
218
- name=object_hint or "unknown",
830
+ name=sanitize_name(object_hint or "unknown"),
219
831
  object_type="unknown",
220
832
  schema=TableSchema(
221
833
  namespace="mssql://localhost/InfoTrackerDW",
222
- name=object_hint or "unknown",
834
+ name=sanitize_name(object_hint or "unknown"),
223
835
  columns=[]
224
836
  ),
225
837
  lineage=[],
@@ -253,7 +865,7 @@ class SqlParser:
253
865
 
254
866
  # Normalize temp table names
255
867
  if table_name.startswith('#'):
256
- namespace = "tempdb"
868
+ namespace = "mssql://localhost/tempdb"
257
869
 
258
870
  # Extract dependencies (tables referenced in FROM/JOIN)
259
871
  dependencies = self._extract_dependencies(statement)
@@ -261,6 +873,11 @@ class SqlParser:
261
873
  # Extract column lineage
262
874
  lineage, output_columns = self._extract_column_lineage(statement, table_name)
263
875
 
876
+ # Register temp table columns if this is a temp table
877
+ if table_name.startswith('#'):
878
+ temp_cols = [col.name for col in output_columns]
879
+ self.temp_registry[table_name] = temp_cols
880
+
264
881
  schema = TableSchema(
265
882
  namespace=namespace,
266
883
  name=table_name,
@@ -286,7 +903,7 @@ class SqlParser:
286
903
 
287
904
  # Normalize temp table names
288
905
  if table_name.startswith('#'):
289
- namespace = "tempdb"
906
+ namespace = "mssql://localhost/tempdb"
290
907
 
291
908
  # Extract the EXEC command
292
909
  expression = statement.expression
@@ -303,7 +920,9 @@ class SqlParser:
303
920
  # Extract procedure name (first identifier after EXEC)
304
921
  parts = exec_text.split()
305
922
  if len(parts) > 1:
306
- procedure_name = self._clean_proc_name(parts[1])
923
+ raw_proc_name = self._clean_proc_name(parts[1])
924
+ # Ensure proper qualification for procedures
925
+ procedure_name = self._get_full_table_name(raw_proc_name)
307
926
  dependencies.add(procedure_name)
308
927
 
309
928
  # For EXEC temp tables, we create placeholder columns since we can't determine
@@ -359,6 +978,55 @@ class SqlParser:
359
978
  # Fallback if we can't parse the EXEC command
360
979
  raise ValueError("Could not parse INSERT INTO ... EXEC statement")
361
980
 
981
+ def _parse_insert_select(self, statement: exp.Insert, object_hint: Optional[str] = None) -> Optional[ObjectInfo]:
982
+ """Parse INSERT INTO ... SELECT statement."""
983
+ from .openlineage_utils import sanitize_name
984
+
985
+ # Get target table name from INSERT INTO clause
986
+ table_name = self._get_table_name(statement.this, object_hint)
987
+ namespace = "mssql://localhost/InfoTrackerDW"
988
+
989
+ # Normalize temp table names
990
+ if table_name.startswith('#') or 'tempdb' in table_name:
991
+ namespace = "mssql://localhost/tempdb"
992
+
993
+ # Extract the SELECT part
994
+ select_expr = statement.expression
995
+ if not isinstance(select_expr, exp.Select):
996
+ return None
997
+
998
+ # Extract dependencies (tables referenced in FROM/JOIN)
999
+ dependencies = self._extract_dependencies(select_expr)
1000
+
1001
+ # Extract column lineage
1002
+ lineage, output_columns = self._extract_column_lineage(select_expr, table_name)
1003
+
1004
+ # Sanitize table name
1005
+ table_name = sanitize_name(table_name)
1006
+
1007
+ # Register temp table columns if this is a temp table
1008
+ if table_name.startswith('#') or 'tempdb' in table_name:
1009
+ temp_cols = [col.name for col in output_columns]
1010
+ simple_name = table_name.split('.')[-1]
1011
+ self.temp_registry[simple_name] = temp_cols
1012
+
1013
+ schema = TableSchema(
1014
+ namespace=namespace,
1015
+ name=table_name,
1016
+ columns=output_columns
1017
+ )
1018
+
1019
+ # Register schema for future reference
1020
+ self.schema_registry.register(schema)
1021
+
1022
+ return ObjectInfo(
1023
+ name=table_name,
1024
+ object_type="temp_table" if (table_name.startswith('#') or 'tempdb' in table_name) else "table",
1025
+ schema=schema,
1026
+ lineage=lineage,
1027
+ dependencies=dependencies
1028
+ )
1029
+
362
1030
  def _parse_create_statement(self, statement: exp.Create, object_hint: Optional[str] = None) -> ObjectInfo:
363
1031
  """Parse CREATE TABLE, CREATE VIEW, CREATE FUNCTION, or CREATE PROCEDURE statement."""
364
1032
  if statement.kind == "TABLE":
@@ -412,6 +1080,81 @@ class SqlParser:
412
1080
  dependencies=set()
413
1081
  )
414
1082
 
1083
+ def _parse_create_table_string(self, sql: str, object_hint: Optional[str] = None) -> ObjectInfo:
1084
+ # 1) Wyciągnij nazwę tabeli
1085
+ m = re.search(r'(?is)CREATE\s+TABLE\s+([^\s(]+)', sql)
1086
+ table_name = self._get_full_table_name(self._normalize_table_ident(m.group(1))) if m else (object_hint or "dbo.unknown_table")
1087
+ namespace = "mssql://localhost/InfoTrackerDW"
1088
+
1089
+ # 2) Wyciągnij definicję kolumn (balansowane nawiasy od pierwszego '(' po nazwie)
1090
+ s = self._normalize_tsql(sql)
1091
+ m = re.search(r'(?is)\bCREATE\s+TABLE\s+([^\s(]+)', s)
1092
+ start = s.find('(', m.end()) if m else -1
1093
+ if start == -1:
1094
+ schema = TableSchema(namespace=namespace, name=table_name, columns=[])
1095
+ self.schema_registry.register(schema)
1096
+ return ObjectInfo(name=table_name, object_type="table", schema=schema, lineage=[], dependencies=set())
1097
+
1098
+ depth, i, end = 0, start, len(s)
1099
+ while i < len(s):
1100
+ ch = s[i]
1101
+ if ch == '(':
1102
+ depth += 1
1103
+ elif ch == ')':
1104
+ depth -= 1
1105
+ if depth == 0:
1106
+ end = i
1107
+ break
1108
+ i += 1
1109
+ body = s[start+1:end]
1110
+
1111
+ # 3) Podziel na wiersze definicji kolumn (odetnij constrainty tabelowe)
1112
+ lines = []
1113
+ depth = 0
1114
+ token = []
1115
+ for ch in body:
1116
+ if ch == '(':
1117
+ depth += 1
1118
+ elif ch == ')':
1119
+ depth -= 1
1120
+ if ch == ',' and depth == 0:
1121
+ lines.append(''.join(token).strip())
1122
+ token = []
1123
+ else:
1124
+ token.append(ch)
1125
+ if token:
1126
+ lines.append(''.join(token).strip())
1127
+ # odfiltruj klauzule constraintów tabelowych
1128
+ col_lines = [ln for ln in lines if not re.match(r'(?i)^(CONSTRAINT|PRIMARY\s+KEY|UNIQUE|FOREIGN\s+KEY|CHECK)\b', ln)]
1129
+
1130
+ cols: List[ColumnSchema] = []
1131
+ for i, ln in enumerate(col_lines):
1132
+ # nazwa kolumny w nawiasach/[] lub goła
1133
+ m = re.match(r'\s*(?:\[([^\]]+)\]|"([^"]+)"|([A-Za-z_][\w$#]*))\s+(.*)$', ln)
1134
+ if not m:
1135
+ continue
1136
+ col_name = next(g for g in m.groups()[:3] if g)
1137
+ rest = m.group(4)
1138
+
1139
+ # typ: pierwszy token (może być typu NVARCHAR(100) itp.) — bierz nazwę typu + (opcjonalnie) długość
1140
+ # typ: lub varchar(32)
1141
+ t = re.match(r'(?i)\s*(?:\[(?P<t1>[^\]]+)\]|(?P<t2>[A-Za-z_][\w$]*))\s*(?:\(\s*(?P<args>[^)]*?)\s*\))?', rest)
1142
+ if t:
1143
+ tname = (t.group('t1') or t.group('t2') or '').upper()
1144
+ targs = t.group('args')
1145
+ dtype = f"{tname}({targs})" if targs else tname
1146
+ else:
1147
+ dtype = "UNKNOWN"
1148
+
1149
+ # nullable / not null
1150
+ nullable = not re.search(r'(?i)\bNOT\s+NULL\b', rest)
1151
+
1152
+ cols.append(ColumnSchema(name=col_name, data_type=dtype, nullable=nullable, ordinal=i))
1153
+
1154
+ schema = TableSchema(namespace=namespace, name=table_name, columns=cols)
1155
+ self.schema_registry.register(schema)
1156
+ return ObjectInfo(name=table_name, object_type="table", schema=schema, lineage=[], dependencies=set())
1157
+
415
1158
  def _parse_create_view(self, statement: exp.Create, object_hint: Optional[str] = None) -> ObjectInfo:
416
1159
  """Parse CREATE VIEW statement."""
417
1160
  view_name = self._get_table_name(statement.this, object_hint)
@@ -449,13 +1192,20 @@ class SqlParser:
449
1192
  # Register schema for future reference
450
1193
  self.schema_registry.register(schema)
451
1194
 
452
- return ObjectInfo(
1195
+ # Create object
1196
+ obj = ObjectInfo(
453
1197
  name=view_name,
454
1198
  object_type="view",
455
1199
  schema=schema,
456
1200
  lineage=lineage,
457
1201
  dependencies=dependencies
458
1202
  )
1203
+
1204
+ # Apply header column names if CREATE VIEW (col1, col2, ...) AS pattern
1205
+ if isinstance(select_stmt, exp.Select):
1206
+ self._apply_view_header_names(statement, select_stmt, obj)
1207
+
1208
+ return obj
459
1209
 
460
1210
  def _parse_create_function(self, statement: exp.Create, object_hint: Optional[str] = None) -> ObjectInfo:
461
1211
  """Parse CREATE FUNCTION statement (table-valued functions only)."""
@@ -502,7 +1252,27 @@ class SqlParser:
502
1252
  procedure_name = self._get_table_name(statement.this, object_hint)
503
1253
  namespace = "mssql://localhost/InfoTrackerDW"
504
1254
 
505
- # Extract the procedure body and find the last SELECT statement
1255
+ # Extract the procedure body and find materialized outputs (SELECT INTO, INSERT INTO)
1256
+ materialized_outputs = self._extract_procedure_outputs(statement)
1257
+
1258
+ # If we have materialized outputs, return the last one instead of the procedure
1259
+ if materialized_outputs:
1260
+ last_output = materialized_outputs[-1]
1261
+ # Extract lineage for the materialized output
1262
+ lineage, output_columns, dependencies = self._extract_procedure_lineage(statement, procedure_name)
1263
+
1264
+ # Update the output object with proper lineage and dependencies
1265
+ last_output.lineage = lineage
1266
+ last_output.dependencies = dependencies
1267
+ if output_columns:
1268
+ last_output.schema = TableSchema(
1269
+ namespace=last_output.schema.namespace,
1270
+ name=last_output.name,
1271
+ columns=output_columns
1272
+ )
1273
+ return last_output
1274
+
1275
+ # Fall back to regular procedure parsing if no materialized outputs
506
1276
  lineage, output_columns, dependencies = self._extract_procedure_lineage(statement, procedure_name)
507
1277
 
508
1278
  schema = TableSchema(
@@ -514,33 +1284,119 @@ class SqlParser:
514
1284
  # Register schema for future reference
515
1285
  self.schema_registry.register(schema)
516
1286
 
517
- return ObjectInfo(
1287
+ # Add reason for procedure with no materialized output
1288
+ obj = ObjectInfo(
518
1289
  name=procedure_name,
519
1290
  object_type="procedure",
520
1291
  schema=schema,
521
1292
  lineage=lineage,
522
1293
  dependencies=dependencies
523
1294
  )
1295
+ obj.no_output_reason = "ONLY_PROCEDURE_RESULTSET"
1296
+ return obj
1297
+
1298
+ def _extract_procedure_outputs(self, statement: exp.Create) -> List[ObjectInfo]:
1299
+ """Extract materialized outputs (SELECT INTO, INSERT INTO) from procedure body."""
1300
+ outputs = []
1301
+ sql_text = str(statement)
1302
+
1303
+ # Look for SELECT ... INTO patterns
1304
+ select_into_pattern = r'(?i)SELECT\s+.*?\s+INTO\s+([^\s,]+)'
1305
+ select_into_matches = re.findall(select_into_pattern, sql_text, re.DOTALL)
1306
+
1307
+ for table_match in select_into_matches:
1308
+ table_name = table_match.strip()
1309
+ # Skip temp tables
1310
+ if not table_name.startswith('#') and 'tempdb' not in table_name.lower():
1311
+ # Normalize table name - remove database prefix for output
1312
+ normalized_name = self._normalize_table_name_for_output(table_name)
1313
+ outputs.append(ObjectInfo(
1314
+ name=normalized_name,
1315
+ object_type="table",
1316
+ schema=TableSchema(
1317
+ namespace="mssql://localhost/InfoTrackerDW",
1318
+ name=normalized_name,
1319
+ columns=[]
1320
+ ),
1321
+ lineage=[],
1322
+ dependencies=set()
1323
+ ))
1324
+
1325
+ # Look for INSERT INTO patterns (non-temp tables)
1326
+ insert_into_pattern = r'(?i)INSERT\s+INTO\s+([^\s,\(]+)'
1327
+ insert_into_matches = re.findall(insert_into_pattern, sql_text)
1328
+
1329
+ for table_match in insert_into_matches:
1330
+ table_name = table_match.strip()
1331
+ # Skip temp tables
1332
+ if not table_name.startswith('#') and 'tempdb' not in table_name.lower():
1333
+ normalized_name = self._normalize_table_name_for_output(table_name)
1334
+ # Check if we already have this table from SELECT INTO
1335
+ if not any(output.name == normalized_name for output in outputs):
1336
+ outputs.append(ObjectInfo(
1337
+ name=normalized_name,
1338
+ object_type="table",
1339
+ schema=TableSchema(
1340
+ namespace="mssql://localhost/InfoTrackerDW",
1341
+ name=normalized_name,
1342
+ columns=[]
1343
+ ),
1344
+ lineage=[],
1345
+ dependencies=set()
1346
+ ))
1347
+
1348
+ return outputs
1349
+
1350
+ def _normalize_table_name_for_output(self, table_name: str) -> str:
1351
+ """Normalize table name for output - remove database prefix, keep schema.table format."""
1352
+ from .openlineage_utils import sanitize_name
1353
+
1354
+ # Clean up the table name
1355
+ table_name = sanitize_name(table_name)
1356
+
1357
+ # Remove database prefix if present (keep only schema.table)
1358
+ parts = table_name.split('.')
1359
+ if len(parts) >= 3:
1360
+ # database.schema.table -> schema.table
1361
+ return f"{parts[-2]}.{parts[-1]}"
1362
+ elif len(parts) == 2:
1363
+ # schema.table -> keep as is
1364
+ return table_name
1365
+ else:
1366
+ # just table -> add dbo prefix
1367
+ return f"dbo.{table_name}"
524
1368
 
525
1369
  def _get_table_name(self, table_expr: exp.Expression, hint: Optional[str] = None) -> str:
526
- """Extract table name from expression and qualify with default database if needed."""
527
- from .openlineage_utils import qualify_identifier
1370
+ """Extract table name from expression and qualify with current or default database."""
1371
+ from .openlineage_utils import qualify_identifier, sanitize_name
1372
+
1373
+ # Use current database from USE statement or fall back to default
1374
+ database_to_use = self.current_database or self.default_database
528
1375
 
529
1376
  if isinstance(table_expr, exp.Table):
530
1377
  # Handle three-part names: database.schema.table
531
1378
  if table_expr.catalog and table_expr.db:
532
- return f"{table_expr.catalog}.{table_expr.db}.{table_expr.name}"
1379
+ full_name = f"{table_expr.catalog}.{table_expr.db}.{table_expr.name}"
533
1380
  # Handle two-part names like dbo.table_name (legacy format)
534
1381
  elif table_expr.db:
535
1382
  table_name = f"{table_expr.db}.{table_expr.name}"
536
- return qualify_identifier(table_name, self.default_database)
1383
+ full_name = qualify_identifier(table_name, database_to_use)
537
1384
  else:
538
1385
  table_name = str(table_expr.name)
539
- return qualify_identifier(table_name, self.default_database)
1386
+ full_name = qualify_identifier(table_name, database_to_use)
540
1387
  elif isinstance(table_expr, exp.Identifier):
541
1388
  table_name = str(table_expr.this)
542
- return qualify_identifier(table_name, self.default_database)
543
- return hint or "unknown"
1389
+ full_name = qualify_identifier(table_name, database_to_use)
1390
+ else:
1391
+ full_name = hint or "unknown"
1392
+
1393
+ # Apply consistent temp table namespace handling
1394
+ if full_name and full_name.startswith('#'):
1395
+ # Temp table - use consistent namespace and naming convention
1396
+ temp_name = full_name.lstrip('#')
1397
+ return f"tempdb..#{temp_name}"
1398
+
1399
+ return sanitize_name(full_name)
544
1400
 
545
1401
  def _extract_column_type(self, column_def: exp.ColumnDef) -> str:
546
1402
  """Extract column type from column definition."""
@@ -606,11 +1462,28 @@ class SqlParser:
606
1462
 
607
1463
  select_stmt = stmt
608
1464
 
1465
+ # Process CTEs first to build registry
1466
+ self._process_ctes(select_stmt)
1467
+
609
1468
  # Use find_all to get all table references (FROM, JOIN, etc.)
610
1469
  for table in select_stmt.find_all(exp.Table):
611
1470
  table_name = self._get_table_name(table)
612
1471
  if table_name != "unknown":
613
- dependencies.add(table_name)
1472
+ # Check if this is a CTE - if so, get its base dependencies instead
1473
+ simple_name = table_name.split('.')[-1]
1474
+ if simple_name in self.cte_registry:
1475
+ # This is a CTE reference - get dependencies from CTE definition
1476
+ with_clause = select_stmt.args.get('with')
1477
+ if with_clause and hasattr(with_clause, 'expressions'):
1478
+ for cte in with_clause.expressions:
1479
+ if hasattr(cte, 'alias') and str(cte.alias) == simple_name:
1480
+ if isinstance(cte.this, exp.Select):
1481
+ cte_deps = self._extract_dependencies(cte.this)
1482
+ dependencies.update(cte_deps)
1483
+ break
1484
+ else:
1485
+ # Regular table dependency
1486
+ dependencies.add(table_name)
614
1487
 
615
1488
  # Also check for subqueries and CTEs
616
1489
  for subquery in select_stmt.find_all(exp.Subquery):
@@ -621,7 +1494,7 @@ class SqlParser:
621
1494
  return dependencies
622
1495
 
623
1496
  def _extract_column_lineage(self, stmt: exp.Expression, view_name: str) -> tuple[List[ColumnLineage], List[ColumnSchema]]:
624
- """Extract column lineage from SELECT or UNION statement."""
1497
+ """Extract column lineage from SELECT or UNION statement using enhanced alias mapping."""
625
1498
  lineage = []
626
1499
  output_columns = []
627
1500
 
@@ -635,6 +1508,9 @@ class SqlParser:
635
1508
 
636
1509
  select_stmt = stmt
637
1510
 
1511
+ # Build alias maps for proper resolution
1512
+ alias_map, derived_cols = self._build_alias_maps(select_stmt)
1513
+
638
1514
  # Try to get projections with fallback
639
1515
  projections = list(getattr(select_stmt, 'expressions', None) or [])
640
1516
  if not projections:
@@ -648,39 +1524,67 @@ class SqlParser:
648
1524
  if self._has_union(select_stmt):
649
1525
  return self._handle_union_lineage(select_stmt, view_name)
650
1526
 
651
- # Standard column-by-column processing
652
- for i, select_expr in enumerate(projections):
653
- if isinstance(select_expr, exp.Alias):
654
- # Aliased column: SELECT column AS alias
655
- output_name = str(select_expr.alias)
656
- source_expr = select_expr.this
1527
+ # Enhanced column-by-column processing
1528
+ ordinal = 0
1529
+ for proj in projections:
1530
+ # Decide output name using enhanced logic
1531
+ if isinstance(proj, exp.Alias):
1532
+ out_name = proj.alias or proj.alias_or_name
1533
+ inner = proj.this
657
1534
  else:
658
- # Direct column reference or expression
659
- # For direct column references, extract just the column name
660
- if isinstance(select_expr, exp.Column):
661
- output_name = str(select_expr.this) # Just the column name, not table.column
662
- else:
663
- output_name = str(select_expr)
664
- source_expr = select_expr
1535
+ # Generate smart fallback names based on expression type
1536
+ s = str(proj).upper()
1537
+ if "HASHBYTES(" in s or "MD5(" in s:
1538
+ out_name = "hash_expr"
1539
+ elif isinstance(proj, exp.Coalesce):
1540
+ out_name = "coalesce_expr"
1541
+ elif isinstance(proj, (exp.Trim, exp.Upper, exp.Lower)):
1542
+ col = proj.find(exp.Column)
1543
+ out_name = (col.name if col else "text_expr")
1544
+ elif isinstance(proj, (exp.Cast, exp.Convert)):
1545
+ out_name = "cast_expr"
1546
+ elif isinstance(proj, exp.Column):
1547
+ out_name = proj.name
1548
+ else:
1549
+ out_name = "calc_expr"
1550
+ inner = proj
1551
+
1552
+ # Collect input fields using enhanced resolution
1553
+ inputs = self._collect_inputs_for_expr(inner, alias_map, derived_cols)
665
1554
 
666
- # Determine data type for ColumnSchema
667
- data_type = "unknown"
668
- if isinstance(source_expr, exp.Cast):
669
- data_type = str(source_expr.to).upper()
1555
+ # Infer output type using enhanced type system
1556
+ out_type = self._infer_type(inner, alias_map)
670
1557
 
671
- # Create output column schema
1558
+ # Determine transformation type
1559
+ if isinstance(inner, (exp.Cast, exp.Convert)):
1560
+ ttype = TransformationType.CAST
1561
+ elif isinstance(inner, exp.Case):
1562
+ ttype = TransformationType.CASE
1563
+ elif isinstance(inner, exp.Column):
1564
+ ttype = TransformationType.IDENTITY
1565
+ else:
1566
+ # IIF(...) bywa mapowane przez sqlglot do CASE; na wszelki wypadek:
1567
+ s = str(inner).upper()
1568
+ if s.startswith("CASE ") or s.startswith("CASEWHEN ") or s.startswith("IIF("):
1569
+ ttype = TransformationType.CASE
1570
+ else:
1571
+ ttype = TransformationType.EXPRESSION
1572
+
1573
+
1574
+ # Create lineage and schema entries
1575
+ lineage.append(ColumnLineage(
1576
+ output_column=out_name,
1577
+ input_fields=inputs,
1578
+ transformation_type=ttype,
1579
+ transformation_description=self._short_desc(inner),
1580
+ ))
672
1581
  output_columns.append(ColumnSchema(
673
- name=output_name,
674
- data_type=data_type,
675
- nullable=True,
676
- ordinal=i
1582
+ name=out_name,
1583
+ data_type=out_type,
1584
+ nullable=True,
1585
+ ordinal=ordinal
677
1586
  ))
678
-
679
- # Extract lineage for this column
680
- col_lineage = self._analyze_expression_lineage(
681
- output_name, source_expr, select_stmt
682
- )
683
- lineage.append(col_lineage)
1587
+ ordinal += 1
684
1588
 
685
1589
  return lineage, output_columns
686
1590
 
@@ -993,18 +1897,36 @@ class SqlParser:
993
1897
  return alias # Fallback to alias as table name
994
1898
 
995
1899
  def _process_ctes(self, select_stmt: exp.Select) -> exp.Select:
996
- """Process Common Table Expressions and return the main SELECT."""
997
- # For now, we'll handle CTEs by treating them as additional dependencies
998
- # The main SELECT statement is typically the last one in the CTE chain
999
-
1900
+ """Process Common Table Expressions and register them properly."""
1000
1901
  with_clause = select_stmt.args.get('with')
1001
1902
  if with_clause and hasattr(with_clause, 'expressions'):
1002
- # Register CTE tables for alias resolution
1903
+ # Register CTE tables and their columns
1003
1904
  for cte in with_clause.expressions:
1004
1905
  if hasattr(cte, 'alias') and hasattr(cte, 'this'):
1005
1906
  cte_name = str(cte.alias)
1006
- # For dependency tracking, we could analyze the CTE definition
1007
- # but for now we'll just note it exists
1907
+
1908
+ # Extract columns from CTE definition
1909
+ cte_columns = []
1910
+ if isinstance(cte.this, exp.Select):
1911
+ # Get column names from SELECT projections
1912
+ for proj in cte.this.expressions:
1913
+ if isinstance(proj, exp.Alias):
1914
+ cte_columns.append(str(proj.alias))
1915
+ elif isinstance(proj, exp.Column):
1916
+ cte_columns.append(str(proj.this))
1917
+ elif isinstance(proj, exp.Star):
1918
+ # For star, try to infer from source tables
1919
+ source_deps = self._extract_dependencies(cte.this)
1920
+ for source_table in source_deps:
1921
+ source_cols = self._infer_table_columns(source_table)
1922
+ cte_columns.extend(source_cols)
1923
+ break
1924
+ else:
1925
+ # Generic expression - use ordinal
1926
+ cte_columns.append(f"col_{len(cte_columns) + 1}")
1927
+
1928
+ # Register CTE in registry
1929
+ self.cte_registry[cte_name] = cte_columns
1008
1930
 
1009
1931
  return select_stmt
1010
1932
 
@@ -1020,6 +1942,10 @@ class SqlParser:
1020
1942
  for expr in select_stmt.expressions:
1021
1943
  if isinstance(expr, exp.Star):
1022
1944
  return True
1945
+ # Also check for Column expressions that represent qualified stars like "o.*"
1946
+ if isinstance(expr, exp.Column):
1947
+ if str(expr.this) == "*" or str(expr).endswith(".*"):
1948
+ return True
1023
1949
  return False
1024
1950
 
1025
1951
  def _has_union(self, stmt: exp.Expression) -> bool:
@@ -1027,12 +1953,13 @@ class SqlParser:
1027
1953
  return isinstance(stmt, exp.Union) or len(list(stmt.find_all(exp.Union))) > 0
1028
1954
 
1029
1955
  def _handle_star_expansion(self, select_stmt: exp.Select, view_name: str) -> tuple[List[ColumnLineage], List[ColumnSchema]]:
1030
- """Handle SELECT * expansion by inferring columns from source tables."""
1956
+ """Handle SELECT * expansion by inferring columns from source tables using unified registry approach."""
1031
1957
  lineage = []
1032
1958
  output_columns = []
1033
1959
 
1034
1960
  # Process all SELECT expressions, including both stars and explicit columns
1035
1961
  ordinal = 0
1962
+ seen_columns = set() # Track column names to avoid duplicates
1036
1963
 
1037
1964
  for select_expr in select_stmt.expressions:
1038
1965
  if isinstance(select_expr, exp.Star):
@@ -1041,29 +1968,32 @@ class SqlParser:
1041
1968
  alias = str(select_expr.table)
1042
1969
  table_name = self._resolve_table_from_alias(alias, select_stmt)
1043
1970
  if table_name != "unknown":
1044
- columns = self._infer_table_columns(table_name)
1971
+ columns = self._infer_table_columns_unified(table_name)
1045
1972
 
1046
1973
  for column_name in columns:
1047
- output_columns.append(ColumnSchema(
1048
- name=column_name,
1049
- data_type="unknown",
1050
- nullable=True,
1051
- ordinal=ordinal
1052
- ))
1053
- ordinal += 1
1054
-
1055
- lineage.append(ColumnLineage(
1056
- output_column=column_name,
1057
- input_fields=[ColumnReference(
1058
- namespace="mssql://localhost/InfoTrackerDW",
1059
- table_name=table_name,
1060
- column_name=column_name
1061
- )],
1062
- transformation_type=TransformationType.IDENTITY,
1063
- transformation_description=f"SELECT {alias}.{column_name}"
1064
- ))
1974
+ # Avoid duplicate column names
1975
+ if column_name not in seen_columns:
1976
+ seen_columns.add(column_name)
1977
+ output_columns.append(ColumnSchema(
1978
+ name=column_name,
1979
+ data_type="unknown",
1980
+ nullable=True,
1981
+ ordinal=ordinal
1982
+ ))
1983
+ ordinal += 1
1984
+
1985
+ lineage.append(ColumnLineage(
1986
+ output_column=column_name,
1987
+ input_fields=[ColumnReference(
1988
+ namespace=self._get_namespace_for_table(table_name),
1989
+ table_name=table_name,
1990
+ column_name=column_name
1991
+ )],
1992
+ transformation_type=TransformationType.IDENTITY,
1993
+ transformation_description=f"{alias}.*"
1994
+ ))
1065
1995
  else:
1066
- # Handle unqualified * - expand all tables
1996
+ # Handle unqualified * - expand all tables in stable order
1067
1997
  source_tables = []
1068
1998
  for table in select_stmt.find_all(exp.Table):
1069
1999
  table_name = self._get_table_name(table)
@@ -1071,27 +2001,59 @@ class SqlParser:
1071
2001
  source_tables.append(table_name)
1072
2002
 
1073
2003
  for table_name in source_tables:
1074
- columns = self._infer_table_columns(table_name)
2004
+ columns = self._infer_table_columns_unified(table_name)
1075
2005
 
1076
2006
  for column_name in columns:
1077
- output_columns.append(ColumnSchema(
1078
- name=column_name,
1079
- data_type="unknown",
1080
- nullable=True,
1081
- ordinal=ordinal
1082
- ))
1083
- ordinal += 1
1084
-
1085
- lineage.append(ColumnLineage(
1086
- output_column=column_name,
1087
- input_fields=[ColumnReference(
1088
- namespace="mssql://localhost/InfoTrackerDW",
1089
- table_name=table_name,
1090
- column_name=column_name
1091
- )],
1092
- transformation_type=TransformationType.IDENTITY,
1093
- transformation_description=f"SELECT * (from {table_name})"
1094
- ))
2007
+ # Avoid duplicate column names across tables
2008
+ if column_name not in seen_columns:
2009
+ seen_columns.add(column_name)
2010
+ output_columns.append(ColumnSchema(
2011
+ name=column_name,
2012
+ data_type="unknown",
2013
+ nullable=True,
2014
+ ordinal=ordinal
2015
+ ))
2016
+ ordinal += 1
2017
+
2018
+ lineage.append(ColumnLineage(
2019
+ output_column=column_name,
2020
+ input_fields=[ColumnReference(
2021
+ namespace=self._get_namespace_for_table(table_name),
2022
+ table_name=table_name,
2023
+ column_name=column_name
2024
+ )],
2025
+ transformation_type=TransformationType.IDENTITY,
2026
+ transformation_description="SELECT *"
2027
+ ))
2028
+ elif isinstance(select_expr, exp.Column) and (str(select_expr.this) == "*" or str(select_expr).endswith(".*")):
2029
+ # Handle qualified stars like "o.*" that are parsed as Column objects
2030
+ if hasattr(select_expr, 'table') and select_expr.table:
2031
+ alias = str(select_expr.table)
2032
+ table_name = self._resolve_table_from_alias(alias, select_stmt)
2033
+ if table_name != "unknown":
2034
+ columns = self._infer_table_columns_unified(table_name)
2035
+
2036
+ for column_name in columns:
2037
+ if column_name not in seen_columns:
2038
+ seen_columns.add(column_name)
2039
+ output_columns.append(ColumnSchema(
2040
+ name=column_name,
2041
+ data_type="unknown",
2042
+ nullable=True,
2043
+ ordinal=ordinal
2044
+ ))
2045
+ ordinal += 1
2046
+
2047
+ lineage.append(ColumnLineage(
2048
+ output_column=column_name,
2049
+ input_fields=[ColumnReference(
2050
+ namespace=self._get_namespace_for_table(table_name),
2051
+ table_name=table_name,
2052
+ column_name=column_name
2053
+ )],
2054
+ transformation_type=TransformationType.IDENTITY,
2055
+ transformation_description=f"{alias}.*"
2056
+ ))
1095
2057
  else:
1096
2058
  # Handle explicit column expressions (like "1 as extra_col")
1097
2059
  col_name = self._extract_column_alias(select_expr) or f"col_{ordinal}"
@@ -1172,34 +2134,37 @@ class SqlParser:
1172
2134
  return lineage, output_columns
1173
2135
 
1174
2136
  def _infer_table_columns(self, table_name: str) -> List[str]:
1175
- """Infer table columns from schema registry or fallback to patterns."""
1176
- # First try to get from schema registry
1177
- # Try different namespace combinations
1178
- namespaces_to_try = [
1179
- "mssql://localhost/InfoTrackerDW",
1180
- "dbo",
1181
- "",
1182
- ]
1183
-
1184
- for namespace in namespaces_to_try:
1185
- schema = self.schema_registry.get(namespace, table_name)
1186
- if schema:
1187
- return [col.name for col in schema.columns]
1188
-
1189
- # Fallback to patterns if not in registry
1190
- table_simple = table_name.split('.')[-1].lower()
1191
-
1192
- if 'orders' in table_simple:
1193
- return ['OrderID', 'CustomerID', 'OrderDate', 'OrderStatus']
1194
- elif 'customers' in table_simple:
1195
- return ['CustomerID', 'CustomerName', 'CustomerEmail', 'CustomerPhone']
1196
- elif 'products' in table_simple:
1197
- return ['ProductID', 'ProductName', 'ProductPrice', 'ProductCategory']
1198
- elif 'order_items' in table_simple:
1199
- return ['OrderItemID', 'OrderID', 'ProductID', 'Quantity', 'UnitPrice', 'ExtendedPrice']
2137
+ """Infer table columns using registry-based approach."""
2138
+ return self._infer_table_columns_unified(table_name)
2139
+
2140
+ def _infer_table_columns_unified(self, table_name: str) -> List[str]:
2141
+ """Unified column lookup using registry chain: temp -> cte -> schema -> fallback."""
2142
+ # Clean table name for registry lookup
2143
+ simple_name = table_name.split('.')[-1]
2144
+
2145
+ # 1. Check temp_registry first
2146
+ if simple_name in self.temp_registry:
2147
+ return self.temp_registry[simple_name]
2148
+
2149
+ # 2. Check cte_registry
2150
+ if simple_name in self.cte_registry:
2151
+ return self.cte_registry[simple_name]
2152
+
2153
+ # 3. Check schema_registry
2154
+ namespace = self._get_namespace_for_table(table_name)
2155
+ table_schema = self.schema_registry.get(namespace, table_name)
2156
+ if table_schema and table_schema.columns:
2157
+ return [col.name for col in table_schema.columns]
2158
+
2159
+ # 4. Fallback to deterministic unknown columns (no hardcoding)
2160
+ return [f"unknown_{i+1}" for i in range(3)] # Generate unknown_1, unknown_2, unknown_3
2161
+
2162
+ def _get_namespace_for_table(self, table_name: str) -> str:
2163
+ """Get appropriate namespace for a table based on its name."""
2164
+ if table_name.startswith('tempdb..#'):
2165
+ return "mssql://localhost/tempdb"
1200
2166
  else:
1201
- # Generic fallback
1202
- return ['Column1', 'Column2', 'Column3']
2167
+ return "mssql://localhost/InfoTrackerDW"
1203
2168
 
1204
2169
  def _parse_function_string(self, sql_content: str, object_hint: Optional[str] = None) -> ObjectInfo:
1205
2170
  """Parse CREATE FUNCTION using string-based approach."""
@@ -1243,29 +2208,147 @@ class SqlParser:
1243
2208
 
1244
2209
  def _parse_procedure_string(self, sql_content: str, object_hint: Optional[str] = None) -> ObjectInfo:
1245
2210
  """Parse CREATE PROCEDURE using string-based approach."""
2211
+ # Znormalizuj nagłówki SET/GO, COLLATE itd.
2212
+ sql_content = self._normalize_tsql(sql_content)
1246
2213
  procedure_name = self._extract_procedure_name(sql_content) or object_hint or "unknown_procedure"
1247
2214
  namespace = "mssql://localhost/InfoTrackerDW"
1248
-
1249
- # Extract the procedure body and find the last SELECT statement
2215
+
2216
+ # 1) Najpierw sprawdź, czy SP materializuje (SELECT INTO / INSERT INTO ... SELECT)
2217
+ materialized_output = self._extract_materialized_output_from_procedure_string(sql_content)
2218
+ if materialized_output:
2219
+ # 1a) Wyciągnij zależności ze SELECT-a po INSERT (żeby inputs nie były puste)
2220
+ # Najpierw spróbuj mocniej (parsuj SELECT po INSERT), potem fallback prosty.
2221
+ try:
2222
+ lineage_sel, _, deps_sel = self._extract_procedure_lineage_string(sql_content, procedure_name)
2223
+ if deps_sel:
2224
+ materialized_output.dependencies = set(deps_sel)
2225
+ except Exception:
2226
+ basic_deps = self._extract_basic_dependencies(sql_content)
2227
+ if basic_deps:
2228
+ materialized_output.dependencies = set(basic_deps)
2229
+
2230
+ # 1b) BACKFILL schematu z rejestru (obsłuż warianty nazw z/bez prefiksu DB)
2231
+ ns = materialized_output.schema.namespace
2232
+ name_key = materialized_output.schema.name # np. "dbo.LeadPartner_ref" albo "InfoTrackerDW.dbo.LeadPartner_ref"
2233
+ known = None
2234
+ if hasattr(self.schema_registry, "get"):
2235
+ # spróbuj 1: jak jest
2236
+ known = self.schema_registry.get(ns, name_key)
2237
+ # spróbuj 2: dołóż prefiks DB jeśli brakuje
2238
+ if not known:
2239
+ db = (self.current_database or self.default_database or "InfoTrackerDW")
2240
+ parts = name_key.split(".")
2241
+ if len(parts) == 2: # schema.table -> spróbuj DB.schema.table
2242
+ name_with_db = f"{db}.{name_key}"
2243
+ known = self.schema_registry.get(ns, name_with_db)
2244
+ else:
2245
+ known = self.schema_registry.get((ns, name_key))
2246
+
2247
+ if known and getattr(known, "columns", None):
2248
+ materialized_output.schema = known
2249
+ else:
2250
+ # 1c) Fallback: kolumny z listy INSERT INTO (…)
2251
+ cols = self._extract_insert_into_columns(sql_content)
2252
+ if cols:
2253
+ materialized_output.schema = TableSchema(
2254
+ namespace=ns,
2255
+ name=name_key,
2256
+ columns=[ColumnSchema(name=c, data_type="unknown", nullable=True, ordinal=i)
2257
+ for i, c in enumerate(cols)]
2258
+ )
2259
+
2260
+ return materialized_output
2261
+
2262
+ # 2) Jeśli nie materializuje — standard: ostatni SELECT jako „wirtualny” dataset procedury
1250
2263
  lineage, output_columns, dependencies = self._extract_procedure_lineage_string(sql_content, procedure_name)
1251
-
2264
+
1252
2265
  schema = TableSchema(
1253
2266
  namespace=namespace,
1254
2267
  name=procedure_name,
1255
2268
  columns=output_columns
1256
2269
  )
1257
-
1258
- # Register schema for future reference
2270
+
1259
2271
  self.schema_registry.register(schema)
1260
-
1261
- return ObjectInfo(
2272
+
2273
+ obj = ObjectInfo(
1262
2274
  name=procedure_name,
1263
2275
  object_type="procedure",
1264
2276
  schema=schema,
1265
2277
  lineage=lineage,
1266
2278
  dependencies=dependencies
1267
2279
  )
2280
+ obj.no_output_reason = "ONLY_PROCEDURE_RESULTSET"
2281
+ return obj
2282
+
2283
+
2284
+ def _extract_materialized_output_from_procedure_string(self, sql_content: str) -> Optional[ObjectInfo]:
2285
+ """
2286
+ Extract materialized output (SELECT INTO, INSERT INTO) from a procedure body.
2287
+ - Zwraca ObjectInfo typu "table" z pełną nazwą DB.schema.table i poprawnym namespace.
2288
+ - Nie używa _normalize_table_name_for_output (nie gubimy DB).
2289
+ """
2290
+ import re
2291
+ from .models import ObjectInfo, TableSchema # lokalny import dla pewności
2292
+
2293
+ # 1) Normalizacja i usunięcie komentarzy (żeby regexy nie łapały śmieci)
2294
+ s = self._normalize_tsql(sql_content)
2295
+ s = re.sub(r'/\*.*?\*/', '', s, flags=re.S) # block comments
2296
+ lines = s.splitlines()
2297
+ s = "\n".join(line for line in lines if not line.lstrip().startswith('--'))
2298
+
2299
+ # Helper: z tokena tabeli zbuduj pełną nazwę i namespace
2300
+ def _to_obj(table_token: str) -> Optional[ObjectInfo]:
2301
+ tok = (table_token or "").strip().rstrip(';')
2302
+ # temp tables out
2303
+ if tok.startswith('#') or tok.lower().startswith('tempdb..#'):
2304
+ return None
2305
+ # 1) znormalizuj identyfikator (zdejmij []/"")
2306
+ norm = self._normalize_table_ident(tok) # np. EDW_CORE.dbo.LeadPartner_ref
2307
+ # 2) pełna nazwa z DB (jeśli brak, dołóż current/default)
2308
+ full_name = self._get_full_table_name(norm) # -> DB.schema.table
2309
+ # 3) namespace z DB
2310
+ try:
2311
+ db, sch, tbl = self._split_fqn(full_name) # -> (DB, schema, table)
2312
+ except Exception:
2313
+ # awaryjnie: spróbuj rozbić ręcznie
2314
+ parts = full_name.split('.')
2315
+ if len(parts) == 3:
2316
+ db, sch, tbl = parts
2317
+ elif len(parts) == 2:
2318
+ db = (self.current_database or self.default_database or "InfoTrackerDW")
2319
+ sch, tbl = parts
2320
+ full_name = f"{db}.{sch}.{tbl}"
2321
+ else:
2322
+ db = (self.current_database or self.default_database or "InfoTrackerDW")
2323
+ sch = "dbo"
2324
+ tbl = parts[0]
2325
+ full_name = f"{db}.{sch}.{tbl}"
2326
+ ns = f"mssql://localhost/{db or (self.current_database or self.default_database or 'InfoTrackerDW')}"
2327
+
2328
+ return ObjectInfo(
2329
+ name=full_name,
2330
+ object_type="table",
2331
+ schema=TableSchema(namespace=ns, name=full_name, columns=[]),
2332
+ lineage=[],
2333
+ dependencies=set()
2334
+ )
1268
2335
 
2336
+ # 2) SELECT ... INTO <table>
2337
+ # (łapiemy pierwszy „persistent” match)
2338
+ for m in re.finditer(r'(?is)\bSELECT\s+.*?\bINTO\s+([^\s,()\r\n;]+)', s):
2339
+ obj = _to_obj(m.group(1))
2340
+ if obj:
2341
+ return obj
2342
+
2343
+ # 3) INSERT INTO <table> [ (cols...) ] SELECT ...
2344
+ for m in re.finditer(r'(?is)\bINSERT\s+INTO\s+([^\s,()\r\n;]+)', s):
2345
+ obj = _to_obj(m.group(1))
2346
+ if obj:
2347
+ return obj
2348
+
2349
+ return None
2350
+
2351
+
1269
2352
  def _extract_function_name(self, sql_content: str) -> Optional[str]:
1270
2353
  """Extract function name from CREATE FUNCTION statement."""
1271
2354
  match = re.search(r'CREATE\s+(?:OR\s+ALTER\s+)?FUNCTION\s+([^\s\(]+)', sql_content, re.IGNORECASE)
@@ -1315,8 +2398,22 @@ class SqlParser:
1315
2398
  lineage = []
1316
2399
  output_columns = []
1317
2400
  dependencies = set()
1318
-
1319
- # Find the last SELECT statement in the procedure body
2401
+ m = re.search(r'(?is)INSERT\s+INTO\s+[^\s(]+(?:\s*\([^)]*\))?\s+SELECT\b(.*)$', sql_content)
2402
+ if m:
2403
+ select_sql = "SELECT " + m.group(1)
2404
+ try:
2405
+ parsed = sqlglot.parse(select_sql, read=self.dialect)
2406
+ if parsed and isinstance(parsed[0], exp.Select):
2407
+ lineage, output_columns = self._extract_column_lineage(parsed[0], procedure_name)
2408
+ deps = self._extract_dependencies(parsed[0])
2409
+ dependencies.update(deps)
2410
+ except Exception:
2411
+ # Fallback: chociaż dependencies ze string-parsera
2412
+ dependencies.update(self._extract_basic_dependencies(select_sql))
2413
+
2414
+
2415
+ # For procedures, extract dependencies from all SQL statements in the procedure body
2416
+ # First try to find the last SELECT statement for lineage
1320
2417
  last_select_sql = self._find_last_select_string(sql_content)
1321
2418
  if last_select_sql:
1322
2419
  try:
@@ -1325,23 +2422,129 @@ class SqlParser:
1325
2422
  lineage, output_columns = self._extract_column_lineage(parsed[0], procedure_name)
1326
2423
  dependencies = self._extract_dependencies(parsed[0])
1327
2424
  except Exception:
1328
- # Fallback to basic analysis
2425
+ # Fallback to basic analysis with string-based lineage
1329
2426
  output_columns = self._extract_basic_select_columns(last_select_sql)
2427
+ lineage = self._extract_basic_lineage_from_select(last_select_sql, output_columns, procedure_name)
1330
2428
  dependencies = self._extract_basic_dependencies(last_select_sql)
1331
2429
 
2430
+ # Additionally, extract dependencies from the entire procedure body
2431
+ # This catches tables used in SELECT INTO, JOIN, etc.
2432
+ procedure_dependencies = self._extract_basic_dependencies(sql_content)
2433
+ dependencies.update(procedure_dependencies)
2434
+
1332
2435
  return lineage, output_columns, dependencies
1333
2436
 
2437
+ def _extract_insert_into_columns(self, sql_content: str) -> list[str]:
2438
+ m = re.search(r'(?is)INSERT\s+INTO\s+[^\s(]+\s*\((.*?)\)', sql_content)
2439
+ if not m:
2440
+ return []
2441
+ inner = m.group(1)
2442
+ cols = []
2443
+ for raw in inner.split(','):
2444
+ col = raw.strip()
2445
+ # zbij aliasy i nawiasy, zostaw samą nazwę
2446
+ col = col.split('.')[-1]
2447
+ col = re.sub(r'[^\w]', '', col)
2448
+ if col:
2449
+ cols.append(col)
2450
+ return cols
2451
+
2452
+
2453
+
2454
+ def _extract_first_create_statement(self, sql_content: str, statement_type: str) -> str:
2455
+ """Extract the first CREATE statement of the specified type."""
2456
+ patterns = {
2457
+ 'FUNCTION': [
2458
+ r'CREATE\s+(?:OR\s+ALTER\s+)?FUNCTION\s+.*?(?=CREATE\s+(?:OR\s+ALTER\s+)?(?:FUNCTION|PROCEDURE)|$)',
2459
+ r'CREATE\s+FUNCTION\s+.*?(?=CREATE\s+(?:FUNCTION|PROCEDURE)|$)'
2460
+ ],
2461
+ 'PROCEDURE': [
2462
+ r'CREATE\s+(?:OR\s+ALTER\s+)?PROCEDURE\s+.*?(?=CREATE\s+(?:OR\s+ALTER\s+)?(?:FUNCTION|PROCEDURE)|$)',
2463
+ r'CREATE\s+PROCEDURE\s+.*?(?=CREATE\s+(?:FUNCTION|PROCEDURE)|$)'
2464
+ ]
2465
+ }
2466
+
2467
+ if statement_type not in patterns:
2468
+ return ""
2469
+
2470
+ for pattern in patterns[statement_type]:
2471
+ match = re.search(pattern, sql_content, re.DOTALL | re.IGNORECASE)
2472
+ if match:
2473
+ return match.group(0).strip()
2474
+
2475
+ return ""
2476
+
2477
+ def _extract_tvf_lineage_string(self, sql_text: str, function_name: str) -> tuple[List[ColumnLineage], List[ColumnSchema], Set[str]]:
2478
+ """Extract TVF lineage using string-based approach as fallback."""
2479
+ lineage = []
2480
+ output_columns = []
2481
+ dependencies = set()
2482
+
2483
+ # Extract SELECT statement from RETURN clause using string patterns
2484
+ select_string = self._extract_select_from_return_string(sql_text)
2485
+
2486
+ if select_string:
2487
+ try:
2488
+ # Parse the extracted SELECT statement
2489
+ statements = sqlglot.parse(select_string, dialect=sqlglot.dialects.TSQL)
2490
+ if statements:
2491
+ select_stmt = statements[0]
2492
+
2493
+ # Process CTEs first
2494
+ self._process_ctes(select_stmt)
2495
+
2496
+ # Extract lineage and expand dependencies
2497
+ lineage, output_columns = self._extract_column_lineage(select_stmt, function_name)
2498
+ raw_deps = self._extract_dependencies(select_stmt)
2499
+
2500
+ # Expand CTEs and temp tables to base tables
2501
+ for dep in raw_deps:
2502
+ expanded_deps = self._expand_dependency_to_base_tables(dep, select_stmt)
2503
+ dependencies.update(expanded_deps)
2504
+ except Exception:
2505
+ # If parsing fails, try basic string extraction
2506
+ basic_deps = self._extract_basic_dependencies(sql_text)
2507
+ dependencies.update(basic_deps)
2508
+
2509
+ return lineage, output_columns, dependencies
2510
+
1334
2511
  def _extract_select_from_return_string(self, sql_content: str) -> Optional[str]:
1335
- """Extract SELECT statement from RETURN clause using regex."""
1336
- # Handle RETURN (SELECT ...)
1337
- match = re.search(r'RETURN\s*\(\s*(SELECT.*?)\s*\)(?:\s*;)?$', sql_content, re.IGNORECASE | re.DOTALL)
1338
- if match:
1339
- return match.group(1).strip()
2512
+ """Extract SELECT statement from RETURN clause using enhanced regex."""
2513
+ # Remove comments first
2514
+ cleaned_sql = re.sub(r'--.*?(?=\n|$)', '', sql_content, flags=re.MULTILINE)
2515
+ cleaned_sql = re.sub(r'/\*.*?\*/', '', cleaned_sql, flags=re.DOTALL)
2516
+
2517
+ # Updated patterns for different RETURN formats with better handling
2518
+ patterns = [
2519
+ # RETURNS TABLE AS RETURN (SELECT
2520
+ r'RETURNS\s+TABLE\s+AS\s+RETURN\s*\(\s*(SELECT.*?)(?=\)[\s;]*(?:END|$))',
2521
+ # RETURNS TABLE RETURN (SELECT
2522
+ r'RETURNS\s+TABLE\s+RETURN\s*\(\s*(SELECT.*?)(?=\)[\s;]*(?:END|$))',
2523
+ # RETURNS TABLE RETURN SELECT
2524
+ r'RETURNS\s+TABLE\s+RETURN\s+(SELECT.*?)(?=[\s;]*(?:END|$))',
2525
+ # RETURN AS \n (\n SELECT
2526
+ r'RETURN\s+AS\s*\n\s*\(\s*(SELECT.*?)(?=\)[\s;]*(?:END|$))',
2527
+ # RETURN \n ( \n SELECT
2528
+ r'RETURN\s*\n\s*\(\s*(SELECT.*?)(?=\)[\s;]*(?:END|$))',
2529
+ # RETURN AS ( SELECT
2530
+ r'RETURN\s+AS\s*\(\s*(SELECT.*?)(?=\)[\s;]*(?:END|$))',
2531
+ # RETURN ( SELECT
2532
+ r'RETURN\s*\(\s*(SELECT.*?)(?=\)[\s;]*(?:END|$))',
2533
+ # AS \n RETURN \n ( \n SELECT
2534
+ r'AS\s*\n\s*RETURN\s*\n\s*\(\s*(SELECT.*?)(?=\)[\s;]*(?:END|$))',
2535
+ # RETURN SELECT (simple case)
2536
+ r'RETURN\s+(SELECT.*?)(?=[\s;]*(?:END|$))',
2537
+ # Fallback - original pattern with end of string
2538
+ r'RETURN\s*\(\s*(SELECT.*?)\s*\)(?:\s*;)?$'
2539
+ ]
1340
2540
 
1341
- # Handle RETURN AS (SELECT ...)
1342
- match = re.search(r'RETURN\s+AS\s*\(\s*(SELECT.*?)\s*\)', sql_content, re.IGNORECASE | re.DOTALL)
1343
- if match:
1344
- return match.group(1).strip()
2541
+ for pattern in patterns:
2542
+ match = re.search(pattern, cleaned_sql, re.DOTALL | re.IGNORECASE)
2543
+ if match:
2544
+ select_statement = match.group(1).strip()
2545
+ # Check if it looks like a valid SELECT statement
2546
+ if select_statement.upper().strip().startswith('SELECT'):
2547
+ return select_statement
1345
2548
 
1346
2549
  return None
1347
2550
 
@@ -1403,21 +2606,254 @@ class SqlParser:
1403
2606
  ))
1404
2607
 
1405
2608
  return output_columns
2609
+
2610
+ def _extract_basic_lineage_from_select(self, select_sql: str, output_columns: List[ColumnSchema], object_name: str) -> List[ColumnLineage]:
2611
+ """Extract basic lineage information from SELECT statement using string parsing."""
2612
+ lineage = []
2613
+
2614
+ try:
2615
+ # Extract table aliases from FROM and JOIN clauses
2616
+ table_aliases = self._extract_table_aliases_from_select(select_sql)
2617
+
2618
+ # Parse the SELECT list to match columns with their sources
2619
+ select_match = re.search(r'SELECT\s+(.*?)\s+FROM', select_sql, re.IGNORECASE | re.DOTALL)
2620
+ if not select_match:
2621
+ return lineage
2622
+
2623
+ select_list = select_match.group(1)
2624
+ column_expressions = [col.strip() for col in select_list.split(',')]
2625
+
2626
+ for i, (output_col, col_expr) in enumerate(zip(output_columns, column_expressions)):
2627
+ # Try to find source table and column
2628
+ source_table, source_column, transformation_type = self._parse_column_expression(col_expr, table_aliases)
2629
+
2630
+ if source_table and source_column:
2631
+ lineage.append(ColumnLineage(
2632
+ column_name=output_col.name,
2633
+ table_name=object_name,
2634
+ source_column=source_column,
2635
+ source_table=source_table,
2636
+ transformation_type=transformation_type,
2637
+ transformation_description=f"Column derived from {source_table}.{source_column}"
2638
+ ))
2639
+
2640
+ except Exception as e:
2641
+ logger.debug(f"Basic lineage extraction failed: {e}")
2642
+
2643
+ return lineage
1406
2644
 
2645
+ def _extract_table_aliases_from_select(self, select_sql: str) -> Dict[str, str]:
2646
+ """Extract table aliases from FROM and JOIN clauses."""
2647
+ aliases = {}
2648
+
2649
+ # Find FROM clause and all JOIN clauses
2650
+ from_join_pattern = r'(?i)\b(?:FROM|JOIN)\s+([^\s]+)(?:\s+AS\s+)?(\w+)?'
2651
+ matches = re.findall(from_join_pattern, select_sql)
2652
+
2653
+ for table_name, alias in matches:
2654
+ clean_table = table_name.strip()
2655
+ clean_alias = alias.strip() if alias else None
2656
+
2657
+ if clean_alias:
2658
+ aliases[clean_alias] = clean_table
2659
+ else:
2660
+ # If no alias, use the table name itself
2661
+ table_short = clean_table.split('.')[-1] # Get last part after dots
2662
+ aliases[table_short] = clean_table
2663
+
2664
+ return aliases
2665
+
2666
+ def _parse_column_expression(self, col_expr: str, table_aliases: Dict[str, str]) -> tuple[str, str, TransformationType]:
2667
+ """Parse a column expression to find source table, column, and transformation type."""
2668
+ col_expr = col_expr.strip()
2669
+
2670
+ # Handle aliases - remove the alias part for analysis
2671
+ if ' AS ' in col_expr.upper():
2672
+ col_expr = col_expr.split(' AS ')[0].strip()
2673
+ elif ' ' in col_expr and not any(func in col_expr.upper() for func in ['SUM', 'COUNT', 'MAX', 'MIN', 'AVG', 'CAST', 'CASE']):
2674
+ # Implicit alias - take everything except the last word
2675
+ parts = col_expr.split()
2676
+ if len(parts) > 1:
2677
+ col_expr = ' '.join(parts[:-1]).strip()
2678
+
2679
+ # Determine transformation type and extract source
2680
+ if any(func in col_expr.upper() for func in ['SUM(', 'COUNT(', 'MAX(', 'MIN(', 'AVG(']):
2681
+ transformation_type = TransformationType.AGGREGATION
2682
+ elif 'CASE' in col_expr.upper():
2683
+ transformation_type = TransformationType.CONDITIONAL
2684
+ elif any(op in col_expr for op in ['+', '-', '*', '/']):
2685
+ transformation_type = TransformationType.ARITHMETIC
2686
+ else:
2687
+ transformation_type = TransformationType.IDENTITY
2688
+
2689
+ # Extract the main column reference (e.g., "c.CustomerID" from "c.CustomerID")
2690
+ col_match = re.search(r'(\w+)\.(\w+)', col_expr)
2691
+ if col_match:
2692
+ alias = col_match.group(1)
2693
+ column = col_match.group(2)
2694
+
2695
+ if alias in table_aliases:
2696
+ table_name = table_aliases[alias]
2697
+ # Normalize table name
2698
+ if not table_name.startswith('dbo.') and '.' not in table_name:
2699
+ table_name = f"dbo.{table_name}"
2700
+ return table_name, column, transformation_type
2701
+
2702
+ # If no table alias found, try to extract just column name
2703
+ simple_col_match = re.search(r'\b(\w+)\b', col_expr)
2704
+ if simple_col_match:
2705
+ column = simple_col_match.group(1)
2706
+ # Return unknown table
2707
+ return "unknown_table", column, transformation_type
2708
+
2709
+ return None, None, transformation_type
2710
+
1407
2711
  def _extract_basic_dependencies(self, sql_content: str) -> Set[str]:
1408
2712
  """Basic extraction of table dependencies from SQL."""
1409
2713
  dependencies = set()
1410
2714
 
1411
- # Find FROM and JOIN clauses
1412
- from_matches = re.findall(r'FROM\s+([^\s\(]+)', sql_content, re.IGNORECASE)
1413
- join_matches = re.findall(r'JOIN\s+([^\s\(]+)', sql_content, re.IGNORECASE)
1414
-
1415
- for match in from_matches + join_matches:
2715
+ # Remove comments to avoid false matches
2716
+ cleaned_sql = re.sub(r'--.*?(?=\n|$)', '', sql_content, flags=re.MULTILINE)
2717
+ cleaned_sql = re.sub(r'/\*.*?\*/', '', cleaned_sql, flags=re.DOTALL)
2718
+
2719
+ # Find FROM and JOIN clauses with better patterns
2720
+ # Match schema.table.name or table patterns
2721
+ from_pattern = r'FROM\s+([^\s\(\),]+(?:\.[^\s\(\),]+)*)'
2722
+ join_pattern = r'JOIN\s+([^\s\(\),]+(?:\.[^\s\(\),]+)*)'
2723
+ update_pattern = r'UPDATE\s+([^\s\(\),]+(?:\.[^\s\(\),]+)*)'
2724
+ delete_from_pattern = r'DELETE\s+FROM\s+([^\s\(\),]+(?:\.[^\s\(\),]+)*)'
2725
+ merge_into_pattern = r'MERGE\s+INTO\s+([^\s\(\),]+(?:\.[^\s\(\),]+)*)'
2726
+
2727
+
2728
+ sql_keywords = {
2729
+ 'select','from','join','on','where','group','having','order','into',
2730
+ 'update','delete','merge','as','and','or','not','case','when','then','else',
2731
+ 'distinct','top','with','nolock','commit','rollback','transaction','begin','try','catch','exists'
2732
+ }
2733
+ builtin_functions = {
2734
+ 'getdate','sysdatetime','xact_state','row_number','count','sum','min','max','avg',
2735
+ 'cast','convert','try_convert','coalesce','isnull','iif','len','substring','replace',
2736
+ 'upper','lower','ltrim','rtrim','trim','dateadd','datediff','format','hashbytes','md5'
2737
+ }
2738
+ sql_types = {
2739
+ 'varchar','nvarchar','char','nchar','text','ntext',
2740
+ 'int','bigint','smallint','tinyint','numeric','decimal','money','smallmoney','float','real',
2741
+ 'bit','binary','varbinary','image',
2742
+ 'datetime','datetime2','smalldatetime','date','time','datetimeoffset',
2743
+ 'uniqueidentifier','xml','cursor','table'
2744
+ }
2745
+
2746
+ update_matches = re.findall(update_pattern, cleaned_sql, re.IGNORECASE)
2747
+ delete_matches = re.findall(delete_from_pattern, cleaned_sql, re.IGNORECASE)
2748
+ merge_matches = re.findall(merge_into_pattern, cleaned_sql, re.IGNORECASE)
2749
+ from_matches = re.findall(from_pattern, cleaned_sql, re.IGNORECASE)
2750
+ join_matches = re.findall(join_pattern, cleaned_sql, re.IGNORECASE)
2751
+
2752
+ # Find function calls - both in FROM clauses and standalone
2753
+ # Pattern for function calls with parentheses
2754
+ function_call_pattern = r'(?:FROM\s+|SELECT\s+.*?\s+FROM\s+|,\s*)?([^\s\(\),]+(?:\.[^\s\(\),]+)*)\s*\([^)]*\)'
2755
+ exec_pattern = r'EXEC\s+([^\s\(\),]+(?:\.[^\s\(\),]+)*)'
2756
+
2757
+ function_matches = re.findall(function_call_pattern, cleaned_sql, re.IGNORECASE)
2758
+ exec_matches = re.findall(exec_pattern, cleaned_sql, re.IGNORECASE)
2759
+
2760
+ # Find table references in SELECT statements (for multi-table queries)
2761
+ # This captures tables in complex queries where they might not be in FROM/JOIN
2762
+ select_table_pattern = r'SELECT\s+.*?\s+FROM\s+([^\s\(\),]+(?:\.[^\s\(\),]+)*)'
2763
+ select_matches = re.findall(select_table_pattern, cleaned_sql, re.IGNORECASE | re.DOTALL)
2764
+
2765
+ # Also exclude INSERT INTO and CREATE TABLE targets from dependencies
2766
+ # These are outputs, not inputs
2767
+ insert_pattern = r'INSERT\s+INTO\s+([^\s\(\),]+(?:\.[^\s\(\),]+)*)'
2768
+ create_pattern = r'CREATE\s+(?:OR\s+ALTER\s+)?(?:TABLE|VIEW|PROCEDURE|FUNCTION)\s+([^\s\(\),]+(?:\.[^\s\(\),]+)*)'
2769
+ select_into_pattern = r'INTO\s+([^\s\(\),]+(?:\.[^\s\(\),]+)*)'
2770
+
2771
+ insert_targets = set()
2772
+ for match in re.findall(insert_pattern, cleaned_sql, re.IGNORECASE):
2773
+ table_name = self._normalize_table_ident(match.strip())
2774
+ if not table_name.startswith('#'):
2775
+ full_name = self._get_full_table_name(table_name)
2776
+ parts = full_name.split('.')
2777
+ if len(parts) >= 2:
2778
+ simplified = f"{parts[-2]}.{parts[-1]}"
2779
+ insert_targets.add(simplified)
2780
+
2781
+ for match in re.findall(create_pattern, cleaned_sql, re.IGNORECASE):
2782
+ table_name = self._normalize_table_ident(match.strip())
2783
+ if not table_name.startswith('#'):
2784
+ full_name = self._get_full_table_name(table_name)
2785
+ parts = full_name.split('.')
2786
+ if len(parts) >= 2:
2787
+ simplified = f"{parts[-2]}.{parts[-1]}"
2788
+ insert_targets.add(simplified)
2789
+
2790
+ for match in re.findall(select_into_pattern, cleaned_sql, re.IGNORECASE):
2791
+ table_name = self._normalize_table_ident(match.strip())
2792
+ if not table_name.startswith('#'):
2793
+ full_name = self._get_full_table_name(table_name)
2794
+ parts = full_name.split('.')
2795
+ if len(parts) >= 2:
2796
+ simplified = f"{parts[-2]}.{parts[-1]}"
2797
+ insert_targets.add(simplified)
2798
+
2799
+ # Process tables, functions, and procedures
2800
+ all_matches = from_matches + join_matches + update_matches + delete_matches + merge_matches + exec_matches
2801
+ for match in all_matches:
1416
2802
  table_name = match.strip()
1417
- # Clean up table name (remove aliases, schema qualifiers for dependency tracking)
1418
- if ' ' in table_name:
2803
+
2804
+ # jeżeli to wzorzec funkcji: "NAME(...)" – pomiń
2805
+ if re.search(r'\w+\s*\(', table_name):
2806
+ continue
2807
+ # wymagaj nazwy w postaci schemat.katalog lub przynajmniej identyfikatora bez słów kluczowych
2808
+ if table_name.lower() in builtin_functions:
2809
+ continue
2810
+
2811
+ # Skip empty matches
2812
+ if not table_name:
2813
+ continue
2814
+
2815
+ # Skip SQL keywords and built-in functions
2816
+
2817
+ if table_name.lower() in sql_keywords or table_name.lower() in builtin_functions or table_name.lower() in sql_types:
2818
+ continue
2819
+
2820
+ # Remove table alias if present (e.g., "table AS t" -> "table")
2821
+ if ' AS ' in table_name.upper():
2822
+ table_name = table_name.split(' AS ')[0].strip()
2823
+ elif ' ' in table_name and not '.' in table_name.split()[-1]:
2824
+ # Just "table alias" format -> take first part
1419
2825
  table_name = table_name.split()[0]
1420
- dependencies.add(table_name.lower())
2826
+
2827
+ # Clean brackets and normalize
2828
+ table_name = self._normalize_table_ident(table_name)
2829
+
2830
+ # Skip temp tables for dependency tracking
2831
+ if not table_name.startswith('#') and table_name.lower() not in sql_keywords:
2832
+ # Get full qualified name for consistent dependency tracking
2833
+ full_name = self._get_full_table_name(table_name)
2834
+ from .openlineage_utils import sanitize_name
2835
+ full_name = sanitize_name(full_name)
2836
+
2837
+ # Always use fully qualified format: database.schema.table
2838
+ # This ensures consistent topological sorting
2839
+ parts = full_name.split('.')
2840
+ if len(parts) >= 3:
2841
+ qualified_name = full_name # Already has database.schema.table
2842
+ elif len(parts) == 2:
2843
+ # schema.table -> add default database
2844
+ db_to_use = self.current_database or self.default_database or "InfoTrackerDW"
2845
+ qualified_name = f"{db_to_use}.{full_name}"
2846
+ else:
2847
+ # just table -> add default database and schema
2848
+ db_to_use = self.current_database or self.default_database or "InfoTrackerDW"
2849
+ qualified_name = f"{db_to_use}.dbo.{table_name}"
2850
+
2851
+ # Check if this is an output table (exclude from dependencies)
2852
+ output_check_parts = qualified_name.split('.')
2853
+ if len(output_check_parts) >= 2:
2854
+ simplified_for_check = f"{output_check_parts[-2]}.{output_check_parts[-1]}"
2855
+ if simplified_for_check not in insert_targets:
2856
+ dependencies.add(qualified_name)
1421
2857
 
1422
2858
  return dependencies
1423
2859
 
@@ -1440,14 +2876,38 @@ class SqlParser:
1440
2876
  # Find the SELECT statement in the RETURN clause
1441
2877
  select_stmt = self._extract_select_from_return(statement)
1442
2878
  if select_stmt:
2879
+ # Process CTEs first
2880
+ self._process_ctes(select_stmt)
2881
+
2882
+ # Extract lineage and expand dependencies
1443
2883
  lineage, output_columns = self._extract_column_lineage(select_stmt, function_name)
1444
- dependencies = self._extract_dependencies(select_stmt)
2884
+ raw_deps = self._extract_dependencies(select_stmt)
2885
+
2886
+ # Expand CTEs and temp tables to base tables
2887
+ for dep in raw_deps:
2888
+ expanded_deps = self._expand_dependency_to_base_tables(dep, select_stmt)
2889
+ dependencies.update(expanded_deps)
1445
2890
 
1446
2891
  # Handle multi-statement TVF (RETURN @table TABLE)
1447
2892
  elif "RETURNS @" in sql_text.upper():
1448
- # Extract the table variable definition and find INSERT statements
2893
+ # Extract the table variable definition and find all statements
1449
2894
  output_columns = self._extract_table_variable_schema(statement)
1450
- lineage, dependencies = self._extract_mstvf_lineage(statement, function_name, output_columns)
2895
+ lineage, raw_deps = self._extract_mstvf_lineage(statement, function_name, output_columns)
2896
+
2897
+ # Expand dependencies for multi-statement TVF
2898
+ for dep in raw_deps:
2899
+ expanded_deps = self._expand_dependency_to_base_tables(dep, statement)
2900
+ dependencies.update(expanded_deps)
2901
+
2902
+ # If AST-based extraction failed, fall back to string-based approach
2903
+ if not dependencies and not lineage:
2904
+ try:
2905
+ lineage, output_columns, dependencies = self._extract_tvf_lineage_string(sql_text, function_name)
2906
+ except Exception:
2907
+ pass
2908
+
2909
+ # Remove any CTE references from final dependencies
2910
+ dependencies = {dep for dep in dependencies if not self._is_cte_reference(dep)}
1451
2911
 
1452
2912
  return lineage, output_columns, dependencies
1453
2913
 
@@ -1510,24 +2970,83 @@ class SqlParser:
1510
2970
  lineage = []
1511
2971
  dependencies = set()
1512
2972
 
1513
- # Find INSERT statements into the @table variable
2973
+ # Parse the entire function body to find all SQL statements
1514
2974
  sql_text = str(statement)
1515
- insert_matches = re.finditer(r'INSERT\s+INTO\s+@\w+.*?SELECT(.*?)(?:FROM|$)', sql_text, re.IGNORECASE | re.DOTALL)
1516
2975
 
1517
- for match in insert_matches:
1518
- try:
1519
- select_part = "SELECT" + match.group(1)
1520
- parsed = sqlglot.parse(select_part, read=self.dialect)
1521
- if parsed and isinstance(parsed[0], exp.Select):
1522
- select_stmt = parsed[0]
1523
- stmt_lineage, _ = self._extract_column_lineage(select_stmt, function_name)
1524
- lineage.extend(stmt_lineage)
1525
- dependencies.update(self._extract_dependencies(select_stmt))
1526
- except Exception:
1527
- continue
2976
+ # Find INSERT, SELECT, UPDATE, DELETE statements
2977
+ stmt_patterns = [
2978
+ r'INSERT\s+INTO\s+@\w+.*?(?=(?:INSERT|SELECT|UPDATE|DELETE|RETURN|END|\Z))',
2979
+ r'(?<!INSERT\s+INTO\s+@\w+.*?)SELECT\s+.*?(?=(?:INSERT|SELECT|UPDATE|DELETE|RETURN|END|\Z))',
2980
+ r'UPDATE\s+.*?(?=(?:INSERT|SELECT|UPDATE|DELETE|RETURN|END|\Z))',
2981
+ r'DELETE\s+.*?(?=(?:INSERT|SELECT|UPDATE|DELETE|RETURN|END|\Z))',
2982
+ r'EXEC\s+.*?(?=(?:INSERT|SELECT|UPDATE|DELETE|RETURN|END|\Z))'
2983
+ ]
2984
+
2985
+ for pattern in stmt_patterns:
2986
+ matches = re.finditer(pattern, sql_text, re.IGNORECASE | re.DOTALL)
2987
+ for match in matches:
2988
+ try:
2989
+ stmt_sql = match.group(0).strip()
2990
+ if not stmt_sql:
2991
+ continue
2992
+
2993
+ # Parse the statement
2994
+ parsed_stmts = sqlglot.parse(stmt_sql, read=self.dialect)
2995
+ if parsed_stmts:
2996
+ for parsed_stmt in parsed_stmts:
2997
+ if isinstance(parsed_stmt, exp.Select):
2998
+ stmt_lineage, _ = self._extract_column_lineage(parsed_stmt, function_name)
2999
+ lineage.extend(stmt_lineage)
3000
+ stmt_deps = self._extract_dependencies(parsed_stmt)
3001
+ dependencies.update(stmt_deps)
3002
+ elif isinstance(parsed_stmt, exp.Insert):
3003
+ # Handle INSERT statements
3004
+ if hasattr(parsed_stmt, 'expression') and isinstance(parsed_stmt.expression, exp.Select):
3005
+ stmt_lineage, _ = self._extract_column_lineage(parsed_stmt.expression, function_name)
3006
+ lineage.extend(stmt_lineage)
3007
+ stmt_deps = self._extract_dependencies(parsed_stmt.expression)
3008
+ dependencies.update(stmt_deps)
3009
+ except Exception as e:
3010
+ logger.debug(f"Failed to parse statement in MSTVF: {e}")
3011
+ continue
1528
3012
 
1529
3013
  return lineage, dependencies
1530
3014
 
3015
+ def _expand_dependency_to_base_tables(self, dep_name: str, context_stmt: exp.Expression) -> Set[str]:
3016
+ """Expand dependency to base tables, resolving CTEs and temp tables."""
3017
+ expanded = set()
3018
+
3019
+ # Check if this is a CTE reference
3020
+ simple_name = dep_name.split('.')[-1]
3021
+ if simple_name in self.cte_registry:
3022
+ # This is a CTE - find its definition and get base dependencies
3023
+ if isinstance(context_stmt, exp.Select) and context_stmt.args.get('with'):
3024
+ with_clause = context_stmt.args.get('with')
3025
+ if hasattr(with_clause, 'expressions'):
3026
+ for cte in with_clause.expressions:
3027
+ if hasattr(cte, 'alias') and str(cte.alias) == simple_name:
3028
+ if isinstance(cte.this, exp.Select):
3029
+ cte_deps = self._extract_dependencies(cte.this)
3030
+ for cte_dep in cte_deps:
3031
+ expanded.update(self._expand_dependency_to_base_tables(cte_dep, cte.this))
3032
+ break
3033
+ return expanded
3034
+
3035
+ # Check if this is a temp table reference
3036
+ if simple_name in self.temp_registry:
3037
+ # For temp tables, return the temp table name itself (it's a base table)
3038
+ expanded.add(dep_name)
3039
+ return expanded
3040
+
3041
+ # It's a regular table - return as is
3042
+ expanded.add(dep_name)
3043
+ return expanded
3044
+
3045
+ def _is_cte_reference(self, dep_name: str) -> bool:
3046
+ """Check if a dependency name refers to a CTE."""
3047
+ simple_name = dep_name.split('.')[-1]
3048
+ return simple_name in self.cte_registry
3049
+
1531
3050
  def _find_last_select_in_procedure(self, statement: exp.Create) -> Optional[exp.Select]:
1532
3051
  """Find the last SELECT statement in a procedure body."""
1533
3052
  sql_text = str(statement)