data-contract-validator 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,115 @@
1
+ """
2
+ Core data models for validation results and schema definitions.
3
+ """
4
+
5
+ from enum import Enum
6
+ from dataclasses import dataclass
7
+ from typing import Dict, List, Optional, Any
8
+
9
+
10
+ class IssueSeverity(Enum):
11
+ """Severity levels for validation issues."""
12
+
13
+ CRITICAL = "critical" # Will break API
14
+ WARNING = "warning" # Might cause issues
15
+ INFO = "info" # Good to know
16
+
17
+
18
+ @dataclass
19
+ class ValidationIssue:
20
+ """Represents a single validation issue."""
21
+
22
+ severity: IssueSeverity
23
+ table: str
24
+ column: Optional[str]
25
+ message: str
26
+ category: str = "Unknown"
27
+ suggested_fix: Optional[str] = None
28
+ source_value: Optional[str] = None
29
+ target_value: Optional[str] = None
30
+
31
+ def to_dict(self) -> Dict[str, Any]:
32
+ """Convert to dictionary for JSON serialization."""
33
+ return {
34
+ "severity": self.severity.value,
35
+ "table": self.table,
36
+ "column": self.column,
37
+ "message": self.message,
38
+ "category": self.category,
39
+ "suggested_fix": self.suggested_fix,
40
+ "source_value": self.source_value,
41
+ "target_value": self.target_value,
42
+ }
43
+
44
+ @classmethod
45
+ def from_dict(cls, data: Dict[str, Any]) -> "ValidationIssue":
46
+ """Create from dictionary."""
47
+ return cls(
48
+ severity=IssueSeverity(data.get("severity", "warning")),
49
+ table=data.get("table", "Unknown"),
50
+ column=data.get("column"),
51
+ message=data.get("message", ""),
52
+ category=data.get("category", "Unknown"),
53
+ suggested_fix=data.get("suggested_fix"),
54
+ source_value=data.get("source_value"),
55
+ target_value=data.get("target_value"),
56
+ )
57
+
58
+
59
+ @dataclass
60
+ class Schema:
61
+ """Represents a table schema."""
62
+
63
+ name: str
64
+ columns: List[Dict[str, Any]]
65
+ source: str = "unknown"
66
+ metadata: Optional[Dict[str, Any]] = None
67
+
68
+ def get_column(self, name: str) -> Optional[Dict[str, Any]]:
69
+ """Get column by name."""
70
+ for col in self.columns:
71
+ if col.get("name") == name:
72
+ return col
73
+ return None
74
+
75
+ def column_names(self) -> List[str]:
76
+ """Get list of column names."""
77
+ return [col.get("name") for col in self.columns if col.get("name")]
78
+
79
+
80
+ @dataclass
81
+ class ValidationResult:
82
+ """Result of contract validation."""
83
+
84
+ success: bool
85
+ issues: List[ValidationIssue]
86
+ source_schemas: Dict[str, Schema]
87
+ target_schemas: Dict[str, Schema]
88
+ summary: Optional[str] = None
89
+
90
+ @property
91
+ def critical_issues(self) -> List[ValidationIssue]:
92
+ """Get only critical issues."""
93
+ return [i for i in self.issues if i.severity == IssueSeverity.CRITICAL]
94
+
95
+ @property
96
+ def warnings(self) -> List[ValidationIssue]:
97
+ """Get only warnings."""
98
+ return [i for i in self.issues if i.severity == IssueSeverity.WARNING]
99
+
100
+ @property
101
+ def info_items(self) -> List[ValidationIssue]:
102
+ """Get only info items."""
103
+ return [i for i in self.issues if i.severity == IssueSeverity.INFO]
104
+
105
+ def to_dict(self) -> Dict[str, Any]:
106
+ """Convert to dictionary."""
107
+ return {
108
+ "success": self.success,
109
+ "summary": self.summary,
110
+ "total_issues": len(self.issues),
111
+ "critical_issues": len(self.critical_issues),
112
+ "warnings": len(self.warnings),
113
+ "info_items": len(self.info_items),
114
+ "issues": [issue.to_dict() for issue in self.issues],
115
+ }
@@ -0,0 +1,187 @@
1
+ """
2
+ Core validation logic for comparing schemas.
3
+ """
4
+
5
+ from typing import Dict, List
6
+ from .models import ValidationResult, ValidationIssue, IssueSeverity, Schema
7
+ from ..extractors.base import BaseExtractor
8
+
9
+
10
+ class ContractValidator:
11
+ """
12
+ Main contract validator that compares schemas from different sources.
13
+ """
14
+
15
+ def __init__(
16
+ self, source_extractor: BaseExtractor, target_extractor: BaseExtractor
17
+ ):
18
+ """
19
+ Initialize validator with source and target extractors.
20
+
21
+ Args:
22
+ source_extractor: Extractor for source schemas (e.g., DBT)
23
+ target_extractor: Extractor for target schemas (e.g., FastAPI)
24
+ """
25
+ self.source_extractor = source_extractor
26
+ self.target_extractor = target_extractor
27
+ self.issues: List[ValidationIssue] = []
28
+
29
+ def validate(self) -> ValidationResult:
30
+ """
31
+ Run validation and return results.
32
+
33
+ Returns:
34
+ ValidationResult with success status and any issues found
35
+ """
36
+ print("🔍 Starting contract validation...")
37
+
38
+ # Extract schemas
39
+ print("📊 Extracting source schemas...")
40
+ source_schemas = self.source_extractor.extract_schemas()
41
+
42
+ print("🎯 Extracting target schemas...")
43
+ target_schemas = self.target_extractor.extract_schemas()
44
+
45
+ print(f" Source: {len(source_schemas)} schemas")
46
+ print(f" Target: {len(target_schemas)} schemas")
47
+
48
+ # Reset issues
49
+ self.issues = []
50
+
51
+ # Validate each target schema against source
52
+ print("🔍 Validating schema compatibility...")
53
+ for table_name, target_schema in target_schemas.items():
54
+ self._validate_table(table_name, target_schema, source_schemas)
55
+
56
+ # Determine success
57
+ critical_issues = [
58
+ i for i in self.issues if i.severity == IssueSeverity.CRITICAL
59
+ ]
60
+ success = len(critical_issues) == 0
61
+
62
+ # Generate summary
63
+ summary = self._generate_summary(success, self.issues)
64
+
65
+ return ValidationResult(
66
+ success=success,
67
+ issues=self.issues,
68
+ source_schemas=source_schemas,
69
+ target_schemas=target_schemas,
70
+ summary=summary,
71
+ )
72
+
73
+ def _validate_table(
74
+ self, table_name: str, target_schema: Schema, source_schemas: Dict[str, Schema]
75
+ ):
76
+ """Validate a single table."""
77
+ print(f" 🔍 Validating table: {table_name}")
78
+
79
+ # Check if source provides this table
80
+ source_schema = source_schemas.get(table_name)
81
+ if not source_schema:
82
+ self.issues.append(
83
+ ValidationIssue(
84
+ severity=IssueSeverity.CRITICAL,
85
+ table=table_name,
86
+ column=None,
87
+ message=f"Target expects table '{table_name}' but source doesn't provide it",
88
+ category="Missing Table",
89
+ suggested_fix=f"Create a source model that outputs table '{table_name}'",
90
+ )
91
+ )
92
+ print(f" ❌ Table '{table_name}' missing in source")
93
+ return
94
+
95
+ # Check columns
96
+ source_columns = {col["name"]: col for col in source_schema.columns}
97
+ target_columns = {col["name"]: col for col in target_schema.columns}
98
+
99
+ # Check for missing required columns
100
+ for col_name, col_info in target_columns.items():
101
+ if col_name not in source_columns:
102
+ is_required = col_info.get("required", True)
103
+ severity = (
104
+ IssueSeverity.CRITICAL if is_required else IssueSeverity.WARNING
105
+ )
106
+
107
+ self.issues.append(
108
+ ValidationIssue(
109
+ severity=severity,
110
+ table=table_name,
111
+ column=col_name,
112
+ message=f"Target {'REQUIRES' if is_required else 'expects'} column '{col_name}' but source doesn't provide it",
113
+ category="Missing Column",
114
+ suggested_fix=f"Add column '{col_name}' to source model for table '{table_name}'",
115
+ )
116
+ )
117
+ else:
118
+ # Check type compatibility
119
+ source_col = source_columns[col_name]
120
+ target_col = col_info
121
+
122
+ if not self._types_compatible(
123
+ source_col.get("type"), target_col.get("type")
124
+ ):
125
+ self.issues.append(
126
+ ValidationIssue(
127
+ severity=IssueSeverity.WARNING,
128
+ table=table_name,
129
+ column=col_name,
130
+ message=f"Type mismatch: source provides '{source_col.get('type')}' but target expects '{target_col.get('type')}'",
131
+ category="Type Mismatch",
132
+ source_value=source_col.get("type"),
133
+ target_value=target_col.get("type"),
134
+ suggested_fix=f"Update target model to expect '{source_col.get('type')}' or fix source column type",
135
+ )
136
+ )
137
+
138
+ # Log results for this table
139
+ table_issues = [i for i in self.issues if i.table == table_name]
140
+ if not table_issues:
141
+ print(f" ✅ All requirements satisfied")
142
+ else:
143
+ critical = [i for i in table_issues if i.severity == IssueSeverity.CRITICAL]
144
+ warnings = [i for i in table_issues if i.severity == IssueSeverity.WARNING]
145
+ if critical:
146
+ print(f" 🚨 {len(critical)} critical issues")
147
+ if warnings:
148
+ print(f" ⚠️ {len(warnings)} warnings")
149
+
150
+ def _types_compatible(self, source_type: str, target_type: str) -> bool:
151
+ """Check if source and target types are compatible."""
152
+ if not source_type or not target_type:
153
+ return True # Skip validation if types are unknown
154
+
155
+ # Normalize types
156
+ source_type = source_type.lower()
157
+ target_type = target_type.lower()
158
+
159
+ # Exact match
160
+ if source_type == target_type:
161
+ return True
162
+
163
+ # Compatible type mappings
164
+ compatible_types = {
165
+ "varchar": ["string", "str", "text"],
166
+ "string": ["varchar", "text"],
167
+ "text": ["varchar", "string"],
168
+ "integer": ["int", "bigint"],
169
+ "int": ["integer", "bigint"],
170
+ "bigint": ["integer", "int"],
171
+ "float": ["double", "decimal", "numeric", "real"],
172
+ "double": ["float", "decimal"],
173
+ "boolean": ["bool"],
174
+ "bool": ["boolean"],
175
+ "timestamp": ["datetime"],
176
+ "datetime": ["timestamp"],
177
+ }
178
+
179
+ return target_type in compatible_types.get(source_type, [])
180
+
181
+ def _generate_summary(self, success: bool, issues: List[ValidationIssue]) -> str:
182
+ """Generate validation summary."""
183
+ if success:
184
+ return f"✅ Validation passed with {len(issues)} non-critical issues"
185
+ else:
186
+ critical = [i for i in issues if i.severity == IssueSeverity.CRITICAL]
187
+ return f"❌ Validation failed with {len(critical)} critical issues"
@@ -0,0 +1,14 @@
1
+ # data_contract_validator/extractors/__init__.py
2
+ """
3
+ Schema extractors for different frameworks.
4
+ """
5
+
6
+ from .base import BaseExtractor
7
+ from .dbt import DBTExtractor
8
+ from .fastapi import FastAPIExtractor
9
+
10
+ __all__ = [
11
+ "BaseExtractor",
12
+ "DBTExtractor",
13
+ "FastAPIExtractor",
14
+ ]
@@ -0,0 +1,45 @@
1
+ """
2
+ Base extractor interface for schema extraction.
3
+ """
4
+
5
+ from abc import ABC, abstractmethod
6
+ from typing import Dict, Any
7
+ from ..core.models import Schema
8
+
9
+
10
+ class BaseExtractor(ABC):
11
+ """Base class for all schema extractors."""
12
+
13
+ @abstractmethod
14
+ def extract_schemas(self) -> Dict[str, Schema]:
15
+ """
16
+ Extract schemas from the source.
17
+
18
+ Returns:
19
+ Dict mapping table names to Schema objects
20
+ """
21
+ pass
22
+
23
+ def _python_to_sql_type(self, python_type: str) -> str:
24
+ """Convert Python type hints to SQL types."""
25
+ type_mappings = {
26
+ "str": "varchar",
27
+ "int": "integer",
28
+ "float": "float",
29
+ "bool": "boolean",
30
+ "datetime": "timestamp",
31
+ "date": "date",
32
+ "list": "json",
33
+ "dict": "json",
34
+ }
35
+
36
+ # Handle Optional types
37
+ if "optional" in python_type.lower():
38
+ inner_type = python_type.lower().replace("optional[", "").replace("]", "")
39
+ return self._python_to_sql_type(inner_type)
40
+
41
+ return type_mappings.get(python_type.lower(), "varchar")
42
+
43
+ def _normalize_column_name(self, name: str) -> str:
44
+ """Normalize column names for comparison."""
45
+ return name.lower().strip()
@@ -0,0 +1,213 @@
1
+ # data_contract_validator/extractors/dbt.py
2
+ """
3
+ DBT schema extractor - simplified version of your working code.
4
+ """
5
+
6
+ import json
7
+ import subprocess
8
+ import re
9
+ from pathlib import Path
10
+ from typing import Dict, List, Optional, Any
11
+
12
+ from .base import BaseExtractor
13
+ from ..core.models import Schema
14
+
15
+
16
+ class DBTExtractor(BaseExtractor):
17
+ """Extract schemas from DBT projects."""
18
+
19
+ def __init__(self, project_path: str = "."):
20
+ self.project_path = Path(project_path)
21
+ self.target_dir = self.project_path / "target"
22
+ self.manifest_path = self.target_dir / "manifest.json"
23
+ self.models_path = self.project_path / "models"
24
+
25
+ def extract_schemas(self) -> Dict[str, Schema]:
26
+ """Extract schemas from DBT project."""
27
+ print(f"🔍 Extracting DBT schemas from {self.project_path}")
28
+
29
+ # Try manifest first, fallback to SQL parsing
30
+ if self._try_compile_dbt() and self.manifest_path.exists():
31
+ print(" 📋 Using manifest.json")
32
+ return self._extract_from_manifest()
33
+ else:
34
+ print(" 📄 Using SQL file parsing")
35
+ return self._extract_from_sql_files()
36
+
37
+ def _try_compile_dbt(self) -> bool:
38
+ """Try to compile DBT project."""
39
+ try:
40
+ result = subprocess.run(
41
+ ["dbt", "parse", "--project-dir", str(self.project_path)],
42
+ capture_output=True,
43
+ text=True,
44
+ timeout=60,
45
+ )
46
+ return result.returncode == 0
47
+ except:
48
+ return False
49
+
50
+ def _extract_from_manifest(self) -> Dict[str, Schema]:
51
+ """Extract schemas from manifest.json."""
52
+ with open(self.manifest_path, "r") as f:
53
+ manifest = json.load(f)
54
+
55
+ schemas = {}
56
+ for node_id, node in manifest.get("nodes", {}).items():
57
+ if node.get("resource_type") == "model":
58
+ model_name = node.get("alias") or node.get("name")
59
+
60
+ columns = []
61
+ for col_name, col_info in node.get("columns", {}).items():
62
+ columns.append(
63
+ {
64
+ "name": col_name,
65
+ "type": col_info.get("data_type", "varchar"),
66
+ "required": True,
67
+ "nullable": False,
68
+ }
69
+ )
70
+
71
+ schemas[model_name] = Schema(
72
+ name=model_name, columns=columns, source="dbt_manifest"
73
+ )
74
+
75
+ print(f" ✅ Found {len(schemas)} tables in manifest")
76
+ return schemas
77
+
78
+ def _extract_from_sql_files(self) -> Dict[str, Schema]:
79
+ """Extract schemas from SQL files directly."""
80
+ schemas = {}
81
+ sql_files = list(self.models_path.rglob("*.sql"))
82
+
83
+ print(f" 🔍 Found {len(sql_files)} SQL files to analyze")
84
+
85
+ for sql_file in sql_files:
86
+ model_name = sql_file.stem
87
+
88
+ # Skip test/analysis files
89
+ if any(skip in str(sql_file) for skip in ["tests", "analysis", "macros"]):
90
+ continue
91
+
92
+ try:
93
+ with open(sql_file, "r", encoding="utf-8") as f:
94
+ sql_content = f.read()
95
+
96
+ columns = self._extract_columns_from_sql(sql_content)
97
+ if columns:
98
+ schemas[model_name] = Schema(
99
+ name=model_name, columns=columns, source="sql_parsing"
100
+ )
101
+ print(f" 📋 {model_name}: {len(columns)} columns")
102
+
103
+ except Exception as e:
104
+ print(f" ❌ Error parsing {model_name}: {e}")
105
+
106
+ return schemas
107
+
108
+ def _extract_columns_from_sql(self, sql_content: str) -> List[Dict[str, Any]]:
109
+ """Extract columns from SQL content - simplified version."""
110
+ # Remove comments and Jinja
111
+ cleaned = re.sub(r"--.*?\n", "\n", sql_content)
112
+ cleaned = re.sub(r"/\*.*?\*/", "", cleaned, flags=re.DOTALL)
113
+ cleaned = re.sub(r"\{\{.*?\}\}", "", cleaned)
114
+
115
+ # Find final SELECT statement
116
+ select_matches = list(
117
+ re.finditer(r"select\s+(.*?)\s+from", cleaned, re.DOTALL | re.IGNORECASE)
118
+ )
119
+
120
+ if not select_matches:
121
+ return []
122
+
123
+ # Use the last SELECT (after CTEs)
124
+ select_content = select_matches[-1].group(1).strip()
125
+
126
+ # Split by comma and parse each column
127
+ columns = []
128
+ column_parts = self._split_columns(select_content)
129
+
130
+ for col_text in column_parts:
131
+ col_text = col_text.strip()
132
+ if col_text and col_text != "*":
133
+ column_name = self._extract_column_name(col_text)
134
+ if column_name:
135
+ columns.append(
136
+ {
137
+ "name": column_name,
138
+ "type": self._infer_data_type(col_text),
139
+ "required": True,
140
+ "nullable": False,
141
+ }
142
+ )
143
+
144
+ return columns
145
+
146
+ def _split_columns(self, select_clause: str) -> List[str]:
147
+ """Split SELECT columns by comma, handling nested functions."""
148
+ columns = []
149
+ current_column = ""
150
+ paren_depth = 0
151
+
152
+ for char in select_clause:
153
+ if char == "(":
154
+ paren_depth += 1
155
+ elif char == ")":
156
+ paren_depth -= 1
157
+ elif char == "," and paren_depth == 0:
158
+ if current_column.strip():
159
+ columns.append(current_column.strip())
160
+ current_column = ""
161
+ continue
162
+
163
+ current_column += char
164
+
165
+ if current_column.strip():
166
+ columns.append(current_column.strip())
167
+
168
+ return columns
169
+
170
+ def _extract_column_name(self, col_text: str) -> Optional[str]:
171
+ """Extract clean column name from column definition."""
172
+ col_text = col_text.strip()
173
+
174
+ # Check for AS alias
175
+ as_match = re.search(r"\s+as\s+(\w+)$", col_text, re.IGNORECASE)
176
+ if as_match:
177
+ return as_match.group(1).lower()
178
+
179
+ # Handle table.column format
180
+ table_match = re.search(r"(\w+)\.(\w+)$", col_text)
181
+ if table_match:
182
+ return table_match.group(2).lower()
183
+
184
+ # Simple column name
185
+ simple_match = re.search(r"^(\w+)$", col_text)
186
+ if simple_match:
187
+ return simple_match.group(1).lower()
188
+
189
+ # For complex expressions, try to extract alias
190
+ parts = col_text.split()
191
+ if len(parts) > 1 and not "(" in parts[-1]:
192
+ return parts[-1].lower()
193
+
194
+ return None
195
+
196
+ def _infer_data_type(self, expression: str) -> str:
197
+ """Infer data type from SQL expression."""
198
+ expr_upper = expression.upper()
199
+
200
+ if any(func in expr_upper for func in ["COUNT", "SUM", "ROW_NUMBER"]):
201
+ return "integer"
202
+ elif "AVG" in expr_upper:
203
+ return "float"
204
+ elif any(func in expr_upper for func in ["CONCAT", "UPPER", "LOWER"]):
205
+ return "varchar"
206
+ elif "TIMESTAMP" in expr_upper or "CURRENT_TIMESTAMP" in expr_upper:
207
+ return "timestamp"
208
+ elif "DATE" in expr_upper:
209
+ return "date"
210
+ elif any(keyword in expr_upper for keyword in ["TRUE", "FALSE", "BOOLEAN"]):
211
+ return "boolean"
212
+ else:
213
+ return "varchar"