kontra 0.5.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- kontra/__init__.py +1871 -0
- kontra/api/__init__.py +22 -0
- kontra/api/compare.py +340 -0
- kontra/api/decorators.py +153 -0
- kontra/api/results.py +2121 -0
- kontra/api/rules.py +681 -0
- kontra/cli/__init__.py +0 -0
- kontra/cli/commands/__init__.py +1 -0
- kontra/cli/commands/config.py +153 -0
- kontra/cli/commands/diff.py +450 -0
- kontra/cli/commands/history.py +196 -0
- kontra/cli/commands/profile.py +289 -0
- kontra/cli/commands/validate.py +468 -0
- kontra/cli/constants.py +6 -0
- kontra/cli/main.py +48 -0
- kontra/cli/renderers.py +304 -0
- kontra/cli/utils.py +28 -0
- kontra/config/__init__.py +34 -0
- kontra/config/loader.py +127 -0
- kontra/config/models.py +49 -0
- kontra/config/settings.py +797 -0
- kontra/connectors/__init__.py +0 -0
- kontra/connectors/db_utils.py +251 -0
- kontra/connectors/detection.py +323 -0
- kontra/connectors/handle.py +368 -0
- kontra/connectors/postgres.py +127 -0
- kontra/connectors/sqlserver.py +226 -0
- kontra/engine/__init__.py +0 -0
- kontra/engine/backends/duckdb_session.py +227 -0
- kontra/engine/backends/duckdb_utils.py +18 -0
- kontra/engine/backends/polars_backend.py +47 -0
- kontra/engine/engine.py +1205 -0
- kontra/engine/executors/__init__.py +15 -0
- kontra/engine/executors/base.py +50 -0
- kontra/engine/executors/database_base.py +528 -0
- kontra/engine/executors/duckdb_sql.py +607 -0
- kontra/engine/executors/postgres_sql.py +162 -0
- kontra/engine/executors/registry.py +69 -0
- kontra/engine/executors/sqlserver_sql.py +163 -0
- kontra/engine/materializers/__init__.py +14 -0
- kontra/engine/materializers/base.py +42 -0
- kontra/engine/materializers/duckdb.py +110 -0
- kontra/engine/materializers/factory.py +22 -0
- kontra/engine/materializers/polars_connector.py +131 -0
- kontra/engine/materializers/postgres.py +157 -0
- kontra/engine/materializers/registry.py +138 -0
- kontra/engine/materializers/sqlserver.py +160 -0
- kontra/engine/result.py +15 -0
- kontra/engine/sql_utils.py +611 -0
- kontra/engine/sql_validator.py +609 -0
- kontra/engine/stats.py +194 -0
- kontra/engine/types.py +138 -0
- kontra/errors.py +533 -0
- kontra/logging.py +85 -0
- kontra/preplan/__init__.py +5 -0
- kontra/preplan/planner.py +253 -0
- kontra/preplan/postgres.py +179 -0
- kontra/preplan/sqlserver.py +191 -0
- kontra/preplan/types.py +24 -0
- kontra/probes/__init__.py +20 -0
- kontra/probes/compare.py +400 -0
- kontra/probes/relationship.py +283 -0
- kontra/reporters/__init__.py +0 -0
- kontra/reporters/json_reporter.py +190 -0
- kontra/reporters/rich_reporter.py +11 -0
- kontra/rules/__init__.py +35 -0
- kontra/rules/base.py +186 -0
- kontra/rules/builtin/__init__.py +40 -0
- kontra/rules/builtin/allowed_values.py +156 -0
- kontra/rules/builtin/compare.py +188 -0
- kontra/rules/builtin/conditional_not_null.py +213 -0
- kontra/rules/builtin/conditional_range.py +310 -0
- kontra/rules/builtin/contains.py +138 -0
- kontra/rules/builtin/custom_sql_check.py +182 -0
- kontra/rules/builtin/disallowed_values.py +140 -0
- kontra/rules/builtin/dtype.py +203 -0
- kontra/rules/builtin/ends_with.py +129 -0
- kontra/rules/builtin/freshness.py +240 -0
- kontra/rules/builtin/length.py +193 -0
- kontra/rules/builtin/max_rows.py +35 -0
- kontra/rules/builtin/min_rows.py +46 -0
- kontra/rules/builtin/not_null.py +121 -0
- kontra/rules/builtin/range.py +222 -0
- kontra/rules/builtin/regex.py +143 -0
- kontra/rules/builtin/starts_with.py +129 -0
- kontra/rules/builtin/unique.py +124 -0
- kontra/rules/condition_parser.py +203 -0
- kontra/rules/execution_plan.py +455 -0
- kontra/rules/factory.py +103 -0
- kontra/rules/predicates.py +25 -0
- kontra/rules/registry.py +24 -0
- kontra/rules/static_predicates.py +120 -0
- kontra/scout/__init__.py +9 -0
- kontra/scout/backends/__init__.py +17 -0
- kontra/scout/backends/base.py +111 -0
- kontra/scout/backends/duckdb_backend.py +359 -0
- kontra/scout/backends/postgres_backend.py +519 -0
- kontra/scout/backends/sqlserver_backend.py +577 -0
- kontra/scout/dtype_mapping.py +150 -0
- kontra/scout/patterns.py +69 -0
- kontra/scout/profiler.py +801 -0
- kontra/scout/reporters/__init__.py +39 -0
- kontra/scout/reporters/json_reporter.py +165 -0
- kontra/scout/reporters/markdown_reporter.py +152 -0
- kontra/scout/reporters/rich_reporter.py +144 -0
- kontra/scout/store.py +208 -0
- kontra/scout/suggest.py +200 -0
- kontra/scout/types.py +652 -0
- kontra/state/__init__.py +29 -0
- kontra/state/backends/__init__.py +79 -0
- kontra/state/backends/base.py +348 -0
- kontra/state/backends/local.py +480 -0
- kontra/state/backends/postgres.py +1010 -0
- kontra/state/backends/s3.py +543 -0
- kontra/state/backends/sqlserver.py +969 -0
- kontra/state/fingerprint.py +166 -0
- kontra/state/types.py +1061 -0
- kontra/version.py +1 -0
- kontra-0.5.2.dist-info/METADATA +122 -0
- kontra-0.5.2.dist-info/RECORD +124 -0
- kontra-0.5.2.dist-info/WHEEL +5 -0
- kontra-0.5.2.dist-info/entry_points.txt +2 -0
- kontra-0.5.2.dist-info/licenses/LICENSE +17 -0
- kontra-0.5.2.dist-info/top_level.txt +1 -0
kontra/api/results.py
ADDED
|
@@ -0,0 +1,2121 @@
|
|
|
1
|
+
# src/kontra/api/results.py
|
|
2
|
+
"""
|
|
3
|
+
Public API result types for Kontra.
|
|
4
|
+
|
|
5
|
+
These classes wrap the internal state/result types with a cleaner interface
|
|
6
|
+
for the public Python API.
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
from __future__ import annotations
|
|
10
|
+
|
|
11
|
+
import json
|
|
12
|
+
from dataclasses import dataclass, field
|
|
13
|
+
from datetime import datetime
|
|
14
|
+
from pathlib import Path
|
|
15
|
+
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union
|
|
16
|
+
|
|
17
|
+
import yaml
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
# --- Unique rule sampling helpers (shared by multiple methods) ---
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def _is_unique_rule(rule: Any) -> bool:
|
|
24
|
+
"""Check if a rule is a unique rule."""
|
|
25
|
+
return getattr(rule, "name", None) == "unique"
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def _filter_samples_polars(
|
|
29
|
+
source: Any, # pl.DataFrame or pl.LazyFrame
|
|
30
|
+
rule: Any,
|
|
31
|
+
predicate: Any,
|
|
32
|
+
n: int,
|
|
33
|
+
) -> Any: # pl.DataFrame
|
|
34
|
+
"""
|
|
35
|
+
Filter samples with special handling for unique rule.
|
|
36
|
+
|
|
37
|
+
Works with both DataFrame and LazyFrame sources. Adds _row_index,
|
|
38
|
+
and for unique rules adds _duplicate_count sorted by worst offenders.
|
|
39
|
+
|
|
40
|
+
Args:
|
|
41
|
+
source: Polars DataFrame or LazyFrame
|
|
42
|
+
rule: Rule object (used to detect unique rule)
|
|
43
|
+
predicate: Polars expression for filtering
|
|
44
|
+
n: Maximum rows to return
|
|
45
|
+
|
|
46
|
+
Returns:
|
|
47
|
+
Polars DataFrame with filtered samples
|
|
48
|
+
"""
|
|
49
|
+
import polars as pl
|
|
50
|
+
|
|
51
|
+
# Convert DataFrame to LazyFrame if needed
|
|
52
|
+
if isinstance(source, pl.DataFrame):
|
|
53
|
+
lf = source.lazy()
|
|
54
|
+
else:
|
|
55
|
+
lf = source
|
|
56
|
+
|
|
57
|
+
# Add row index
|
|
58
|
+
lf = lf.with_row_index("_row_index")
|
|
59
|
+
|
|
60
|
+
# Special case: unique rule - add duplicate count, sort by worst offenders
|
|
61
|
+
if _is_unique_rule(rule):
|
|
62
|
+
column = rule.params.get("column")
|
|
63
|
+
return (
|
|
64
|
+
lf.with_columns(
|
|
65
|
+
pl.col(column).count().over(column).alias("_duplicate_count")
|
|
66
|
+
)
|
|
67
|
+
.filter(predicate)
|
|
68
|
+
.sort("_duplicate_count", descending=True)
|
|
69
|
+
.head(n)
|
|
70
|
+
.collect()
|
|
71
|
+
)
|
|
72
|
+
else:
|
|
73
|
+
return lf.filter(predicate).head(n).collect()
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
def _build_unique_sample_query_sql(
|
|
77
|
+
table: str,
|
|
78
|
+
column: str,
|
|
79
|
+
n: int,
|
|
80
|
+
dialect: str,
|
|
81
|
+
) -> str:
|
|
82
|
+
"""
|
|
83
|
+
Build SQL query for sampling unique rule violations.
|
|
84
|
+
|
|
85
|
+
Returns query that finds duplicate values, orders by worst offenders,
|
|
86
|
+
and includes _duplicate_count and _row_index.
|
|
87
|
+
|
|
88
|
+
Args:
|
|
89
|
+
table: Fully qualified table name
|
|
90
|
+
column: Column being checked for uniqueness
|
|
91
|
+
n: Maximum rows to return
|
|
92
|
+
dialect: SQL dialect ("postgres", "mssql")
|
|
93
|
+
|
|
94
|
+
Returns:
|
|
95
|
+
SQL query string
|
|
96
|
+
"""
|
|
97
|
+
col = f'"{column}"'
|
|
98
|
+
|
|
99
|
+
if dialect == "mssql":
|
|
100
|
+
return f"""
|
|
101
|
+
SELECT t.*, dup._duplicate_count,
|
|
102
|
+
ROW_NUMBER() OVER (ORDER BY dup._duplicate_count DESC) - 1 AS _row_index
|
|
103
|
+
FROM {table} t
|
|
104
|
+
JOIN (
|
|
105
|
+
SELECT {col}, COUNT(*) as _duplicate_count
|
|
106
|
+
FROM {table}
|
|
107
|
+
GROUP BY {col}
|
|
108
|
+
HAVING COUNT(*) > 1
|
|
109
|
+
) dup ON t.{col} = dup.{col}
|
|
110
|
+
ORDER BY dup._duplicate_count DESC
|
|
111
|
+
OFFSET 0 ROWS FETCH FIRST {n} ROWS ONLY
|
|
112
|
+
"""
|
|
113
|
+
else:
|
|
114
|
+
return f"""
|
|
115
|
+
SELECT t.*, dup._duplicate_count,
|
|
116
|
+
ROW_NUMBER() OVER (ORDER BY dup._duplicate_count DESC) - 1 AS _row_index
|
|
117
|
+
FROM {table} t
|
|
118
|
+
JOIN (
|
|
119
|
+
SELECT {col}, COUNT(*) as _duplicate_count
|
|
120
|
+
FROM {table}
|
|
121
|
+
GROUP BY {col}
|
|
122
|
+
HAVING COUNT(*) > 1
|
|
123
|
+
) dup ON t.{col} = dup.{col}
|
|
124
|
+
ORDER BY dup._duplicate_count DESC
|
|
125
|
+
LIMIT {n}
|
|
126
|
+
"""
|
|
127
|
+
|
|
128
|
+
|
|
129
|
+
# --- End unique rule helpers ---
|
|
130
|
+
|
|
131
|
+
|
|
132
|
+
class FailureSamples:
|
|
133
|
+
"""
|
|
134
|
+
Collection of sample rows that failed a validation rule.
|
|
135
|
+
|
|
136
|
+
This class wraps a list of failing rows with serialization methods.
|
|
137
|
+
It's iterable and indexable like a list.
|
|
138
|
+
|
|
139
|
+
Properties:
|
|
140
|
+
rule_id: The rule ID these samples are from
|
|
141
|
+
count: Number of samples in this collection
|
|
142
|
+
|
|
143
|
+
Methods:
|
|
144
|
+
to_dict(): Convert to list of dicts
|
|
145
|
+
to_json(): Convert to JSON string
|
|
146
|
+
to_llm(): Token-optimized format for LLM context
|
|
147
|
+
"""
|
|
148
|
+
|
|
149
|
+
def __init__(self, samples: List[Dict[str, Any]], rule_id: str):
|
|
150
|
+
self._samples = samples
|
|
151
|
+
self.rule_id = rule_id
|
|
152
|
+
|
|
153
|
+
def __repr__(self) -> str:
|
|
154
|
+
return f"FailureSamples({self.rule_id}, {len(self._samples)} rows)"
|
|
155
|
+
|
|
156
|
+
def __len__(self) -> int:
|
|
157
|
+
return len(self._samples)
|
|
158
|
+
|
|
159
|
+
def __iter__(self) -> Iterator[Dict[str, Any]]:
|
|
160
|
+
return iter(self._samples)
|
|
161
|
+
|
|
162
|
+
def __getitem__(self, index: int) -> Dict[str, Any]:
|
|
163
|
+
return self._samples[index]
|
|
164
|
+
|
|
165
|
+
def __bool__(self) -> bool:
|
|
166
|
+
return len(self._samples) > 0
|
|
167
|
+
|
|
168
|
+
@property
|
|
169
|
+
def count(self) -> int:
|
|
170
|
+
"""Number of sample rows."""
|
|
171
|
+
return len(self._samples)
|
|
172
|
+
|
|
173
|
+
def to_dict(self) -> List[Dict[str, Any]]:
|
|
174
|
+
"""Convert to list of dicts."""
|
|
175
|
+
return self._samples
|
|
176
|
+
|
|
177
|
+
def to_json(self, indent: Optional[int] = None) -> str:
|
|
178
|
+
"""Convert to JSON string."""
|
|
179
|
+
return json.dumps(self._samples, indent=indent, default=str)
|
|
180
|
+
|
|
181
|
+
def to_llm(self) -> str:
|
|
182
|
+
"""
|
|
183
|
+
Token-optimized format for LLM context.
|
|
184
|
+
|
|
185
|
+
Example output:
|
|
186
|
+
SAMPLES: COL:email:not_null (2 rows)
|
|
187
|
+
[0] row=1: id=2, email=None, status=active
|
|
188
|
+
[1] row=3: id=4, email=None, status=active
|
|
189
|
+
|
|
190
|
+
For unique rule:
|
|
191
|
+
SAMPLES: COL:user_id:unique (2 rows)
|
|
192
|
+
[0] row=5, dupes=3: user_id=123, name=Alice
|
|
193
|
+
[1] row=8, dupes=3: user_id=123, name=Bob
|
|
194
|
+
"""
|
|
195
|
+
if not self._samples:
|
|
196
|
+
return f"SAMPLES: {self.rule_id} (0 rows)"
|
|
197
|
+
|
|
198
|
+
lines = [f"SAMPLES: {self.rule_id} ({len(self._samples)} rows)"]
|
|
199
|
+
|
|
200
|
+
for i, row in enumerate(self._samples[:10]): # Limit to 10 for token efficiency
|
|
201
|
+
# Extract special columns for prefix
|
|
202
|
+
row_idx = row.get("_row_index")
|
|
203
|
+
dup_count = row.get("_duplicate_count")
|
|
204
|
+
|
|
205
|
+
# Build prefix with metadata
|
|
206
|
+
prefix_parts = []
|
|
207
|
+
if row_idx is not None:
|
|
208
|
+
prefix_parts.append(f"row={row_idx}")
|
|
209
|
+
if dup_count is not None:
|
|
210
|
+
prefix_parts.append(f"dupes={dup_count}")
|
|
211
|
+
prefix = ", ".join(prefix_parts) + ": " if prefix_parts else ""
|
|
212
|
+
|
|
213
|
+
# Format remaining columns as compact key=value pairs
|
|
214
|
+
parts = []
|
|
215
|
+
for k, v in row.items():
|
|
216
|
+
# Skip special columns (already in prefix)
|
|
217
|
+
if k in ("_row_index", "_duplicate_count"):
|
|
218
|
+
continue
|
|
219
|
+
if v is None:
|
|
220
|
+
parts.append(f"{k}=None")
|
|
221
|
+
elif isinstance(v, str) and len(v) > 20:
|
|
222
|
+
parts.append(f"{k}={v[:20]}...")
|
|
223
|
+
else:
|
|
224
|
+
parts.append(f"{k}={v}")
|
|
225
|
+
lines.append(f"[{i}] {prefix}" + ", ".join(parts))
|
|
226
|
+
|
|
227
|
+
if len(self._samples) > 10:
|
|
228
|
+
lines.append(f"... +{len(self._samples) - 10} more rows")
|
|
229
|
+
|
|
230
|
+
return "\n".join(lines)
|
|
231
|
+
|
|
232
|
+
|
|
233
|
+
class SampleReason:
|
|
234
|
+
"""Constants for why samples may be unavailable."""
|
|
235
|
+
|
|
236
|
+
UNAVAILABLE_METADATA = "unavailable_from_metadata" # Preplan tier - knows existence, not location
|
|
237
|
+
UNAVAILABLE_PASSED = "rule_passed" # No failures to sample
|
|
238
|
+
UNAVAILABLE_UNSUPPORTED = "rule_unsupported" # dtype, min_rows, etc. - no row-level samples
|
|
239
|
+
TRUNCATED_BUDGET = "budget_exhausted" # Global budget hit
|
|
240
|
+
TRUNCATED_LIMIT = "per_rule_limit" # Per-rule cap hit
|
|
241
|
+
|
|
242
|
+
|
|
243
|
+
@dataclass
|
|
244
|
+
class RuleResult:
|
|
245
|
+
"""
|
|
246
|
+
Result for a single validation rule.
|
|
247
|
+
|
|
248
|
+
Properties:
|
|
249
|
+
rule_id: Unique identifier (e.g., "COL:user_id:not_null")
|
|
250
|
+
name: Rule type name (e.g., "not_null")
|
|
251
|
+
passed: Whether the rule passed
|
|
252
|
+
failed_count: Number of failing rows
|
|
253
|
+
violation_rate: Fraction of rows that failed (0.0-1.0), or None if passed
|
|
254
|
+
message: Human-readable result message
|
|
255
|
+
severity: "blocking" | "warning" | "info"
|
|
256
|
+
source: Measurement source ("metadata", "sql", "polars")
|
|
257
|
+
column: Column name if applicable
|
|
258
|
+
context: Consumer-defined metadata (owner, tags, fix_hint, etc.)
|
|
259
|
+
annotations: List of annotations on this rule (opt-in, loaded via get_run_with_annotations)
|
|
260
|
+
severity_weight: User-defined numeric weight (None if unconfigured)
|
|
261
|
+
|
|
262
|
+
Sampling properties:
|
|
263
|
+
samples: List of sample failing rows, or None if unavailable
|
|
264
|
+
samples_source: Where samples came from ("sql", "polars"), or None
|
|
265
|
+
samples_reason: Why samples unavailable (see SampleReason)
|
|
266
|
+
samples_truncated: True if more samples exist but were cut off
|
|
267
|
+
"""
|
|
268
|
+
|
|
269
|
+
rule_id: str
|
|
270
|
+
name: str
|
|
271
|
+
passed: bool
|
|
272
|
+
failed_count: int
|
|
273
|
+
message: str
|
|
274
|
+
severity: str = "blocking"
|
|
275
|
+
source: str = "polars"
|
|
276
|
+
column: Optional[str] = None
|
|
277
|
+
details: Optional[Dict[str, Any]] = None
|
|
278
|
+
context: Optional[Dict[str, Any]] = None
|
|
279
|
+
failure_mode: Optional[str] = None # Semantic failure type (e.g., "config_error", "null_values")
|
|
280
|
+
|
|
281
|
+
# Sampling fields (eager sampling)
|
|
282
|
+
samples: Optional[List[Dict[str, Any]]] = None
|
|
283
|
+
samples_source: Optional[str] = None
|
|
284
|
+
samples_reason: Optional[str] = None
|
|
285
|
+
samples_truncated: bool = False
|
|
286
|
+
|
|
287
|
+
# Annotations (opt-in, loaded via get_run_with_annotations)
|
|
288
|
+
annotations: Optional[List[Dict[str, Any]]] = None
|
|
289
|
+
|
|
290
|
+
# LLM juice: user-defined severity weight (None if unconfigured)
|
|
291
|
+
severity_weight: Optional[float] = None
|
|
292
|
+
|
|
293
|
+
# For violation_rate computation (populated during result creation)
|
|
294
|
+
_total_rows: Optional[int] = field(default=None, repr=False)
|
|
295
|
+
|
|
296
|
+
@property
|
|
297
|
+
def violation_rate(self) -> Optional[float]:
|
|
298
|
+
"""
|
|
299
|
+
Fraction of rows that failed this rule.
|
|
300
|
+
|
|
301
|
+
Returns:
|
|
302
|
+
Float 0.0-1.0, or None if:
|
|
303
|
+
- Rule passed (failed_count == 0)
|
|
304
|
+
- total_rows is 0 or unknown
|
|
305
|
+
|
|
306
|
+
Example:
|
|
307
|
+
for rule in result.rules:
|
|
308
|
+
if rule.violation_rate:
|
|
309
|
+
print(f"{rule.rule_id}: {rule.violation_rate:.2%} of rows failed")
|
|
310
|
+
"""
|
|
311
|
+
if self.passed or self.failed_count == 0:
|
|
312
|
+
return None
|
|
313
|
+
if self._total_rows is None or self._total_rows == 0:
|
|
314
|
+
return None
|
|
315
|
+
return self.failed_count / self._total_rows
|
|
316
|
+
|
|
317
|
+
def __repr__(self) -> str:
|
|
318
|
+
status = "PASS" if self.passed else "FAIL"
|
|
319
|
+
if self.failed_count > 0:
|
|
320
|
+
return f"RuleResult({self.rule_id}) {status} - {self.failed_count:,} failures"
|
|
321
|
+
return f"RuleResult({self.rule_id}) {status}"
|
|
322
|
+
|
|
323
|
+
@classmethod
|
|
324
|
+
def from_dict(cls, d: Dict[str, Any]) -> "RuleResult":
|
|
325
|
+
"""Create from engine result dict."""
|
|
326
|
+
rule_id = d.get("rule_id", "")
|
|
327
|
+
|
|
328
|
+
# Extract column from rule_id if present
|
|
329
|
+
column = None
|
|
330
|
+
if rule_id.startswith("COL:"):
|
|
331
|
+
parts = rule_id.split(":")
|
|
332
|
+
if len(parts) >= 2:
|
|
333
|
+
column = parts[1]
|
|
334
|
+
|
|
335
|
+
# Extract rule name
|
|
336
|
+
name = d.get("rule_name", d.get("name", ""))
|
|
337
|
+
if not name and ":" in rule_id:
|
|
338
|
+
name = rule_id.split(":")[-1]
|
|
339
|
+
|
|
340
|
+
return cls(
|
|
341
|
+
rule_id=rule_id,
|
|
342
|
+
name=name,
|
|
343
|
+
passed=d.get("passed", False),
|
|
344
|
+
failed_count=d.get("failed_count", 0),
|
|
345
|
+
message=d.get("message", ""),
|
|
346
|
+
severity=d.get("severity", "blocking"),
|
|
347
|
+
source=d.get("execution_source", d.get("source", "polars")),
|
|
348
|
+
column=column,
|
|
349
|
+
details=d.get("details"),
|
|
350
|
+
context=d.get("context"),
|
|
351
|
+
failure_mode=d.get("failure_mode"),
|
|
352
|
+
# Sampling fields
|
|
353
|
+
samples=d.get("samples"),
|
|
354
|
+
samples_source=d.get("samples_source"),
|
|
355
|
+
samples_reason=d.get("samples_reason"),
|
|
356
|
+
samples_truncated=d.get("samples_truncated", False),
|
|
357
|
+
# LLM juice
|
|
358
|
+
severity_weight=d.get("severity_weight"),
|
|
359
|
+
)
|
|
360
|
+
|
|
361
|
+
def to_dict(self) -> Dict[str, Any]:
|
|
362
|
+
"""Convert to dictionary."""
|
|
363
|
+
d = {
|
|
364
|
+
"rule_id": self.rule_id,
|
|
365
|
+
"name": self.name,
|
|
366
|
+
"passed": self.passed,
|
|
367
|
+
"failed_count": self.failed_count,
|
|
368
|
+
"message": self.message,
|
|
369
|
+
"severity": self.severity,
|
|
370
|
+
"source": self.source,
|
|
371
|
+
}
|
|
372
|
+
# Include violation_rate if available
|
|
373
|
+
if self.violation_rate is not None:
|
|
374
|
+
d["violation_rate"] = self.violation_rate
|
|
375
|
+
if self.column:
|
|
376
|
+
d["column"] = self.column
|
|
377
|
+
if self.details:
|
|
378
|
+
d["details"] = self.details
|
|
379
|
+
if self.context:
|
|
380
|
+
d["context"] = self.context
|
|
381
|
+
if self.failure_mode:
|
|
382
|
+
d["failure_mode"] = self.failure_mode
|
|
383
|
+
|
|
384
|
+
# Sampling fields - always include for clarity
|
|
385
|
+
d["samples"] = self.samples # None = unavailable, [] = none found
|
|
386
|
+
if self.samples_source:
|
|
387
|
+
d["samples_source"] = self.samples_source
|
|
388
|
+
if self.samples_reason:
|
|
389
|
+
d["samples_reason"] = self.samples_reason
|
|
390
|
+
if self.samples_truncated:
|
|
391
|
+
d["samples_truncated"] = self.samples_truncated
|
|
392
|
+
|
|
393
|
+
# Annotations (opt-in)
|
|
394
|
+
if self.annotations is not None:
|
|
395
|
+
d["annotations"] = self.annotations
|
|
396
|
+
|
|
397
|
+
# LLM juice (only include if configured)
|
|
398
|
+
if self.severity_weight is not None:
|
|
399
|
+
d["severity_weight"] = self.severity_weight
|
|
400
|
+
|
|
401
|
+
return d
|
|
402
|
+
|
|
403
|
+
def to_llm(self) -> str:
|
|
404
|
+
"""Token-optimized format for LLM context."""
|
|
405
|
+
status = "PASS" if self.passed else "FAIL"
|
|
406
|
+
parts = [f"{self.rule_id}: {status}"]
|
|
407
|
+
|
|
408
|
+
if self.failed_count > 0:
|
|
409
|
+
parts.append(f"({self.failed_count:,} failures)")
|
|
410
|
+
|
|
411
|
+
# Include violation_rate for failed rules (LLM juice)
|
|
412
|
+
if self.violation_rate is not None:
|
|
413
|
+
parts.append(f"[{self.violation_rate:.1%}]")
|
|
414
|
+
|
|
415
|
+
# Include severity weight if configured (LLM juice)
|
|
416
|
+
if self.severity_weight is not None:
|
|
417
|
+
parts.append(f"[w={self.severity_weight}]")
|
|
418
|
+
|
|
419
|
+
# Add samples if available
|
|
420
|
+
if self.samples:
|
|
421
|
+
parts.append(f"\n Samples ({len(self.samples)}):")
|
|
422
|
+
for i, row in enumerate(self.samples[:5]):
|
|
423
|
+
# Extract metadata
|
|
424
|
+
row_idx = row.get("_row_index")
|
|
425
|
+
dup_count = row.get("_duplicate_count")
|
|
426
|
+
prefix_parts = []
|
|
427
|
+
if row_idx is not None:
|
|
428
|
+
prefix_parts.append(f"row={row_idx}")
|
|
429
|
+
if dup_count is not None:
|
|
430
|
+
prefix_parts.append(f"dupes={dup_count}")
|
|
431
|
+
prefix = ", ".join(prefix_parts) + ": " if prefix_parts else ""
|
|
432
|
+
|
|
433
|
+
# Format data columns
|
|
434
|
+
data_parts = []
|
|
435
|
+
for k, v in row.items():
|
|
436
|
+
if k in ("_row_index", "_duplicate_count"):
|
|
437
|
+
continue
|
|
438
|
+
if v is None:
|
|
439
|
+
data_parts.append(f"{k}=None")
|
|
440
|
+
elif isinstance(v, str) and len(v) > 15:
|
|
441
|
+
data_parts.append(f"{k}={v[:15]}...")
|
|
442
|
+
else:
|
|
443
|
+
data_parts.append(f"{k}={v}")
|
|
444
|
+
parts.append(f" [{i}] {prefix}" + ", ".join(data_parts[:5]))
|
|
445
|
+
if len(self.samples) > 5:
|
|
446
|
+
parts.append(f" ... +{len(self.samples) - 5} more")
|
|
447
|
+
elif self.samples_reason:
|
|
448
|
+
parts.append(f"\n Samples: {self.samples_reason}")
|
|
449
|
+
|
|
450
|
+
# Add annotations if available
|
|
451
|
+
if self.annotations:
|
|
452
|
+
parts.append(f"\n Annotations ({len(self.annotations)}):")
|
|
453
|
+
for ann in self.annotations[:3]:
|
|
454
|
+
ann_type = ann.get("annotation_type", "note")
|
|
455
|
+
actor = ann.get("actor_id", "unknown")
|
|
456
|
+
summary = ann.get("summary", "")[:40]
|
|
457
|
+
if len(ann.get("summary", "")) > 40:
|
|
458
|
+
summary += "..."
|
|
459
|
+
parts.append(f' [{ann_type}] by {actor}: "{summary}"')
|
|
460
|
+
if len(self.annotations) > 3:
|
|
461
|
+
parts.append(f" ... +{len(self.annotations) - 3} more")
|
|
462
|
+
|
|
463
|
+
return " ".join(parts[:2]) + "".join(parts[2:])
|
|
464
|
+
|
|
465
|
+
|
|
466
|
+
@dataclass
|
|
467
|
+
class ValidationResult:
|
|
468
|
+
"""
|
|
469
|
+
Result of a validation run.
|
|
470
|
+
|
|
471
|
+
Properties:
|
|
472
|
+
passed: True if all blocking rules passed
|
|
473
|
+
dataset: Dataset name/path
|
|
474
|
+
total_rows: Number of rows in the validated dataset
|
|
475
|
+
total_rules: Total number of rules evaluated
|
|
476
|
+
passed_count: Number of rules that passed
|
|
477
|
+
failed_count: Number of blocking rules that failed
|
|
478
|
+
warning_count: Number of warning rules that failed
|
|
479
|
+
rules: List of RuleResult objects
|
|
480
|
+
blocking_failures: List of failed blocking rules
|
|
481
|
+
warnings: List of failed warning rules
|
|
482
|
+
quality_score: Deterministic score 0.0-1.0 (None if weights unconfigured)
|
|
483
|
+
data: The validated DataFrame (if loaded), None if preplan/pushdown handled everything
|
|
484
|
+
stats: Optional statistics dict
|
|
485
|
+
annotations: List of run-level annotations (opt-in, loaded via get_run_with_annotations)
|
|
486
|
+
|
|
487
|
+
Methods:
|
|
488
|
+
sample_failures(rule_id, n=5): Get sample of failing rows for a rule
|
|
489
|
+
|
|
490
|
+
Note:
|
|
491
|
+
Each RuleResult has a `violation_rate` property for per-rule failure rates.
|
|
492
|
+
"""
|
|
493
|
+
|
|
494
|
+
passed: bool
|
|
495
|
+
dataset: str
|
|
496
|
+
total_rows: int
|
|
497
|
+
total_rules: int
|
|
498
|
+
passed_count: int
|
|
499
|
+
failed_count: int
|
|
500
|
+
warning_count: int
|
|
501
|
+
rules: List[RuleResult]
|
|
502
|
+
stats: Optional[Dict[str, Any]] = None
|
|
503
|
+
annotations: Optional[List[Dict[str, Any]]] = None # Run-level annotations (opt-in)
|
|
504
|
+
_raw: Optional[Dict[str, Any]] = field(default=None, repr=False)
|
|
505
|
+
# For sample_failures() - lazy evaluation
|
|
506
|
+
_data_source: Optional[Any] = field(default=None, repr=False)
|
|
507
|
+
_rule_objects: Optional[List[Any]] = field(default=None, repr=False)
|
|
508
|
+
# Loaded data (if Polars execution occurred)
|
|
509
|
+
_data: Optional[Any] = field(default=None, repr=False)
|
|
510
|
+
|
|
511
|
+
@property
|
|
512
|
+
def data(self) -> Optional["pl.DataFrame"]:
|
|
513
|
+
"""
|
|
514
|
+
The validated DataFrame, if data was loaded.
|
|
515
|
+
|
|
516
|
+
Returns the Polars DataFrame that was validated when:
|
|
517
|
+
- Polars execution occurred (residual rules needed data)
|
|
518
|
+
- A DataFrame was passed directly to validate()
|
|
519
|
+
|
|
520
|
+
Returns None when:
|
|
521
|
+
- All rules were resolved by preplan/pushdown (no data loaded)
|
|
522
|
+
- Data source was a file path and wasn't materialized
|
|
523
|
+
|
|
524
|
+
Example:
|
|
525
|
+
result = kontra.validate("data.parquet", rules=[...], preplan="off", pushdown="off")
|
|
526
|
+
if result.passed and result.data is not None:
|
|
527
|
+
# Use the already-loaded data
|
|
528
|
+
process(result.data)
|
|
529
|
+
"""
|
|
530
|
+
return self._data
|
|
531
|
+
|
|
532
|
+
def __repr__(self) -> str:
|
|
533
|
+
status = "PASSED" if self.passed else "FAILED"
|
|
534
|
+
parts = [f"ValidationResult({self.dataset}) {status}"]
|
|
535
|
+
parts.append(f" Total: {self.total_rules} rules | Passed: {self.passed_count} | Failed: {self.failed_count}")
|
|
536
|
+
if self.warning_count > 0:
|
|
537
|
+
parts.append(f" Warnings: {self.warning_count}")
|
|
538
|
+
if not self.passed:
|
|
539
|
+
blocking = [r.rule_id for r in self.blocking_failures[:3]]
|
|
540
|
+
if blocking:
|
|
541
|
+
parts.append(f" Blocking: {', '.join(blocking)}")
|
|
542
|
+
if len(self.blocking_failures) > 3:
|
|
543
|
+
parts.append(f" ... and {len(self.blocking_failures) - 3} more")
|
|
544
|
+
return "\n".join(parts)
|
|
545
|
+
|
|
546
|
+
@property
|
|
547
|
+
def blocking_failures(self) -> List[RuleResult]:
|
|
548
|
+
"""Get all failed blocking rules."""
|
|
549
|
+
return [r for r in self.rules if not r.passed and r.severity == "blocking"]
|
|
550
|
+
|
|
551
|
+
@property
|
|
552
|
+
def warnings(self) -> List[RuleResult]:
|
|
553
|
+
"""Get all failed warning rules."""
|
|
554
|
+
return [r for r in self.rules if not r.passed and r.severity == "warning"]
|
|
555
|
+
|
|
556
|
+
@property
|
|
557
|
+
def quality_score(self) -> Optional[float]:
|
|
558
|
+
"""
|
|
559
|
+
Deterministic quality score derived from violation data.
|
|
560
|
+
|
|
561
|
+
Formula:
|
|
562
|
+
quality_score = 1.0 - weighted_violation_rate
|
|
563
|
+
weighted_violation_rate = Σ(failed_count * severity_weight) / (total_rows * Σ(weights))
|
|
564
|
+
|
|
565
|
+
Returns:
|
|
566
|
+
Float 0.0-1.0, or None if:
|
|
567
|
+
- severity_weights not configured
|
|
568
|
+
- total_rows is 0
|
|
569
|
+
- No rules have weights
|
|
570
|
+
|
|
571
|
+
Note:
|
|
572
|
+
This score is pure data - Kontra never interprets it as "good" or "bad".
|
|
573
|
+
Consumers/agents use it for trend reasoning.
|
|
574
|
+
"""
|
|
575
|
+
# Check if any rule has a weight (weights configured)
|
|
576
|
+
rules_with_weights = [r for r in self.rules if r.severity_weight is not None]
|
|
577
|
+
if not rules_with_weights:
|
|
578
|
+
return None
|
|
579
|
+
|
|
580
|
+
# Avoid division by zero
|
|
581
|
+
if self.total_rows == 0:
|
|
582
|
+
return None
|
|
583
|
+
|
|
584
|
+
# Calculate weighted violation sum
|
|
585
|
+
weighted_violations = sum(
|
|
586
|
+
r.failed_count * r.severity_weight
|
|
587
|
+
for r in rules_with_weights
|
|
588
|
+
)
|
|
589
|
+
|
|
590
|
+
# Calculate total possible weighted violations
|
|
591
|
+
# (if every row failed every rule)
|
|
592
|
+
total_weight = sum(r.severity_weight for r in rules_with_weights)
|
|
593
|
+
max_weighted_violations = self.total_rows * total_weight
|
|
594
|
+
|
|
595
|
+
if max_weighted_violations == 0:
|
|
596
|
+
return 1.0 # No possible violations
|
|
597
|
+
|
|
598
|
+
# Quality = 1 - violation_rate
|
|
599
|
+
violation_rate = weighted_violations / max_weighted_violations
|
|
600
|
+
return max(0.0, min(1.0, 1.0 - violation_rate))
|
|
601
|
+
|
|
602
|
+
@classmethod
|
|
603
|
+
def from_engine_result(
|
|
604
|
+
cls,
|
|
605
|
+
result: Dict[str, Any],
|
|
606
|
+
dataset: str = "unknown",
|
|
607
|
+
data_source: Optional[Any] = None,
|
|
608
|
+
rule_objects: Optional[List[Any]] = None,
|
|
609
|
+
sample: int = 5,
|
|
610
|
+
sample_budget: int = 50,
|
|
611
|
+
sample_columns: Optional[Union[List[str], str]] = None,
|
|
612
|
+
severity_weights: Optional[Dict[str, float]] = None,
|
|
613
|
+
data: Optional[Any] = None,
|
|
614
|
+
) -> "ValidationResult":
|
|
615
|
+
"""Create from ValidationEngine.run() result dict.
|
|
616
|
+
|
|
617
|
+
Args:
|
|
618
|
+
result: Engine result dict
|
|
619
|
+
dataset: Dataset name (fallback)
|
|
620
|
+
data_source: Original data source for lazy sample_failures()
|
|
621
|
+
rule_objects: Rule objects for sample_failures() predicates
|
|
622
|
+
sample: Per-rule sample cap (0 to disable)
|
|
623
|
+
sample_budget: Global sample cap across all rules
|
|
624
|
+
sample_columns: Columns to include in samples (None=all, list=specific, "relevant"=rule columns)
|
|
625
|
+
severity_weights: User-defined severity weights from config (None if unconfigured)
|
|
626
|
+
data: Loaded DataFrame (if Polars execution occurred)
|
|
627
|
+
"""
|
|
628
|
+
summary = result.get("summary", {})
|
|
629
|
+
results_list = result.get("results", [])
|
|
630
|
+
|
|
631
|
+
# Convert raw results to RuleResult objects
|
|
632
|
+
rules = [RuleResult.from_dict(r) for r in results_list]
|
|
633
|
+
|
|
634
|
+
# Populate context from rule objects if available
|
|
635
|
+
if rule_objects is not None:
|
|
636
|
+
context_map = {
|
|
637
|
+
getattr(r, "rule_id", r.name): getattr(r, "context", {})
|
|
638
|
+
for r in rule_objects
|
|
639
|
+
}
|
|
640
|
+
for rule_result in rules:
|
|
641
|
+
ctx = context_map.get(rule_result.rule_id)
|
|
642
|
+
if ctx:
|
|
643
|
+
rule_result.context = ctx
|
|
644
|
+
|
|
645
|
+
# Populate severity weights from config (LLM juice)
|
|
646
|
+
if severity_weights is not None:
|
|
647
|
+
for rule_result in rules:
|
|
648
|
+
weight = severity_weights.get(rule_result.severity)
|
|
649
|
+
if weight is not None:
|
|
650
|
+
rule_result.severity_weight = weight
|
|
651
|
+
|
|
652
|
+
# Calculate counts
|
|
653
|
+
total = summary.get("total_rules", len(rules))
|
|
654
|
+
passed_count = summary.get("rules_passed", sum(1 for r in rules if r.passed))
|
|
655
|
+
|
|
656
|
+
# Count by severity
|
|
657
|
+
blocking_failed = sum(1 for r in rules if not r.passed and r.severity == "blocking")
|
|
658
|
+
warning_failed = sum(1 for r in rules if not r.passed and r.severity == "warning")
|
|
659
|
+
|
|
660
|
+
# Extract total_rows from summary
|
|
661
|
+
total_rows = summary.get("total_rows", 0)
|
|
662
|
+
|
|
663
|
+
# Populate _total_rows on each rule for violation_rate property
|
|
664
|
+
for rule_result in rules:
|
|
665
|
+
rule_result._total_rows = total_rows
|
|
666
|
+
|
|
667
|
+
# Create instance first (need it for sampling methods)
|
|
668
|
+
instance = cls(
|
|
669
|
+
passed=summary.get("passed", blocking_failed == 0),
|
|
670
|
+
dataset=summary.get("dataset_name", dataset),
|
|
671
|
+
total_rows=total_rows,
|
|
672
|
+
total_rules=total,
|
|
673
|
+
passed_count=passed_count,
|
|
674
|
+
failed_count=blocking_failed,
|
|
675
|
+
warning_count=warning_failed,
|
|
676
|
+
rules=rules,
|
|
677
|
+
stats=result.get("stats"),
|
|
678
|
+
_raw=result,
|
|
679
|
+
_data_source=data_source,
|
|
680
|
+
_rule_objects=rule_objects,
|
|
681
|
+
_data=data,
|
|
682
|
+
)
|
|
683
|
+
|
|
684
|
+
# Perform eager sampling if enabled
|
|
685
|
+
if sample > 0 and rule_objects is not None:
|
|
686
|
+
instance._perform_eager_sampling(sample, sample_budget, rule_objects, sample_columns)
|
|
687
|
+
|
|
688
|
+
return instance
|
|
689
|
+
|
|
690
|
+
def _perform_eager_sampling(
|
|
691
|
+
self,
|
|
692
|
+
per_rule_cap: int,
|
|
693
|
+
global_budget: int,
|
|
694
|
+
rule_objects: List[Any],
|
|
695
|
+
sample_columns: Optional[Union[List[str], str]] = None,
|
|
696
|
+
) -> None:
|
|
697
|
+
"""
|
|
698
|
+
Populate samples for each rule (eager sampling).
|
|
699
|
+
|
|
700
|
+
Uses batched SQL queries when possible (1 query for all rules).
|
|
701
|
+
Falls back to per-rule Polars sampling when SQL not supported.
|
|
702
|
+
|
|
703
|
+
Args:
|
|
704
|
+
per_rule_cap: Max samples per rule
|
|
705
|
+
global_budget: Total samples across all rules
|
|
706
|
+
rule_objects: Rule objects for predicates
|
|
707
|
+
sample_columns: Columns to include (None=all, list=specific, "relevant"=rule columns)
|
|
708
|
+
"""
|
|
709
|
+
import polars as pl
|
|
710
|
+
|
|
711
|
+
# Build rule_id -> rule_object map
|
|
712
|
+
rule_map = {getattr(r, "rule_id", None): r for r in rule_objects}
|
|
713
|
+
|
|
714
|
+
# Sort rules by failed_count descending (worst offenders first)
|
|
715
|
+
sorted_rules = sorted(
|
|
716
|
+
self.rules,
|
|
717
|
+
key=lambda r: r.failed_count if not r.passed else 0,
|
|
718
|
+
reverse=True,
|
|
719
|
+
)
|
|
720
|
+
|
|
721
|
+
remaining_budget = global_budget
|
|
722
|
+
|
|
723
|
+
# Phase 1: Collect rules that can use SQL batching
|
|
724
|
+
sql_rules: List[Tuple[Any, Any, int, Optional[List[str]]]] = [] # (rule_result, rule_obj, n, columns)
|
|
725
|
+
polars_rules: List[Tuple[Any, Any, Any, int, Optional[List[str]]]] = [] # (rule_result, rule_obj, predicate, n, columns)
|
|
726
|
+
|
|
727
|
+
for rule_result in sorted_rules:
|
|
728
|
+
# Handle passing rules
|
|
729
|
+
if rule_result.passed:
|
|
730
|
+
rule_result.samples = []
|
|
731
|
+
rule_result.samples_reason = SampleReason.UNAVAILABLE_PASSED
|
|
732
|
+
continue
|
|
733
|
+
|
|
734
|
+
# Check budget
|
|
735
|
+
if remaining_budget <= 0:
|
|
736
|
+
rule_result.samples = None
|
|
737
|
+
rule_result.samples_reason = SampleReason.TRUNCATED_BUDGET
|
|
738
|
+
rule_result.samples_truncated = True
|
|
739
|
+
continue
|
|
740
|
+
|
|
741
|
+
# Get corresponding rule object
|
|
742
|
+
rule_obj = rule_map.get(rule_result.rule_id)
|
|
743
|
+
if rule_obj is None:
|
|
744
|
+
rule_result.samples = None
|
|
745
|
+
rule_result.samples_reason = SampleReason.UNAVAILABLE_UNSUPPORTED
|
|
746
|
+
continue
|
|
747
|
+
|
|
748
|
+
# Check if rule was resolved via metadata (preplan)
|
|
749
|
+
if rule_result.source == "metadata":
|
|
750
|
+
if not self._can_sample_source():
|
|
751
|
+
rule_result.samples = None
|
|
752
|
+
rule_result.samples_reason = SampleReason.UNAVAILABLE_METADATA
|
|
753
|
+
continue
|
|
754
|
+
|
|
755
|
+
# Calculate samples to get for this rule
|
|
756
|
+
n = min(per_rule_cap, remaining_budget)
|
|
757
|
+
remaining_budget -= n # Reserve budget
|
|
758
|
+
|
|
759
|
+
# Determine columns to include
|
|
760
|
+
cols_to_include = self._resolve_sample_columns(sample_columns, rule_obj)
|
|
761
|
+
|
|
762
|
+
# Check if rule supports SQL filter (for batching)
|
|
763
|
+
# Note: to_sql_filter can return None to indicate no SQL support for this dialect
|
|
764
|
+
sql_filter_available = False
|
|
765
|
+
if hasattr(rule_obj, "to_sql_filter"):
|
|
766
|
+
# Probe to see if SQL filter is actually available
|
|
767
|
+
sql_filter = rule_obj.to_sql_filter("duckdb")
|
|
768
|
+
if sql_filter is not None:
|
|
769
|
+
sql_rules.append((rule_result, rule_obj, n, cols_to_include))
|
|
770
|
+
sql_filter_available = True
|
|
771
|
+
|
|
772
|
+
# Fall back to sample_predicate/compile_predicate if SQL not available
|
|
773
|
+
if not sql_filter_available:
|
|
774
|
+
# Check sample_predicate() first (for rules like unique that have
|
|
775
|
+
# different counting vs sampling semantics)
|
|
776
|
+
pred_obj = None
|
|
777
|
+
if hasattr(rule_obj, "sample_predicate"):
|
|
778
|
+
pred_obj = rule_obj.sample_predicate()
|
|
779
|
+
if pred_obj is None and hasattr(rule_obj, "compile_predicate"):
|
|
780
|
+
pred_obj = rule_obj.compile_predicate()
|
|
781
|
+
|
|
782
|
+
if pred_obj is not None:
|
|
783
|
+
polars_rules.append((rule_result, rule_obj, pred_obj.expr, n, cols_to_include))
|
|
784
|
+
else:
|
|
785
|
+
rule_result.samples = None
|
|
786
|
+
rule_result.samples_reason = SampleReason.UNAVAILABLE_UNSUPPORTED
|
|
787
|
+
|
|
788
|
+
# Phase 2: Execute batched SQL sampling if applicable
|
|
789
|
+
if sql_rules:
|
|
790
|
+
self._execute_batched_sql_sampling(sql_rules, per_rule_cap)
|
|
791
|
+
|
|
792
|
+
# Phase 3: Execute per-rule Polars sampling for remaining rules
|
|
793
|
+
for rule_result, rule_obj, predicate, n, cols_to_include in polars_rules:
|
|
794
|
+
try:
|
|
795
|
+
samples, samples_source = self._collect_samples_for_rule(rule_obj, predicate, n, cols_to_include)
|
|
796
|
+
rule_result.samples = samples
|
|
797
|
+
rule_result.samples_source = samples_source
|
|
798
|
+
|
|
799
|
+
if len(samples) == per_rule_cap and rule_result.failed_count > per_rule_cap:
|
|
800
|
+
rule_result.samples_truncated = True
|
|
801
|
+
rule_result.samples_reason = SampleReason.TRUNCATED_LIMIT
|
|
802
|
+
|
|
803
|
+
except Exception as e:
|
|
804
|
+
rule_result.samples = None
|
|
805
|
+
rule_result.samples_reason = f"error: {str(e)[:50]}"
|
|
806
|
+
|
|
807
|
+
def _execute_batched_sql_sampling(
|
|
808
|
+
self,
|
|
809
|
+
sql_rules: List[Tuple[Any, Any, int, Optional[List[str]]]],
|
|
810
|
+
per_rule_cap: int,
|
|
811
|
+
) -> None:
|
|
812
|
+
"""
|
|
813
|
+
Execute batched SQL sampling for rules that support to_sql_filter().
|
|
814
|
+
|
|
815
|
+
Builds a single UNION ALL query for all rules, executes once,
|
|
816
|
+
and distributes results.
|
|
817
|
+
"""
|
|
818
|
+
import polars as pl
|
|
819
|
+
|
|
820
|
+
source = self._data_source
|
|
821
|
+
if source is None:
|
|
822
|
+
return
|
|
823
|
+
|
|
824
|
+
# Determine dialect and source type
|
|
825
|
+
dialect = "duckdb"
|
|
826
|
+
is_parquet = False
|
|
827
|
+
is_database = False
|
|
828
|
+
db_conn = None
|
|
829
|
+
db_table = None
|
|
830
|
+
parquet_path = None
|
|
831
|
+
|
|
832
|
+
is_s3 = False
|
|
833
|
+
|
|
834
|
+
if isinstance(source, str):
|
|
835
|
+
lower = source.lower()
|
|
836
|
+
if source.startswith("s3://") or source.startswith("http://") or source.startswith("https://"):
|
|
837
|
+
# Remote file - use DuckDB for efficient sampling
|
|
838
|
+
is_parquet = True
|
|
839
|
+
is_s3 = True
|
|
840
|
+
parquet_path = source
|
|
841
|
+
elif lower.endswith(".parquet") or lower.endswith(".csv"):
|
|
842
|
+
# Local file - use Polars (more efficient for local)
|
|
843
|
+
is_parquet = False # Will fall through to Polars path
|
|
844
|
+
elif hasattr(source, "scheme"):
|
|
845
|
+
handle = source
|
|
846
|
+
if handle.scheme in ("postgres", "postgresql"):
|
|
847
|
+
is_database = True
|
|
848
|
+
dialect = "postgres"
|
|
849
|
+
db_conn = getattr(handle, "external_conn", None)
|
|
850
|
+
db_table = getattr(handle, "table_ref", None) or f'"{handle.schema}"."{handle.path}"'
|
|
851
|
+
elif handle.scheme == "mssql":
|
|
852
|
+
is_database = True
|
|
853
|
+
dialect = "mssql"
|
|
854
|
+
db_conn = getattr(handle, "external_conn", None)
|
|
855
|
+
db_table = getattr(handle, "table_ref", None) or f"[{handle.schema}].[{handle.path}]"
|
|
856
|
+
elif handle.scheme == "s3":
|
|
857
|
+
# S3 via handle - use DuckDB
|
|
858
|
+
is_parquet = True
|
|
859
|
+
is_s3 = True
|
|
860
|
+
parquet_path = handle.uri
|
|
861
|
+
elif handle.scheme in ("", "file"):
|
|
862
|
+
# Local file via handle - use Polars
|
|
863
|
+
is_parquet = False
|
|
864
|
+
|
|
865
|
+
# Build list of (rule_id, sql_filter, limit, columns) for batching
|
|
866
|
+
rules_to_sample: List[Tuple[str, str, int, Optional[List[str]]]] = []
|
|
867
|
+
|
|
868
|
+
for rule_result, rule_obj, n, cols_to_include in sql_rules:
|
|
869
|
+
sql_filter = rule_obj.to_sql_filter(dialect)
|
|
870
|
+
if sql_filter:
|
|
871
|
+
rules_to_sample.append((rule_result.rule_id, sql_filter, n, cols_to_include))
|
|
872
|
+
else:
|
|
873
|
+
# Rule doesn't support this dialect, mark as unsupported
|
|
874
|
+
rule_result.samples = None
|
|
875
|
+
rule_result.samples_reason = SampleReason.UNAVAILABLE_UNSUPPORTED
|
|
876
|
+
|
|
877
|
+
if not rules_to_sample:
|
|
878
|
+
return
|
|
879
|
+
|
|
880
|
+
# Execute batched query
|
|
881
|
+
try:
|
|
882
|
+
if is_parquet and is_s3 and parquet_path:
|
|
883
|
+
# S3/remote: Use DuckDB batched sampling (much faster than Polars)
|
|
884
|
+
results = self._batch_sample_parquet_duckdb(parquet_path, rules_to_sample)
|
|
885
|
+
samples_source = "sql"
|
|
886
|
+
elif is_database and db_conn and db_table:
|
|
887
|
+
results = self._batch_sample_db(db_conn, db_table, rules_to_sample, dialect)
|
|
888
|
+
samples_source = "sql"
|
|
889
|
+
else:
|
|
890
|
+
# Fall back to per-rule sampling
|
|
891
|
+
results = {}
|
|
892
|
+
samples_source = "polars"
|
|
893
|
+
for rule_result, rule_obj, n, cols_to_include in sql_rules:
|
|
894
|
+
if hasattr(rule_obj, "compile_predicate"):
|
|
895
|
+
pred_obj = rule_obj.compile_predicate()
|
|
896
|
+
if pred_obj:
|
|
897
|
+
try:
|
|
898
|
+
samples, src = self._collect_samples_for_rule(rule_obj, pred_obj.expr, n, cols_to_include)
|
|
899
|
+
results[rule_result.rule_id] = samples
|
|
900
|
+
samples_source = src
|
|
901
|
+
except Exception:
|
|
902
|
+
results[rule_result.rule_id] = []
|
|
903
|
+
|
|
904
|
+
# Distribute results to rule_result objects
|
|
905
|
+
for rule_result, rule_obj, n, cols_to_include in sql_rules:
|
|
906
|
+
samples = results.get(rule_result.rule_id, [])
|
|
907
|
+
rule_result.samples = samples
|
|
908
|
+
rule_result.samples_source = samples_source
|
|
909
|
+
|
|
910
|
+
if len(samples) == n and rule_result.failed_count > n:
|
|
911
|
+
rule_result.samples_truncated = True
|
|
912
|
+
rule_result.samples_reason = SampleReason.TRUNCATED_LIMIT
|
|
913
|
+
|
|
914
|
+
except Exception as e:
|
|
915
|
+
# Batched sampling failed, mark all rules
|
|
916
|
+
for rule_result, rule_obj, n, cols_to_include in sql_rules:
|
|
917
|
+
rule_result.samples = None
|
|
918
|
+
rule_result.samples_reason = f"error: {str(e)[:50]}"
|
|
919
|
+
|
|
920
|
+
def _can_sample_source(self) -> bool:
|
|
921
|
+
"""
|
|
922
|
+
Check if the data source supports sampling.
|
|
923
|
+
|
|
924
|
+
File-based sources (Parquet, CSV, S3) can always be sampled.
|
|
925
|
+
Database sources need a live connection.
|
|
926
|
+
|
|
927
|
+
Returns:
|
|
928
|
+
True if sampling is possible, False otherwise.
|
|
929
|
+
"""
|
|
930
|
+
import polars as pl
|
|
931
|
+
|
|
932
|
+
source = self._data_source
|
|
933
|
+
|
|
934
|
+
if source is None:
|
|
935
|
+
return False
|
|
936
|
+
|
|
937
|
+
# DataFrame - always sampleable
|
|
938
|
+
if isinstance(source, pl.DataFrame):
|
|
939
|
+
return True
|
|
940
|
+
|
|
941
|
+
# String path - file based, always sampleable
|
|
942
|
+
if isinstance(source, str):
|
|
943
|
+
return True
|
|
944
|
+
|
|
945
|
+
# DatasetHandle - check scheme and connection
|
|
946
|
+
if hasattr(source, "scheme"):
|
|
947
|
+
scheme = getattr(source, "scheme", None)
|
|
948
|
+
|
|
949
|
+
# File-based schemes - always sampleable
|
|
950
|
+
if scheme in (None, "file") or (hasattr(source, "uri") and source.uri):
|
|
951
|
+
uri = getattr(source, "uri", "")
|
|
952
|
+
if uri.lower().endswith((".parquet", ".csv")) or uri.startswith("s3://"):
|
|
953
|
+
return True
|
|
954
|
+
|
|
955
|
+
# BYOC or database with connection - check if connection exists
|
|
956
|
+
if hasattr(source, "external_conn") and source.external_conn is not None:
|
|
957
|
+
return True
|
|
958
|
+
|
|
959
|
+
# Database without connection - can't sample
|
|
960
|
+
if scheme in ("postgres", "postgresql", "mssql"):
|
|
961
|
+
return False
|
|
962
|
+
|
|
963
|
+
return True # Default to sampleable
|
|
964
|
+
|
|
965
|
+
def _resolve_sample_columns(
|
|
966
|
+
self,
|
|
967
|
+
sample_columns: Optional[Union[List[str], str]],
|
|
968
|
+
rule_obj: Any,
|
|
969
|
+
) -> Optional[List[str]]:
|
|
970
|
+
"""
|
|
971
|
+
Resolve sample_columns to a list of column names.
|
|
972
|
+
|
|
973
|
+
Args:
|
|
974
|
+
sample_columns: None (all), list of names, or "relevant"
|
|
975
|
+
rule_obj: Rule object for "relevant" mode
|
|
976
|
+
|
|
977
|
+
Returns:
|
|
978
|
+
List of column names to include, or None for all columns
|
|
979
|
+
"""
|
|
980
|
+
if sample_columns is None:
|
|
981
|
+
return None
|
|
982
|
+
|
|
983
|
+
if isinstance(sample_columns, list):
|
|
984
|
+
return sample_columns
|
|
985
|
+
|
|
986
|
+
if sample_columns == "relevant":
|
|
987
|
+
# Get columns from rule's required_columns() if available
|
|
988
|
+
cols = set()
|
|
989
|
+
if hasattr(rule_obj, "required_columns"):
|
|
990
|
+
cols.update(rule_obj.required_columns())
|
|
991
|
+
|
|
992
|
+
# Also check params for column names (required_columns() may be incomplete)
|
|
993
|
+
if hasattr(rule_obj, "params"):
|
|
994
|
+
params = rule_obj.params
|
|
995
|
+
if "column" in params:
|
|
996
|
+
cols.add(params["column"])
|
|
997
|
+
if "left" in params:
|
|
998
|
+
cols.add(params["left"])
|
|
999
|
+
if "right" in params:
|
|
1000
|
+
cols.add(params["right"])
|
|
1001
|
+
if "when_column" in params:
|
|
1002
|
+
cols.add(params["when_column"])
|
|
1003
|
+
|
|
1004
|
+
return list(cols) if cols else None
|
|
1005
|
+
|
|
1006
|
+
# Unknown value - return all columns
|
|
1007
|
+
return None
|
|
1008
|
+
|
|
1009
|
+
def _collect_samples_for_rule(
|
|
1010
|
+
self,
|
|
1011
|
+
rule_obj: Any,
|
|
1012
|
+
predicate: Any,
|
|
1013
|
+
n: int,
|
|
1014
|
+
columns: Optional[List[str]] = None,
|
|
1015
|
+
) -> Tuple[List[Dict[str, Any]], str]:
|
|
1016
|
+
"""
|
|
1017
|
+
Collect sample rows for a single rule.
|
|
1018
|
+
|
|
1019
|
+
Uses the existing sampling infrastructure (SQL pushdown, Parquet predicate, etc.)
|
|
1020
|
+
|
|
1021
|
+
Args:
|
|
1022
|
+
rule_obj: Rule object
|
|
1023
|
+
predicate: Polars expression for filtering
|
|
1024
|
+
n: Number of samples to collect
|
|
1025
|
+
columns: Columns to include (None = all)
|
|
1026
|
+
|
|
1027
|
+
Returns:
|
|
1028
|
+
Tuple of (samples list, source string "sql" or "polars")
|
|
1029
|
+
"""
|
|
1030
|
+
import polars as pl
|
|
1031
|
+
|
|
1032
|
+
source = self._data_source
|
|
1033
|
+
|
|
1034
|
+
if source is None:
|
|
1035
|
+
return [], "polars"
|
|
1036
|
+
|
|
1037
|
+
# Reuse existing loading/filtering logic
|
|
1038
|
+
df, load_source = self._load_data_for_sampling(rule_obj, n)
|
|
1039
|
+
|
|
1040
|
+
# If data was already filtered by SQL/DuckDB, just apply column projection
|
|
1041
|
+
if load_source == "sql":
|
|
1042
|
+
result_df = df.head(n)
|
|
1043
|
+
return self._apply_column_projection(result_df, columns), "sql"
|
|
1044
|
+
|
|
1045
|
+
# For Polars path, filter with predicate (unique rule handled by helper)
|
|
1046
|
+
result_df = _filter_samples_polars(df, rule_obj, predicate, n)
|
|
1047
|
+
|
|
1048
|
+
# For unique rule, always include _duplicate_count in projection
|
|
1049
|
+
if _is_unique_rule(rule_obj) and columns is not None:
|
|
1050
|
+
columns = list(columns) + ["_duplicate_count"]
|
|
1051
|
+
|
|
1052
|
+
return self._apply_column_projection(result_df, columns), "polars"
|
|
1053
|
+
|
|
1054
|
+
def _apply_column_projection(
|
|
1055
|
+
self,
|
|
1056
|
+
df: Any,
|
|
1057
|
+
columns: Optional[List[str]],
|
|
1058
|
+
) -> List[Dict[str, Any]]:
|
|
1059
|
+
"""
|
|
1060
|
+
Apply column projection to a DataFrame before converting to dicts.
|
|
1061
|
+
|
|
1062
|
+
Always includes _row_index if present.
|
|
1063
|
+
|
|
1064
|
+
Args:
|
|
1065
|
+
df: Polars DataFrame
|
|
1066
|
+
columns: Columns to include (None = all)
|
|
1067
|
+
|
|
1068
|
+
Returns:
|
|
1069
|
+
List of row dicts
|
|
1070
|
+
"""
|
|
1071
|
+
if columns is None:
|
|
1072
|
+
return df.to_dicts()
|
|
1073
|
+
|
|
1074
|
+
# Always include _row_index and _duplicate_count if present
|
|
1075
|
+
cols_to_select = set(columns)
|
|
1076
|
+
if "_row_index" in df.columns:
|
|
1077
|
+
cols_to_select.add("_row_index")
|
|
1078
|
+
if "_duplicate_count" in df.columns:
|
|
1079
|
+
cols_to_select.add("_duplicate_count")
|
|
1080
|
+
|
|
1081
|
+
# Only select columns that exist in the DataFrame
|
|
1082
|
+
available_cols = set(df.columns)
|
|
1083
|
+
cols_to_select = cols_to_select & available_cols
|
|
1084
|
+
|
|
1085
|
+
if not cols_to_select:
|
|
1086
|
+
return df.to_dicts()
|
|
1087
|
+
|
|
1088
|
+
return df.select(sorted(cols_to_select)).to_dicts()
|
|
1089
|
+
|
|
1090
|
+
def to_dict(self) -> Dict[str, Any]:
|
|
1091
|
+
"""Convert to dictionary."""
|
|
1092
|
+
d = {
|
|
1093
|
+
"passed": self.passed,
|
|
1094
|
+
"dataset": self.dataset,
|
|
1095
|
+
"total_rows": self.total_rows,
|
|
1096
|
+
"total_rules": self.total_rules,
|
|
1097
|
+
"passed_count": self.passed_count,
|
|
1098
|
+
"failed_count": self.failed_count,
|
|
1099
|
+
"warning_count": self.warning_count,
|
|
1100
|
+
"rules": [r.to_dict() for r in self.rules],
|
|
1101
|
+
"stats": self.stats,
|
|
1102
|
+
}
|
|
1103
|
+
if self.annotations is not None:
|
|
1104
|
+
d["annotations"] = self.annotations
|
|
1105
|
+
# LLM juice (only include if configured)
|
|
1106
|
+
if self.quality_score is not None:
|
|
1107
|
+
d["quality_score"] = self.quality_score
|
|
1108
|
+
return d
|
|
1109
|
+
|
|
1110
|
+
def to_json(self, indent: Optional[int] = None) -> str:
|
|
1111
|
+
"""Convert to JSON string."""
|
|
1112
|
+
return json.dumps(self.to_dict(), indent=indent, default=str)
|
|
1113
|
+
|
|
1114
|
+
def to_llm(self) -> str:
|
|
1115
|
+
"""
|
|
1116
|
+
Token-optimized format for LLM context.
|
|
1117
|
+
|
|
1118
|
+
Example output:
|
|
1119
|
+
VALIDATION: my_contract FAILED (1000 rows)
|
|
1120
|
+
BLOCKING: COL:email:not_null (523 nulls), COL:status:allowed_values (12 invalid)
|
|
1121
|
+
WARNING: COL:age:range (3 out of bounds)
|
|
1122
|
+
PASSED: 15 rules
|
|
1123
|
+
"""
|
|
1124
|
+
lines = []
|
|
1125
|
+
|
|
1126
|
+
status = "PASSED" if self.passed else "FAILED"
|
|
1127
|
+
rows_str = f" ({self.total_rows:,} rows)" if self.total_rows > 0 else ""
|
|
1128
|
+
score_str = f" [score={self.quality_score:.2f}]" if self.quality_score is not None else ""
|
|
1129
|
+
lines.append(f"VALIDATION: {self.dataset} {status}{rows_str}{score_str}")
|
|
1130
|
+
|
|
1131
|
+
# Blocking failures
|
|
1132
|
+
blocking = self.blocking_failures
|
|
1133
|
+
if blocking:
|
|
1134
|
+
parts = []
|
|
1135
|
+
for r in blocking[:5]:
|
|
1136
|
+
count = f"({r.failed_count:,})" if r.failed_count > 0 else ""
|
|
1137
|
+
parts.append(f"{r.rule_id} {count}".strip())
|
|
1138
|
+
line = "BLOCKING: " + ", ".join(parts)
|
|
1139
|
+
if len(blocking) > 5:
|
|
1140
|
+
line += f" ... +{len(blocking) - 5} more"
|
|
1141
|
+
lines.append(line)
|
|
1142
|
+
|
|
1143
|
+
# Warnings
|
|
1144
|
+
warnings = self.warnings
|
|
1145
|
+
if warnings:
|
|
1146
|
+
parts = []
|
|
1147
|
+
for r in warnings[:5]:
|
|
1148
|
+
count = f"({r.failed_count:,})" if r.failed_count > 0 else ""
|
|
1149
|
+
parts.append(f"{r.rule_id} {count}".strip())
|
|
1150
|
+
line = "WARNING: " + ", ".join(parts)
|
|
1151
|
+
if len(warnings) > 5:
|
|
1152
|
+
line += f" ... +{len(warnings) - 5} more"
|
|
1153
|
+
lines.append(line)
|
|
1154
|
+
|
|
1155
|
+
# Passed summary
|
|
1156
|
+
lines.append(f"PASSED: {self.passed_count} rules")
|
|
1157
|
+
|
|
1158
|
+
# Run-level annotations
|
|
1159
|
+
if self.annotations:
|
|
1160
|
+
lines.append(f"ANNOTATIONS ({len(self.annotations)}):")
|
|
1161
|
+
for ann in self.annotations[:3]:
|
|
1162
|
+
ann_type = ann.get("annotation_type", "note")
|
|
1163
|
+
actor = ann.get("actor_id", "unknown")
|
|
1164
|
+
summary = ann.get("summary", "")[:50]
|
|
1165
|
+
if len(ann.get("summary", "")) > 50:
|
|
1166
|
+
summary += "..."
|
|
1167
|
+
lines.append(f' [{ann_type}] by {actor}: "{summary}"')
|
|
1168
|
+
if len(self.annotations) > 3:
|
|
1169
|
+
lines.append(f" ... +{len(self.annotations) - 3} more")
|
|
1170
|
+
|
|
1171
|
+
return "\n".join(lines)
|
|
1172
|
+
|
|
1173
|
+
def sample_failures(
|
|
1174
|
+
self,
|
|
1175
|
+
rule_id: str,
|
|
1176
|
+
n: int = 5,
|
|
1177
|
+
*,
|
|
1178
|
+
upgrade_tier: bool = False,
|
|
1179
|
+
) -> FailureSamples:
|
|
1180
|
+
"""
|
|
1181
|
+
Get a sample of rows that failed a specific rule.
|
|
1182
|
+
|
|
1183
|
+
If eager sampling is enabled (default), this returns cached samples.
|
|
1184
|
+
Otherwise, it lazily re-queries the data source.
|
|
1185
|
+
|
|
1186
|
+
Args:
|
|
1187
|
+
rule_id: The rule ID to get failures for (e.g., "COL:email:not_null")
|
|
1188
|
+
n: Number of sample rows to return (default: 5, max: 100)
|
|
1189
|
+
upgrade_tier: If True, re-execute rules resolved via metadata
|
|
1190
|
+
tier to get actual samples. Required for preplan rules.
|
|
1191
|
+
|
|
1192
|
+
Returns:
|
|
1193
|
+
FailureSamples: Collection of failing rows with "_row_index" field.
|
|
1194
|
+
Supports to_dict(), to_json(), to_llm() methods.
|
|
1195
|
+
Empty if the rule passed (no failures).
|
|
1196
|
+
|
|
1197
|
+
Raises:
|
|
1198
|
+
ValueError: If rule_id not found or rule doesn't support row-level samples
|
|
1199
|
+
RuntimeError: If data source is unavailable for re-query,
|
|
1200
|
+
or if samples unavailable from metadata tier without upgrade_tier=True
|
|
1201
|
+
|
|
1202
|
+
Example:
|
|
1203
|
+
result = kontra.validate("data.parquet", contract)
|
|
1204
|
+
if not result.passed:
|
|
1205
|
+
samples = result.sample_failures("COL:email:not_null", n=5)
|
|
1206
|
+
for row in samples:
|
|
1207
|
+
print(f"Row {row['_row_index']}: {row}")
|
|
1208
|
+
|
|
1209
|
+
# For metadata-tier rules, use upgrade_tier to get samples:
|
|
1210
|
+
samples = result.sample_failures("COL:id:not_null", upgrade_tier=True)
|
|
1211
|
+
"""
|
|
1212
|
+
import polars as pl
|
|
1213
|
+
|
|
1214
|
+
# Cap n at 100
|
|
1215
|
+
n = min(n, 100)
|
|
1216
|
+
|
|
1217
|
+
# Find the rule result
|
|
1218
|
+
rule_result = None
|
|
1219
|
+
for r in self.rules:
|
|
1220
|
+
if r.rule_id == rule_id:
|
|
1221
|
+
rule_result = r
|
|
1222
|
+
break
|
|
1223
|
+
|
|
1224
|
+
if rule_result is None:
|
|
1225
|
+
raise ValueError(f"Rule not found: {rule_id}")
|
|
1226
|
+
|
|
1227
|
+
# If rule passed, return empty FailureSamples
|
|
1228
|
+
if rule_result.passed:
|
|
1229
|
+
return FailureSamples([], rule_id)
|
|
1230
|
+
|
|
1231
|
+
# Check for cached samples first
|
|
1232
|
+
if rule_result.samples is not None:
|
|
1233
|
+
# Have cached samples - return if n <= cached, else fetch more
|
|
1234
|
+
if len(rule_result.samples) >= n:
|
|
1235
|
+
return FailureSamples(rule_result.samples[:n], rule_id)
|
|
1236
|
+
# Need more samples than cached - fall through to lazy path
|
|
1237
|
+
|
|
1238
|
+
# Handle unavailable samples
|
|
1239
|
+
if rule_result.samples_reason == SampleReason.UNAVAILABLE_METADATA:
|
|
1240
|
+
if not upgrade_tier:
|
|
1241
|
+
raise RuntimeError(
|
|
1242
|
+
f"Samples unavailable for {rule_id}: rule was resolved via metadata tier. "
|
|
1243
|
+
"Use upgrade_tier=True to re-execute and get samples."
|
|
1244
|
+
)
|
|
1245
|
+
# Fall through to lazy path for tier upgrade
|
|
1246
|
+
|
|
1247
|
+
elif rule_result.samples_reason == SampleReason.UNAVAILABLE_UNSUPPORTED:
|
|
1248
|
+
raise ValueError(
|
|
1249
|
+
f"Rule '{rule_result.name}' does not support row-level samples. "
|
|
1250
|
+
"Dataset-level rules (min_rows, max_rows, freshness, etc.) "
|
|
1251
|
+
"cannot identify specific failing rows."
|
|
1252
|
+
)
|
|
1253
|
+
|
|
1254
|
+
# Find the rule object to get the failure predicate
|
|
1255
|
+
if self._rule_objects is None:
|
|
1256
|
+
raise RuntimeError(
|
|
1257
|
+
"sample_failures() requires rule objects. "
|
|
1258
|
+
"This may happen if ValidationResult was created manually."
|
|
1259
|
+
)
|
|
1260
|
+
|
|
1261
|
+
rule_obj = None
|
|
1262
|
+
for r in self._rule_objects:
|
|
1263
|
+
if getattr(r, "rule_id", None) == rule_id:
|
|
1264
|
+
rule_obj = r
|
|
1265
|
+
break
|
|
1266
|
+
|
|
1267
|
+
if rule_obj is None:
|
|
1268
|
+
raise ValueError(f"Rule object not found for: {rule_id}")
|
|
1269
|
+
|
|
1270
|
+
# Get the failure predicate
|
|
1271
|
+
# Check sample_predicate() first (used by rules like unique that have
|
|
1272
|
+
# different counting vs sampling semantics), then fall back to compile_predicate()
|
|
1273
|
+
predicate = None
|
|
1274
|
+
if hasattr(rule_obj, "sample_predicate"):
|
|
1275
|
+
pred_obj = rule_obj.sample_predicate()
|
|
1276
|
+
if pred_obj is not None:
|
|
1277
|
+
predicate = pred_obj.expr
|
|
1278
|
+
if predicate is None and hasattr(rule_obj, "compile_predicate"):
|
|
1279
|
+
pred_obj = rule_obj.compile_predicate()
|
|
1280
|
+
if pred_obj is not None:
|
|
1281
|
+
predicate = pred_obj.expr
|
|
1282
|
+
|
|
1283
|
+
if predicate is None:
|
|
1284
|
+
raise ValueError(
|
|
1285
|
+
f"Rule '{rule_obj.name}' does not support row-level samples. "
|
|
1286
|
+
"Dataset-level rules (min_rows, max_rows, freshness, etc.) "
|
|
1287
|
+
"cannot identify specific failing rows."
|
|
1288
|
+
)
|
|
1289
|
+
|
|
1290
|
+
# Load the data
|
|
1291
|
+
if self._data_source is None:
|
|
1292
|
+
raise RuntimeError(
|
|
1293
|
+
"sample_failures() requires data source reference. "
|
|
1294
|
+
"This may happen if ValidationResult was created manually "
|
|
1295
|
+
"or the data source is no longer available."
|
|
1296
|
+
)
|
|
1297
|
+
|
|
1298
|
+
# Load data based on source type
|
|
1299
|
+
# Try SQL pushdown for database sources
|
|
1300
|
+
df, load_source = self._load_data_for_sampling(rule_obj, n)
|
|
1301
|
+
|
|
1302
|
+
# For non-database sources (or if SQL filter wasn't available),
|
|
1303
|
+
# we need to filter with Polars
|
|
1304
|
+
if load_source != "sql":
|
|
1305
|
+
# Filter to failing rows, add index, limit (unique rule handled by helper)
|
|
1306
|
+
try:
|
|
1307
|
+
failing = _filter_samples_polars(df, rule_obj, predicate, n).to_dicts()
|
|
1308
|
+
except Exception as e:
|
|
1309
|
+
raise RuntimeError(f"Failed to query failing rows: {e}") from e
|
|
1310
|
+
else:
|
|
1311
|
+
# SQL pushdown already applied filter and added row index
|
|
1312
|
+
failing = df.head(n).to_dicts()
|
|
1313
|
+
|
|
1314
|
+
return FailureSamples(failing, rule_id)
|
|
1315
|
+
|
|
1316
|
+
def _load_data_for_sampling(
|
|
1317
|
+
self, rule: Any = None, n: int = 5
|
|
1318
|
+
) -> Tuple["pl.DataFrame", str]:
|
|
1319
|
+
"""
|
|
1320
|
+
Load data from the stored data source for sample_failures().
|
|
1321
|
+
|
|
1322
|
+
For database sources with rules that support SQL filters,
|
|
1323
|
+
pushes the filter to SQL for performance.
|
|
1324
|
+
|
|
1325
|
+
Returns:
|
|
1326
|
+
Tuple of (DataFrame, source) where source is "sql" or "polars"
|
|
1327
|
+
"""
|
|
1328
|
+
import polars as pl
|
|
1329
|
+
|
|
1330
|
+
source = self._data_source
|
|
1331
|
+
|
|
1332
|
+
if source is None:
|
|
1333
|
+
raise RuntimeError("No data source available")
|
|
1334
|
+
|
|
1335
|
+
# String path/URI
|
|
1336
|
+
if isinstance(source, str):
|
|
1337
|
+
# Try to load as file with predicate pushdown for Parquet
|
|
1338
|
+
if source.lower().endswith(".parquet") or source.startswith("s3://"):
|
|
1339
|
+
return self._load_parquet_with_filter(source, rule, n)
|
|
1340
|
+
elif source.lower().endswith(".csv"):
|
|
1341
|
+
return pl.read_csv(source), "polars"
|
|
1342
|
+
else:
|
|
1343
|
+
# Try parquet first, then CSV
|
|
1344
|
+
try:
|
|
1345
|
+
return self._load_parquet_with_filter(source, rule, n)
|
|
1346
|
+
except Exception:
|
|
1347
|
+
try:
|
|
1348
|
+
return pl.read_csv(source), "polars"
|
|
1349
|
+
except Exception:
|
|
1350
|
+
raise RuntimeError(f"Cannot load data from: {source}")
|
|
1351
|
+
|
|
1352
|
+
# Polars DataFrame (was passed directly)
|
|
1353
|
+
if isinstance(source, pl.DataFrame):
|
|
1354
|
+
return source, "polars"
|
|
1355
|
+
|
|
1356
|
+
# DatasetHandle (BYOC or parsed URI)
|
|
1357
|
+
if hasattr(source, "scheme") and hasattr(source, "uri"):
|
|
1358
|
+
# It's a DatasetHandle
|
|
1359
|
+
handle = source
|
|
1360
|
+
|
|
1361
|
+
# Check for BYOC (external connection)
|
|
1362
|
+
if handle.scheme == "byoc" or hasattr(handle, "external_conn"):
|
|
1363
|
+
conn = getattr(handle, "external_conn", None)
|
|
1364
|
+
if conn is None:
|
|
1365
|
+
raise RuntimeError(
|
|
1366
|
+
"Database connection is closed. "
|
|
1367
|
+
"For BYOC, keep the connection open until done with sample_failures()."
|
|
1368
|
+
)
|
|
1369
|
+
table = getattr(handle, "table_ref", None) or handle.path
|
|
1370
|
+
return self._query_db_with_filter(conn, table, rule, n, "postgres"), "sql"
|
|
1371
|
+
|
|
1372
|
+
elif handle.scheme in ("postgres", "postgresql"):
|
|
1373
|
+
# PostgreSQL via URI
|
|
1374
|
+
if hasattr(handle, "external_conn") and handle.external_conn:
|
|
1375
|
+
conn = handle.external_conn
|
|
1376
|
+
else:
|
|
1377
|
+
raise RuntimeError(
|
|
1378
|
+
"Database connection is not available. "
|
|
1379
|
+
"For URI-based connections, sample_failures() requires re-connection."
|
|
1380
|
+
)
|
|
1381
|
+
table = getattr(handle, "table_ref", None) or handle.path
|
|
1382
|
+
return self._query_db_with_filter(conn, table, rule, n, "postgres"), "sql"
|
|
1383
|
+
|
|
1384
|
+
elif handle.scheme == "mssql":
|
|
1385
|
+
# SQL Server
|
|
1386
|
+
if hasattr(handle, "external_conn") and handle.external_conn:
|
|
1387
|
+
conn = handle.external_conn
|
|
1388
|
+
else:
|
|
1389
|
+
raise RuntimeError(
|
|
1390
|
+
"Database connection is not available."
|
|
1391
|
+
)
|
|
1392
|
+
table = getattr(handle, "table_ref", None) or handle.path
|
|
1393
|
+
return self._query_db_with_filter(conn, table, rule, n, "mssql"), "sql"
|
|
1394
|
+
|
|
1395
|
+
elif handle.scheme in ("file", None) or (handle.uri and not handle.scheme):
|
|
1396
|
+
# File-based
|
|
1397
|
+
uri = handle.uri
|
|
1398
|
+
if uri.lower().endswith(".parquet"):
|
|
1399
|
+
return self._load_parquet_with_filter(uri, rule, n)
|
|
1400
|
+
elif uri.lower().endswith(".csv"):
|
|
1401
|
+
return pl.read_csv(uri), "polars"
|
|
1402
|
+
else:
|
|
1403
|
+
return self._load_parquet_with_filter(uri, rule, n)
|
|
1404
|
+
|
|
1405
|
+
raise RuntimeError(f"Unsupported data source type: {type(source)}")
|
|
1406
|
+
|
|
1407
|
+
def _query_db_with_filter(
|
|
1408
|
+
self,
|
|
1409
|
+
conn: Any,
|
|
1410
|
+
table: str,
|
|
1411
|
+
rule: Any,
|
|
1412
|
+
n: int,
|
|
1413
|
+
dialect: str,
|
|
1414
|
+
) -> "pl.DataFrame":
|
|
1415
|
+
"""
|
|
1416
|
+
Query database with SQL filter if rule supports it.
|
|
1417
|
+
|
|
1418
|
+
Uses the rule's to_sql_filter() method to push the filter to SQL,
|
|
1419
|
+
avoiding loading the entire table.
|
|
1420
|
+
"""
|
|
1421
|
+
import polars as pl
|
|
1422
|
+
|
|
1423
|
+
sql_filter = None
|
|
1424
|
+
|
|
1425
|
+
# Special case: unique rule needs subquery with table name
|
|
1426
|
+
if _is_unique_rule(rule):
|
|
1427
|
+
column = rule.params.get("column")
|
|
1428
|
+
if column:
|
|
1429
|
+
query = _build_unique_sample_query_sql(table, column, n, dialect)
|
|
1430
|
+
return pl.read_database(query, conn)
|
|
1431
|
+
|
|
1432
|
+
if rule is not None and hasattr(rule, "to_sql_filter"):
|
|
1433
|
+
sql_filter = rule.to_sql_filter(dialect)
|
|
1434
|
+
|
|
1435
|
+
if sql_filter:
|
|
1436
|
+
# Build query with filter and row number
|
|
1437
|
+
# ROW_NUMBER() gives us the original row index
|
|
1438
|
+
if dialect == "mssql":
|
|
1439
|
+
# SQL Server syntax
|
|
1440
|
+
query = f"""
|
|
1441
|
+
SELECT *, ROW_NUMBER() OVER (ORDER BY (SELECT NULL)) - 1 AS _row_index
|
|
1442
|
+
FROM {table}
|
|
1443
|
+
WHERE {sql_filter}
|
|
1444
|
+
ORDER BY (SELECT NULL)
|
|
1445
|
+
OFFSET 0 ROWS FETCH FIRST {n} ROWS ONLY
|
|
1446
|
+
"""
|
|
1447
|
+
else:
|
|
1448
|
+
# PostgreSQL / DuckDB syntax
|
|
1449
|
+
query = f"""
|
|
1450
|
+
SELECT *, ROW_NUMBER() OVER () - 1 AS _row_index
|
|
1451
|
+
FROM {table}
|
|
1452
|
+
WHERE {sql_filter}
|
|
1453
|
+
LIMIT {n}
|
|
1454
|
+
"""
|
|
1455
|
+
return pl.read_database(query, conn)
|
|
1456
|
+
else:
|
|
1457
|
+
# Fall back to loading all data (rule doesn't support SQL filter)
|
|
1458
|
+
return pl.read_database(f"SELECT * FROM {table}", conn)
|
|
1459
|
+
|
|
1460
|
+
def _load_parquet_with_filter(
|
|
1461
|
+
self,
|
|
1462
|
+
path: str,
|
|
1463
|
+
rule: Any,
|
|
1464
|
+
n: int,
|
|
1465
|
+
) -> Tuple["pl.DataFrame", str]:
|
|
1466
|
+
"""
|
|
1467
|
+
Load Parquet file with predicate pushdown for performance.
|
|
1468
|
+
|
|
1469
|
+
- S3/remote files: Uses DuckDB SQL pushdown (doesn't download whole file)
|
|
1470
|
+
- Local files: Uses Polars scan_parquet (efficient for local)
|
|
1471
|
+
|
|
1472
|
+
Returns:
|
|
1473
|
+
Tuple of (DataFrame, source) where source is "sql" or "polars"
|
|
1474
|
+
"""
|
|
1475
|
+
import polars as pl
|
|
1476
|
+
|
|
1477
|
+
# Check if this is a remote file (S3, HTTP)
|
|
1478
|
+
is_remote = path.startswith("s3://") or path.startswith("http://") or path.startswith("https://")
|
|
1479
|
+
|
|
1480
|
+
# For remote files, use DuckDB SQL pushdown (much faster - doesn't download whole file)
|
|
1481
|
+
if is_remote and rule is not None and hasattr(rule, "to_sql_filter"):
|
|
1482
|
+
sql_filter = rule.to_sql_filter("duckdb")
|
|
1483
|
+
if sql_filter:
|
|
1484
|
+
try:
|
|
1485
|
+
return self._query_parquet_with_duckdb(path, sql_filter, n), "sql"
|
|
1486
|
+
except Exception:
|
|
1487
|
+
pass # Fall through to Polars
|
|
1488
|
+
|
|
1489
|
+
# For local files, just return the raw data - caller will filter
|
|
1490
|
+
# (Don't filter here to avoid double-filtering in _collect_samples_for_rule)
|
|
1491
|
+
return pl.read_parquet(path), "polars"
|
|
1492
|
+
|
|
1493
|
+
def _query_parquet_with_duckdb(
|
|
1494
|
+
self,
|
|
1495
|
+
path: str,
|
|
1496
|
+
sql_filter: str,
|
|
1497
|
+
n: int,
|
|
1498
|
+
columns: Optional[List[str]] = None,
|
|
1499
|
+
) -> "pl.DataFrame":
|
|
1500
|
+
"""
|
|
1501
|
+
Query Parquet file using DuckDB with SQL filter.
|
|
1502
|
+
|
|
1503
|
+
Much faster than Polars for S3 files because DuckDB pushes
|
|
1504
|
+
the filter and LIMIT to the row group level.
|
|
1505
|
+
"""
|
|
1506
|
+
import duckdb
|
|
1507
|
+
import polars as pl
|
|
1508
|
+
|
|
1509
|
+
con = duckdb.connect()
|
|
1510
|
+
|
|
1511
|
+
# Configure S3 if needed
|
|
1512
|
+
if path.startswith("s3://"):
|
|
1513
|
+
import os
|
|
1514
|
+
con.execute("INSTALL httpfs; LOAD httpfs;")
|
|
1515
|
+
if os.environ.get("AWS_ACCESS_KEY_ID"):
|
|
1516
|
+
con.execute(f"SET s3_access_key_id='{os.environ['AWS_ACCESS_KEY_ID']}';")
|
|
1517
|
+
if os.environ.get("AWS_SECRET_ACCESS_KEY"):
|
|
1518
|
+
con.execute(f"SET s3_secret_access_key='{os.environ['AWS_SECRET_ACCESS_KEY']}';")
|
|
1519
|
+
if os.environ.get("AWS_ENDPOINT_URL"):
|
|
1520
|
+
endpoint = os.environ["AWS_ENDPOINT_URL"].replace("http://", "").replace("https://", "")
|
|
1521
|
+
con.execute(f"SET s3_endpoint='{endpoint}';")
|
|
1522
|
+
con.execute("SET s3_use_ssl=false;")
|
|
1523
|
+
con.execute("SET s3_url_style='path';")
|
|
1524
|
+
if os.environ.get("AWS_REGION"):
|
|
1525
|
+
con.execute(f"SET s3_region='{os.environ['AWS_REGION']}';")
|
|
1526
|
+
|
|
1527
|
+
# Escape path for SQL
|
|
1528
|
+
escaped_path = path.replace("'", "''")
|
|
1529
|
+
|
|
1530
|
+
# Build column list (with projection if specified)
|
|
1531
|
+
if columns:
|
|
1532
|
+
col_list = ", ".join(f'"{c}"' for c in columns)
|
|
1533
|
+
else:
|
|
1534
|
+
col_list = "*"
|
|
1535
|
+
|
|
1536
|
+
# Build query with filter and row number
|
|
1537
|
+
query = f"""
|
|
1538
|
+
SELECT {col_list}, ROW_NUMBER() OVER () - 1 AS _row_index
|
|
1539
|
+
FROM read_parquet('{escaped_path}')
|
|
1540
|
+
WHERE {sql_filter}
|
|
1541
|
+
LIMIT {n}
|
|
1542
|
+
"""
|
|
1543
|
+
|
|
1544
|
+
result = con.execute(query).pl()
|
|
1545
|
+
con.close()
|
|
1546
|
+
return result
|
|
1547
|
+
|
|
1548
|
+
def _batch_sample_parquet_duckdb(
|
|
1549
|
+
self,
|
|
1550
|
+
path: str,
|
|
1551
|
+
rules_to_sample: List[Tuple[str, str, int, Optional[List[str]]]],
|
|
1552
|
+
) -> Dict[str, List[Dict[str, Any]]]:
|
|
1553
|
+
"""
|
|
1554
|
+
Batch sample multiple rules from Parquet using a single DuckDB query.
|
|
1555
|
+
|
|
1556
|
+
Args:
|
|
1557
|
+
path: Parquet file path or S3 URI
|
|
1558
|
+
rules_to_sample: List of (rule_id, sql_filter, limit, columns)
|
|
1559
|
+
|
|
1560
|
+
Returns:
|
|
1561
|
+
Dict mapping rule_id to list of sample dicts
|
|
1562
|
+
"""
|
|
1563
|
+
import duckdb
|
|
1564
|
+
import polars as pl
|
|
1565
|
+
|
|
1566
|
+
if not rules_to_sample:
|
|
1567
|
+
return {}
|
|
1568
|
+
|
|
1569
|
+
con = duckdb.connect()
|
|
1570
|
+
|
|
1571
|
+
# Configure S3 if needed
|
|
1572
|
+
if path.startswith("s3://"):
|
|
1573
|
+
import os
|
|
1574
|
+
con.execute("INSTALL httpfs; LOAD httpfs;")
|
|
1575
|
+
if os.environ.get("AWS_ACCESS_KEY_ID"):
|
|
1576
|
+
con.execute(f"SET s3_access_key_id='{os.environ['AWS_ACCESS_KEY_ID']}';")
|
|
1577
|
+
if os.environ.get("AWS_SECRET_ACCESS_KEY"):
|
|
1578
|
+
con.execute(f"SET s3_secret_access_key='{os.environ['AWS_SECRET_ACCESS_KEY']}';")
|
|
1579
|
+
if os.environ.get("AWS_ENDPOINT_URL"):
|
|
1580
|
+
endpoint = os.environ["AWS_ENDPOINT_URL"].replace("http://", "").replace("https://", "")
|
|
1581
|
+
con.execute(f"SET s3_endpoint='{endpoint}';")
|
|
1582
|
+
con.execute("SET s3_use_ssl=false;")
|
|
1583
|
+
con.execute("SET s3_url_style='path';")
|
|
1584
|
+
if os.environ.get("AWS_REGION"):
|
|
1585
|
+
con.execute(f"SET s3_region='{os.environ['AWS_REGION']}';")
|
|
1586
|
+
|
|
1587
|
+
escaped_path = path.replace("'", "''")
|
|
1588
|
+
|
|
1589
|
+
# Collect all columns needed across all rules
|
|
1590
|
+
all_columns: set = set()
|
|
1591
|
+
for rule_id, sql_filter, limit, columns in rules_to_sample:
|
|
1592
|
+
if columns:
|
|
1593
|
+
all_columns.update(columns)
|
|
1594
|
+
|
|
1595
|
+
# If any rule needs all columns, use *
|
|
1596
|
+
needs_all = any(cols is None for _, _, _, cols in rules_to_sample)
|
|
1597
|
+
|
|
1598
|
+
if needs_all or not all_columns:
|
|
1599
|
+
col_list = "*"
|
|
1600
|
+
else:
|
|
1601
|
+
col_list = ", ".join(f'"{c}"' for c in sorted(all_columns))
|
|
1602
|
+
|
|
1603
|
+
# Build UNION ALL query - one subquery per rule (wrapped in parens for DuckDB)
|
|
1604
|
+
subqueries = []
|
|
1605
|
+
for rule_id, sql_filter, limit, columns in rules_to_sample:
|
|
1606
|
+
escaped_rule_id = rule_id.replace("'", "''")
|
|
1607
|
+
subquery = f"""(
|
|
1608
|
+
SELECT '{escaped_rule_id}' AS _rule_id, {col_list}, ROW_NUMBER() OVER () - 1 AS _row_index
|
|
1609
|
+
FROM read_parquet('{escaped_path}')
|
|
1610
|
+
WHERE {sql_filter}
|
|
1611
|
+
LIMIT {limit}
|
|
1612
|
+
)"""
|
|
1613
|
+
subqueries.append(subquery)
|
|
1614
|
+
|
|
1615
|
+
query = " UNION ALL ".join(subqueries)
|
|
1616
|
+
|
|
1617
|
+
try:
|
|
1618
|
+
result_df = con.execute(query).pl()
|
|
1619
|
+
finally:
|
|
1620
|
+
con.close()
|
|
1621
|
+
|
|
1622
|
+
# Distribute results to each rule
|
|
1623
|
+
results: Dict[str, List[Dict[str, Any]]] = {rule_id: [] for rule_id, _, _, _ in rules_to_sample}
|
|
1624
|
+
|
|
1625
|
+
for row in result_df.to_dicts():
|
|
1626
|
+
rule_id = row.pop("_rule_id")
|
|
1627
|
+
if rule_id in results:
|
|
1628
|
+
results[rule_id].append(row)
|
|
1629
|
+
|
|
1630
|
+
return results
|
|
1631
|
+
|
|
1632
|
+
def _batch_sample_db(
|
|
1633
|
+
self,
|
|
1634
|
+
conn: Any,
|
|
1635
|
+
table: str,
|
|
1636
|
+
rules_to_sample: List[Tuple[str, str, int, Optional[List[str]]]],
|
|
1637
|
+
dialect: str,
|
|
1638
|
+
) -> Dict[str, List[Dict[str, Any]]]:
|
|
1639
|
+
"""
|
|
1640
|
+
Batch sample multiple rules from database using a single query.
|
|
1641
|
+
|
|
1642
|
+
Args:
|
|
1643
|
+
conn: Database connection
|
|
1644
|
+
table: Table name (with schema if needed)
|
|
1645
|
+
rules_to_sample: List of (rule_id, sql_filter, limit, columns)
|
|
1646
|
+
dialect: "postgres" or "mssql"
|
|
1647
|
+
|
|
1648
|
+
Returns:
|
|
1649
|
+
Dict mapping rule_id to list of sample dicts
|
|
1650
|
+
"""
|
|
1651
|
+
import polars as pl
|
|
1652
|
+
|
|
1653
|
+
if not rules_to_sample:
|
|
1654
|
+
return {}
|
|
1655
|
+
|
|
1656
|
+
# Collect all columns needed across all rules
|
|
1657
|
+
all_columns: set = set()
|
|
1658
|
+
for rule_id, sql_filter, limit, columns in rules_to_sample:
|
|
1659
|
+
if columns:
|
|
1660
|
+
all_columns.update(columns)
|
|
1661
|
+
|
|
1662
|
+
needs_all = any(cols is None for _, _, _, cols in rules_to_sample)
|
|
1663
|
+
|
|
1664
|
+
if dialect == "mssql":
|
|
1665
|
+
# SQL Server syntax
|
|
1666
|
+
if needs_all or not all_columns:
|
|
1667
|
+
col_list = "*"
|
|
1668
|
+
else:
|
|
1669
|
+
col_list = ", ".join(f"[{c}]" for c in sorted(all_columns))
|
|
1670
|
+
|
|
1671
|
+
subqueries = []
|
|
1672
|
+
for rule_id, sql_filter, limit, columns in rules_to_sample:
|
|
1673
|
+
escaped_rule_id = rule_id.replace("'", "''")
|
|
1674
|
+
subquery = f"""(
|
|
1675
|
+
SELECT TOP {limit} '{escaped_rule_id}' AS _rule_id, {col_list},
|
|
1676
|
+
ROW_NUMBER() OVER (ORDER BY (SELECT NULL)) - 1 AS _row_index
|
|
1677
|
+
FROM {table}
|
|
1678
|
+
WHERE {sql_filter}
|
|
1679
|
+
)"""
|
|
1680
|
+
subqueries.append(subquery)
|
|
1681
|
+
else:
|
|
1682
|
+
# PostgreSQL syntax
|
|
1683
|
+
if needs_all or not all_columns:
|
|
1684
|
+
col_list = "*"
|
|
1685
|
+
else:
|
|
1686
|
+
col_list = ", ".join(f'"{c}"' for c in sorted(all_columns))
|
|
1687
|
+
|
|
1688
|
+
subqueries = []
|
|
1689
|
+
for rule_id, sql_filter, limit, columns in rules_to_sample:
|
|
1690
|
+
escaped_rule_id = rule_id.replace("'", "''")
|
|
1691
|
+
subquery = f"""(
|
|
1692
|
+
SELECT '{escaped_rule_id}' AS _rule_id, {col_list},
|
|
1693
|
+
ROW_NUMBER() OVER () - 1 AS _row_index
|
|
1694
|
+
FROM {table}
|
|
1695
|
+
WHERE {sql_filter}
|
|
1696
|
+
LIMIT {limit}
|
|
1697
|
+
)"""
|
|
1698
|
+
subqueries.append(subquery)
|
|
1699
|
+
|
|
1700
|
+
query = " UNION ALL ".join(subqueries)
|
|
1701
|
+
|
|
1702
|
+
result_df = pl.read_database(query, conn)
|
|
1703
|
+
|
|
1704
|
+
# Distribute results to each rule
|
|
1705
|
+
results: Dict[str, List[Dict[str, Any]]] = {rule_id: [] for rule_id, _, _, _ in rules_to_sample}
|
|
1706
|
+
|
|
1707
|
+
for row in result_df.to_dicts():
|
|
1708
|
+
rule_id = row.pop("_rule_id")
|
|
1709
|
+
if rule_id in results:
|
|
1710
|
+
results[rule_id].append(row)
|
|
1711
|
+
|
|
1712
|
+
return results
|
|
1713
|
+
|
|
1714
|
+
|
|
1715
|
+
@dataclass
|
|
1716
|
+
class DryRunResult:
|
|
1717
|
+
"""
|
|
1718
|
+
Result of a dry run (contract validation without execution).
|
|
1719
|
+
|
|
1720
|
+
Properties:
|
|
1721
|
+
valid: Whether the contract is syntactically valid
|
|
1722
|
+
rules_count: Number of rules that would run
|
|
1723
|
+
columns_needed: Columns the contract requires
|
|
1724
|
+
contract_name: Name of the contract (if any)
|
|
1725
|
+
errors: List of errors found during validation
|
|
1726
|
+
datasource: Datasource from contract
|
|
1727
|
+
"""
|
|
1728
|
+
|
|
1729
|
+
valid: bool
|
|
1730
|
+
rules_count: int
|
|
1731
|
+
columns_needed: List[str]
|
|
1732
|
+
contract_name: Optional[str] = None
|
|
1733
|
+
datasource: Optional[str] = None
|
|
1734
|
+
errors: List[str] = field(default_factory=list)
|
|
1735
|
+
|
|
1736
|
+
def __repr__(self) -> str:
|
|
1737
|
+
status = "VALID" if self.valid else "INVALID"
|
|
1738
|
+
parts = [f"DryRunResult({self.contract_name or 'inline'}) {status}"]
|
|
1739
|
+
if self.valid:
|
|
1740
|
+
parts.append(f" Rules: {self.rules_count}, Columns: {len(self.columns_needed)}")
|
|
1741
|
+
if self.columns_needed:
|
|
1742
|
+
cols = ", ".join(self.columns_needed[:5])
|
|
1743
|
+
if len(self.columns_needed) > 5:
|
|
1744
|
+
cols += f" ... +{len(self.columns_needed) - 5} more"
|
|
1745
|
+
parts.append(f" Needs: {cols}")
|
|
1746
|
+
else:
|
|
1747
|
+
for err in self.errors[:3]:
|
|
1748
|
+
parts.append(f" ERROR: {err}")
|
|
1749
|
+
if len(self.errors) > 3:
|
|
1750
|
+
parts.append(f" ... +{len(self.errors) - 3} more errors")
|
|
1751
|
+
return "\n".join(parts)
|
|
1752
|
+
|
|
1753
|
+
def to_dict(self) -> Dict[str, Any]:
|
|
1754
|
+
"""Convert to dictionary."""
|
|
1755
|
+
return {
|
|
1756
|
+
"valid": self.valid,
|
|
1757
|
+
"rules_count": self.rules_count,
|
|
1758
|
+
"columns_needed": self.columns_needed,
|
|
1759
|
+
"contract_name": self.contract_name,
|
|
1760
|
+
"datasource": self.datasource,
|
|
1761
|
+
"errors": self.errors,
|
|
1762
|
+
}
|
|
1763
|
+
|
|
1764
|
+
def to_json(self, indent: Optional[int] = None) -> str:
|
|
1765
|
+
"""Convert to JSON string."""
|
|
1766
|
+
return json.dumps(self.to_dict(), indent=indent, default=str)
|
|
1767
|
+
|
|
1768
|
+
def to_llm(self) -> str:
|
|
1769
|
+
"""Token-optimized format for LLM context."""
|
|
1770
|
+
if self.valid:
|
|
1771
|
+
cols = ",".join(self.columns_needed[:10])
|
|
1772
|
+
if len(self.columns_needed) > 10:
|
|
1773
|
+
cols += f"...+{len(self.columns_needed) - 10}"
|
|
1774
|
+
return f"DRYRUN: {self.contract_name or 'inline'} VALID rules={self.rules_count} cols=[{cols}]"
|
|
1775
|
+
else:
|
|
1776
|
+
errs = "; ".join(self.errors[:3])
|
|
1777
|
+
return f"DRYRUN: {self.contract_name or 'inline'} INVALID errors=[{errs}]"
|
|
1778
|
+
|
|
1779
|
+
|
|
1780
|
+
@dataclass
|
|
1781
|
+
class Diff:
|
|
1782
|
+
"""
|
|
1783
|
+
Diff between two validation runs.
|
|
1784
|
+
|
|
1785
|
+
Properties:
|
|
1786
|
+
has_changes: Whether there are any changes
|
|
1787
|
+
improved: Fewer failures than before
|
|
1788
|
+
regressed: More failures than before
|
|
1789
|
+
before: Summary of before run
|
|
1790
|
+
after: Summary of after run
|
|
1791
|
+
new_failures: Rules that started failing
|
|
1792
|
+
resolved: Rules that stopped failing
|
|
1793
|
+
count_changes: Rules where failure count changed
|
|
1794
|
+
"""
|
|
1795
|
+
|
|
1796
|
+
has_changes: bool
|
|
1797
|
+
improved: bool
|
|
1798
|
+
regressed: bool
|
|
1799
|
+
before: Dict[str, Any]
|
|
1800
|
+
after: Dict[str, Any]
|
|
1801
|
+
new_failures: List[Dict[str, Any]]
|
|
1802
|
+
resolved: List[Dict[str, Any]]
|
|
1803
|
+
regressions: List[Dict[str, Any]]
|
|
1804
|
+
improvements: List[Dict[str, Any]]
|
|
1805
|
+
_state_diff: Optional[Any] = field(default=None, repr=False)
|
|
1806
|
+
|
|
1807
|
+
def __repr__(self) -> str:
|
|
1808
|
+
if self.regressed:
|
|
1809
|
+
status = "REGRESSED"
|
|
1810
|
+
elif self.improved:
|
|
1811
|
+
status = "IMPROVED"
|
|
1812
|
+
else:
|
|
1813
|
+
status = "NO CHANGE"
|
|
1814
|
+
|
|
1815
|
+
contract = self.after.get("contract_name", "unknown")
|
|
1816
|
+
before_date = self.before.get("run_at", "")[:10]
|
|
1817
|
+
after_date = self.after.get("run_at", "")[:10]
|
|
1818
|
+
|
|
1819
|
+
parts = [f"Diff({contract}) {status}"]
|
|
1820
|
+
parts.append(f" {before_date} -> {after_date}")
|
|
1821
|
+
if self.new_failures:
|
|
1822
|
+
parts.append(f" New failures: {len(self.new_failures)}")
|
|
1823
|
+
if self.resolved:
|
|
1824
|
+
parts.append(f" Resolved: {len(self.resolved)}")
|
|
1825
|
+
return "\n".join(parts)
|
|
1826
|
+
|
|
1827
|
+
@property
|
|
1828
|
+
def count_changes(self) -> List[Dict[str, Any]]:
|
|
1829
|
+
"""Rules where failure count changed (both regressions and improvements)."""
|
|
1830
|
+
return self.regressions + self.improvements
|
|
1831
|
+
|
|
1832
|
+
@classmethod
|
|
1833
|
+
def from_state_diff(cls, state_diff: "StateDiff") -> "Diff":
|
|
1834
|
+
"""Create from internal StateDiff object."""
|
|
1835
|
+
return cls(
|
|
1836
|
+
has_changes=state_diff.has_regressions or state_diff.has_improvements,
|
|
1837
|
+
improved=state_diff.has_improvements and not state_diff.has_regressions,
|
|
1838
|
+
regressed=state_diff.has_regressions,
|
|
1839
|
+
before={
|
|
1840
|
+
"run_at": state_diff.before.run_at.isoformat(),
|
|
1841
|
+
"passed": state_diff.before.summary.passed,
|
|
1842
|
+
"total_rules": state_diff.before.summary.total_rules,
|
|
1843
|
+
"failed_count": state_diff.before.summary.failed_rules,
|
|
1844
|
+
"contract_name": state_diff.before.contract_name,
|
|
1845
|
+
},
|
|
1846
|
+
after={
|
|
1847
|
+
"run_at": state_diff.after.run_at.isoformat(),
|
|
1848
|
+
"passed": state_diff.after.summary.passed,
|
|
1849
|
+
"total_rules": state_diff.after.summary.total_rules,
|
|
1850
|
+
"failed_count": state_diff.after.summary.failed_rules,
|
|
1851
|
+
"contract_name": state_diff.after.contract_name,
|
|
1852
|
+
},
|
|
1853
|
+
new_failures=[rd.to_dict() for rd in state_diff.new_failures],
|
|
1854
|
+
resolved=[rd.to_dict() for rd in state_diff.resolved],
|
|
1855
|
+
regressions=[rd.to_dict() for rd in state_diff.regressions],
|
|
1856
|
+
improvements=[rd.to_dict() for rd in state_diff.improvements],
|
|
1857
|
+
_state_diff=state_diff,
|
|
1858
|
+
)
|
|
1859
|
+
|
|
1860
|
+
def to_dict(self) -> Dict[str, Any]:
|
|
1861
|
+
"""Convert to dictionary."""
|
|
1862
|
+
return {
|
|
1863
|
+
"has_changes": self.has_changes,
|
|
1864
|
+
"improved": self.improved,
|
|
1865
|
+
"regressed": self.regressed,
|
|
1866
|
+
"before": self.before,
|
|
1867
|
+
"after": self.after,
|
|
1868
|
+
"new_failures": self.new_failures,
|
|
1869
|
+
"resolved": self.resolved,
|
|
1870
|
+
"regressions": self.regressions,
|
|
1871
|
+
"improvements": self.improvements,
|
|
1872
|
+
}
|
|
1873
|
+
|
|
1874
|
+
def to_json(self, indent: Optional[int] = None) -> str:
|
|
1875
|
+
"""Convert to JSON string."""
|
|
1876
|
+
return json.dumps(self.to_dict(), indent=indent, default=str)
|
|
1877
|
+
|
|
1878
|
+
def to_llm(self) -> str:
|
|
1879
|
+
"""Token-optimized format for LLM context."""
|
|
1880
|
+
if self._state_diff is not None:
|
|
1881
|
+
return self._state_diff.to_llm()
|
|
1882
|
+
|
|
1883
|
+
# Fallback if no state_diff
|
|
1884
|
+
lines = []
|
|
1885
|
+
contract = self.after.get("contract_name", "unknown")
|
|
1886
|
+
|
|
1887
|
+
if self.regressed:
|
|
1888
|
+
status = "REGRESSION"
|
|
1889
|
+
elif self.improved:
|
|
1890
|
+
status = "IMPROVED"
|
|
1891
|
+
else:
|
|
1892
|
+
status = "NO_CHANGE"
|
|
1893
|
+
|
|
1894
|
+
lines.append(f"DIFF: {contract} {status}")
|
|
1895
|
+
lines.append(f"{self.before.get('run_at', '')[:10]} -> {self.after.get('run_at', '')[:10]}")
|
|
1896
|
+
|
|
1897
|
+
if self.new_failures:
|
|
1898
|
+
lines.append(f"NEW_FAILURES: {len(self.new_failures)}")
|
|
1899
|
+
for nf in self.new_failures[:3]:
|
|
1900
|
+
lines.append(f" - {nf.get('rule_id', '')}")
|
|
1901
|
+
|
|
1902
|
+
if self.resolved:
|
|
1903
|
+
lines.append(f"RESOLVED: {len(self.resolved)}")
|
|
1904
|
+
|
|
1905
|
+
return "\n".join(lines)
|
|
1906
|
+
|
|
1907
|
+
|
|
1908
|
+
@dataclass
|
|
1909
|
+
class SuggestedRule:
|
|
1910
|
+
"""A suggested validation rule from profile analysis."""
|
|
1911
|
+
|
|
1912
|
+
name: str
|
|
1913
|
+
params: Dict[str, Any]
|
|
1914
|
+
confidence: float
|
|
1915
|
+
reason: str
|
|
1916
|
+
|
|
1917
|
+
def __repr__(self) -> str:
|
|
1918
|
+
return f"SuggestedRule({self.name}, confidence={self.confidence:.2f})"
|
|
1919
|
+
|
|
1920
|
+
def to_dict(self) -> Dict[str, Any]:
|
|
1921
|
+
"""Convert to rule dict format (for inline rules)."""
|
|
1922
|
+
return {
|
|
1923
|
+
"name": self.name,
|
|
1924
|
+
"params": self.params,
|
|
1925
|
+
}
|
|
1926
|
+
|
|
1927
|
+
def to_full_dict(self) -> Dict[str, Any]:
|
|
1928
|
+
"""Convert to dict including metadata."""
|
|
1929
|
+
return {
|
|
1930
|
+
"name": self.name,
|
|
1931
|
+
"params": self.params,
|
|
1932
|
+
"confidence": self.confidence,
|
|
1933
|
+
"reason": self.reason,
|
|
1934
|
+
}
|
|
1935
|
+
|
|
1936
|
+
|
|
1937
|
+
class Suggestions:
|
|
1938
|
+
"""
|
|
1939
|
+
Collection of suggested validation rules from profile analysis.
|
|
1940
|
+
|
|
1941
|
+
Methods:
|
|
1942
|
+
to_yaml(): Export as YAML contract
|
|
1943
|
+
to_json(): Export as JSON
|
|
1944
|
+
to_dict(): Export as list of rule dicts (for inline rules)
|
|
1945
|
+
save(path): Save to file
|
|
1946
|
+
filter(min_confidence=None, name=None): Filter suggestions
|
|
1947
|
+
"""
|
|
1948
|
+
|
|
1949
|
+
def __init__(
|
|
1950
|
+
self,
|
|
1951
|
+
rules: List[SuggestedRule],
|
|
1952
|
+
source: str = "unknown",
|
|
1953
|
+
):
|
|
1954
|
+
self._rules = rules
|
|
1955
|
+
self.source = source
|
|
1956
|
+
|
|
1957
|
+
def __repr__(self) -> str:
|
|
1958
|
+
return f"Suggestions({len(self._rules)} rules from {self.source})"
|
|
1959
|
+
|
|
1960
|
+
def __len__(self) -> int:
|
|
1961
|
+
return len(self._rules)
|
|
1962
|
+
|
|
1963
|
+
def __iter__(self) -> Iterator[SuggestedRule]:
|
|
1964
|
+
return iter(self._rules)
|
|
1965
|
+
|
|
1966
|
+
def __getitem__(self, index: int) -> SuggestedRule:
|
|
1967
|
+
return self._rules[index]
|
|
1968
|
+
|
|
1969
|
+
def filter(
|
|
1970
|
+
self,
|
|
1971
|
+
min_confidence: Optional[float] = None,
|
|
1972
|
+
name: Optional[str] = None,
|
|
1973
|
+
) -> "Suggestions":
|
|
1974
|
+
"""
|
|
1975
|
+
Filter suggestions by criteria.
|
|
1976
|
+
|
|
1977
|
+
Args:
|
|
1978
|
+
min_confidence: Minimum confidence score (0.0-1.0)
|
|
1979
|
+
name: Filter by rule name
|
|
1980
|
+
|
|
1981
|
+
Returns:
|
|
1982
|
+
New Suggestions with filtered rules
|
|
1983
|
+
"""
|
|
1984
|
+
filtered = self._rules
|
|
1985
|
+
|
|
1986
|
+
if min_confidence is not None:
|
|
1987
|
+
filtered = [r for r in filtered if r.confidence >= min_confidence]
|
|
1988
|
+
|
|
1989
|
+
if name is not None:
|
|
1990
|
+
filtered = [r for r in filtered if r.name == name]
|
|
1991
|
+
|
|
1992
|
+
return Suggestions(filtered, self.source)
|
|
1993
|
+
|
|
1994
|
+
def to_dict(self) -> List[Dict[str, Any]]:
|
|
1995
|
+
"""Convert to list of rule dicts (usable with kontra.validate(rules=...))."""
|
|
1996
|
+
return [r.to_dict() for r in self._rules]
|
|
1997
|
+
|
|
1998
|
+
def to_json(self, indent: Optional[int] = None) -> str:
|
|
1999
|
+
"""Convert to JSON string."""
|
|
2000
|
+
return json.dumps(self.to_dict(), indent=indent, default=str)
|
|
2001
|
+
|
|
2002
|
+
def to_yaml(self, contract_name: str = "suggested_contract") -> str:
|
|
2003
|
+
"""
|
|
2004
|
+
Convert to YAML contract format.
|
|
2005
|
+
|
|
2006
|
+
Args:
|
|
2007
|
+
contract_name: Name for the contract
|
|
2008
|
+
|
|
2009
|
+
Returns:
|
|
2010
|
+
YAML string
|
|
2011
|
+
"""
|
|
2012
|
+
contract = {
|
|
2013
|
+
"name": contract_name,
|
|
2014
|
+
"dataset": self.source,
|
|
2015
|
+
"rules": self.to_dict(),
|
|
2016
|
+
}
|
|
2017
|
+
return yaml.dump(contract, default_flow_style=False, sort_keys=False)
|
|
2018
|
+
|
|
2019
|
+
def save(self, path: Union[str, Path]) -> None:
|
|
2020
|
+
"""
|
|
2021
|
+
Save suggestions to file.
|
|
2022
|
+
|
|
2023
|
+
Args:
|
|
2024
|
+
path: Output path (YAML format)
|
|
2025
|
+
"""
|
|
2026
|
+
path = Path(path)
|
|
2027
|
+
path.parent.mkdir(parents=True, exist_ok=True)
|
|
2028
|
+
path.write_text(self.to_yaml(contract_name=path.stem))
|
|
2029
|
+
|
|
2030
|
+
@classmethod
|
|
2031
|
+
def from_profile(
|
|
2032
|
+
cls,
|
|
2033
|
+
profile: "DatasetProfile",
|
|
2034
|
+
min_confidence: float = 0.5,
|
|
2035
|
+
) -> "Suggestions":
|
|
2036
|
+
"""
|
|
2037
|
+
Generate rule suggestions from a profile.
|
|
2038
|
+
|
|
2039
|
+
This is a basic implementation. More sophisticated analysis
|
|
2040
|
+
could be added based on profile depth/preset.
|
|
2041
|
+
"""
|
|
2042
|
+
rules: List[SuggestedRule] = []
|
|
2043
|
+
|
|
2044
|
+
for col in profile.columns:
|
|
2045
|
+
# not_null suggestion
|
|
2046
|
+
if col.null_rate == 0:
|
|
2047
|
+
rules.append(SuggestedRule(
|
|
2048
|
+
name="not_null",
|
|
2049
|
+
params={"column": col.name},
|
|
2050
|
+
confidence=1.0,
|
|
2051
|
+
reason=f"Column {col.name} has no nulls",
|
|
2052
|
+
))
|
|
2053
|
+
elif col.null_rate < 0.01: # < 1% nulls
|
|
2054
|
+
rules.append(SuggestedRule(
|
|
2055
|
+
name="not_null",
|
|
2056
|
+
params={"column": col.name},
|
|
2057
|
+
confidence=0.8,
|
|
2058
|
+
reason=f"Column {col.name} has very few nulls ({col.null_rate:.1%})",
|
|
2059
|
+
))
|
|
2060
|
+
|
|
2061
|
+
# unique suggestion
|
|
2062
|
+
if col.uniqueness_ratio == 1.0 and col.distinct_count > 1:
|
|
2063
|
+
rules.append(SuggestedRule(
|
|
2064
|
+
name="unique",
|
|
2065
|
+
params={"column": col.name},
|
|
2066
|
+
confidence=1.0,
|
|
2067
|
+
reason=f"Column {col.name} has all unique values",
|
|
2068
|
+
))
|
|
2069
|
+
elif col.uniqueness_ratio > 0.99:
|
|
2070
|
+
rules.append(SuggestedRule(
|
|
2071
|
+
name="unique",
|
|
2072
|
+
params={"column": col.name},
|
|
2073
|
+
confidence=0.7,
|
|
2074
|
+
reason=f"Column {col.name} is nearly unique ({col.uniqueness_ratio:.1%})",
|
|
2075
|
+
))
|
|
2076
|
+
|
|
2077
|
+
# dtype suggestion
|
|
2078
|
+
rules.append(SuggestedRule(
|
|
2079
|
+
name="dtype",
|
|
2080
|
+
params={"column": col.name, "type": col.dtype},
|
|
2081
|
+
confidence=1.0,
|
|
2082
|
+
reason=f"Column {col.name} is {col.dtype}",
|
|
2083
|
+
))
|
|
2084
|
+
|
|
2085
|
+
# allowed_values for low cardinality
|
|
2086
|
+
if col.is_low_cardinality and col.values:
|
|
2087
|
+
rules.append(SuggestedRule(
|
|
2088
|
+
name="allowed_values",
|
|
2089
|
+
params={"column": col.name, "values": col.values},
|
|
2090
|
+
confidence=0.9,
|
|
2091
|
+
reason=f"Column {col.name} has {len(col.values)} distinct values",
|
|
2092
|
+
))
|
|
2093
|
+
|
|
2094
|
+
# range for numeric
|
|
2095
|
+
if col.numeric and col.numeric.min is not None and col.numeric.max is not None:
|
|
2096
|
+
rules.append(SuggestedRule(
|
|
2097
|
+
name="range",
|
|
2098
|
+
params={
|
|
2099
|
+
"column": col.name,
|
|
2100
|
+
"min": col.numeric.min,
|
|
2101
|
+
"max": col.numeric.max,
|
|
2102
|
+
},
|
|
2103
|
+
confidence=0.7,
|
|
2104
|
+
reason=f"Column {col.name} ranges from {col.numeric.min} to {col.numeric.max}",
|
|
2105
|
+
))
|
|
2106
|
+
|
|
2107
|
+
# min_rows suggestion
|
|
2108
|
+
if profile.row_count > 0:
|
|
2109
|
+
# Suggest minimum as 80% of current count (or 1 if small dataset)
|
|
2110
|
+
min_rows = max(1, int(profile.row_count * 0.8))
|
|
2111
|
+
rules.append(SuggestedRule(
|
|
2112
|
+
name="min_rows",
|
|
2113
|
+
params={"threshold": min_rows},
|
|
2114
|
+
confidence=0.6,
|
|
2115
|
+
reason=f"Dataset has {profile.row_count:,} rows",
|
|
2116
|
+
))
|
|
2117
|
+
|
|
2118
|
+
# Filter by confidence
|
|
2119
|
+
filtered = [r for r in rules if r.confidence >= min_confidence]
|
|
2120
|
+
|
|
2121
|
+
return cls(filtered, source=profile.source_uri)
|