cfa-kernel 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (98) hide show
  1. cfa/__init__.py +39 -0
  2. cfa/_lazy.py +39 -0
  3. cfa/adapters/__init__.py +104 -0
  4. cfa/adapters/autogen.py +19 -0
  5. cfa/adapters/crewai.py +19 -0
  6. cfa/adapters/dspy.py +19 -0
  7. cfa/adapters/langgraph.py +19 -0
  8. cfa/adapters/openai_agents.py +19 -0
  9. cfa/audit/__init__.py +15 -0
  10. cfa/audit/context.py +205 -0
  11. cfa/audit/hashing.py +41 -0
  12. cfa/audit/trail.py +194 -0
  13. cfa/backends/__init__.py +132 -0
  14. cfa/backends/dbt.py +338 -0
  15. cfa/backends/pyspark.py +240 -0
  16. cfa/backends/sql.py +270 -0
  17. cfa/behavior/__init__.py +49 -0
  18. cfa/behavior/llm.py +244 -0
  19. cfa/behavior/spec.py +235 -0
  20. cfa/behavior/systematizer.py +222 -0
  21. cfa/cli/__init__.py +296 -0
  22. cfa/cli/__main__.py +6 -0
  23. cfa/cli/_helpers.py +109 -0
  24. cfa/cli/core/__init__.py +0 -0
  25. cfa/cli/core/evaluate.py +72 -0
  26. cfa/cli/core/validate.py +29 -0
  27. cfa/cli/formatters.py +280 -0
  28. cfa/cli/governance/__init__.py +0 -0
  29. cfa/cli/governance/audit.py +65 -0
  30. cfa/cli/governance/catalog.py +28 -0
  31. cfa/cli/governance/policy.py +119 -0
  32. cfa/cli/governance/rules.py +42 -0
  33. cfa/cli/governance/signature.py +31 -0
  34. cfa/cli/infrastructure/__init__.py +0 -0
  35. cfa/cli/infrastructure/backend_list.py +24 -0
  36. cfa/cli/infrastructure/storage.py +87 -0
  37. cfa/cli/project/__init__.py +0 -0
  38. cfa/cli/project/init.py +73 -0
  39. cfa/cli/project/lifecycle.py +92 -0
  40. cfa/cli/project/status.py +75 -0
  41. cfa/cli/project/taxonomy.py +38 -0
  42. cfa/cli/reporting/__init__.py +0 -0
  43. cfa/cli/reporting/report.py +109 -0
  44. cfa/cli/reporting/serve.py +43 -0
  45. cfa/config.py +103 -0
  46. cfa/core/__init__.py +19 -0
  47. cfa/core/codegen.py +65 -0
  48. cfa/core/conditions.py +129 -0
  49. cfa/core/kernel.py +224 -0
  50. cfa/core/phases/__init__.py +0 -0
  51. cfa/core/phases/runner.py +477 -0
  52. cfa/core/planner.py +290 -0
  53. cfa/execution/__init__.py +12 -0
  54. cfa/execution/partial.py +339 -0
  55. cfa/execution/state_projection.py +216 -0
  56. cfa/governance/__init__.py +76 -0
  57. cfa/lifecycle/__init__.py +51 -0
  58. cfa/mcp/__init__.py +347 -0
  59. cfa/mcp/__main__.py +4 -0
  60. cfa/normalizer/__init__.py +15 -0
  61. cfa/normalizer/base.py +441 -0
  62. cfa/normalizer/llm.py +426 -0
  63. cfa/observability/__init__.py +14 -0
  64. cfa/observability/indices.py +177 -0
  65. cfa/observability/metrics.py +91 -0
  66. cfa/observability/notify.py +79 -0
  67. cfa/observability/otel.py +81 -0
  68. cfa/observability/promotion.py +367 -0
  69. cfa/policy/__init__.py +12 -0
  70. cfa/policy/bundle.py +317 -0
  71. cfa/policy/catalog.py +117 -0
  72. cfa/policy/engine.py +306 -0
  73. cfa/reporting/__init__.py +42 -0
  74. cfa/reporting/charts.py +223 -0
  75. cfa/reporting/engine.py +456 -0
  76. cfa/resolution/__init__.py +62 -0
  77. cfa/runtime/__init__.py +13 -0
  78. cfa/runtime/gate.py +287 -0
  79. cfa/sandbox/__init__.py +189 -0
  80. cfa/sandbox/executor.py +92 -0
  81. cfa/sandbox/mock.py +89 -0
  82. cfa/sandbox/panic.py +52 -0
  83. cfa/storage/__init__.py +591 -0
  84. cfa/testing/__init__.py +60 -0
  85. cfa/testing/asserts.py +77 -0
  86. cfa/testing/evaluate.py +168 -0
  87. cfa/testing/fixtures.py +89 -0
  88. cfa/testing/markers.py +36 -0
  89. cfa/types.py +489 -0
  90. cfa/validation/__init__.py +14 -0
  91. cfa/validation/runtime.py +285 -0
  92. cfa/validation/signature.py +146 -0
  93. cfa/validation/static.py +252 -0
  94. cfa_kernel-0.1.0.dist-info/METADATA +32 -0
  95. cfa_kernel-0.1.0.dist-info/RECORD +98 -0
  96. cfa_kernel-0.1.0.dist-info/WHEEL +4 -0
  97. cfa_kernel-0.1.0.dist-info/entry_points.txt +3 -0
  98. cfa_kernel-0.1.0.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,285 @@
1
+ """
2
+ CFA Runtime Validation
3
+ ======================
4
+ Post-execution behavioral validation.
5
+
6
+ Validates sandbox execution results against:
7
+ - Cardinality bounds (rows output within expected range)
8
+ - Cost ceiling (actual DBU vs declared maximum)
9
+ - Schema contract (output columns match expectation)
10
+ - Null ratio thresholds (per-column null limits)
11
+ - Shuffle size limits (data movement budget)
12
+
13
+ Produces Runtime Faults (FaultFamily.RUNTIME) — never throws exceptions.
14
+ """
15
+
16
+ from __future__ import annotations
17
+
18
+ from dataclasses import dataclass, field
19
+ from typing import Any
20
+
21
+ from cfa.sandbox import ExecutionMetrics, SandboxOutcome, SandboxResult
22
+ from cfa.types import (
23
+ Fault,
24
+ FaultFamily,
25
+ FaultSeverity,
26
+ PolicyAction,
27
+ StateSignature,
28
+ )
29
+
30
+ # ── Thresholds ──────────────────────────────────────────────────────────────
31
+
32
+
33
+ @dataclass(frozen=True)
34
+ class RuntimeThresholds:
35
+ """Configurable thresholds for runtime validation."""
36
+
37
+ max_null_ratio: float = 0.10 # 10% null per column
38
+ max_shuffle_mb: float = 500.0 # 500 MB shuffle budget
39
+ min_rows: int = 0 # minimum expected rows (0 = no check)
40
+ max_rows: int = 0 # maximum expected rows (0 = no check)
41
+ max_cost_dbu: float | None = None # cost ceiling (overridden by signature)
42
+ required_output_columns: tuple[str, ...] = ()
43
+ forbidden_output_columns: tuple[str, ...] = ()
44
+
45
+
46
+ # ── Runtime Validation Result ───────────────────────────────────────────────
47
+
48
+
49
+ @dataclass
50
+ class RuntimeValidationResult:
51
+ """Result of runtime behavioral validation."""
52
+
53
+ passed: bool = True
54
+ faults: list[Fault] = field(default_factory=list)
55
+ checks_performed: int = 0
56
+ metrics_snapshot: dict[str, Any] = field(default_factory=dict)
57
+
58
+ @property
59
+ def fault_codes(self) -> list[str]:
60
+ return [f.code for f in self.faults]
61
+
62
+ def add_fault(self, fault: Fault) -> None:
63
+ self.faults.append(fault)
64
+ if fault.severity in (FaultSeverity.HIGH, FaultSeverity.CRITICAL):
65
+ self.passed = False
66
+
67
+
68
+ # ── Runtime Validator ───────────────────────────────────────────────────────
69
+
70
+
71
+ class RuntimeValidator:
72
+ """
73
+ Validates SandboxResult metrics against thresholds and signature constraints.
74
+
75
+ Called after sandbox execution completes (or partially completes).
76
+ Produces RUNTIME faults that feed back into the decision engine.
77
+ """
78
+
79
+ def __init__(self, thresholds: RuntimeThresholds | None = None) -> None:
80
+ self.thresholds = thresholds or RuntimeThresholds()
81
+
82
+ def validate(
83
+ self,
84
+ sandbox_result: SandboxResult,
85
+ signature: StateSignature,
86
+ schema_contract: dict[str, Any] | None = None,
87
+ ) -> RuntimeValidationResult:
88
+ result = RuntimeValidationResult()
89
+ metrics = sandbox_result.aggregate_metrics
90
+
91
+ # Snapshot metrics for audit
92
+ result.metrics_snapshot = {
93
+ "rows_output": metrics.rows_output,
94
+ "shuffle_mb": metrics.shuffle_mb,
95
+ "cost_dbu": metrics.cost_dbu,
96
+ "duration_seconds": metrics.duration_seconds,
97
+ "null_counts": dict(metrics.null_counts),
98
+ "output_schema": list(metrics.output_schema),
99
+ }
100
+
101
+ # Skip validation if sandbox panicked (environmental fault already captured)
102
+ if sandbox_result.outcome == SandboxOutcome.PANIC:
103
+ result.checks_performed = 0
104
+ return result
105
+
106
+ # ── Check 1: Cardinality ────────────────────────────────────────
107
+ self._check_cardinality(metrics, result)
108
+
109
+ # ── Check 2: Cost ceiling ───────────────────────────────────────
110
+ self._check_cost(metrics, signature, result)
111
+
112
+ # ── Check 3: Null ratio ─────────────────────────────────────────
113
+ self._check_null_ratio(metrics, result)
114
+
115
+ # ── Check 4: Shuffle budget ─────────────────────────────────────
116
+ self._check_shuffle(metrics, result)
117
+
118
+ # ── Check 5: Schema contract ────────────────────────────────────
119
+ self._check_schema(metrics, schema_contract, result)
120
+
121
+ # ── Check 6: Output columns from signature ──────────────────────
122
+ self._check_output_columns(metrics, result)
123
+
124
+ return result
125
+
126
+ def _check_cardinality(self, metrics: ExecutionMetrics, result: RuntimeValidationResult) -> None:
127
+ result.checks_performed += 1
128
+
129
+ if self.thresholds.min_rows > 0 and metrics.rows_output < self.thresholds.min_rows:
130
+ result.add_fault(Fault(
131
+ code="RUNTIME_CARDINALITY_BELOW_MINIMUM",
132
+ family=FaultFamily.RUNTIME,
133
+ severity=FaultSeverity.HIGH,
134
+ stage="runtime_validation",
135
+ message=(
136
+ f"Output rows ({metrics.rows_output}) below minimum "
137
+ f"threshold ({self.thresholds.min_rows})."
138
+ ),
139
+ mandatory_action=PolicyAction.BLOCK,
140
+ detected_before_execution=False,
141
+ ))
142
+
143
+ if self.thresholds.max_rows > 0 and metrics.rows_output > self.thresholds.max_rows:
144
+ result.add_fault(Fault(
145
+ code="RUNTIME_CARDINALITY_ABOVE_MAXIMUM",
146
+ family=FaultFamily.RUNTIME,
147
+ severity=FaultSeverity.HIGH,
148
+ stage="runtime_validation",
149
+ message=(
150
+ f"Output rows ({metrics.rows_output}) above maximum "
151
+ f"threshold ({self.thresholds.max_rows})."
152
+ ),
153
+ mandatory_action=PolicyAction.BLOCK,
154
+ detected_before_execution=False,
155
+ ))
156
+
157
+ def _check_cost(
158
+ self, metrics: ExecutionMetrics, signature: StateSignature, result: RuntimeValidationResult
159
+ ) -> None:
160
+ result.checks_performed += 1
161
+
162
+ # Signature ceiling takes precedence over threshold default
163
+ ceiling = signature.constraints.max_cost_dbu or self.thresholds.max_cost_dbu
164
+ if ceiling is not None and metrics.cost_dbu > ceiling:
165
+ result.add_fault(Fault(
166
+ code="RUNTIME_COST_CEILING_EXCEEDED",
167
+ family=FaultFamily.RUNTIME,
168
+ severity=FaultSeverity.CRITICAL,
169
+ stage="runtime_validation",
170
+ message=(
171
+ f"Execution cost ({metrics.cost_dbu:.2f} DBU) exceeds "
172
+ f"ceiling ({ceiling:.2f} DBU)."
173
+ ),
174
+ mandatory_action=PolicyAction.BLOCK,
175
+ detected_before_execution=False,
176
+ ))
177
+
178
+ def _check_null_ratio(self, metrics: ExecutionMetrics, result: RuntimeValidationResult) -> None:
179
+ result.checks_performed += 1
180
+
181
+ if metrics.rows_output == 0:
182
+ return
183
+
184
+ for column, count in metrics.null_counts.items():
185
+ ratio = count / metrics.rows_output
186
+ if ratio > self.thresholds.max_null_ratio:
187
+ result.add_fault(Fault(
188
+ code=f"RUNTIME_NULL_RATIO_EXCEEDED_{column.upper()}",
189
+ family=FaultFamily.RUNTIME,
190
+ severity=FaultSeverity.WARNING,
191
+ stage="runtime_validation",
192
+ message=(
193
+ f"Column '{column}' null ratio ({ratio:.2%}) exceeds "
194
+ f"threshold ({self.thresholds.max_null_ratio:.2%})."
195
+ ),
196
+ mandatory_action=PolicyAction.APPROVE,
197
+ detected_before_execution=False,
198
+ ))
199
+
200
+ def _check_shuffle(self, metrics: ExecutionMetrics, result: RuntimeValidationResult) -> None:
201
+ result.checks_performed += 1
202
+
203
+ if metrics.shuffle_mb > self.thresholds.max_shuffle_mb:
204
+ result.add_fault(Fault(
205
+ code="RUNTIME_SHUFFLE_BUDGET_EXCEEDED",
206
+ family=FaultFamily.RUNTIME,
207
+ severity=FaultSeverity.HIGH,
208
+ stage="runtime_validation",
209
+ message=(
210
+ f"Shuffle size ({metrics.shuffle_mb:.1f} MB) exceeds "
211
+ f"budget ({self.thresholds.max_shuffle_mb:.1f} MB)."
212
+ ),
213
+ mandatory_action=PolicyAction.BLOCK,
214
+ detected_before_execution=False,
215
+ ))
216
+
217
+ def _check_schema(
218
+ self,
219
+ metrics: ExecutionMetrics,
220
+ schema_contract: dict[str, Any] | None,
221
+ result: RuntimeValidationResult,
222
+ ) -> None:
223
+ if not schema_contract:
224
+ return
225
+
226
+ result.checks_performed += 1
227
+
228
+ required = set(schema_contract.get("required_columns", []))
229
+ forbidden = set(schema_contract.get("forbidden_columns", []))
230
+ actual = set(metrics.output_schema)
231
+
232
+ missing = required - actual
233
+ if missing:
234
+ result.add_fault(Fault(
235
+ code="RUNTIME_SCHEMA_MISSING_COLUMNS",
236
+ family=FaultFamily.RUNTIME,
237
+ severity=FaultSeverity.HIGH,
238
+ stage="runtime_validation",
239
+ message=f"Output missing required columns: {sorted(missing)}.",
240
+ mandatory_action=PolicyAction.BLOCK,
241
+ detected_before_execution=False,
242
+ ))
243
+
244
+ leaked = forbidden & actual
245
+ if leaked:
246
+ result.add_fault(Fault(
247
+ code="RUNTIME_SCHEMA_FORBIDDEN_COLUMNS",
248
+ family=FaultFamily.RUNTIME,
249
+ severity=FaultSeverity.CRITICAL,
250
+ stage="runtime_validation",
251
+ message=f"Output contains forbidden columns: {sorted(leaked)}.",
252
+ mandatory_action=PolicyAction.BLOCK,
253
+ detected_before_execution=False,
254
+ ))
255
+
256
+ def _check_output_columns(self, metrics: ExecutionMetrics, result: RuntimeValidationResult) -> None:
257
+ if not self.thresholds.required_output_columns and not self.thresholds.forbidden_output_columns:
258
+ return
259
+
260
+ result.checks_performed += 1
261
+ actual = set(metrics.output_schema)
262
+
263
+ missing = set(self.thresholds.required_output_columns) - actual
264
+ if missing:
265
+ result.add_fault(Fault(
266
+ code="RUNTIME_MISSING_REQUIRED_OUTPUT_COLUMNS",
267
+ family=FaultFamily.RUNTIME,
268
+ severity=FaultSeverity.HIGH,
269
+ stage="runtime_validation",
270
+ message=f"Output missing required columns: {sorted(missing)}.",
271
+ mandatory_action=PolicyAction.BLOCK,
272
+ detected_before_execution=False,
273
+ ))
274
+
275
+ leaked = set(self.thresholds.forbidden_output_columns) & actual
276
+ if leaked:
277
+ result.add_fault(Fault(
278
+ code="RUNTIME_FORBIDDEN_OUTPUT_COLUMNS",
279
+ family=FaultFamily.RUNTIME,
280
+ severity=FaultSeverity.CRITICAL,
281
+ stage="runtime_validation",
282
+ message=f"Output contains forbidden columns: {sorted(leaked)}.",
283
+ mandatory_action=PolicyAction.BLOCK,
284
+ detected_before_execution=False,
285
+ ))
@@ -0,0 +1,146 @@
1
+ """StateSignature validation utilities.
2
+
3
+ Validation is intentionally separate from ``StateSignature.from_dict`` so legacy
4
+ internal callers can keep deserializing permissively while CLI/API boundaries can
5
+ enforce a strict contract for external systems.
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ from dataclasses import dataclass, field
11
+ from typing import Any
12
+
13
+ from cfa.types import DatasetClassification, TargetLayer
14
+
15
+ _LAYERS = {layer.value for layer in TargetLayer}
16
+ _CLASSIFICATIONS = {classification.value for classification in DatasetClassification}
17
+
18
+
19
+ @dataclass(frozen=True)
20
+ class SignatureValidationIssue:
21
+ path: str
22
+ message: str
23
+
24
+
25
+ @dataclass(frozen=True)
26
+ class SignatureValidationResult:
27
+ valid: bool
28
+ issues: list[SignatureValidationIssue] = field(default_factory=list)
29
+
30
+ @property
31
+ def messages(self) -> list[str]:
32
+ return [f"{i.path}: {i.message}" for i in self.issues]
33
+
34
+
35
+ def unwrap_signature_data(data: dict[str, Any]) -> dict[str, Any]:
36
+ """Accept common wrappers used by APIs and CLI payloads."""
37
+ if isinstance(data, dict):
38
+ wrapped = data.get("signature", data.get("state_signature"))
39
+ if isinstance(wrapped, dict):
40
+ return wrapped
41
+ return data
42
+
43
+
44
+ def validate_signature_data(
45
+ data: dict[str, Any] | None,
46
+ *,
47
+ require_datasets: bool = False,
48
+ ) -> SignatureValidationResult:
49
+ issues: list[SignatureValidationIssue] = []
50
+
51
+ if data is None:
52
+ return SignatureValidationResult(
53
+ valid=False,
54
+ issues=[SignatureValidationIssue("signature", "file is empty")],
55
+ )
56
+ if not isinstance(data, dict):
57
+ return SignatureValidationResult(
58
+ valid=False,
59
+ issues=[SignatureValidationIssue("signature", "must be an object")],
60
+ )
61
+
62
+ sig = unwrap_signature_data(data)
63
+ if not isinstance(sig, dict):
64
+ return SignatureValidationResult(
65
+ valid=False,
66
+ issues=[SignatureValidationIssue("signature", "must be an object")],
67
+ )
68
+
69
+ _require_non_empty_string(sig, "domain", issues)
70
+ _require_non_empty_string(sig, "intent", issues)
71
+
72
+ target_layer = sig.get("target_layer")
73
+ if target_layer not in _LAYERS:
74
+ issues.append(SignatureValidationIssue("target_layer", f"must be one of {sorted(_LAYERS)}"))
75
+
76
+ datasets = sig.get("datasets")
77
+ if datasets is None:
78
+ if require_datasets:
79
+ issues.append(SignatureValidationIssue("datasets", "is required"))
80
+ elif not isinstance(datasets, list):
81
+ issues.append(SignatureValidationIssue("datasets", "must be a list"))
82
+ elif require_datasets and not datasets:
83
+ issues.append(SignatureValidationIssue("datasets", "must contain at least one dataset"))
84
+ elif isinstance(datasets, list):
85
+ for idx, dataset in enumerate(datasets):
86
+ base = f"datasets[{idx}]"
87
+ if not isinstance(dataset, dict):
88
+ issues.append(SignatureValidationIssue(base, "must be an object"))
89
+ continue
90
+ _require_non_empty_string(dataset, f"{base}.name", issues, key="name")
91
+ classification = dataset.get("classification", "internal")
92
+ if classification not in _CLASSIFICATIONS:
93
+ issues.append(SignatureValidationIssue(f"{base}.classification", f"must be one of {sorted(_CLASSIFICATIONS)}"))
94
+ size_gb = dataset.get("size_gb", 0.0)
95
+ if not isinstance(size_gb, (int, float)) or isinstance(size_gb, bool) or size_gb < 0:
96
+ issues.append(SignatureValidationIssue(f"{base}.size_gb", "must be a non-negative number"))
97
+ pii_columns = dataset.get("pii_columns", [])
98
+ if not isinstance(pii_columns, list):
99
+ issues.append(SignatureValidationIssue(f"{base}.pii_columns", "must be a list of strings"))
100
+ elif any(not isinstance(col, str) or not col.strip() for col in pii_columns):
101
+ issues.append(SignatureValidationIssue(f"{base}.pii_columns", "must contain only non-empty strings"))
102
+ partition_column = dataset.get("partition_column")
103
+ if partition_column is not None and not isinstance(partition_column, str):
104
+ issues.append(SignatureValidationIssue(f"{base}.partition_column", "must be a string or null"))
105
+
106
+ constraints = sig.get("constraints", {})
107
+ if not isinstance(constraints, dict):
108
+ issues.append(SignatureValidationIssue("constraints", "must be an object"))
109
+ else:
110
+ for key in ("no_pii_raw", "merge_key_required", "enforce_types"):
111
+ if key in constraints and not isinstance(constraints[key], bool):
112
+ issues.append(SignatureValidationIssue(f"constraints.{key}", "must be a boolean"))
113
+ partition_by = constraints.get("partition_by", [])
114
+ if not isinstance(partition_by, list):
115
+ issues.append(SignatureValidationIssue("constraints.partition_by", "must be a list of strings"))
116
+ elif any(not isinstance(col, str) or not col.strip() for col in partition_by):
117
+ issues.append(SignatureValidationIssue("constraints.partition_by", "must contain only non-empty strings"))
118
+ max_cost = constraints.get("max_cost_dbu")
119
+ if max_cost is not None and (not isinstance(max_cost, (int, float)) or isinstance(max_cost, bool) or max_cost < 0):
120
+ issues.append(SignatureValidationIssue("constraints.max_cost_dbu", "must be a non-negative number or null"))
121
+ custom = constraints.get("custom", {})
122
+ if custom is not None and not isinstance(custom, dict):
123
+ issues.append(SignatureValidationIssue("constraints.custom", "must be an object"))
124
+
125
+ ctx = sig.get("execution_context")
126
+ if not isinstance(ctx, dict):
127
+ issues.append(SignatureValidationIssue("execution_context", "is required and must be an object"))
128
+ else:
129
+ _require_non_empty_string(ctx, "execution_context.policy_bundle_version", issues, key="policy_bundle_version")
130
+ _require_non_empty_string(ctx, "execution_context.catalog_snapshot_version", issues, key="catalog_snapshot_version")
131
+ _require_non_empty_string(ctx, "execution_context.context_registry_version_id", issues, key="context_registry_version_id")
132
+
133
+ return SignatureValidationResult(valid=not issues, issues=issues)
134
+
135
+
136
+ def _require_non_empty_string(
137
+ data: dict[str, Any],
138
+ path: str,
139
+ issues: list[SignatureValidationIssue],
140
+ *,
141
+ key: str | None = None,
142
+ ) -> None:
143
+ lookup = key or path
144
+ value = data.get(lookup)
145
+ if not isinstance(value, str) or not value.strip():
146
+ issues.append(SignatureValidationIssue(path, "is required and must be a non-empty string"))
@@ -0,0 +1,252 @@
1
+ """
2
+ CFA Static Validation
3
+ =====================
4
+ Analyzes generated code BEFORE execution.
5
+ Detects violations that can be identified without running the job.
6
+
7
+ Belongs to the Static Safety Faults family (Invariant I6).
8
+
9
+ Checks:
10
+ 1. Forbidden tokens (backend-declared)
11
+ 2. Required patterns (filter(), merge()) based on Signature constraints
12
+ 3. Schema contract (expected columns, forbidden columns)
13
+
14
+ Forbidden tokens are declared by each backend via ``BackendCapabilities``.
15
+ New backends automatically bring their own validation rules — no central registry.
16
+ """
17
+
18
+ from __future__ import annotations
19
+
20
+ import re
21
+ from dataclasses import dataclass, field
22
+ from typing import Any
23
+
24
+ from cfa.core.codegen import GeneratedCode
25
+ from cfa.types import (
26
+ Fault,
27
+ FaultFamily,
28
+ FaultSeverity,
29
+ PolicyAction,
30
+ StateSignature,
31
+ )
32
+
33
+ # ── Validation Result ────────────────────────────────────────────────────────
34
+
35
+
36
+ @dataclass
37
+ class StaticValidationResult:
38
+ """Result of static code analysis."""
39
+
40
+ passed: bool
41
+ faults: list[Fault] = field(default_factory=list)
42
+ warnings: list[Fault] = field(default_factory=list)
43
+ checks_performed: int = 0
44
+
45
+ @property
46
+ def fault_codes(self) -> list[str]:
47
+ return [f.code for f in self.faults]
48
+
49
+ @property
50
+ def is_blocked(self) -> bool:
51
+ return not self.passed
52
+
53
+
54
+ # ── Validation Rules ─────────────────────────────────────────────────────────
55
+
56
+
57
+ @dataclass(frozen=True)
58
+ class ForbiddenToken:
59
+ """A pattern that must not appear in generated code.
60
+
61
+ Declared by each backend in its ``BackendCapabilities.forbidden_tokens``.
62
+ """
63
+
64
+ pattern: str
65
+ fault_code: str
66
+ severity: FaultSeverity
67
+ message: str
68
+ is_regex: bool = False
69
+
70
+
71
+ @dataclass(frozen=True)
72
+ class RequiredPattern:
73
+ pattern: str
74
+ fault_code: str
75
+ message: str
76
+ condition_description: str
77
+ is_regex: bool = False
78
+
79
+
80
+ # ── Static Validator ─────────────────────────────────────────────────────────
81
+
82
+
83
+ class StaticValidator:
84
+ """Analyzes generated code before execution.
85
+
86
+ Forbidden tokens come from the backend that generated the code,
87
+ queried via ``backend.get_capabilities().forbidden_tokens``.
88
+
89
+ If no backend is provided, a minimal set of common-sense defaults
90
+ is used (``import os``, ``import subprocess``).
91
+ """
92
+
93
+ _MINIMAL_FORBIDDEN: list[ForbiddenToken] = (
94
+ ForbiddenToken("import os", "STATIC_FORBIDDEN_IMPORT_OS",
95
+ FaultSeverity.CRITICAL, "os module import forbidden."),
96
+ ForbiddenToken("import subprocess", "STATIC_FORBIDDEN_IMPORT_SUBPROCESS",
97
+ FaultSeverity.CRITICAL, "subprocess module import forbidden."),
98
+ )
99
+
100
+ def __init__(
101
+ self,
102
+ forbidden_tokens: list[ForbiddenToken] | None = None,
103
+ ) -> None:
104
+ self._explicit_tokens = forbidden_tokens
105
+
106
+ def validate(
107
+ self,
108
+ code: GeneratedCode,
109
+ signature: StateSignature,
110
+ schema_contract: dict[str, Any] | None = None,
111
+ *,
112
+ backend: Any | None = None,
113
+ ) -> StaticValidationResult:
114
+ faults: list[Fault] = []
115
+ checks = 0
116
+
117
+ # 1. Forbidden tokens (from backend or explicit override or minimal defaults)
118
+ checks += 1
119
+ tokens = self._resolve_tokens(code, backend)
120
+ faults.extend(self._check_forbidden_tokens(code.code, tokens))
121
+
122
+ # 2. Raw PII references
123
+ checks += 1
124
+ faults.extend(self._check_pii_references(code.code, signature))
125
+
126
+ # 3. Required patterns
127
+ checks += 1
128
+ faults.extend(self._check_required_patterns(code.code, signature, language=code.language))
129
+
130
+ # 4. Schema contract
131
+ if schema_contract:
132
+ checks += 1
133
+ faults.extend(self._check_schema_contract(code.code, schema_contract))
134
+
135
+ blocking = [f for f in faults if f.severity in (FaultSeverity.CRITICAL, FaultSeverity.HIGH)]
136
+ non_blocking = [f for f in faults if f.severity in (FaultSeverity.WARNING, FaultSeverity.INFO)]
137
+
138
+ return StaticValidationResult(
139
+ passed=len(blocking) == 0,
140
+ faults=blocking,
141
+ warnings=non_blocking,
142
+ checks_performed=checks,
143
+ )
144
+
145
+ def _resolve_tokens(
146
+ self, code: GeneratedCode, backend: Any | None
147
+ ) -> list[ForbiddenToken]:
148
+ if self._explicit_tokens is not None:
149
+ return self._explicit_tokens
150
+ if backend is not None and hasattr(backend, "get_capabilities"):
151
+ return list(backend.get_capabilities().forbidden_tokens)
152
+ return list(self._MINIMAL_FORBIDDEN)
153
+
154
+ # ── Private checks ────────────────────────────────────────────────────
155
+
156
+ def _check_forbidden_tokens(
157
+ self, src: str, tokens: list[ForbiddenToken]
158
+ ) -> list[Fault]:
159
+ faults: list[Fault] = []
160
+ for token in tokens:
161
+ found = bool(re.search(token.pattern, src, re.IGNORECASE)) if token.is_regex else token.pattern in src
162
+ if found:
163
+ faults.append(Fault(
164
+ code=token.fault_code,
165
+ family=FaultFamily.STATIC,
166
+ severity=token.severity,
167
+ stage="static_validation",
168
+ message=token.message,
169
+ mandatory_action=PolicyAction.BLOCK,
170
+ ))
171
+ return faults
172
+
173
+ def _check_pii_references(self, src: str, signature: StateSignature) -> list[Fault]:
174
+ faults: list[Fault] = []
175
+ if not signature.constraints.no_pii_raw:
176
+ return faults
177
+
178
+ for ds in signature.datasets:
179
+ for col in ds.pii_columns:
180
+ pattern = rf'F\.col\(\s*"{re.escape(col)}"\s*\)'
181
+ matches = list(re.finditer(pattern, src))
182
+ for match in matches:
183
+ context_start = max(0, match.start() - 50)
184
+ context = src[context_start:match.end() + 30]
185
+ if "sha2(" in context or ".drop(" in context:
186
+ continue
187
+ faults.append(Fault(
188
+ code=f"STATIC_RAW_PII_REFERENCE_{col.upper()}",
189
+ family=FaultFamily.STATIC,
190
+ severity=FaultSeverity.CRITICAL,
191
+ stage="static_validation",
192
+ message=f"Raw PII column '{col}' referenced without anonymization context.",
193
+ mandatory_action=PolicyAction.BLOCK,
194
+ ))
195
+ return faults
196
+
197
+ def _check_required_patterns(self, src: str, signature: StateSignature, *, language: str = "") -> list[Fault]:
198
+ faults: list[Fault] = []
199
+ is_sql = language == "sql"
200
+
201
+ # Partition filter
202
+ if signature.constraints.partition_by:
203
+ has_filter = (
204
+ ".filter(" in src or ".where(" in src
205
+ or (is_sql and bool(re.search(r"\bWHERE\b", src, re.IGNORECASE)))
206
+ )
207
+ if not has_filter:
208
+ faults.append(Fault(
209
+ code="STATIC_MISSING_PARTITION_FILTER",
210
+ family=FaultFamily.STATIC,
211
+ severity=FaultSeverity.HIGH,
212
+ stage="static_validation",
213
+ message="Partition filter required but not found in code.",
214
+ mandatory_action=PolicyAction.BLOCK,
215
+ remediation=("Add a filter/WHERE clause on the temporal column.",),
216
+ ))
217
+
218
+ # Merge operation for Silver/Gold
219
+ if signature.writes_to_protected_layer and signature.constraints.merge_key_required:
220
+ has_merge = (
221
+ "merge(" in src.lower()
222
+ or "mergebuilder" in src.lower()
223
+ or "DeltaTable" in src
224
+ or (is_sql and bool(re.search(r"\bMERGE\s+INTO\b", src, re.IGNORECASE)))
225
+ )
226
+ if not has_merge:
227
+ faults.append(Fault(
228
+ code="STATIC_MISSING_MERGE_OPERATION",
229
+ family=FaultFamily.STATIC,
230
+ severity=FaultSeverity.HIGH,
231
+ stage="static_validation",
232
+ message="Merge operation required for Silver/Gold write but not found.",
233
+ mandatory_action=PolicyAction.BLOCK,
234
+ remediation=("Use MERGE INTO or DeltaTable.merge() instead of append.",),
235
+ ))
236
+
237
+ return faults
238
+
239
+ def _check_schema_contract(self, src: str, contract: dict[str, Any]) -> list[Fault]:
240
+ faults: list[Fault] = []
241
+ forbidden_cols = contract.get("forbidden_columns", [])
242
+ for col in forbidden_cols:
243
+ if re.search(rf'\b{re.escape(col)}\b', src):
244
+ faults.append(Fault(
245
+ code=f"STATIC_FORBIDDEN_COLUMN_{col.upper()}",
246
+ family=FaultFamily.STATIC,
247
+ severity=FaultSeverity.CRITICAL,
248
+ stage="static_validation",
249
+ message=f"Forbidden column '{col}' detected in output.",
250
+ mandatory_action=PolicyAction.BLOCK,
251
+ ))
252
+ return faults