data-contract-validator 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- data_contract_validator/__init__.py +24 -0
- data_contract_validator/cli.py +672 -0
- data_contract_validator/core/__init__.py +0 -0
- data_contract_validator/core/models.py +115 -0
- data_contract_validator/core/validator.py +187 -0
- data_contract_validator/extractors/__init__.py +14 -0
- data_contract_validator/extractors/base.py +45 -0
- data_contract_validator/extractors/dbt.py +213 -0
- data_contract_validator/extractors/fastapi.py +200 -0
- data_contract_validator/integrations/__init__.py +0 -0
- data_contract_validator/py.typed +2 -0
- data_contract_validator/templates/github-actions-template.yml +75 -0
- data_contract_validator-1.0.0.dist-info/METADATA +344 -0
- data_contract_validator-1.0.0.dist-info/RECORD +18 -0
- data_contract_validator-1.0.0.dist-info/WHEEL +5 -0
- data_contract_validator-1.0.0.dist-info/entry_points.txt +3 -0
- data_contract_validator-1.0.0.dist-info/licenses/LICENSE +21 -0
- data_contract_validator-1.0.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,115 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Core data models for validation results and schema definitions.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from enum import Enum
|
|
6
|
+
from dataclasses import dataclass
|
|
7
|
+
from typing import Dict, List, Optional, Any
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class IssueSeverity(Enum):
|
|
11
|
+
"""Severity levels for validation issues."""
|
|
12
|
+
|
|
13
|
+
CRITICAL = "critical" # Will break API
|
|
14
|
+
WARNING = "warning" # Might cause issues
|
|
15
|
+
INFO = "info" # Good to know
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
@dataclass
|
|
19
|
+
class ValidationIssue:
|
|
20
|
+
"""Represents a single validation issue."""
|
|
21
|
+
|
|
22
|
+
severity: IssueSeverity
|
|
23
|
+
table: str
|
|
24
|
+
column: Optional[str]
|
|
25
|
+
message: str
|
|
26
|
+
category: str = "Unknown"
|
|
27
|
+
suggested_fix: Optional[str] = None
|
|
28
|
+
source_value: Optional[str] = None
|
|
29
|
+
target_value: Optional[str] = None
|
|
30
|
+
|
|
31
|
+
def to_dict(self) -> Dict[str, Any]:
|
|
32
|
+
"""Convert to dictionary for JSON serialization."""
|
|
33
|
+
return {
|
|
34
|
+
"severity": self.severity.value,
|
|
35
|
+
"table": self.table,
|
|
36
|
+
"column": self.column,
|
|
37
|
+
"message": self.message,
|
|
38
|
+
"category": self.category,
|
|
39
|
+
"suggested_fix": self.suggested_fix,
|
|
40
|
+
"source_value": self.source_value,
|
|
41
|
+
"target_value": self.target_value,
|
|
42
|
+
}
|
|
43
|
+
|
|
44
|
+
@classmethod
|
|
45
|
+
def from_dict(cls, data: Dict[str, Any]) -> "ValidationIssue":
|
|
46
|
+
"""Create from dictionary."""
|
|
47
|
+
return cls(
|
|
48
|
+
severity=IssueSeverity(data.get("severity", "warning")),
|
|
49
|
+
table=data.get("table", "Unknown"),
|
|
50
|
+
column=data.get("column"),
|
|
51
|
+
message=data.get("message", ""),
|
|
52
|
+
category=data.get("category", "Unknown"),
|
|
53
|
+
suggested_fix=data.get("suggested_fix"),
|
|
54
|
+
source_value=data.get("source_value"),
|
|
55
|
+
target_value=data.get("target_value"),
|
|
56
|
+
)
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
@dataclass
|
|
60
|
+
class Schema:
|
|
61
|
+
"""Represents a table schema."""
|
|
62
|
+
|
|
63
|
+
name: str
|
|
64
|
+
columns: List[Dict[str, Any]]
|
|
65
|
+
source: str = "unknown"
|
|
66
|
+
metadata: Optional[Dict[str, Any]] = None
|
|
67
|
+
|
|
68
|
+
def get_column(self, name: str) -> Optional[Dict[str, Any]]:
|
|
69
|
+
"""Get column by name."""
|
|
70
|
+
for col in self.columns:
|
|
71
|
+
if col.get("name") == name:
|
|
72
|
+
return col
|
|
73
|
+
return None
|
|
74
|
+
|
|
75
|
+
def column_names(self) -> List[str]:
|
|
76
|
+
"""Get list of column names."""
|
|
77
|
+
return [col.get("name") for col in self.columns if col.get("name")]
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
@dataclass
|
|
81
|
+
class ValidationResult:
|
|
82
|
+
"""Result of contract validation."""
|
|
83
|
+
|
|
84
|
+
success: bool
|
|
85
|
+
issues: List[ValidationIssue]
|
|
86
|
+
source_schemas: Dict[str, Schema]
|
|
87
|
+
target_schemas: Dict[str, Schema]
|
|
88
|
+
summary: Optional[str] = None
|
|
89
|
+
|
|
90
|
+
@property
|
|
91
|
+
def critical_issues(self) -> List[ValidationIssue]:
|
|
92
|
+
"""Get only critical issues."""
|
|
93
|
+
return [i for i in self.issues if i.severity == IssueSeverity.CRITICAL]
|
|
94
|
+
|
|
95
|
+
@property
|
|
96
|
+
def warnings(self) -> List[ValidationIssue]:
|
|
97
|
+
"""Get only warnings."""
|
|
98
|
+
return [i for i in self.issues if i.severity == IssueSeverity.WARNING]
|
|
99
|
+
|
|
100
|
+
@property
|
|
101
|
+
def info_items(self) -> List[ValidationIssue]:
|
|
102
|
+
"""Get only info items."""
|
|
103
|
+
return [i for i in self.issues if i.severity == IssueSeverity.INFO]
|
|
104
|
+
|
|
105
|
+
def to_dict(self) -> Dict[str, Any]:
|
|
106
|
+
"""Convert to dictionary."""
|
|
107
|
+
return {
|
|
108
|
+
"success": self.success,
|
|
109
|
+
"summary": self.summary,
|
|
110
|
+
"total_issues": len(self.issues),
|
|
111
|
+
"critical_issues": len(self.critical_issues),
|
|
112
|
+
"warnings": len(self.warnings),
|
|
113
|
+
"info_items": len(self.info_items),
|
|
114
|
+
"issues": [issue.to_dict() for issue in self.issues],
|
|
115
|
+
}
|
|
@@ -0,0 +1,187 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Core validation logic for comparing schemas.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from typing import Dict, List
|
|
6
|
+
from .models import ValidationResult, ValidationIssue, IssueSeverity, Schema
|
|
7
|
+
from ..extractors.base import BaseExtractor
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class ContractValidator:
|
|
11
|
+
"""
|
|
12
|
+
Main contract validator that compares schemas from different sources.
|
|
13
|
+
"""
|
|
14
|
+
|
|
15
|
+
def __init__(
|
|
16
|
+
self, source_extractor: BaseExtractor, target_extractor: BaseExtractor
|
|
17
|
+
):
|
|
18
|
+
"""
|
|
19
|
+
Initialize validator with source and target extractors.
|
|
20
|
+
|
|
21
|
+
Args:
|
|
22
|
+
source_extractor: Extractor for source schemas (e.g., DBT)
|
|
23
|
+
target_extractor: Extractor for target schemas (e.g., FastAPI)
|
|
24
|
+
"""
|
|
25
|
+
self.source_extractor = source_extractor
|
|
26
|
+
self.target_extractor = target_extractor
|
|
27
|
+
self.issues: List[ValidationIssue] = []
|
|
28
|
+
|
|
29
|
+
def validate(self) -> ValidationResult:
|
|
30
|
+
"""
|
|
31
|
+
Run validation and return results.
|
|
32
|
+
|
|
33
|
+
Returns:
|
|
34
|
+
ValidationResult with success status and any issues found
|
|
35
|
+
"""
|
|
36
|
+
print("🔍 Starting contract validation...")
|
|
37
|
+
|
|
38
|
+
# Extract schemas
|
|
39
|
+
print("📊 Extracting source schemas...")
|
|
40
|
+
source_schemas = self.source_extractor.extract_schemas()
|
|
41
|
+
|
|
42
|
+
print("🎯 Extracting target schemas...")
|
|
43
|
+
target_schemas = self.target_extractor.extract_schemas()
|
|
44
|
+
|
|
45
|
+
print(f" Source: {len(source_schemas)} schemas")
|
|
46
|
+
print(f" Target: {len(target_schemas)} schemas")
|
|
47
|
+
|
|
48
|
+
# Reset issues
|
|
49
|
+
self.issues = []
|
|
50
|
+
|
|
51
|
+
# Validate each target schema against source
|
|
52
|
+
print("🔍 Validating schema compatibility...")
|
|
53
|
+
for table_name, target_schema in target_schemas.items():
|
|
54
|
+
self._validate_table(table_name, target_schema, source_schemas)
|
|
55
|
+
|
|
56
|
+
# Determine success
|
|
57
|
+
critical_issues = [
|
|
58
|
+
i for i in self.issues if i.severity == IssueSeverity.CRITICAL
|
|
59
|
+
]
|
|
60
|
+
success = len(critical_issues) == 0
|
|
61
|
+
|
|
62
|
+
# Generate summary
|
|
63
|
+
summary = self._generate_summary(success, self.issues)
|
|
64
|
+
|
|
65
|
+
return ValidationResult(
|
|
66
|
+
success=success,
|
|
67
|
+
issues=self.issues,
|
|
68
|
+
source_schemas=source_schemas,
|
|
69
|
+
target_schemas=target_schemas,
|
|
70
|
+
summary=summary,
|
|
71
|
+
)
|
|
72
|
+
|
|
73
|
+
def _validate_table(
|
|
74
|
+
self, table_name: str, target_schema: Schema, source_schemas: Dict[str, Schema]
|
|
75
|
+
):
|
|
76
|
+
"""Validate a single table."""
|
|
77
|
+
print(f" 🔍 Validating table: {table_name}")
|
|
78
|
+
|
|
79
|
+
# Check if source provides this table
|
|
80
|
+
source_schema = source_schemas.get(table_name)
|
|
81
|
+
if not source_schema:
|
|
82
|
+
self.issues.append(
|
|
83
|
+
ValidationIssue(
|
|
84
|
+
severity=IssueSeverity.CRITICAL,
|
|
85
|
+
table=table_name,
|
|
86
|
+
column=None,
|
|
87
|
+
message=f"Target expects table '{table_name}' but source doesn't provide it",
|
|
88
|
+
category="Missing Table",
|
|
89
|
+
suggested_fix=f"Create a source model that outputs table '{table_name}'",
|
|
90
|
+
)
|
|
91
|
+
)
|
|
92
|
+
print(f" ❌ Table '{table_name}' missing in source")
|
|
93
|
+
return
|
|
94
|
+
|
|
95
|
+
# Check columns
|
|
96
|
+
source_columns = {col["name"]: col for col in source_schema.columns}
|
|
97
|
+
target_columns = {col["name"]: col for col in target_schema.columns}
|
|
98
|
+
|
|
99
|
+
# Check for missing required columns
|
|
100
|
+
for col_name, col_info in target_columns.items():
|
|
101
|
+
if col_name not in source_columns:
|
|
102
|
+
is_required = col_info.get("required", True)
|
|
103
|
+
severity = (
|
|
104
|
+
IssueSeverity.CRITICAL if is_required else IssueSeverity.WARNING
|
|
105
|
+
)
|
|
106
|
+
|
|
107
|
+
self.issues.append(
|
|
108
|
+
ValidationIssue(
|
|
109
|
+
severity=severity,
|
|
110
|
+
table=table_name,
|
|
111
|
+
column=col_name,
|
|
112
|
+
message=f"Target {'REQUIRES' if is_required else 'expects'} column '{col_name}' but source doesn't provide it",
|
|
113
|
+
category="Missing Column",
|
|
114
|
+
suggested_fix=f"Add column '{col_name}' to source model for table '{table_name}'",
|
|
115
|
+
)
|
|
116
|
+
)
|
|
117
|
+
else:
|
|
118
|
+
# Check type compatibility
|
|
119
|
+
source_col = source_columns[col_name]
|
|
120
|
+
target_col = col_info
|
|
121
|
+
|
|
122
|
+
if not self._types_compatible(
|
|
123
|
+
source_col.get("type"), target_col.get("type")
|
|
124
|
+
):
|
|
125
|
+
self.issues.append(
|
|
126
|
+
ValidationIssue(
|
|
127
|
+
severity=IssueSeverity.WARNING,
|
|
128
|
+
table=table_name,
|
|
129
|
+
column=col_name,
|
|
130
|
+
message=f"Type mismatch: source provides '{source_col.get('type')}' but target expects '{target_col.get('type')}'",
|
|
131
|
+
category="Type Mismatch",
|
|
132
|
+
source_value=source_col.get("type"),
|
|
133
|
+
target_value=target_col.get("type"),
|
|
134
|
+
suggested_fix=f"Update target model to expect '{source_col.get('type')}' or fix source column type",
|
|
135
|
+
)
|
|
136
|
+
)
|
|
137
|
+
|
|
138
|
+
# Log results for this table
|
|
139
|
+
table_issues = [i for i in self.issues if i.table == table_name]
|
|
140
|
+
if not table_issues:
|
|
141
|
+
print(f" ✅ All requirements satisfied")
|
|
142
|
+
else:
|
|
143
|
+
critical = [i for i in table_issues if i.severity == IssueSeverity.CRITICAL]
|
|
144
|
+
warnings = [i for i in table_issues if i.severity == IssueSeverity.WARNING]
|
|
145
|
+
if critical:
|
|
146
|
+
print(f" 🚨 {len(critical)} critical issues")
|
|
147
|
+
if warnings:
|
|
148
|
+
print(f" ⚠️ {len(warnings)} warnings")
|
|
149
|
+
|
|
150
|
+
def _types_compatible(self, source_type: str, target_type: str) -> bool:
|
|
151
|
+
"""Check if source and target types are compatible."""
|
|
152
|
+
if not source_type or not target_type:
|
|
153
|
+
return True # Skip validation if types are unknown
|
|
154
|
+
|
|
155
|
+
# Normalize types
|
|
156
|
+
source_type = source_type.lower()
|
|
157
|
+
target_type = target_type.lower()
|
|
158
|
+
|
|
159
|
+
# Exact match
|
|
160
|
+
if source_type == target_type:
|
|
161
|
+
return True
|
|
162
|
+
|
|
163
|
+
# Compatible type mappings
|
|
164
|
+
compatible_types = {
|
|
165
|
+
"varchar": ["string", "str", "text"],
|
|
166
|
+
"string": ["varchar", "text"],
|
|
167
|
+
"text": ["varchar", "string"],
|
|
168
|
+
"integer": ["int", "bigint"],
|
|
169
|
+
"int": ["integer", "bigint"],
|
|
170
|
+
"bigint": ["integer", "int"],
|
|
171
|
+
"float": ["double", "decimal", "numeric", "real"],
|
|
172
|
+
"double": ["float", "decimal"],
|
|
173
|
+
"boolean": ["bool"],
|
|
174
|
+
"bool": ["boolean"],
|
|
175
|
+
"timestamp": ["datetime"],
|
|
176
|
+
"datetime": ["timestamp"],
|
|
177
|
+
}
|
|
178
|
+
|
|
179
|
+
return target_type in compatible_types.get(source_type, [])
|
|
180
|
+
|
|
181
|
+
def _generate_summary(self, success: bool, issues: List[ValidationIssue]) -> str:
|
|
182
|
+
"""Generate validation summary."""
|
|
183
|
+
if success:
|
|
184
|
+
return f"✅ Validation passed with {len(issues)} non-critical issues"
|
|
185
|
+
else:
|
|
186
|
+
critical = [i for i in issues if i.severity == IssueSeverity.CRITICAL]
|
|
187
|
+
return f"❌ Validation failed with {len(critical)} critical issues"
|
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
# data_contract_validator/extractors/__init__.py
|
|
2
|
+
"""
|
|
3
|
+
Schema extractors for different frameworks.
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
from .base import BaseExtractor
|
|
7
|
+
from .dbt import DBTExtractor
|
|
8
|
+
from .fastapi import FastAPIExtractor
|
|
9
|
+
|
|
10
|
+
__all__ = [
|
|
11
|
+
"BaseExtractor",
|
|
12
|
+
"DBTExtractor",
|
|
13
|
+
"FastAPIExtractor",
|
|
14
|
+
]
|
|
@@ -0,0 +1,45 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Base extractor interface for schema extraction.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from abc import ABC, abstractmethod
|
|
6
|
+
from typing import Dict, Any
|
|
7
|
+
from ..core.models import Schema
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class BaseExtractor(ABC):
|
|
11
|
+
"""Base class for all schema extractors."""
|
|
12
|
+
|
|
13
|
+
@abstractmethod
|
|
14
|
+
def extract_schemas(self) -> Dict[str, Schema]:
|
|
15
|
+
"""
|
|
16
|
+
Extract schemas from the source.
|
|
17
|
+
|
|
18
|
+
Returns:
|
|
19
|
+
Dict mapping table names to Schema objects
|
|
20
|
+
"""
|
|
21
|
+
pass
|
|
22
|
+
|
|
23
|
+
def _python_to_sql_type(self, python_type: str) -> str:
|
|
24
|
+
"""Convert Python type hints to SQL types."""
|
|
25
|
+
type_mappings = {
|
|
26
|
+
"str": "varchar",
|
|
27
|
+
"int": "integer",
|
|
28
|
+
"float": "float",
|
|
29
|
+
"bool": "boolean",
|
|
30
|
+
"datetime": "timestamp",
|
|
31
|
+
"date": "date",
|
|
32
|
+
"list": "json",
|
|
33
|
+
"dict": "json",
|
|
34
|
+
}
|
|
35
|
+
|
|
36
|
+
# Handle Optional types
|
|
37
|
+
if "optional" in python_type.lower():
|
|
38
|
+
inner_type = python_type.lower().replace("optional[", "").replace("]", "")
|
|
39
|
+
return self._python_to_sql_type(inner_type)
|
|
40
|
+
|
|
41
|
+
return type_mappings.get(python_type.lower(), "varchar")
|
|
42
|
+
|
|
43
|
+
def _normalize_column_name(self, name: str) -> str:
|
|
44
|
+
"""Normalize column names for comparison."""
|
|
45
|
+
return name.lower().strip()
|
|
@@ -0,0 +1,213 @@
|
|
|
1
|
+
# data_contract_validator/extractors/dbt.py
|
|
2
|
+
"""
|
|
3
|
+
DBT schema extractor - simplified version of your working code.
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
import json
|
|
7
|
+
import subprocess
|
|
8
|
+
import re
|
|
9
|
+
from pathlib import Path
|
|
10
|
+
from typing import Dict, List, Optional, Any
|
|
11
|
+
|
|
12
|
+
from .base import BaseExtractor
|
|
13
|
+
from ..core.models import Schema
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class DBTExtractor(BaseExtractor):
|
|
17
|
+
"""Extract schemas from DBT projects."""
|
|
18
|
+
|
|
19
|
+
def __init__(self, project_path: str = "."):
|
|
20
|
+
self.project_path = Path(project_path)
|
|
21
|
+
self.target_dir = self.project_path / "target"
|
|
22
|
+
self.manifest_path = self.target_dir / "manifest.json"
|
|
23
|
+
self.models_path = self.project_path / "models"
|
|
24
|
+
|
|
25
|
+
def extract_schemas(self) -> Dict[str, Schema]:
|
|
26
|
+
"""Extract schemas from DBT project."""
|
|
27
|
+
print(f"🔍 Extracting DBT schemas from {self.project_path}")
|
|
28
|
+
|
|
29
|
+
# Try manifest first, fallback to SQL parsing
|
|
30
|
+
if self._try_compile_dbt() and self.manifest_path.exists():
|
|
31
|
+
print(" 📋 Using manifest.json")
|
|
32
|
+
return self._extract_from_manifest()
|
|
33
|
+
else:
|
|
34
|
+
print(" 📄 Using SQL file parsing")
|
|
35
|
+
return self._extract_from_sql_files()
|
|
36
|
+
|
|
37
|
+
def _try_compile_dbt(self) -> bool:
|
|
38
|
+
"""Try to compile DBT project."""
|
|
39
|
+
try:
|
|
40
|
+
result = subprocess.run(
|
|
41
|
+
["dbt", "parse", "--project-dir", str(self.project_path)],
|
|
42
|
+
capture_output=True,
|
|
43
|
+
text=True,
|
|
44
|
+
timeout=60,
|
|
45
|
+
)
|
|
46
|
+
return result.returncode == 0
|
|
47
|
+
except:
|
|
48
|
+
return False
|
|
49
|
+
|
|
50
|
+
def _extract_from_manifest(self) -> Dict[str, Schema]:
|
|
51
|
+
"""Extract schemas from manifest.json."""
|
|
52
|
+
with open(self.manifest_path, "r") as f:
|
|
53
|
+
manifest = json.load(f)
|
|
54
|
+
|
|
55
|
+
schemas = {}
|
|
56
|
+
for node_id, node in manifest.get("nodes", {}).items():
|
|
57
|
+
if node.get("resource_type") == "model":
|
|
58
|
+
model_name = node.get("alias") or node.get("name")
|
|
59
|
+
|
|
60
|
+
columns = []
|
|
61
|
+
for col_name, col_info in node.get("columns", {}).items():
|
|
62
|
+
columns.append(
|
|
63
|
+
{
|
|
64
|
+
"name": col_name,
|
|
65
|
+
"type": col_info.get("data_type", "varchar"),
|
|
66
|
+
"required": True,
|
|
67
|
+
"nullable": False,
|
|
68
|
+
}
|
|
69
|
+
)
|
|
70
|
+
|
|
71
|
+
schemas[model_name] = Schema(
|
|
72
|
+
name=model_name, columns=columns, source="dbt_manifest"
|
|
73
|
+
)
|
|
74
|
+
|
|
75
|
+
print(f" ✅ Found {len(schemas)} tables in manifest")
|
|
76
|
+
return schemas
|
|
77
|
+
|
|
78
|
+
def _extract_from_sql_files(self) -> Dict[str, Schema]:
|
|
79
|
+
"""Extract schemas from SQL files directly."""
|
|
80
|
+
schemas = {}
|
|
81
|
+
sql_files = list(self.models_path.rglob("*.sql"))
|
|
82
|
+
|
|
83
|
+
print(f" 🔍 Found {len(sql_files)} SQL files to analyze")
|
|
84
|
+
|
|
85
|
+
for sql_file in sql_files:
|
|
86
|
+
model_name = sql_file.stem
|
|
87
|
+
|
|
88
|
+
# Skip test/analysis files
|
|
89
|
+
if any(skip in str(sql_file) for skip in ["tests", "analysis", "macros"]):
|
|
90
|
+
continue
|
|
91
|
+
|
|
92
|
+
try:
|
|
93
|
+
with open(sql_file, "r", encoding="utf-8") as f:
|
|
94
|
+
sql_content = f.read()
|
|
95
|
+
|
|
96
|
+
columns = self._extract_columns_from_sql(sql_content)
|
|
97
|
+
if columns:
|
|
98
|
+
schemas[model_name] = Schema(
|
|
99
|
+
name=model_name, columns=columns, source="sql_parsing"
|
|
100
|
+
)
|
|
101
|
+
print(f" 📋 {model_name}: {len(columns)} columns")
|
|
102
|
+
|
|
103
|
+
except Exception as e:
|
|
104
|
+
print(f" ❌ Error parsing {model_name}: {e}")
|
|
105
|
+
|
|
106
|
+
return schemas
|
|
107
|
+
|
|
108
|
+
def _extract_columns_from_sql(self, sql_content: str) -> List[Dict[str, Any]]:
|
|
109
|
+
"""Extract columns from SQL content - simplified version."""
|
|
110
|
+
# Remove comments and Jinja
|
|
111
|
+
cleaned = re.sub(r"--.*?\n", "\n", sql_content)
|
|
112
|
+
cleaned = re.sub(r"/\*.*?\*/", "", cleaned, flags=re.DOTALL)
|
|
113
|
+
cleaned = re.sub(r"\{\{.*?\}\}", "", cleaned)
|
|
114
|
+
|
|
115
|
+
# Find final SELECT statement
|
|
116
|
+
select_matches = list(
|
|
117
|
+
re.finditer(r"select\s+(.*?)\s+from", cleaned, re.DOTALL | re.IGNORECASE)
|
|
118
|
+
)
|
|
119
|
+
|
|
120
|
+
if not select_matches:
|
|
121
|
+
return []
|
|
122
|
+
|
|
123
|
+
# Use the last SELECT (after CTEs)
|
|
124
|
+
select_content = select_matches[-1].group(1).strip()
|
|
125
|
+
|
|
126
|
+
# Split by comma and parse each column
|
|
127
|
+
columns = []
|
|
128
|
+
column_parts = self._split_columns(select_content)
|
|
129
|
+
|
|
130
|
+
for col_text in column_parts:
|
|
131
|
+
col_text = col_text.strip()
|
|
132
|
+
if col_text and col_text != "*":
|
|
133
|
+
column_name = self._extract_column_name(col_text)
|
|
134
|
+
if column_name:
|
|
135
|
+
columns.append(
|
|
136
|
+
{
|
|
137
|
+
"name": column_name,
|
|
138
|
+
"type": self._infer_data_type(col_text),
|
|
139
|
+
"required": True,
|
|
140
|
+
"nullable": False,
|
|
141
|
+
}
|
|
142
|
+
)
|
|
143
|
+
|
|
144
|
+
return columns
|
|
145
|
+
|
|
146
|
+
def _split_columns(self, select_clause: str) -> List[str]:
|
|
147
|
+
"""Split SELECT columns by comma, handling nested functions."""
|
|
148
|
+
columns = []
|
|
149
|
+
current_column = ""
|
|
150
|
+
paren_depth = 0
|
|
151
|
+
|
|
152
|
+
for char in select_clause:
|
|
153
|
+
if char == "(":
|
|
154
|
+
paren_depth += 1
|
|
155
|
+
elif char == ")":
|
|
156
|
+
paren_depth -= 1
|
|
157
|
+
elif char == "," and paren_depth == 0:
|
|
158
|
+
if current_column.strip():
|
|
159
|
+
columns.append(current_column.strip())
|
|
160
|
+
current_column = ""
|
|
161
|
+
continue
|
|
162
|
+
|
|
163
|
+
current_column += char
|
|
164
|
+
|
|
165
|
+
if current_column.strip():
|
|
166
|
+
columns.append(current_column.strip())
|
|
167
|
+
|
|
168
|
+
return columns
|
|
169
|
+
|
|
170
|
+
def _extract_column_name(self, col_text: str) -> Optional[str]:
|
|
171
|
+
"""Extract clean column name from column definition."""
|
|
172
|
+
col_text = col_text.strip()
|
|
173
|
+
|
|
174
|
+
# Check for AS alias
|
|
175
|
+
as_match = re.search(r"\s+as\s+(\w+)$", col_text, re.IGNORECASE)
|
|
176
|
+
if as_match:
|
|
177
|
+
return as_match.group(1).lower()
|
|
178
|
+
|
|
179
|
+
# Handle table.column format
|
|
180
|
+
table_match = re.search(r"(\w+)\.(\w+)$", col_text)
|
|
181
|
+
if table_match:
|
|
182
|
+
return table_match.group(2).lower()
|
|
183
|
+
|
|
184
|
+
# Simple column name
|
|
185
|
+
simple_match = re.search(r"^(\w+)$", col_text)
|
|
186
|
+
if simple_match:
|
|
187
|
+
return simple_match.group(1).lower()
|
|
188
|
+
|
|
189
|
+
# For complex expressions, try to extract alias
|
|
190
|
+
parts = col_text.split()
|
|
191
|
+
if len(parts) > 1 and not "(" in parts[-1]:
|
|
192
|
+
return parts[-1].lower()
|
|
193
|
+
|
|
194
|
+
return None
|
|
195
|
+
|
|
196
|
+
def _infer_data_type(self, expression: str) -> str:
|
|
197
|
+
"""Infer data type from SQL expression."""
|
|
198
|
+
expr_upper = expression.upper()
|
|
199
|
+
|
|
200
|
+
if any(func in expr_upper for func in ["COUNT", "SUM", "ROW_NUMBER"]):
|
|
201
|
+
return "integer"
|
|
202
|
+
elif "AVG" in expr_upper:
|
|
203
|
+
return "float"
|
|
204
|
+
elif any(func in expr_upper for func in ["CONCAT", "UPPER", "LOWER"]):
|
|
205
|
+
return "varchar"
|
|
206
|
+
elif "TIMESTAMP" in expr_upper or "CURRENT_TIMESTAMP" in expr_upper:
|
|
207
|
+
return "timestamp"
|
|
208
|
+
elif "DATE" in expr_upper:
|
|
209
|
+
return "date"
|
|
210
|
+
elif any(keyword in expr_upper for keyword in ["TRUE", "FALSE", "BOOLEAN"]):
|
|
211
|
+
return "boolean"
|
|
212
|
+
else:
|
|
213
|
+
return "varchar"
|