InfoTracker 0.1.0__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
infotracker/parser.py CHANGED
@@ -3,6 +3,7 @@ SQL parsing and lineage extraction using SQLGlot.
3
3
  """
4
4
  from __future__ import annotations
5
5
 
6
+ import logging
6
7
  import re
7
8
  from typing import List, Optional, Set, Dict, Any
8
9
 
@@ -11,9 +12,11 @@ from sqlglot import expressions as exp
11
12
 
12
13
  from .models import (
13
14
  ColumnReference, ColumnSchema, TableSchema, ColumnLineage,
14
- TransformationType, ObjectInfo, SchemaRegistry
15
+ TransformationType, ObjectInfo, SchemaRegistry, ColumnNode
15
16
  )
16
17
 
18
+ logger = logging.getLogger(__name__)
19
+
17
20
 
18
21
  class SqlParser:
19
22
  """Parser for SQL statements using SQLGlot."""
@@ -21,12 +24,153 @@ class SqlParser:
21
24
  def __init__(self, dialect: str = "tsql"):
22
25
  self.dialect = dialect
23
26
  self.schema_registry = SchemaRegistry()
27
+ self.default_database: Optional[str] = None # Will be set from config
28
+
29
+ def set_default_database(self, default_database: Optional[str]):
30
+ """Set the default database for qualification."""
31
+ self.default_database = default_database
32
+
33
+ def _preprocess_sql(self, sql: str) -> str:
34
+ """
35
+ Preprocess SQL to remove control lines and join INSERT INTO #temp EXEC patterns.
36
+ """
37
+ import re
38
+
39
+ lines = sql.split('\n')
40
+ processed_lines = []
41
+
42
+ for line in lines:
43
+ stripped_line = line.strip()
44
+
45
+ # Skip lines starting with DECLARE, SET, PRINT (case-insensitive)
46
+ if re.match(r'(?i)^(DECLARE|SET|PRINT)\b', stripped_line):
47
+ continue
48
+
49
+ # Skip IF OBJECT_ID('tempdb..#...') patterns and DROP TABLE #temp patterns
50
+ if (re.match(r"(?i)^IF\s+OBJECT_ID\('tempdb\.\.#", stripped_line) or
51
+ re.match(r'(?i)^DROP\s+TABLE\s+#\w+', stripped_line)):
52
+ continue
53
+
54
+ processed_lines.append(line)
55
+
56
+ # Join the lines back together
57
+ processed_sql = '\n'.join(processed_lines)
58
+
59
+ # Join two-line INSERT INTO #temp + EXEC patterns
60
+ processed_sql = re.sub(
61
+ r'(?i)(INSERT\s+INTO\s+#\w+)\s*\n\s*(EXEC\b)',
62
+ r'\1 \2',
63
+ processed_sql
64
+ )
65
+
66
+ return processed_sql
67
+
68
+ def _try_insert_exec_fallback(self, sql_content: str, object_hint: Optional[str] = None) -> Optional[ObjectInfo]:
69
+ """
70
+ Fallback parser for INSERT INTO #temp EXEC pattern when SQLGlot fails.
71
+ """
72
+ import re
73
+
74
+ # Look for INSERT INTO #temp EXEC pattern
75
+ pattern = r'(?is)INSERT\s+INTO\s+(#\w+)\s+EXEC\s+([^\s(]+)'
76
+ match = re.search(pattern, sql_content)
77
+
78
+ if not match:
79
+ return None
80
+
81
+ temp_table = match.group(1) # e.g., "#customer_metrics"
82
+ proc_name = match.group(2) # e.g., "dbo.usp_customer_metrics_dataset"
83
+
84
+ # Qualify procedure name if needed
85
+ if '.' not in proc_name and self.default_database:
86
+ qualified_proc_name = f"{self.default_database}.dbo.{proc_name}"
87
+ else:
88
+ qualified_proc_name = proc_name
89
+
90
+ # Create placeholder columns for the temp table
91
+ placeholder_columns = [
92
+ ColumnSchema(
93
+ name="output_col_1",
94
+ data_type="unknown",
95
+ nullable=True,
96
+ ordinal=0
97
+ ),
98
+ ColumnSchema(
99
+ name="output_col_2",
100
+ data_type="unknown",
101
+ nullable=True,
102
+ ordinal=1
103
+ )
104
+ ]
105
+
106
+ # Create schema for temp table
107
+ schema = TableSchema(
108
+ namespace="tempdb",
109
+ name=temp_table,
110
+ columns=placeholder_columns
111
+ )
112
+
113
+ # Create lineage for each placeholder column
114
+ lineage = []
115
+ for col in placeholder_columns:
116
+ lineage.append(ColumnLineage(
117
+ output_column=col.name,
118
+ input_fields=[
119
+ ColumnReference(
120
+ namespace="mssql://localhost/InfoTrackerDW",
121
+ table_name=qualified_proc_name,
122
+ column_name="*"
123
+ )
124
+ ],
125
+ transformation_type=TransformationType.EXEC,
126
+ transformation_description=f"INSERT INTO {temp_table} EXEC {proc_name}"
127
+ ))
128
+
129
+ # Set dependencies to the procedure
130
+ dependencies = {qualified_proc_name}
131
+
132
+ # Register schema in registry
133
+ self.schema_registry.register(schema)
134
+
135
+ # Create and return ObjectInfo
136
+ return ObjectInfo(
137
+ name=temp_table,
138
+ object_type="temp_table",
139
+ schema=schema,
140
+ lineage=lineage,
141
+ dependencies=dependencies
142
+ )
143
+
144
+ def _find_last_select_string(self, sql_content: str, dialect: str = "tsql") -> str | None:
145
+ """Find the last SELECT statement in SQL content using SQLGlot AST."""
146
+ import sqlglot
147
+ from sqlglot import exp
148
+ try:
149
+ parsed = sqlglot.parse(sql_content, read=dialect)
150
+ selects = []
151
+ for stmt in parsed:
152
+ selects.extend(list(stmt.find_all(exp.Select)))
153
+ if not selects:
154
+ return None
155
+ return str(selects[-1])
156
+ except Exception:
157
+ return None
24
158
 
25
159
  def parse_sql_file(self, sql_content: str, object_hint: Optional[str] = None) -> ObjectInfo:
26
160
  """Parse a SQL file and extract object information."""
27
161
  try:
28
- # Parse the SQL statement
29
- statements = sqlglot.parse(sql_content, read=self.dialect)
162
+ # First check if this is a function or procedure using string matching
163
+ sql_upper = sql_content.upper()
164
+ if "CREATE FUNCTION" in sql_upper or "CREATE OR ALTER FUNCTION" in sql_upper:
165
+ return self._parse_function_string(sql_content, object_hint)
166
+ elif "CREATE PROCEDURE" in sql_upper or "CREATE OR ALTER PROCEDURE" in sql_upper:
167
+ return self._parse_procedure_string(sql_content, object_hint)
168
+
169
+ # Preprocess the SQL content to handle demo script patterns
170
+ preprocessed_sql = self._preprocess_sql(sql_content)
171
+
172
+ # Parse the SQL statement with SQLGlot
173
+ statements = sqlglot.parse(preprocessed_sql, read=self.dialect)
30
174
  if not statements:
31
175
  raise ValueError("No valid SQL statements found")
32
176
 
@@ -37,10 +181,18 @@ class SqlParser:
37
181
  return self._parse_create_statement(statement, object_hint)
38
182
  elif isinstance(statement, exp.Select) and self._is_select_into(statement):
39
183
  return self._parse_select_into(statement, object_hint)
184
+ elif isinstance(statement, exp.Insert) and self._is_insert_exec(statement):
185
+ return self._parse_insert_exec(statement, object_hint)
40
186
  else:
41
187
  raise ValueError(f"Unsupported statement type: {type(statement)}")
42
188
 
43
189
  except Exception as e:
190
+ # Try fallback for INSERT INTO #temp EXEC pattern
191
+ fallback_result = self._try_insert_exec_fallback(sql_content, object_hint)
192
+ if fallback_result:
193
+ return fallback_result
194
+
195
+ logger.warning("parse failed: %s", e)
44
196
  # Return an object with error information
45
197
  return ObjectInfo(
46
198
  name=object_hint or "unknown",
@@ -58,6 +210,17 @@ class SqlParser:
58
210
  """Check if this is a SELECT INTO statement."""
59
211
  return statement.args.get('into') is not None
60
212
 
213
+ def _is_insert_exec(self, statement: exp.Insert) -> bool:
214
+ """Check if this is an INSERT INTO ... EXEC statement."""
215
+ # Check if the expression is a command (EXEC)
216
+ expression = statement.expression
217
+ return (
218
+ hasattr(expression, 'expressions') and
219
+ expression.expressions and
220
+ isinstance(expression.expressions[0], exp.Command) and
221
+ str(expression.expressions[0]).upper().startswith('EXEC')
222
+ )
223
+
61
224
  def _parse_select_into(self, statement: exp.Select, object_hint: Optional[str] = None) -> ObjectInfo:
62
225
  """Parse SELECT INTO statement."""
63
226
  # Get target table name from INTO clause
@@ -95,12 +258,97 @@ class SqlParser:
95
258
  dependencies=dependencies
96
259
  )
97
260
 
261
+ def _parse_insert_exec(self, statement: exp.Insert, object_hint: Optional[str] = None) -> ObjectInfo:
262
+ """Parse INSERT INTO ... EXEC statement."""
263
+ # Get target table name from INSERT INTO clause
264
+ table_name = self._get_table_name(statement.this, object_hint)
265
+ namespace = "mssql://localhost/InfoTrackerDW"
266
+
267
+ # Normalize temp table names
268
+ if table_name.startswith('#'):
269
+ namespace = "tempdb"
270
+
271
+ # Extract the EXEC command
272
+ expression = statement.expression
273
+ if hasattr(expression, 'expressions') and expression.expressions:
274
+ exec_command = expression.expressions[0]
275
+
276
+ # Extract procedure name and dependencies
277
+ dependencies = set()
278
+ procedure_name = None
279
+
280
+ # Parse the EXEC command text
281
+ exec_text = str(exec_command)
282
+ if exec_text.upper().startswith('EXEC'):
283
+ # Extract procedure name (first identifier after EXEC)
284
+ parts = exec_text.split()
285
+ if len(parts) > 1:
286
+ procedure_name = parts[1].strip('()').split('(')[0]
287
+ dependencies.add(procedure_name)
288
+
289
+ # For EXEC temp tables, we create placeholder columns since we can't determine
290
+ # the actual structure without executing the procedure
291
+ # Create at least 2 output columns as per the requirement
292
+ output_columns = [
293
+ ColumnSchema(
294
+ name="output_col_1",
295
+ data_type="unknown",
296
+ ordinal=0,
297
+ nullable=True
298
+ ),
299
+ ColumnSchema(
300
+ name="output_col_2",
301
+ data_type="unknown",
302
+ ordinal=1,
303
+ nullable=True
304
+ )
305
+ ]
306
+
307
+ # Create placeholder lineage pointing to the procedure
308
+ lineage = []
309
+ if procedure_name:
310
+ for i, col in enumerate(output_columns):
311
+ lineage.append(ColumnLineage(
312
+ output_column=col.name,
313
+ input_fields=[ColumnReference(
314
+ namespace="mssql://localhost/InfoTrackerDW",
315
+ table_name=procedure_name,
316
+ column_name="*" # Wildcard since we don't know the procedure output
317
+ )],
318
+ transformation_type=TransformationType.EXEC,
319
+ transformation_description=f"INSERT INTO {table_name} EXEC {procedure_name}"
320
+ ))
321
+
322
+ schema = TableSchema(
323
+ namespace=namespace,
324
+ name=table_name,
325
+ columns=output_columns
326
+ )
327
+
328
+ # Register schema for future reference
329
+ self.schema_registry.register(schema)
330
+
331
+ return ObjectInfo(
332
+ name=table_name,
333
+ object_type="temp_table" if table_name.startswith('#') else "table",
334
+ schema=schema,
335
+ lineage=lineage,
336
+ dependencies=dependencies
337
+ )
338
+
339
+ # Fallback if we can't parse the EXEC command
340
+ raise ValueError("Could not parse INSERT INTO ... EXEC statement")
341
+
98
342
  def _parse_create_statement(self, statement: exp.Create, object_hint: Optional[str] = None) -> ObjectInfo:
99
- """Parse CREATE TABLE or CREATE VIEW statement."""
343
+ """Parse CREATE TABLE, CREATE VIEW, CREATE FUNCTION, or CREATE PROCEDURE statement."""
100
344
  if statement.kind == "TABLE":
101
345
  return self._parse_create_table(statement, object_hint)
102
346
  elif statement.kind == "VIEW":
103
347
  return self._parse_create_view(statement, object_hint)
348
+ elif statement.kind == "FUNCTION":
349
+ return self._parse_create_function(statement, object_hint)
350
+ elif statement.kind == "PROCEDURE":
351
+ return self._parse_create_procedure(statement, object_hint)
104
352
  else:
105
353
  raise ValueError(f"Unsupported CREATE statement: {statement.kind}")
106
354
 
@@ -189,31 +437,117 @@ class SqlParser:
189
437
  dependencies=dependencies
190
438
  )
191
439
 
440
+ def _parse_create_function(self, statement: exp.Create, object_hint: Optional[str] = None) -> ObjectInfo:
441
+ """Parse CREATE FUNCTION statement (table-valued functions only)."""
442
+ function_name = self._get_table_name(statement.this, object_hint)
443
+ namespace = "mssql://localhost/InfoTrackerDW"
444
+
445
+ # Check if this is a table-valued function
446
+ if not self._is_table_valued_function(statement):
447
+ # For scalar functions, create a simple object without lineage
448
+ return ObjectInfo(
449
+ name=function_name,
450
+ object_type="function",
451
+ schema=TableSchema(
452
+ namespace=namespace,
453
+ name=function_name,
454
+ columns=[]
455
+ ),
456
+ lineage=[],
457
+ dependencies=set()
458
+ )
459
+
460
+ # Handle table-valued functions
461
+ lineage, output_columns, dependencies = self._extract_tvf_lineage(statement, function_name)
462
+
463
+ schema = TableSchema(
464
+ namespace=namespace,
465
+ name=function_name,
466
+ columns=output_columns
467
+ )
468
+
469
+ # Register schema for future reference
470
+ self.schema_registry.register(schema)
471
+
472
+ return ObjectInfo(
473
+ name=function_name,
474
+ object_type="function",
475
+ schema=schema,
476
+ lineage=lineage,
477
+ dependencies=dependencies
478
+ )
479
+
480
+ def _parse_create_procedure(self, statement: exp.Create, object_hint: Optional[str] = None) -> ObjectInfo:
481
+ """Parse CREATE PROCEDURE statement."""
482
+ procedure_name = self._get_table_name(statement.this, object_hint)
483
+ namespace = "mssql://localhost/InfoTrackerDW"
484
+
485
+ # Extract the procedure body and find the last SELECT statement
486
+ lineage, output_columns, dependencies = self._extract_procedure_lineage(statement, procedure_name)
487
+
488
+ schema = TableSchema(
489
+ namespace=namespace,
490
+ name=procedure_name,
491
+ columns=output_columns
492
+ )
493
+
494
+ # Register schema for future reference
495
+ self.schema_registry.register(schema)
496
+
497
+ return ObjectInfo(
498
+ name=procedure_name,
499
+ object_type="procedure",
500
+ schema=schema,
501
+ lineage=lineage,
502
+ dependencies=dependencies
503
+ )
504
+
192
505
  def _get_table_name(self, table_expr: exp.Expression, hint: Optional[str] = None) -> str:
193
- """Extract table name from expression."""
506
+ """Extract table name from expression and qualify with default database if needed."""
507
+ from .openlineage_utils import qualify_identifier
508
+
194
509
  if isinstance(table_expr, exp.Table):
195
- # Handle qualified names like dbo.table_name
196
- if table_expr.db:
197
- return f"{table_expr.db}.{table_expr.name}"
198
- return str(table_expr.name)
510
+ # Handle three-part names: database.schema.table
511
+ if table_expr.catalog and table_expr.db:
512
+ return f"{table_expr.catalog}.{table_expr.db}.{table_expr.name}"
513
+ # Handle two-part names like dbo.table_name (legacy format)
514
+ elif table_expr.db:
515
+ table_name = f"{table_expr.db}.{table_expr.name}"
516
+ return qualify_identifier(table_name, self.default_database)
517
+ else:
518
+ table_name = str(table_expr.name)
519
+ return qualify_identifier(table_name, self.default_database)
199
520
  elif isinstance(table_expr, exp.Identifier):
200
- return str(table_expr.this)
521
+ table_name = str(table_expr.this)
522
+ return qualify_identifier(table_name, self.default_database)
201
523
  return hint or "unknown"
202
524
 
203
525
  def _extract_column_type(self, column_def: exp.ColumnDef) -> str:
204
526
  """Extract column type from column definition."""
205
527
  if column_def.kind:
206
528
  data_type = str(column_def.kind)
207
- # Convert to match expected format (lowercase for simple types)
208
- if data_type.upper().startswith('VARCHAR'):
209
- data_type = data_type.replace('VARCHAR', 'nvarchar')
210
- elif data_type.upper() == 'INT':
211
- data_type = 'int'
212
- elif data_type.upper() == 'DATE':
213
- data_type = 'date'
214
- elif 'DECIMAL' in data_type.upper():
529
+
530
+ # Type normalization mappings - adjust these as needed for your environment
531
+ # Note: This aggressive normalization can be modified by updating the mappings below
532
+ TYPE_MAPPINGS = {
533
+ 'VARCHAR': 'nvarchar', # SQL Server: VARCHAR -> NVARCHAR
534
+ 'INT': 'int',
535
+ 'DATE': 'date',
536
+ }
537
+
538
+ data_type_upper = data_type.upper()
539
+ for old_type, new_type in TYPE_MAPPINGS.items():
540
+ if data_type_upper.startswith(old_type):
541
+ data_type = data_type.replace(old_type, new_type)
542
+ break
543
+ elif data_type_upper == old_type:
544
+ data_type = new_type
545
+ break
546
+
547
+ if 'DECIMAL' in data_type_upper:
215
548
  # Normalize decimal formatting: "DECIMAL(10, 2)" -> "decimal(10,2)"
216
549
  data_type = data_type.replace(' ', '').lower()
550
+
217
551
  return data_type.lower()
218
552
  return "unknown"
219
553
 
@@ -670,76 +1004,97 @@ class SqlParser:
670
1004
  lineage = []
671
1005
  output_columns = []
672
1006
 
673
- # Get source tables and their aliases
674
- source_tables = []
675
- table_aliases = {}
1007
+ # Process all SELECT expressions, including both stars and explicit columns
1008
+ ordinal = 0
676
1009
 
677
- # Check for explicit aliased star (o.*, c.*)
678
1010
  for select_expr in select_stmt.expressions:
679
- if isinstance(select_expr, exp.Star) and select_expr.table:
680
- # This is an aliased star like o.* or c.*
681
- alias = str(select_expr.table)
682
- table_name = self._resolve_table_from_alias(alias, select_stmt)
683
- if table_name != "unknown":
684
- columns = self._infer_table_columns(table_name)
685
- ordinal = len(output_columns)
1011
+ if isinstance(select_expr, exp.Star):
1012
+ if hasattr(select_expr, 'table') and select_expr.table:
1013
+ # This is an aliased star like o.* or c.*
1014
+ alias = str(select_expr.table)
1015
+ table_name = self._resolve_table_from_alias(alias, select_stmt)
1016
+ if table_name != "unknown":
1017
+ columns = self._infer_table_columns(table_name)
1018
+
1019
+ for column_name in columns:
1020
+ output_columns.append(ColumnSchema(
1021
+ name=column_name,
1022
+ data_type="unknown",
1023
+ nullable=True,
1024
+ ordinal=ordinal
1025
+ ))
1026
+ ordinal += 1
1027
+
1028
+ lineage.append(ColumnLineage(
1029
+ output_column=column_name,
1030
+ input_fields=[ColumnReference(
1031
+ namespace="mssql://localhost/InfoTrackerDW",
1032
+ table_name=table_name,
1033
+ column_name=column_name
1034
+ )],
1035
+ transformation_type=TransformationType.IDENTITY,
1036
+ transformation_description=f"SELECT {alias}.{column_name}"
1037
+ ))
1038
+ else:
1039
+ # Handle unqualified * - expand all tables
1040
+ source_tables = []
1041
+ for table in select_stmt.find_all(exp.Table):
1042
+ table_name = self._get_table_name(table)
1043
+ if table_name != "unknown":
1044
+ source_tables.append(table_name)
686
1045
 
687
- for column_name in columns:
688
- output_columns.append(ColumnSchema(
689
- name=column_name,
690
- data_type="unknown",
691
- nullable=True,
692
- ordinal=ordinal
693
- ))
694
- ordinal += 1
1046
+ for table_name in source_tables:
1047
+ columns = self._infer_table_columns(table_name)
695
1048
 
696
- lineage.append(ColumnLineage(
697
- output_column=column_name,
698
- input_fields=[ColumnReference(
699
- namespace="mssql://localhost/InfoTrackerDW",
700
- table_name=table_name,
701
- column_name=column_name
702
- )],
703
- transformation_type=TransformationType.IDENTITY,
704
- transformation_description=f"SELECT {alias}.{column_name}"
705
- ))
706
- return lineage, output_columns
707
-
708
- # Handle unqualified * - expand all tables
709
- for table in select_stmt.find_all(exp.Table):
710
- table_name = self._get_table_name(table)
711
- if table_name != "unknown":
712
- source_tables.append(table_name)
713
-
714
- if not source_tables:
715
- return lineage, output_columns
716
-
717
- # For unqualified *, expand columns from all tables
718
- ordinal = 0
719
- for table_name in source_tables:
720
- columns = self._infer_table_columns(table_name)
721
-
722
- for column_name in columns:
1049
+ for column_name in columns:
1050
+ output_columns.append(ColumnSchema(
1051
+ name=column_name,
1052
+ data_type="unknown",
1053
+ nullable=True,
1054
+ ordinal=ordinal
1055
+ ))
1056
+ ordinal += 1
1057
+
1058
+ lineage.append(ColumnLineage(
1059
+ output_column=column_name,
1060
+ input_fields=[ColumnReference(
1061
+ namespace="mssql://localhost/InfoTrackerDW",
1062
+ table_name=table_name,
1063
+ column_name=column_name
1064
+ )],
1065
+ transformation_type=TransformationType.IDENTITY,
1066
+ transformation_description=f"SELECT * (from {table_name})"
1067
+ ))
1068
+ else:
1069
+ # Handle explicit column expressions (like "1 as extra_col")
1070
+ col_name = self._extract_column_alias(select_expr) or f"col_{ordinal}"
723
1071
  output_columns.append(ColumnSchema(
724
- name=column_name,
1072
+ name=col_name,
725
1073
  data_type="unknown",
726
1074
  nullable=True,
727
1075
  ordinal=ordinal
728
1076
  ))
729
1077
  ordinal += 1
730
1078
 
731
- lineage.append(ColumnLineage(
732
- output_column=column_name,
733
- input_fields=[ColumnReference(
1079
+ # Try to extract lineage for this column
1080
+ input_refs = self._extract_column_references(select_expr, select_stmt)
1081
+ if not input_refs:
1082
+ # If no specific references found, treat as expression
1083
+ input_refs = [ColumnReference(
734
1084
  namespace="mssql://localhost/InfoTrackerDW",
735
- table_name=table_name,
736
- column_name=column_name
737
- )],
738
- transformation_type=TransformationType.IDENTITY,
739
- transformation_description=f"SELECT * (from {table_name})"
1085
+ table_name="LITERAL",
1086
+ column_name=str(select_expr)
1087
+ )]
1088
+
1089
+ lineage.append(ColumnLineage(
1090
+ output_column=col_name,
1091
+ input_fields=input_refs,
1092
+ transformation_type=TransformationType.EXPRESSION,
1093
+ transformation_description=f"SELECT {str(select_expr)}"
740
1094
  ))
741
1095
 
742
1096
  return lineage, output_columns
1097
+
743
1098
 
744
1099
  def _handle_union_lineage(self, stmt: exp.Expression, view_name: str) -> tuple[List[ColumnLineage], List[ColumnSchema]]:
745
1100
  """Handle UNION operations."""
@@ -790,8 +1145,21 @@ class SqlParser:
790
1145
  return lineage, output_columns
791
1146
 
792
1147
  def _infer_table_columns(self, table_name: str) -> List[str]:
793
- """Infer table columns based on known schemas or naming patterns."""
794
- # This is a simplified approach - you'd typically query the database
1148
+ """Infer table columns from schema registry or fallback to patterns."""
1149
+ # First try to get from schema registry
1150
+ # Try different namespace combinations
1151
+ namespaces_to_try = [
1152
+ "mssql://localhost/InfoTrackerDW",
1153
+ "dbo",
1154
+ "",
1155
+ ]
1156
+
1157
+ for namespace in namespaces_to_try:
1158
+ schema = self.schema_registry.get(namespace, table_name)
1159
+ if schema:
1160
+ return [col.name for col in schema.columns]
1161
+
1162
+ # Fallback to patterns if not in registry
795
1163
  table_simple = table_name.split('.')[-1].lower()
796
1164
 
797
1165
  if 'orders' in table_simple:
@@ -805,3 +1173,407 @@ class SqlParser:
805
1173
  else:
806
1174
  # Generic fallback
807
1175
  return ['Column1', 'Column2', 'Column3']
1176
+
1177
+ def _parse_function_string(self, sql_content: str, object_hint: Optional[str] = None) -> ObjectInfo:
1178
+ """Parse CREATE FUNCTION using string-based approach."""
1179
+ function_name = self._extract_function_name(sql_content) or object_hint or "unknown_function"
1180
+ namespace = "mssql://localhost/InfoTrackerDW"
1181
+
1182
+ # Check if this is a table-valued function
1183
+ if not self._is_table_valued_function_string(sql_content):
1184
+ # For scalar functions, create a simple object without lineage
1185
+ return ObjectInfo(
1186
+ name=function_name,
1187
+ object_type="function",
1188
+ schema=TableSchema(
1189
+ namespace=namespace,
1190
+ name=function_name,
1191
+ columns=[]
1192
+ ),
1193
+ lineage=[],
1194
+ dependencies=set()
1195
+ )
1196
+
1197
+ # Handle table-valued functions
1198
+ lineage, output_columns, dependencies = self._extract_tvf_lineage_string(sql_content, function_name)
1199
+
1200
+ schema = TableSchema(
1201
+ namespace=namespace,
1202
+ name=function_name,
1203
+ columns=output_columns
1204
+ )
1205
+
1206
+ # Register schema for future reference
1207
+ self.schema_registry.register(schema)
1208
+
1209
+ return ObjectInfo(
1210
+ name=function_name,
1211
+ object_type="function",
1212
+ schema=schema,
1213
+ lineage=lineage,
1214
+ dependencies=dependencies
1215
+ )
1216
+
1217
+ def _parse_procedure_string(self, sql_content: str, object_hint: Optional[str] = None) -> ObjectInfo:
1218
+ """Parse CREATE PROCEDURE using string-based approach."""
1219
+ procedure_name = self._extract_procedure_name(sql_content) or object_hint or "unknown_procedure"
1220
+ namespace = "mssql://localhost/InfoTrackerDW"
1221
+
1222
+ # Extract the procedure body and find the last SELECT statement
1223
+ lineage, output_columns, dependencies = self._extract_procedure_lineage_string(sql_content, procedure_name)
1224
+
1225
+ schema = TableSchema(
1226
+ namespace=namespace,
1227
+ name=procedure_name,
1228
+ columns=output_columns
1229
+ )
1230
+
1231
+ # Register schema for future reference
1232
+ self.schema_registry.register(schema)
1233
+
1234
+ return ObjectInfo(
1235
+ name=procedure_name,
1236
+ object_type="procedure",
1237
+ schema=schema,
1238
+ lineage=lineage,
1239
+ dependencies=dependencies
1240
+ )
1241
+
1242
+ def _extract_function_name(self, sql_content: str) -> Optional[str]:
1243
+ """Extract function name from CREATE FUNCTION statement."""
1244
+ match = re.search(r'CREATE\s+(?:OR\s+ALTER\s+)?FUNCTION\s+([^\s\(]+)', sql_content, re.IGNORECASE)
1245
+ return match.group(1).strip() if match else None
1246
+
1247
+ def _extract_procedure_name(self, sql_content: str) -> Optional[str]:
1248
+ """Extract procedure name from CREATE PROCEDURE statement."""
1249
+ match = re.search(r'CREATE\s+(?:OR\s+ALTER\s+)?PROCEDURE\s+([^\s\(]+)', sql_content, re.IGNORECASE)
1250
+ return match.group(1).strip() if match else None
1251
+
1252
+ def _is_table_valued_function_string(self, sql_content: str) -> bool:
1253
+ """Check if this is a table-valued function (returns TABLE)."""
1254
+ sql_upper = sql_content.upper()
1255
+ return "RETURNS TABLE" in sql_upper or "RETURNS @" in sql_upper
1256
+
1257
+ def _extract_tvf_lineage_string(self, sql_content: str, function_name: str) -> tuple[List[ColumnLineage], List[ColumnSchema], Set[str]]:
1258
+ """Extract lineage from a table-valued function using string parsing."""
1259
+ lineage = []
1260
+ output_columns = []
1261
+ dependencies = set()
1262
+
1263
+ sql_upper = sql_content.upper()
1264
+
1265
+ # Handle inline TVF (RETURN AS SELECT or RETURN (SELECT))
1266
+ if "RETURN" in sql_upper and ("AS" in sql_upper or "(" in sql_upper):
1267
+ select_sql = self._extract_select_from_return_string(sql_content)
1268
+ if select_sql:
1269
+ try:
1270
+ parsed = sqlglot.parse(select_sql, read=self.dialect)
1271
+ if parsed and isinstance(parsed[0], exp.Select):
1272
+ lineage, output_columns = self._extract_column_lineage(parsed[0], function_name)
1273
+ dependencies = self._extract_dependencies(parsed[0])
1274
+ except Exception:
1275
+ # Fallback to basic analysis
1276
+ output_columns = self._extract_basic_select_columns(select_sql)
1277
+ dependencies = self._extract_basic_dependencies(select_sql)
1278
+
1279
+ # Handle multi-statement TVF (RETURNS @table TABLE)
1280
+ elif "RETURNS @" in sql_upper:
1281
+ output_columns = self._extract_table_variable_schema_string(sql_content)
1282
+ dependencies = self._extract_basic_dependencies(sql_content)
1283
+
1284
+ return lineage, output_columns, dependencies
1285
+
1286
+ def _extract_procedure_lineage_string(self, sql_content: str, procedure_name: str) -> tuple[List[ColumnLineage], List[ColumnSchema], Set[str]]:
1287
+ """Extract lineage from a procedure using string parsing."""
1288
+ lineage = []
1289
+ output_columns = []
1290
+ dependencies = set()
1291
+
1292
+ # Find the last SELECT statement in the procedure body
1293
+ last_select_sql = self._find_last_select_string(sql_content)
1294
+ if last_select_sql:
1295
+ try:
1296
+ parsed = sqlglot.parse(last_select_sql, read=self.dialect)
1297
+ if parsed and isinstance(parsed[0], exp.Select):
1298
+ lineage, output_columns = self._extract_column_lineage(parsed[0], procedure_name)
1299
+ dependencies = self._extract_dependencies(parsed[0])
1300
+ except Exception:
1301
+ # Fallback to basic analysis
1302
+ output_columns = self._extract_basic_select_columns(last_select_sql)
1303
+ dependencies = self._extract_basic_dependencies(last_select_sql)
1304
+
1305
+ return lineage, output_columns, dependencies
1306
+
1307
+ def _extract_select_from_return_string(self, sql_content: str) -> Optional[str]:
1308
+ """Extract SELECT statement from RETURN clause using regex."""
1309
+ # Handle RETURN (SELECT ...)
1310
+ match = re.search(r'RETURN\s*\(\s*(SELECT.*?)\s*\)(?:\s*;)?$', sql_content, re.IGNORECASE | re.DOTALL)
1311
+ if match:
1312
+ return match.group(1).strip()
1313
+
1314
+ # Handle RETURN AS (SELECT ...)
1315
+ match = re.search(r'RETURN\s+AS\s*\(\s*(SELECT.*?)\s*\)', sql_content, re.IGNORECASE | re.DOTALL)
1316
+ if match:
1317
+ return match.group(1).strip()
1318
+
1319
+ return None
1320
+
1321
+ def _extract_table_variable_schema_string(self, sql_content: str) -> List[ColumnSchema]:
1322
+ """Extract column schema from @table TABLE definition using regex."""
1323
+ output_columns = []
1324
+
1325
+ # Look for @Variable TABLE (column definitions)
1326
+ match = re.search(r'@\w+\s+TABLE\s*\((.*?)\)', sql_content, re.IGNORECASE | re.DOTALL)
1327
+ if match:
1328
+ columns_def = match.group(1)
1329
+ # Simple parsing of column definitions
1330
+ for i, col_def in enumerate(columns_def.split(',')):
1331
+ col_def = col_def.strip()
1332
+ if col_def:
1333
+ parts = col_def.split()
1334
+ if len(parts) >= 2:
1335
+ col_name = parts[0].strip()
1336
+ col_type = parts[1].strip()
1337
+ output_columns.append(ColumnSchema(
1338
+ name=col_name,
1339
+ data_type=col_type,
1340
+ nullable=True,
1341
+ ordinal=i
1342
+ ))
1343
+
1344
+ return output_columns
1345
+
1346
+
1347
+
1348
+ def _extract_basic_select_columns(self, select_sql: str) -> List[ColumnSchema]:
1349
+ """Basic extraction of column names from SELECT statement."""
1350
+ output_columns = []
1351
+
1352
+ # Extract the SELECT list (between SELECT and FROM)
1353
+ match = re.search(r'SELECT\s+(.*?)\s+FROM', select_sql, re.IGNORECASE | re.DOTALL)
1354
+ if match:
1355
+ select_list = match.group(1)
1356
+ columns = [col.strip() for col in select_list.split(',')]
1357
+
1358
+ for i, col in enumerate(columns):
1359
+ # Handle aliases (column AS alias or column alias)
1360
+ if ' AS ' in col.upper():
1361
+ col_name = col.split(' AS ')[-1].strip()
1362
+ elif ' ' in col and not any(func in col.upper() for func in ['SUM', 'COUNT', 'MAX', 'MIN', 'AVG', 'CAST', 'CASE']):
1363
+ parts = col.strip().split()
1364
+ col_name = parts[-1] # Last part is usually the alias
1365
+ else:
1366
+ # Extract the base column name
1367
+ col_name = col.split('.')[-1] if '.' in col else col
1368
+ col_name = re.sub(r'[^\w]', '', col_name) # Remove non-alphanumeric
1369
+
1370
+ if col_name:
1371
+ output_columns.append(ColumnSchema(
1372
+ name=col_name,
1373
+ data_type="varchar", # Default type
1374
+ nullable=True,
1375
+ ordinal=i
1376
+ ))
1377
+
1378
+ return output_columns
1379
+
1380
+ def _extract_basic_dependencies(self, sql_content: str) -> Set[str]:
1381
+ """Basic extraction of table dependencies from SQL."""
1382
+ dependencies = set()
1383
+
1384
+ # Find FROM and JOIN clauses
1385
+ from_matches = re.findall(r'FROM\s+([^\s\(]+)', sql_content, re.IGNORECASE)
1386
+ join_matches = re.findall(r'JOIN\s+([^\s\(]+)', sql_content, re.IGNORECASE)
1387
+
1388
+ for match in from_matches + join_matches:
1389
+ table_name = match.strip()
1390
+ # Clean up table name (remove aliases, schema qualifiers for dependency tracking)
1391
+ if ' ' in table_name:
1392
+ table_name = table_name.split()[0]
1393
+ dependencies.add(table_name.lower())
1394
+
1395
+ return dependencies
1396
+
1397
+ def _is_table_valued_function(self, statement: exp.Create) -> bool:
1398
+ """Check if this is a table-valued function (returns TABLE)."""
1399
+ # Simple heuristic: check if the function has RETURNS TABLE
1400
+ sql_text = str(statement).upper()
1401
+ return "RETURNS TABLE" in sql_text or "RETURNS @" in sql_text
1402
+
1403
+ def _extract_tvf_lineage(self, statement: exp.Create, function_name: str) -> tuple[List[ColumnLineage], List[ColumnSchema], Set[str]]:
1404
+ """Extract lineage from a table-valued function."""
1405
+ lineage = []
1406
+ output_columns = []
1407
+ dependencies = set()
1408
+
1409
+ sql_text = str(statement)
1410
+
1411
+ # Handle inline TVF (RETURN AS SELECT)
1412
+ if "RETURN AS" in sql_text.upper() or "RETURN(" in sql_text.upper():
1413
+ # Find the SELECT statement in the RETURN clause
1414
+ select_stmt = self._extract_select_from_return(statement)
1415
+ if select_stmt:
1416
+ lineage, output_columns = self._extract_column_lineage(select_stmt, function_name)
1417
+ dependencies = self._extract_dependencies(select_stmt)
1418
+
1419
+ # Handle multi-statement TVF (RETURN @table TABLE)
1420
+ elif "RETURNS @" in sql_text.upper():
1421
+ # Extract the table variable definition and find INSERT statements
1422
+ output_columns = self._extract_table_variable_schema(statement)
1423
+ lineage, dependencies = self._extract_mstvf_lineage(statement, function_name, output_columns)
1424
+
1425
+ return lineage, output_columns, dependencies
1426
+
1427
+ def _extract_procedure_lineage(self, statement: exp.Create, procedure_name: str) -> tuple[List[ColumnLineage], List[ColumnSchema], Set[str]]:
1428
+ """Extract lineage from a procedure that returns a dataset."""
1429
+ lineage = []
1430
+ output_columns = []
1431
+ dependencies = set()
1432
+
1433
+ # Find the last SELECT statement in the procedure body
1434
+ last_select = self._find_last_select_in_procedure(statement)
1435
+ if last_select:
1436
+ lineage, output_columns = self._extract_column_lineage(last_select, procedure_name)
1437
+ dependencies = self._extract_dependencies(last_select)
1438
+
1439
+ return lineage, output_columns, dependencies
1440
+
1441
+ def _extract_select_from_return(self, statement: exp.Create) -> Optional[exp.Select]:
1442
+ """Extract SELECT statement from RETURN AS clause."""
1443
+ # This is a simplified implementation - in practice would need more robust parsing
1444
+ try:
1445
+ sql_text = str(statement)
1446
+ return_as_match = re.search(r'RETURN\s*\(\s*(SELECT.*?)\s*\)', sql_text, re.IGNORECASE | re.DOTALL)
1447
+ if return_as_match:
1448
+ select_sql = return_as_match.group(1)
1449
+ parsed = sqlglot.parse(select_sql, read=self.dialect)
1450
+ if parsed and isinstance(parsed[0], exp.Select):
1451
+ return parsed[0]
1452
+ except Exception:
1453
+ pass
1454
+ return None
1455
+
1456
+ def _extract_table_variable_schema(self, statement: exp.Create) -> List[ColumnSchema]:
1457
+ """Extract column schema from @table TABLE definition."""
1458
+ # Simplified implementation - would need more robust parsing for production
1459
+ output_columns = []
1460
+ sql_text = str(statement)
1461
+
1462
+ # Look for @Result TABLE (col1 type1, col2 type2, ...)
1463
+ table_def_match = re.search(r'@\w+\s+TABLE\s*\((.*?)\)', sql_text, re.IGNORECASE | re.DOTALL)
1464
+ if table_def_match:
1465
+ columns_def = table_def_match.group(1)
1466
+ # Parse column definitions
1467
+ for i, col_def in enumerate(columns_def.split(',')):
1468
+ col_parts = col_def.strip().split()
1469
+ if len(col_parts) >= 2:
1470
+ col_name = col_parts[0].strip()
1471
+ col_type = col_parts[1].strip()
1472
+ output_columns.append(ColumnSchema(
1473
+ name=col_name,
1474
+ data_type=col_type,
1475
+ nullable=True,
1476
+ ordinal=i
1477
+ ))
1478
+
1479
+ return output_columns
1480
+
1481
+ def _extract_mstvf_lineage(self, statement: exp.Create, function_name: str, output_columns: List[ColumnSchema]) -> tuple[List[ColumnLineage], Set[str]]:
1482
+ """Extract lineage from multi-statement table-valued function."""
1483
+ lineage = []
1484
+ dependencies = set()
1485
+
1486
+ # Find INSERT statements into the @table variable
1487
+ sql_text = str(statement)
1488
+ insert_matches = re.finditer(r'INSERT\s+INTO\s+@\w+.*?SELECT(.*?)(?:FROM|$)', sql_text, re.IGNORECASE | re.DOTALL)
1489
+
1490
+ for match in insert_matches:
1491
+ try:
1492
+ select_part = "SELECT" + match.group(1)
1493
+ parsed = sqlglot.parse(select_part, read=self.dialect)
1494
+ if parsed and isinstance(parsed[0], exp.Select):
1495
+ select_stmt = parsed[0]
1496
+ stmt_lineage, _ = self._extract_column_lineage(select_stmt, function_name)
1497
+ lineage.extend(stmt_lineage)
1498
+ dependencies.update(self._extract_dependencies(select_stmt))
1499
+ except Exception:
1500
+ continue
1501
+
1502
+ return lineage, dependencies
1503
+
1504
+ def _find_last_select_in_procedure(self, statement: exp.Create) -> Optional[exp.Select]:
1505
+ """Find the last SELECT statement in a procedure body."""
1506
+ sql_text = str(statement)
1507
+
1508
+ # Find all SELECT statements that are not part of INSERT/UPDATE/DELETE
1509
+ select_matches = list(re.finditer(r'(?<!INSERT\s)(?<!UPDATE\s)(?<!DELETE\s)SELECT\s+.*?(?=(?:FROM|$))', sql_text, re.IGNORECASE | re.DOTALL))
1510
+
1511
+ if select_matches:
1512
+ # Get the last SELECT statement
1513
+ last_match = select_matches[-1]
1514
+ try:
1515
+ select_sql = last_match.group(0)
1516
+ # Find the FROM clause and complete SELECT
1517
+ from_match = re.search(r'FROM.*?(?=(?:WHERE|GROUP|ORDER|HAVING|;|$))', sql_text[last_match.end():], re.IGNORECASE | re.DOTALL)
1518
+ if from_match:
1519
+ select_sql += from_match.group(0)
1520
+
1521
+ parsed = sqlglot.parse(select_sql, read=self.dialect)
1522
+ if parsed and isinstance(parsed[0], exp.Select):
1523
+ return parsed[0]
1524
+ except Exception:
1525
+ pass
1526
+
1527
+ return None
1528
+
1529
+ def _extract_column_alias(self, select_expr: exp.Expression) -> Optional[str]:
1530
+ """Extract column alias from a SELECT expression."""
1531
+ if hasattr(select_expr, 'alias') and select_expr.alias:
1532
+ return str(select_expr.alias)
1533
+ elif isinstance(select_expr, exp.Alias):
1534
+ return str(select_expr.alias)
1535
+ elif isinstance(select_expr, exp.Column):
1536
+ return str(select_expr.this)
1537
+ else:
1538
+ # Try to extract from the expression itself
1539
+ expr_str = str(select_expr)
1540
+ if ' AS ' in expr_str.upper():
1541
+ parts = expr_str.split()
1542
+ as_idx = -1
1543
+ for i, part in enumerate(parts):
1544
+ if part.upper() == 'AS':
1545
+ as_idx = i
1546
+ break
1547
+ if as_idx >= 0 and as_idx + 1 < len(parts):
1548
+ return parts[as_idx + 1].strip("'\"")
1549
+ return None
1550
+
1551
+ def _extract_column_references(self, select_expr: exp.Expression, select_stmt: exp.Select) -> List[ColumnReference]:
1552
+ """Extract column references from a SELECT expression."""
1553
+ refs = []
1554
+
1555
+ # Find all column references in the expression
1556
+ for column_expr in select_expr.find_all(exp.Column):
1557
+ table_name = "unknown"
1558
+ column_name = str(column_expr.this)
1559
+
1560
+ # Try to resolve table name from table reference or alias
1561
+ if hasattr(column_expr, 'table') and column_expr.table:
1562
+ table_alias = str(column_expr.table)
1563
+ table_name = self._resolve_table_from_alias(table_alias, select_stmt)
1564
+ else:
1565
+ # If no table specified, try to infer from FROM clause
1566
+ tables = []
1567
+ for table in select_stmt.find_all(exp.Table):
1568
+ tables.append(self._get_table_name(table))
1569
+ if len(tables) == 1:
1570
+ table_name = tables[0]
1571
+
1572
+ if table_name != "unknown":
1573
+ refs.append(ColumnReference(
1574
+ namespace="mssql://localhost/InfoTrackerDW",
1575
+ table_name=table_name,
1576
+ column_name=column_name
1577
+ ))
1578
+
1579
+ return refs