kontra 0.5.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (124) hide show
  1. kontra/__init__.py +1871 -0
  2. kontra/api/__init__.py +22 -0
  3. kontra/api/compare.py +340 -0
  4. kontra/api/decorators.py +153 -0
  5. kontra/api/results.py +2121 -0
  6. kontra/api/rules.py +681 -0
  7. kontra/cli/__init__.py +0 -0
  8. kontra/cli/commands/__init__.py +1 -0
  9. kontra/cli/commands/config.py +153 -0
  10. kontra/cli/commands/diff.py +450 -0
  11. kontra/cli/commands/history.py +196 -0
  12. kontra/cli/commands/profile.py +289 -0
  13. kontra/cli/commands/validate.py +468 -0
  14. kontra/cli/constants.py +6 -0
  15. kontra/cli/main.py +48 -0
  16. kontra/cli/renderers.py +304 -0
  17. kontra/cli/utils.py +28 -0
  18. kontra/config/__init__.py +34 -0
  19. kontra/config/loader.py +127 -0
  20. kontra/config/models.py +49 -0
  21. kontra/config/settings.py +797 -0
  22. kontra/connectors/__init__.py +0 -0
  23. kontra/connectors/db_utils.py +251 -0
  24. kontra/connectors/detection.py +323 -0
  25. kontra/connectors/handle.py +368 -0
  26. kontra/connectors/postgres.py +127 -0
  27. kontra/connectors/sqlserver.py +226 -0
  28. kontra/engine/__init__.py +0 -0
  29. kontra/engine/backends/duckdb_session.py +227 -0
  30. kontra/engine/backends/duckdb_utils.py +18 -0
  31. kontra/engine/backends/polars_backend.py +47 -0
  32. kontra/engine/engine.py +1205 -0
  33. kontra/engine/executors/__init__.py +15 -0
  34. kontra/engine/executors/base.py +50 -0
  35. kontra/engine/executors/database_base.py +528 -0
  36. kontra/engine/executors/duckdb_sql.py +607 -0
  37. kontra/engine/executors/postgres_sql.py +162 -0
  38. kontra/engine/executors/registry.py +69 -0
  39. kontra/engine/executors/sqlserver_sql.py +163 -0
  40. kontra/engine/materializers/__init__.py +14 -0
  41. kontra/engine/materializers/base.py +42 -0
  42. kontra/engine/materializers/duckdb.py +110 -0
  43. kontra/engine/materializers/factory.py +22 -0
  44. kontra/engine/materializers/polars_connector.py +131 -0
  45. kontra/engine/materializers/postgres.py +157 -0
  46. kontra/engine/materializers/registry.py +138 -0
  47. kontra/engine/materializers/sqlserver.py +160 -0
  48. kontra/engine/result.py +15 -0
  49. kontra/engine/sql_utils.py +611 -0
  50. kontra/engine/sql_validator.py +609 -0
  51. kontra/engine/stats.py +194 -0
  52. kontra/engine/types.py +138 -0
  53. kontra/errors.py +533 -0
  54. kontra/logging.py +85 -0
  55. kontra/preplan/__init__.py +5 -0
  56. kontra/preplan/planner.py +253 -0
  57. kontra/preplan/postgres.py +179 -0
  58. kontra/preplan/sqlserver.py +191 -0
  59. kontra/preplan/types.py +24 -0
  60. kontra/probes/__init__.py +20 -0
  61. kontra/probes/compare.py +400 -0
  62. kontra/probes/relationship.py +283 -0
  63. kontra/reporters/__init__.py +0 -0
  64. kontra/reporters/json_reporter.py +190 -0
  65. kontra/reporters/rich_reporter.py +11 -0
  66. kontra/rules/__init__.py +35 -0
  67. kontra/rules/base.py +186 -0
  68. kontra/rules/builtin/__init__.py +40 -0
  69. kontra/rules/builtin/allowed_values.py +156 -0
  70. kontra/rules/builtin/compare.py +188 -0
  71. kontra/rules/builtin/conditional_not_null.py +213 -0
  72. kontra/rules/builtin/conditional_range.py +310 -0
  73. kontra/rules/builtin/contains.py +138 -0
  74. kontra/rules/builtin/custom_sql_check.py +182 -0
  75. kontra/rules/builtin/disallowed_values.py +140 -0
  76. kontra/rules/builtin/dtype.py +203 -0
  77. kontra/rules/builtin/ends_with.py +129 -0
  78. kontra/rules/builtin/freshness.py +240 -0
  79. kontra/rules/builtin/length.py +193 -0
  80. kontra/rules/builtin/max_rows.py +35 -0
  81. kontra/rules/builtin/min_rows.py +46 -0
  82. kontra/rules/builtin/not_null.py +121 -0
  83. kontra/rules/builtin/range.py +222 -0
  84. kontra/rules/builtin/regex.py +143 -0
  85. kontra/rules/builtin/starts_with.py +129 -0
  86. kontra/rules/builtin/unique.py +124 -0
  87. kontra/rules/condition_parser.py +203 -0
  88. kontra/rules/execution_plan.py +455 -0
  89. kontra/rules/factory.py +103 -0
  90. kontra/rules/predicates.py +25 -0
  91. kontra/rules/registry.py +24 -0
  92. kontra/rules/static_predicates.py +120 -0
  93. kontra/scout/__init__.py +9 -0
  94. kontra/scout/backends/__init__.py +17 -0
  95. kontra/scout/backends/base.py +111 -0
  96. kontra/scout/backends/duckdb_backend.py +359 -0
  97. kontra/scout/backends/postgres_backend.py +519 -0
  98. kontra/scout/backends/sqlserver_backend.py +577 -0
  99. kontra/scout/dtype_mapping.py +150 -0
  100. kontra/scout/patterns.py +69 -0
  101. kontra/scout/profiler.py +801 -0
  102. kontra/scout/reporters/__init__.py +39 -0
  103. kontra/scout/reporters/json_reporter.py +165 -0
  104. kontra/scout/reporters/markdown_reporter.py +152 -0
  105. kontra/scout/reporters/rich_reporter.py +144 -0
  106. kontra/scout/store.py +208 -0
  107. kontra/scout/suggest.py +200 -0
  108. kontra/scout/types.py +652 -0
  109. kontra/state/__init__.py +29 -0
  110. kontra/state/backends/__init__.py +79 -0
  111. kontra/state/backends/base.py +348 -0
  112. kontra/state/backends/local.py +480 -0
  113. kontra/state/backends/postgres.py +1010 -0
  114. kontra/state/backends/s3.py +543 -0
  115. kontra/state/backends/sqlserver.py +969 -0
  116. kontra/state/fingerprint.py +166 -0
  117. kontra/state/types.py +1061 -0
  118. kontra/version.py +1 -0
  119. kontra-0.5.2.dist-info/METADATA +122 -0
  120. kontra-0.5.2.dist-info/RECORD +124 -0
  121. kontra-0.5.2.dist-info/WHEEL +5 -0
  122. kontra-0.5.2.dist-info/entry_points.txt +2 -0
  123. kontra-0.5.2.dist-info/licenses/LICENSE +17 -0
  124. kontra-0.5.2.dist-info/top_level.txt +1 -0
kontra/api/results.py ADDED
@@ -0,0 +1,2121 @@
1
+ # src/kontra/api/results.py
2
+ """
3
+ Public API result types for Kontra.
4
+
5
+ These classes wrap the internal state/result types with a cleaner interface
6
+ for the public Python API.
7
+ """
8
+
9
+ from __future__ import annotations
10
+
11
+ import json
12
+ from dataclasses import dataclass, field
13
+ from datetime import datetime
14
+ from pathlib import Path
15
+ from typing import Any, Dict, Iterator, List, Optional, Tuple, Union
16
+
17
+ import yaml
18
+
19
+
20
+ # --- Unique rule sampling helpers (shared by multiple methods) ---
21
+
22
+
23
+ def _is_unique_rule(rule: Any) -> bool:
24
+ """Check if a rule is a unique rule."""
25
+ return getattr(rule, "name", None) == "unique"
26
+
27
+
28
+ def _filter_samples_polars(
29
+ source: Any, # pl.DataFrame or pl.LazyFrame
30
+ rule: Any,
31
+ predicate: Any,
32
+ n: int,
33
+ ) -> Any: # pl.DataFrame
34
+ """
35
+ Filter samples with special handling for unique rule.
36
+
37
+ Works with both DataFrame and LazyFrame sources. Adds _row_index,
38
+ and for unique rules adds _duplicate_count sorted by worst offenders.
39
+
40
+ Args:
41
+ source: Polars DataFrame or LazyFrame
42
+ rule: Rule object (used to detect unique rule)
43
+ predicate: Polars expression for filtering
44
+ n: Maximum rows to return
45
+
46
+ Returns:
47
+ Polars DataFrame with filtered samples
48
+ """
49
+ import polars as pl
50
+
51
+ # Convert DataFrame to LazyFrame if needed
52
+ if isinstance(source, pl.DataFrame):
53
+ lf = source.lazy()
54
+ else:
55
+ lf = source
56
+
57
+ # Add row index
58
+ lf = lf.with_row_index("_row_index")
59
+
60
+ # Special case: unique rule - add duplicate count, sort by worst offenders
61
+ if _is_unique_rule(rule):
62
+ column = rule.params.get("column")
63
+ return (
64
+ lf.with_columns(
65
+ pl.col(column).count().over(column).alias("_duplicate_count")
66
+ )
67
+ .filter(predicate)
68
+ .sort("_duplicate_count", descending=True)
69
+ .head(n)
70
+ .collect()
71
+ )
72
+ else:
73
+ return lf.filter(predicate).head(n).collect()
74
+
75
+
76
+ def _build_unique_sample_query_sql(
77
+ table: str,
78
+ column: str,
79
+ n: int,
80
+ dialect: str,
81
+ ) -> str:
82
+ """
83
+ Build SQL query for sampling unique rule violations.
84
+
85
+ Returns query that finds duplicate values, orders by worst offenders,
86
+ and includes _duplicate_count and _row_index.
87
+
88
+ Args:
89
+ table: Fully qualified table name
90
+ column: Column being checked for uniqueness
91
+ n: Maximum rows to return
92
+ dialect: SQL dialect ("postgres", "mssql")
93
+
94
+ Returns:
95
+ SQL query string
96
+ """
97
+ col = f'"{column}"'
98
+
99
+ if dialect == "mssql":
100
+ return f"""
101
+ SELECT t.*, dup._duplicate_count,
102
+ ROW_NUMBER() OVER (ORDER BY dup._duplicate_count DESC) - 1 AS _row_index
103
+ FROM {table} t
104
+ JOIN (
105
+ SELECT {col}, COUNT(*) as _duplicate_count
106
+ FROM {table}
107
+ GROUP BY {col}
108
+ HAVING COUNT(*) > 1
109
+ ) dup ON t.{col} = dup.{col}
110
+ ORDER BY dup._duplicate_count DESC
111
+ OFFSET 0 ROWS FETCH FIRST {n} ROWS ONLY
112
+ """
113
+ else:
114
+ return f"""
115
+ SELECT t.*, dup._duplicate_count,
116
+ ROW_NUMBER() OVER (ORDER BY dup._duplicate_count DESC) - 1 AS _row_index
117
+ FROM {table} t
118
+ JOIN (
119
+ SELECT {col}, COUNT(*) as _duplicate_count
120
+ FROM {table}
121
+ GROUP BY {col}
122
+ HAVING COUNT(*) > 1
123
+ ) dup ON t.{col} = dup.{col}
124
+ ORDER BY dup._duplicate_count DESC
125
+ LIMIT {n}
126
+ """
127
+
128
+
129
+ # --- End unique rule helpers ---
130
+
131
+
132
+ class FailureSamples:
133
+ """
134
+ Collection of sample rows that failed a validation rule.
135
+
136
+ This class wraps a list of failing rows with serialization methods.
137
+ It's iterable and indexable like a list.
138
+
139
+ Properties:
140
+ rule_id: The rule ID these samples are from
141
+ count: Number of samples in this collection
142
+
143
+ Methods:
144
+ to_dict(): Convert to list of dicts
145
+ to_json(): Convert to JSON string
146
+ to_llm(): Token-optimized format for LLM context
147
+ """
148
+
149
+ def __init__(self, samples: List[Dict[str, Any]], rule_id: str):
150
+ self._samples = samples
151
+ self.rule_id = rule_id
152
+
153
+ def __repr__(self) -> str:
154
+ return f"FailureSamples({self.rule_id}, {len(self._samples)} rows)"
155
+
156
+ def __len__(self) -> int:
157
+ return len(self._samples)
158
+
159
+ def __iter__(self) -> Iterator[Dict[str, Any]]:
160
+ return iter(self._samples)
161
+
162
+ def __getitem__(self, index: int) -> Dict[str, Any]:
163
+ return self._samples[index]
164
+
165
+ def __bool__(self) -> bool:
166
+ return len(self._samples) > 0
167
+
168
+ @property
169
+ def count(self) -> int:
170
+ """Number of sample rows."""
171
+ return len(self._samples)
172
+
173
+ def to_dict(self) -> List[Dict[str, Any]]:
174
+ """Convert to list of dicts."""
175
+ return self._samples
176
+
177
+ def to_json(self, indent: Optional[int] = None) -> str:
178
+ """Convert to JSON string."""
179
+ return json.dumps(self._samples, indent=indent, default=str)
180
+
181
+ def to_llm(self) -> str:
182
+ """
183
+ Token-optimized format for LLM context.
184
+
185
+ Example output:
186
+ SAMPLES: COL:email:not_null (2 rows)
187
+ [0] row=1: id=2, email=None, status=active
188
+ [1] row=3: id=4, email=None, status=active
189
+
190
+ For unique rule:
191
+ SAMPLES: COL:user_id:unique (2 rows)
192
+ [0] row=5, dupes=3: user_id=123, name=Alice
193
+ [1] row=8, dupes=3: user_id=123, name=Bob
194
+ """
195
+ if not self._samples:
196
+ return f"SAMPLES: {self.rule_id} (0 rows)"
197
+
198
+ lines = [f"SAMPLES: {self.rule_id} ({len(self._samples)} rows)"]
199
+
200
+ for i, row in enumerate(self._samples[:10]): # Limit to 10 for token efficiency
201
+ # Extract special columns for prefix
202
+ row_idx = row.get("_row_index")
203
+ dup_count = row.get("_duplicate_count")
204
+
205
+ # Build prefix with metadata
206
+ prefix_parts = []
207
+ if row_idx is not None:
208
+ prefix_parts.append(f"row={row_idx}")
209
+ if dup_count is not None:
210
+ prefix_parts.append(f"dupes={dup_count}")
211
+ prefix = ", ".join(prefix_parts) + ": " if prefix_parts else ""
212
+
213
+ # Format remaining columns as compact key=value pairs
214
+ parts = []
215
+ for k, v in row.items():
216
+ # Skip special columns (already in prefix)
217
+ if k in ("_row_index", "_duplicate_count"):
218
+ continue
219
+ if v is None:
220
+ parts.append(f"{k}=None")
221
+ elif isinstance(v, str) and len(v) > 20:
222
+ parts.append(f"{k}={v[:20]}...")
223
+ else:
224
+ parts.append(f"{k}={v}")
225
+ lines.append(f"[{i}] {prefix}" + ", ".join(parts))
226
+
227
+ if len(self._samples) > 10:
228
+ lines.append(f"... +{len(self._samples) - 10} more rows")
229
+
230
+ return "\n".join(lines)
231
+
232
+
233
+ class SampleReason:
234
+ """Constants for why samples may be unavailable."""
235
+
236
+ UNAVAILABLE_METADATA = "unavailable_from_metadata" # Preplan tier - knows existence, not location
237
+ UNAVAILABLE_PASSED = "rule_passed" # No failures to sample
238
+ UNAVAILABLE_UNSUPPORTED = "rule_unsupported" # dtype, min_rows, etc. - no row-level samples
239
+ TRUNCATED_BUDGET = "budget_exhausted" # Global budget hit
240
+ TRUNCATED_LIMIT = "per_rule_limit" # Per-rule cap hit
241
+
242
+
243
+ @dataclass
244
+ class RuleResult:
245
+ """
246
+ Result for a single validation rule.
247
+
248
+ Properties:
249
+ rule_id: Unique identifier (e.g., "COL:user_id:not_null")
250
+ name: Rule type name (e.g., "not_null")
251
+ passed: Whether the rule passed
252
+ failed_count: Number of failing rows
253
+ violation_rate: Fraction of rows that failed (0.0-1.0), or None if passed
254
+ message: Human-readable result message
255
+ severity: "blocking" | "warning" | "info"
256
+ source: Measurement source ("metadata", "sql", "polars")
257
+ column: Column name if applicable
258
+ context: Consumer-defined metadata (owner, tags, fix_hint, etc.)
259
+ annotations: List of annotations on this rule (opt-in, loaded via get_run_with_annotations)
260
+ severity_weight: User-defined numeric weight (None if unconfigured)
261
+
262
+ Sampling properties:
263
+ samples: List of sample failing rows, or None if unavailable
264
+ samples_source: Where samples came from ("sql", "polars"), or None
265
+ samples_reason: Why samples unavailable (see SampleReason)
266
+ samples_truncated: True if more samples exist but were cut off
267
+ """
268
+
269
+ rule_id: str
270
+ name: str
271
+ passed: bool
272
+ failed_count: int
273
+ message: str
274
+ severity: str = "blocking"
275
+ source: str = "polars"
276
+ column: Optional[str] = None
277
+ details: Optional[Dict[str, Any]] = None
278
+ context: Optional[Dict[str, Any]] = None
279
+ failure_mode: Optional[str] = None # Semantic failure type (e.g., "config_error", "null_values")
280
+
281
+ # Sampling fields (eager sampling)
282
+ samples: Optional[List[Dict[str, Any]]] = None
283
+ samples_source: Optional[str] = None
284
+ samples_reason: Optional[str] = None
285
+ samples_truncated: bool = False
286
+
287
+ # Annotations (opt-in, loaded via get_run_with_annotations)
288
+ annotations: Optional[List[Dict[str, Any]]] = None
289
+
290
+ # LLM juice: user-defined severity weight (None if unconfigured)
291
+ severity_weight: Optional[float] = None
292
+
293
+ # For violation_rate computation (populated during result creation)
294
+ _total_rows: Optional[int] = field(default=None, repr=False)
295
+
296
+ @property
297
+ def violation_rate(self) -> Optional[float]:
298
+ """
299
+ Fraction of rows that failed this rule.
300
+
301
+ Returns:
302
+ Float 0.0-1.0, or None if:
303
+ - Rule passed (failed_count == 0)
304
+ - total_rows is 0 or unknown
305
+
306
+ Example:
307
+ for rule in result.rules:
308
+ if rule.violation_rate:
309
+ print(f"{rule.rule_id}: {rule.violation_rate:.2%} of rows failed")
310
+ """
311
+ if self.passed or self.failed_count == 0:
312
+ return None
313
+ if self._total_rows is None or self._total_rows == 0:
314
+ return None
315
+ return self.failed_count / self._total_rows
316
+
317
+ def __repr__(self) -> str:
318
+ status = "PASS" if self.passed else "FAIL"
319
+ if self.failed_count > 0:
320
+ return f"RuleResult({self.rule_id}) {status} - {self.failed_count:,} failures"
321
+ return f"RuleResult({self.rule_id}) {status}"
322
+
323
+ @classmethod
324
+ def from_dict(cls, d: Dict[str, Any]) -> "RuleResult":
325
+ """Create from engine result dict."""
326
+ rule_id = d.get("rule_id", "")
327
+
328
+ # Extract column from rule_id if present
329
+ column = None
330
+ if rule_id.startswith("COL:"):
331
+ parts = rule_id.split(":")
332
+ if len(parts) >= 2:
333
+ column = parts[1]
334
+
335
+ # Extract rule name
336
+ name = d.get("rule_name", d.get("name", ""))
337
+ if not name and ":" in rule_id:
338
+ name = rule_id.split(":")[-1]
339
+
340
+ return cls(
341
+ rule_id=rule_id,
342
+ name=name,
343
+ passed=d.get("passed", False),
344
+ failed_count=d.get("failed_count", 0),
345
+ message=d.get("message", ""),
346
+ severity=d.get("severity", "blocking"),
347
+ source=d.get("execution_source", d.get("source", "polars")),
348
+ column=column,
349
+ details=d.get("details"),
350
+ context=d.get("context"),
351
+ failure_mode=d.get("failure_mode"),
352
+ # Sampling fields
353
+ samples=d.get("samples"),
354
+ samples_source=d.get("samples_source"),
355
+ samples_reason=d.get("samples_reason"),
356
+ samples_truncated=d.get("samples_truncated", False),
357
+ # LLM juice
358
+ severity_weight=d.get("severity_weight"),
359
+ )
360
+
361
+ def to_dict(self) -> Dict[str, Any]:
362
+ """Convert to dictionary."""
363
+ d = {
364
+ "rule_id": self.rule_id,
365
+ "name": self.name,
366
+ "passed": self.passed,
367
+ "failed_count": self.failed_count,
368
+ "message": self.message,
369
+ "severity": self.severity,
370
+ "source": self.source,
371
+ }
372
+ # Include violation_rate if available
373
+ if self.violation_rate is not None:
374
+ d["violation_rate"] = self.violation_rate
375
+ if self.column:
376
+ d["column"] = self.column
377
+ if self.details:
378
+ d["details"] = self.details
379
+ if self.context:
380
+ d["context"] = self.context
381
+ if self.failure_mode:
382
+ d["failure_mode"] = self.failure_mode
383
+
384
+ # Sampling fields - always include for clarity
385
+ d["samples"] = self.samples # None = unavailable, [] = none found
386
+ if self.samples_source:
387
+ d["samples_source"] = self.samples_source
388
+ if self.samples_reason:
389
+ d["samples_reason"] = self.samples_reason
390
+ if self.samples_truncated:
391
+ d["samples_truncated"] = self.samples_truncated
392
+
393
+ # Annotations (opt-in)
394
+ if self.annotations is not None:
395
+ d["annotations"] = self.annotations
396
+
397
+ # LLM juice (only include if configured)
398
+ if self.severity_weight is not None:
399
+ d["severity_weight"] = self.severity_weight
400
+
401
+ return d
402
+
403
+ def to_llm(self) -> str:
404
+ """Token-optimized format for LLM context."""
405
+ status = "PASS" if self.passed else "FAIL"
406
+ parts = [f"{self.rule_id}: {status}"]
407
+
408
+ if self.failed_count > 0:
409
+ parts.append(f"({self.failed_count:,} failures)")
410
+
411
+ # Include violation_rate for failed rules (LLM juice)
412
+ if self.violation_rate is not None:
413
+ parts.append(f"[{self.violation_rate:.1%}]")
414
+
415
+ # Include severity weight if configured (LLM juice)
416
+ if self.severity_weight is not None:
417
+ parts.append(f"[w={self.severity_weight}]")
418
+
419
+ # Add samples if available
420
+ if self.samples:
421
+ parts.append(f"\n Samples ({len(self.samples)}):")
422
+ for i, row in enumerate(self.samples[:5]):
423
+ # Extract metadata
424
+ row_idx = row.get("_row_index")
425
+ dup_count = row.get("_duplicate_count")
426
+ prefix_parts = []
427
+ if row_idx is not None:
428
+ prefix_parts.append(f"row={row_idx}")
429
+ if dup_count is not None:
430
+ prefix_parts.append(f"dupes={dup_count}")
431
+ prefix = ", ".join(prefix_parts) + ": " if prefix_parts else ""
432
+
433
+ # Format data columns
434
+ data_parts = []
435
+ for k, v in row.items():
436
+ if k in ("_row_index", "_duplicate_count"):
437
+ continue
438
+ if v is None:
439
+ data_parts.append(f"{k}=None")
440
+ elif isinstance(v, str) and len(v) > 15:
441
+ data_parts.append(f"{k}={v[:15]}...")
442
+ else:
443
+ data_parts.append(f"{k}={v}")
444
+ parts.append(f" [{i}] {prefix}" + ", ".join(data_parts[:5]))
445
+ if len(self.samples) > 5:
446
+ parts.append(f" ... +{len(self.samples) - 5} more")
447
+ elif self.samples_reason:
448
+ parts.append(f"\n Samples: {self.samples_reason}")
449
+
450
+ # Add annotations if available
451
+ if self.annotations:
452
+ parts.append(f"\n Annotations ({len(self.annotations)}):")
453
+ for ann in self.annotations[:3]:
454
+ ann_type = ann.get("annotation_type", "note")
455
+ actor = ann.get("actor_id", "unknown")
456
+ summary = ann.get("summary", "")[:40]
457
+ if len(ann.get("summary", "")) > 40:
458
+ summary += "..."
459
+ parts.append(f' [{ann_type}] by {actor}: "{summary}"')
460
+ if len(self.annotations) > 3:
461
+ parts.append(f" ... +{len(self.annotations) - 3} more")
462
+
463
+ return " ".join(parts[:2]) + "".join(parts[2:])
464
+
465
+
466
+ @dataclass
467
+ class ValidationResult:
468
+ """
469
+ Result of a validation run.
470
+
471
+ Properties:
472
+ passed: True if all blocking rules passed
473
+ dataset: Dataset name/path
474
+ total_rows: Number of rows in the validated dataset
475
+ total_rules: Total number of rules evaluated
476
+ passed_count: Number of rules that passed
477
+ failed_count: Number of blocking rules that failed
478
+ warning_count: Number of warning rules that failed
479
+ rules: List of RuleResult objects
480
+ blocking_failures: List of failed blocking rules
481
+ warnings: List of failed warning rules
482
+ quality_score: Deterministic score 0.0-1.0 (None if weights unconfigured)
483
+ data: The validated DataFrame (if loaded), None if preplan/pushdown handled everything
484
+ stats: Optional statistics dict
485
+ annotations: List of run-level annotations (opt-in, loaded via get_run_with_annotations)
486
+
487
+ Methods:
488
+ sample_failures(rule_id, n=5): Get sample of failing rows for a rule
489
+
490
+ Note:
491
+ Each RuleResult has a `violation_rate` property for per-rule failure rates.
492
+ """
493
+
494
+ passed: bool
495
+ dataset: str
496
+ total_rows: int
497
+ total_rules: int
498
+ passed_count: int
499
+ failed_count: int
500
+ warning_count: int
501
+ rules: List[RuleResult]
502
+ stats: Optional[Dict[str, Any]] = None
503
+ annotations: Optional[List[Dict[str, Any]]] = None # Run-level annotations (opt-in)
504
+ _raw: Optional[Dict[str, Any]] = field(default=None, repr=False)
505
+ # For sample_failures() - lazy evaluation
506
+ _data_source: Optional[Any] = field(default=None, repr=False)
507
+ _rule_objects: Optional[List[Any]] = field(default=None, repr=False)
508
+ # Loaded data (if Polars execution occurred)
509
+ _data: Optional[Any] = field(default=None, repr=False)
510
+
511
+ @property
512
+ def data(self) -> Optional["pl.DataFrame"]:
513
+ """
514
+ The validated DataFrame, if data was loaded.
515
+
516
+ Returns the Polars DataFrame that was validated when:
517
+ - Polars execution occurred (residual rules needed data)
518
+ - A DataFrame was passed directly to validate()
519
+
520
+ Returns None when:
521
+ - All rules were resolved by preplan/pushdown (no data loaded)
522
+ - Data source was a file path and wasn't materialized
523
+
524
+ Example:
525
+ result = kontra.validate("data.parquet", rules=[...], preplan="off", pushdown="off")
526
+ if result.passed and result.data is not None:
527
+ # Use the already-loaded data
528
+ process(result.data)
529
+ """
530
+ return self._data
531
+
532
+ def __repr__(self) -> str:
533
+ status = "PASSED" if self.passed else "FAILED"
534
+ parts = [f"ValidationResult({self.dataset}) {status}"]
535
+ parts.append(f" Total: {self.total_rules} rules | Passed: {self.passed_count} | Failed: {self.failed_count}")
536
+ if self.warning_count > 0:
537
+ parts.append(f" Warnings: {self.warning_count}")
538
+ if not self.passed:
539
+ blocking = [r.rule_id for r in self.blocking_failures[:3]]
540
+ if blocking:
541
+ parts.append(f" Blocking: {', '.join(blocking)}")
542
+ if len(self.blocking_failures) > 3:
543
+ parts.append(f" ... and {len(self.blocking_failures) - 3} more")
544
+ return "\n".join(parts)
545
+
546
+ @property
547
+ def blocking_failures(self) -> List[RuleResult]:
548
+ """Get all failed blocking rules."""
549
+ return [r for r in self.rules if not r.passed and r.severity == "blocking"]
550
+
551
+ @property
552
+ def warnings(self) -> List[RuleResult]:
553
+ """Get all failed warning rules."""
554
+ return [r for r in self.rules if not r.passed and r.severity == "warning"]
555
+
556
+ @property
557
+ def quality_score(self) -> Optional[float]:
558
+ """
559
+ Deterministic quality score derived from violation data.
560
+
561
+ Formula:
562
+ quality_score = 1.0 - weighted_violation_rate
563
+ weighted_violation_rate = Σ(failed_count * severity_weight) / (total_rows * Σ(weights))
564
+
565
+ Returns:
566
+ Float 0.0-1.0, or None if:
567
+ - severity_weights not configured
568
+ - total_rows is 0
569
+ - No rules have weights
570
+
571
+ Note:
572
+ This score is pure data - Kontra never interprets it as "good" or "bad".
573
+ Consumers/agents use it for trend reasoning.
574
+ """
575
+ # Check if any rule has a weight (weights configured)
576
+ rules_with_weights = [r for r in self.rules if r.severity_weight is not None]
577
+ if not rules_with_weights:
578
+ return None
579
+
580
+ # Avoid division by zero
581
+ if self.total_rows == 0:
582
+ return None
583
+
584
+ # Calculate weighted violation sum
585
+ weighted_violations = sum(
586
+ r.failed_count * r.severity_weight
587
+ for r in rules_with_weights
588
+ )
589
+
590
+ # Calculate total possible weighted violations
591
+ # (if every row failed every rule)
592
+ total_weight = sum(r.severity_weight for r in rules_with_weights)
593
+ max_weighted_violations = self.total_rows * total_weight
594
+
595
+ if max_weighted_violations == 0:
596
+ return 1.0 # No possible violations
597
+
598
+ # Quality = 1 - violation_rate
599
+ violation_rate = weighted_violations / max_weighted_violations
600
+ return max(0.0, min(1.0, 1.0 - violation_rate))
601
+
602
+ @classmethod
603
+ def from_engine_result(
604
+ cls,
605
+ result: Dict[str, Any],
606
+ dataset: str = "unknown",
607
+ data_source: Optional[Any] = None,
608
+ rule_objects: Optional[List[Any]] = None,
609
+ sample: int = 5,
610
+ sample_budget: int = 50,
611
+ sample_columns: Optional[Union[List[str], str]] = None,
612
+ severity_weights: Optional[Dict[str, float]] = None,
613
+ data: Optional[Any] = None,
614
+ ) -> "ValidationResult":
615
+ """Create from ValidationEngine.run() result dict.
616
+
617
+ Args:
618
+ result: Engine result dict
619
+ dataset: Dataset name (fallback)
620
+ data_source: Original data source for lazy sample_failures()
621
+ rule_objects: Rule objects for sample_failures() predicates
622
+ sample: Per-rule sample cap (0 to disable)
623
+ sample_budget: Global sample cap across all rules
624
+ sample_columns: Columns to include in samples (None=all, list=specific, "relevant"=rule columns)
625
+ severity_weights: User-defined severity weights from config (None if unconfigured)
626
+ data: Loaded DataFrame (if Polars execution occurred)
627
+ """
628
+ summary = result.get("summary", {})
629
+ results_list = result.get("results", [])
630
+
631
+ # Convert raw results to RuleResult objects
632
+ rules = [RuleResult.from_dict(r) for r in results_list]
633
+
634
+ # Populate context from rule objects if available
635
+ if rule_objects is not None:
636
+ context_map = {
637
+ getattr(r, "rule_id", r.name): getattr(r, "context", {})
638
+ for r in rule_objects
639
+ }
640
+ for rule_result in rules:
641
+ ctx = context_map.get(rule_result.rule_id)
642
+ if ctx:
643
+ rule_result.context = ctx
644
+
645
+ # Populate severity weights from config (LLM juice)
646
+ if severity_weights is not None:
647
+ for rule_result in rules:
648
+ weight = severity_weights.get(rule_result.severity)
649
+ if weight is not None:
650
+ rule_result.severity_weight = weight
651
+
652
+ # Calculate counts
653
+ total = summary.get("total_rules", len(rules))
654
+ passed_count = summary.get("rules_passed", sum(1 for r in rules if r.passed))
655
+
656
+ # Count by severity
657
+ blocking_failed = sum(1 for r in rules if not r.passed and r.severity == "blocking")
658
+ warning_failed = sum(1 for r in rules if not r.passed and r.severity == "warning")
659
+
660
+ # Extract total_rows from summary
661
+ total_rows = summary.get("total_rows", 0)
662
+
663
+ # Populate _total_rows on each rule for violation_rate property
664
+ for rule_result in rules:
665
+ rule_result._total_rows = total_rows
666
+
667
+ # Create instance first (need it for sampling methods)
668
+ instance = cls(
669
+ passed=summary.get("passed", blocking_failed == 0),
670
+ dataset=summary.get("dataset_name", dataset),
671
+ total_rows=total_rows,
672
+ total_rules=total,
673
+ passed_count=passed_count,
674
+ failed_count=blocking_failed,
675
+ warning_count=warning_failed,
676
+ rules=rules,
677
+ stats=result.get("stats"),
678
+ _raw=result,
679
+ _data_source=data_source,
680
+ _rule_objects=rule_objects,
681
+ _data=data,
682
+ )
683
+
684
+ # Perform eager sampling if enabled
685
+ if sample > 0 and rule_objects is not None:
686
+ instance._perform_eager_sampling(sample, sample_budget, rule_objects, sample_columns)
687
+
688
+ return instance
689
+
690
+ def _perform_eager_sampling(
691
+ self,
692
+ per_rule_cap: int,
693
+ global_budget: int,
694
+ rule_objects: List[Any],
695
+ sample_columns: Optional[Union[List[str], str]] = None,
696
+ ) -> None:
697
+ """
698
+ Populate samples for each rule (eager sampling).
699
+
700
+ Uses batched SQL queries when possible (1 query for all rules).
701
+ Falls back to per-rule Polars sampling when SQL not supported.
702
+
703
+ Args:
704
+ per_rule_cap: Max samples per rule
705
+ global_budget: Total samples across all rules
706
+ rule_objects: Rule objects for predicates
707
+ sample_columns: Columns to include (None=all, list=specific, "relevant"=rule columns)
708
+ """
709
+ import polars as pl
710
+
711
+ # Build rule_id -> rule_object map
712
+ rule_map = {getattr(r, "rule_id", None): r for r in rule_objects}
713
+
714
+ # Sort rules by failed_count descending (worst offenders first)
715
+ sorted_rules = sorted(
716
+ self.rules,
717
+ key=lambda r: r.failed_count if not r.passed else 0,
718
+ reverse=True,
719
+ )
720
+
721
+ remaining_budget = global_budget
722
+
723
+ # Phase 1: Collect rules that can use SQL batching
724
+ sql_rules: List[Tuple[Any, Any, int, Optional[List[str]]]] = [] # (rule_result, rule_obj, n, columns)
725
+ polars_rules: List[Tuple[Any, Any, Any, int, Optional[List[str]]]] = [] # (rule_result, rule_obj, predicate, n, columns)
726
+
727
+ for rule_result in sorted_rules:
728
+ # Handle passing rules
729
+ if rule_result.passed:
730
+ rule_result.samples = []
731
+ rule_result.samples_reason = SampleReason.UNAVAILABLE_PASSED
732
+ continue
733
+
734
+ # Check budget
735
+ if remaining_budget <= 0:
736
+ rule_result.samples = None
737
+ rule_result.samples_reason = SampleReason.TRUNCATED_BUDGET
738
+ rule_result.samples_truncated = True
739
+ continue
740
+
741
+ # Get corresponding rule object
742
+ rule_obj = rule_map.get(rule_result.rule_id)
743
+ if rule_obj is None:
744
+ rule_result.samples = None
745
+ rule_result.samples_reason = SampleReason.UNAVAILABLE_UNSUPPORTED
746
+ continue
747
+
748
+ # Check if rule was resolved via metadata (preplan)
749
+ if rule_result.source == "metadata":
750
+ if not self._can_sample_source():
751
+ rule_result.samples = None
752
+ rule_result.samples_reason = SampleReason.UNAVAILABLE_METADATA
753
+ continue
754
+
755
+ # Calculate samples to get for this rule
756
+ n = min(per_rule_cap, remaining_budget)
757
+ remaining_budget -= n # Reserve budget
758
+
759
+ # Determine columns to include
760
+ cols_to_include = self._resolve_sample_columns(sample_columns, rule_obj)
761
+
762
+ # Check if rule supports SQL filter (for batching)
763
+ # Note: to_sql_filter can return None to indicate no SQL support for this dialect
764
+ sql_filter_available = False
765
+ if hasattr(rule_obj, "to_sql_filter"):
766
+ # Probe to see if SQL filter is actually available
767
+ sql_filter = rule_obj.to_sql_filter("duckdb")
768
+ if sql_filter is not None:
769
+ sql_rules.append((rule_result, rule_obj, n, cols_to_include))
770
+ sql_filter_available = True
771
+
772
+ # Fall back to sample_predicate/compile_predicate if SQL not available
773
+ if not sql_filter_available:
774
+ # Check sample_predicate() first (for rules like unique that have
775
+ # different counting vs sampling semantics)
776
+ pred_obj = None
777
+ if hasattr(rule_obj, "sample_predicate"):
778
+ pred_obj = rule_obj.sample_predicate()
779
+ if pred_obj is None and hasattr(rule_obj, "compile_predicate"):
780
+ pred_obj = rule_obj.compile_predicate()
781
+
782
+ if pred_obj is not None:
783
+ polars_rules.append((rule_result, rule_obj, pred_obj.expr, n, cols_to_include))
784
+ else:
785
+ rule_result.samples = None
786
+ rule_result.samples_reason = SampleReason.UNAVAILABLE_UNSUPPORTED
787
+
788
+ # Phase 2: Execute batched SQL sampling if applicable
789
+ if sql_rules:
790
+ self._execute_batched_sql_sampling(sql_rules, per_rule_cap)
791
+
792
+ # Phase 3: Execute per-rule Polars sampling for remaining rules
793
+ for rule_result, rule_obj, predicate, n, cols_to_include in polars_rules:
794
+ try:
795
+ samples, samples_source = self._collect_samples_for_rule(rule_obj, predicate, n, cols_to_include)
796
+ rule_result.samples = samples
797
+ rule_result.samples_source = samples_source
798
+
799
+ if len(samples) == per_rule_cap and rule_result.failed_count > per_rule_cap:
800
+ rule_result.samples_truncated = True
801
+ rule_result.samples_reason = SampleReason.TRUNCATED_LIMIT
802
+
803
+ except Exception as e:
804
+ rule_result.samples = None
805
+ rule_result.samples_reason = f"error: {str(e)[:50]}"
806
+
807
+ def _execute_batched_sql_sampling(
808
+ self,
809
+ sql_rules: List[Tuple[Any, Any, int, Optional[List[str]]]],
810
+ per_rule_cap: int,
811
+ ) -> None:
812
+ """
813
+ Execute batched SQL sampling for rules that support to_sql_filter().
814
+
815
+ Builds a single UNION ALL query for all rules, executes once,
816
+ and distributes results.
817
+ """
818
+ import polars as pl
819
+
820
+ source = self._data_source
821
+ if source is None:
822
+ return
823
+
824
+ # Determine dialect and source type
825
+ dialect = "duckdb"
826
+ is_parquet = False
827
+ is_database = False
828
+ db_conn = None
829
+ db_table = None
830
+ parquet_path = None
831
+
832
+ is_s3 = False
833
+
834
+ if isinstance(source, str):
835
+ lower = source.lower()
836
+ if source.startswith("s3://") or source.startswith("http://") or source.startswith("https://"):
837
+ # Remote file - use DuckDB for efficient sampling
838
+ is_parquet = True
839
+ is_s3 = True
840
+ parquet_path = source
841
+ elif lower.endswith(".parquet") or lower.endswith(".csv"):
842
+ # Local file - use Polars (more efficient for local)
843
+ is_parquet = False # Will fall through to Polars path
844
+ elif hasattr(source, "scheme"):
845
+ handle = source
846
+ if handle.scheme in ("postgres", "postgresql"):
847
+ is_database = True
848
+ dialect = "postgres"
849
+ db_conn = getattr(handle, "external_conn", None)
850
+ db_table = getattr(handle, "table_ref", None) or f'"{handle.schema}"."{handle.path}"'
851
+ elif handle.scheme == "mssql":
852
+ is_database = True
853
+ dialect = "mssql"
854
+ db_conn = getattr(handle, "external_conn", None)
855
+ db_table = getattr(handle, "table_ref", None) or f"[{handle.schema}].[{handle.path}]"
856
+ elif handle.scheme == "s3":
857
+ # S3 via handle - use DuckDB
858
+ is_parquet = True
859
+ is_s3 = True
860
+ parquet_path = handle.uri
861
+ elif handle.scheme in ("", "file"):
862
+ # Local file via handle - use Polars
863
+ is_parquet = False
864
+
865
+ # Build list of (rule_id, sql_filter, limit, columns) for batching
866
+ rules_to_sample: List[Tuple[str, str, int, Optional[List[str]]]] = []
867
+
868
+ for rule_result, rule_obj, n, cols_to_include in sql_rules:
869
+ sql_filter = rule_obj.to_sql_filter(dialect)
870
+ if sql_filter:
871
+ rules_to_sample.append((rule_result.rule_id, sql_filter, n, cols_to_include))
872
+ else:
873
+ # Rule doesn't support this dialect, mark as unsupported
874
+ rule_result.samples = None
875
+ rule_result.samples_reason = SampleReason.UNAVAILABLE_UNSUPPORTED
876
+
877
+ if not rules_to_sample:
878
+ return
879
+
880
+ # Execute batched query
881
+ try:
882
+ if is_parquet and is_s3 and parquet_path:
883
+ # S3/remote: Use DuckDB batched sampling (much faster than Polars)
884
+ results = self._batch_sample_parquet_duckdb(parquet_path, rules_to_sample)
885
+ samples_source = "sql"
886
+ elif is_database and db_conn and db_table:
887
+ results = self._batch_sample_db(db_conn, db_table, rules_to_sample, dialect)
888
+ samples_source = "sql"
889
+ else:
890
+ # Fall back to per-rule sampling
891
+ results = {}
892
+ samples_source = "polars"
893
+ for rule_result, rule_obj, n, cols_to_include in sql_rules:
894
+ if hasattr(rule_obj, "compile_predicate"):
895
+ pred_obj = rule_obj.compile_predicate()
896
+ if pred_obj:
897
+ try:
898
+ samples, src = self._collect_samples_for_rule(rule_obj, pred_obj.expr, n, cols_to_include)
899
+ results[rule_result.rule_id] = samples
900
+ samples_source = src
901
+ except Exception:
902
+ results[rule_result.rule_id] = []
903
+
904
+ # Distribute results to rule_result objects
905
+ for rule_result, rule_obj, n, cols_to_include in sql_rules:
906
+ samples = results.get(rule_result.rule_id, [])
907
+ rule_result.samples = samples
908
+ rule_result.samples_source = samples_source
909
+
910
+ if len(samples) == n and rule_result.failed_count > n:
911
+ rule_result.samples_truncated = True
912
+ rule_result.samples_reason = SampleReason.TRUNCATED_LIMIT
913
+
914
+ except Exception as e:
915
+ # Batched sampling failed, mark all rules
916
+ for rule_result, rule_obj, n, cols_to_include in sql_rules:
917
+ rule_result.samples = None
918
+ rule_result.samples_reason = f"error: {str(e)[:50]}"
919
+
920
+ def _can_sample_source(self) -> bool:
921
+ """
922
+ Check if the data source supports sampling.
923
+
924
+ File-based sources (Parquet, CSV, S3) can always be sampled.
925
+ Database sources need a live connection.
926
+
927
+ Returns:
928
+ True if sampling is possible, False otherwise.
929
+ """
930
+ import polars as pl
931
+
932
+ source = self._data_source
933
+
934
+ if source is None:
935
+ return False
936
+
937
+ # DataFrame - always sampleable
938
+ if isinstance(source, pl.DataFrame):
939
+ return True
940
+
941
+ # String path - file based, always sampleable
942
+ if isinstance(source, str):
943
+ return True
944
+
945
+ # DatasetHandle - check scheme and connection
946
+ if hasattr(source, "scheme"):
947
+ scheme = getattr(source, "scheme", None)
948
+
949
+ # File-based schemes - always sampleable
950
+ if scheme in (None, "file") or (hasattr(source, "uri") and source.uri):
951
+ uri = getattr(source, "uri", "")
952
+ if uri.lower().endswith((".parquet", ".csv")) or uri.startswith("s3://"):
953
+ return True
954
+
955
+ # BYOC or database with connection - check if connection exists
956
+ if hasattr(source, "external_conn") and source.external_conn is not None:
957
+ return True
958
+
959
+ # Database without connection - can't sample
960
+ if scheme in ("postgres", "postgresql", "mssql"):
961
+ return False
962
+
963
+ return True # Default to sampleable
964
+
965
+ def _resolve_sample_columns(
966
+ self,
967
+ sample_columns: Optional[Union[List[str], str]],
968
+ rule_obj: Any,
969
+ ) -> Optional[List[str]]:
970
+ """
971
+ Resolve sample_columns to a list of column names.
972
+
973
+ Args:
974
+ sample_columns: None (all), list of names, or "relevant"
975
+ rule_obj: Rule object for "relevant" mode
976
+
977
+ Returns:
978
+ List of column names to include, or None for all columns
979
+ """
980
+ if sample_columns is None:
981
+ return None
982
+
983
+ if isinstance(sample_columns, list):
984
+ return sample_columns
985
+
986
+ if sample_columns == "relevant":
987
+ # Get columns from rule's required_columns() if available
988
+ cols = set()
989
+ if hasattr(rule_obj, "required_columns"):
990
+ cols.update(rule_obj.required_columns())
991
+
992
+ # Also check params for column names (required_columns() may be incomplete)
993
+ if hasattr(rule_obj, "params"):
994
+ params = rule_obj.params
995
+ if "column" in params:
996
+ cols.add(params["column"])
997
+ if "left" in params:
998
+ cols.add(params["left"])
999
+ if "right" in params:
1000
+ cols.add(params["right"])
1001
+ if "when_column" in params:
1002
+ cols.add(params["when_column"])
1003
+
1004
+ return list(cols) if cols else None
1005
+
1006
+ # Unknown value - return all columns
1007
+ return None
1008
+
1009
+ def _collect_samples_for_rule(
1010
+ self,
1011
+ rule_obj: Any,
1012
+ predicate: Any,
1013
+ n: int,
1014
+ columns: Optional[List[str]] = None,
1015
+ ) -> Tuple[List[Dict[str, Any]], str]:
1016
+ """
1017
+ Collect sample rows for a single rule.
1018
+
1019
+ Uses the existing sampling infrastructure (SQL pushdown, Parquet predicate, etc.)
1020
+
1021
+ Args:
1022
+ rule_obj: Rule object
1023
+ predicate: Polars expression for filtering
1024
+ n: Number of samples to collect
1025
+ columns: Columns to include (None = all)
1026
+
1027
+ Returns:
1028
+ Tuple of (samples list, source string "sql" or "polars")
1029
+ """
1030
+ import polars as pl
1031
+
1032
+ source = self._data_source
1033
+
1034
+ if source is None:
1035
+ return [], "polars"
1036
+
1037
+ # Reuse existing loading/filtering logic
1038
+ df, load_source = self._load_data_for_sampling(rule_obj, n)
1039
+
1040
+ # If data was already filtered by SQL/DuckDB, just apply column projection
1041
+ if load_source == "sql":
1042
+ result_df = df.head(n)
1043
+ return self._apply_column_projection(result_df, columns), "sql"
1044
+
1045
+ # For Polars path, filter with predicate (unique rule handled by helper)
1046
+ result_df = _filter_samples_polars(df, rule_obj, predicate, n)
1047
+
1048
+ # For unique rule, always include _duplicate_count in projection
1049
+ if _is_unique_rule(rule_obj) and columns is not None:
1050
+ columns = list(columns) + ["_duplicate_count"]
1051
+
1052
+ return self._apply_column_projection(result_df, columns), "polars"
1053
+
1054
+ def _apply_column_projection(
1055
+ self,
1056
+ df: Any,
1057
+ columns: Optional[List[str]],
1058
+ ) -> List[Dict[str, Any]]:
1059
+ """
1060
+ Apply column projection to a DataFrame before converting to dicts.
1061
+
1062
+ Always includes _row_index if present.
1063
+
1064
+ Args:
1065
+ df: Polars DataFrame
1066
+ columns: Columns to include (None = all)
1067
+
1068
+ Returns:
1069
+ List of row dicts
1070
+ """
1071
+ if columns is None:
1072
+ return df.to_dicts()
1073
+
1074
+ # Always include _row_index and _duplicate_count if present
1075
+ cols_to_select = set(columns)
1076
+ if "_row_index" in df.columns:
1077
+ cols_to_select.add("_row_index")
1078
+ if "_duplicate_count" in df.columns:
1079
+ cols_to_select.add("_duplicate_count")
1080
+
1081
+ # Only select columns that exist in the DataFrame
1082
+ available_cols = set(df.columns)
1083
+ cols_to_select = cols_to_select & available_cols
1084
+
1085
+ if not cols_to_select:
1086
+ return df.to_dicts()
1087
+
1088
+ return df.select(sorted(cols_to_select)).to_dicts()
1089
+
1090
+ def to_dict(self) -> Dict[str, Any]:
1091
+ """Convert to dictionary."""
1092
+ d = {
1093
+ "passed": self.passed,
1094
+ "dataset": self.dataset,
1095
+ "total_rows": self.total_rows,
1096
+ "total_rules": self.total_rules,
1097
+ "passed_count": self.passed_count,
1098
+ "failed_count": self.failed_count,
1099
+ "warning_count": self.warning_count,
1100
+ "rules": [r.to_dict() for r in self.rules],
1101
+ "stats": self.stats,
1102
+ }
1103
+ if self.annotations is not None:
1104
+ d["annotations"] = self.annotations
1105
+ # LLM juice (only include if configured)
1106
+ if self.quality_score is not None:
1107
+ d["quality_score"] = self.quality_score
1108
+ return d
1109
+
1110
+ def to_json(self, indent: Optional[int] = None) -> str:
1111
+ """Convert to JSON string."""
1112
+ return json.dumps(self.to_dict(), indent=indent, default=str)
1113
+
1114
+ def to_llm(self) -> str:
1115
+ """
1116
+ Token-optimized format for LLM context.
1117
+
1118
+ Example output:
1119
+ VALIDATION: my_contract FAILED (1000 rows)
1120
+ BLOCKING: COL:email:not_null (523 nulls), COL:status:allowed_values (12 invalid)
1121
+ WARNING: COL:age:range (3 out of bounds)
1122
+ PASSED: 15 rules
1123
+ """
1124
+ lines = []
1125
+
1126
+ status = "PASSED" if self.passed else "FAILED"
1127
+ rows_str = f" ({self.total_rows:,} rows)" if self.total_rows > 0 else ""
1128
+ score_str = f" [score={self.quality_score:.2f}]" if self.quality_score is not None else ""
1129
+ lines.append(f"VALIDATION: {self.dataset} {status}{rows_str}{score_str}")
1130
+
1131
+ # Blocking failures
1132
+ blocking = self.blocking_failures
1133
+ if blocking:
1134
+ parts = []
1135
+ for r in blocking[:5]:
1136
+ count = f"({r.failed_count:,})" if r.failed_count > 0 else ""
1137
+ parts.append(f"{r.rule_id} {count}".strip())
1138
+ line = "BLOCKING: " + ", ".join(parts)
1139
+ if len(blocking) > 5:
1140
+ line += f" ... +{len(blocking) - 5} more"
1141
+ lines.append(line)
1142
+
1143
+ # Warnings
1144
+ warnings = self.warnings
1145
+ if warnings:
1146
+ parts = []
1147
+ for r in warnings[:5]:
1148
+ count = f"({r.failed_count:,})" if r.failed_count > 0 else ""
1149
+ parts.append(f"{r.rule_id} {count}".strip())
1150
+ line = "WARNING: " + ", ".join(parts)
1151
+ if len(warnings) > 5:
1152
+ line += f" ... +{len(warnings) - 5} more"
1153
+ lines.append(line)
1154
+
1155
+ # Passed summary
1156
+ lines.append(f"PASSED: {self.passed_count} rules")
1157
+
1158
+ # Run-level annotations
1159
+ if self.annotations:
1160
+ lines.append(f"ANNOTATIONS ({len(self.annotations)}):")
1161
+ for ann in self.annotations[:3]:
1162
+ ann_type = ann.get("annotation_type", "note")
1163
+ actor = ann.get("actor_id", "unknown")
1164
+ summary = ann.get("summary", "")[:50]
1165
+ if len(ann.get("summary", "")) > 50:
1166
+ summary += "..."
1167
+ lines.append(f' [{ann_type}] by {actor}: "{summary}"')
1168
+ if len(self.annotations) > 3:
1169
+ lines.append(f" ... +{len(self.annotations) - 3} more")
1170
+
1171
+ return "\n".join(lines)
1172
+
1173
+ def sample_failures(
1174
+ self,
1175
+ rule_id: str,
1176
+ n: int = 5,
1177
+ *,
1178
+ upgrade_tier: bool = False,
1179
+ ) -> FailureSamples:
1180
+ """
1181
+ Get a sample of rows that failed a specific rule.
1182
+
1183
+ If eager sampling is enabled (default), this returns cached samples.
1184
+ Otherwise, it lazily re-queries the data source.
1185
+
1186
+ Args:
1187
+ rule_id: The rule ID to get failures for (e.g., "COL:email:not_null")
1188
+ n: Number of sample rows to return (default: 5, max: 100)
1189
+ upgrade_tier: If True, re-execute rules resolved via metadata
1190
+ tier to get actual samples. Required for preplan rules.
1191
+
1192
+ Returns:
1193
+ FailureSamples: Collection of failing rows with "_row_index" field.
1194
+ Supports to_dict(), to_json(), to_llm() methods.
1195
+ Empty if the rule passed (no failures).
1196
+
1197
+ Raises:
1198
+ ValueError: If rule_id not found or rule doesn't support row-level samples
1199
+ RuntimeError: If data source is unavailable for re-query,
1200
+ or if samples unavailable from metadata tier without upgrade_tier=True
1201
+
1202
+ Example:
1203
+ result = kontra.validate("data.parquet", contract)
1204
+ if not result.passed:
1205
+ samples = result.sample_failures("COL:email:not_null", n=5)
1206
+ for row in samples:
1207
+ print(f"Row {row['_row_index']}: {row}")
1208
+
1209
+ # For metadata-tier rules, use upgrade_tier to get samples:
1210
+ samples = result.sample_failures("COL:id:not_null", upgrade_tier=True)
1211
+ """
1212
+ import polars as pl
1213
+
1214
+ # Cap n at 100
1215
+ n = min(n, 100)
1216
+
1217
+ # Find the rule result
1218
+ rule_result = None
1219
+ for r in self.rules:
1220
+ if r.rule_id == rule_id:
1221
+ rule_result = r
1222
+ break
1223
+
1224
+ if rule_result is None:
1225
+ raise ValueError(f"Rule not found: {rule_id}")
1226
+
1227
+ # If rule passed, return empty FailureSamples
1228
+ if rule_result.passed:
1229
+ return FailureSamples([], rule_id)
1230
+
1231
+ # Check for cached samples first
1232
+ if rule_result.samples is not None:
1233
+ # Have cached samples - return if n <= cached, else fetch more
1234
+ if len(rule_result.samples) >= n:
1235
+ return FailureSamples(rule_result.samples[:n], rule_id)
1236
+ # Need more samples than cached - fall through to lazy path
1237
+
1238
+ # Handle unavailable samples
1239
+ if rule_result.samples_reason == SampleReason.UNAVAILABLE_METADATA:
1240
+ if not upgrade_tier:
1241
+ raise RuntimeError(
1242
+ f"Samples unavailable for {rule_id}: rule was resolved via metadata tier. "
1243
+ "Use upgrade_tier=True to re-execute and get samples."
1244
+ )
1245
+ # Fall through to lazy path for tier upgrade
1246
+
1247
+ elif rule_result.samples_reason == SampleReason.UNAVAILABLE_UNSUPPORTED:
1248
+ raise ValueError(
1249
+ f"Rule '{rule_result.name}' does not support row-level samples. "
1250
+ "Dataset-level rules (min_rows, max_rows, freshness, etc.) "
1251
+ "cannot identify specific failing rows."
1252
+ )
1253
+
1254
+ # Find the rule object to get the failure predicate
1255
+ if self._rule_objects is None:
1256
+ raise RuntimeError(
1257
+ "sample_failures() requires rule objects. "
1258
+ "This may happen if ValidationResult was created manually."
1259
+ )
1260
+
1261
+ rule_obj = None
1262
+ for r in self._rule_objects:
1263
+ if getattr(r, "rule_id", None) == rule_id:
1264
+ rule_obj = r
1265
+ break
1266
+
1267
+ if rule_obj is None:
1268
+ raise ValueError(f"Rule object not found for: {rule_id}")
1269
+
1270
+ # Get the failure predicate
1271
+ # Check sample_predicate() first (used by rules like unique that have
1272
+ # different counting vs sampling semantics), then fall back to compile_predicate()
1273
+ predicate = None
1274
+ if hasattr(rule_obj, "sample_predicate"):
1275
+ pred_obj = rule_obj.sample_predicate()
1276
+ if pred_obj is not None:
1277
+ predicate = pred_obj.expr
1278
+ if predicate is None and hasattr(rule_obj, "compile_predicate"):
1279
+ pred_obj = rule_obj.compile_predicate()
1280
+ if pred_obj is not None:
1281
+ predicate = pred_obj.expr
1282
+
1283
+ if predicate is None:
1284
+ raise ValueError(
1285
+ f"Rule '{rule_obj.name}' does not support row-level samples. "
1286
+ "Dataset-level rules (min_rows, max_rows, freshness, etc.) "
1287
+ "cannot identify specific failing rows."
1288
+ )
1289
+
1290
+ # Load the data
1291
+ if self._data_source is None:
1292
+ raise RuntimeError(
1293
+ "sample_failures() requires data source reference. "
1294
+ "This may happen if ValidationResult was created manually "
1295
+ "or the data source is no longer available."
1296
+ )
1297
+
1298
+ # Load data based on source type
1299
+ # Try SQL pushdown for database sources
1300
+ df, load_source = self._load_data_for_sampling(rule_obj, n)
1301
+
1302
+ # For non-database sources (or if SQL filter wasn't available),
1303
+ # we need to filter with Polars
1304
+ if load_source != "sql":
1305
+ # Filter to failing rows, add index, limit (unique rule handled by helper)
1306
+ try:
1307
+ failing = _filter_samples_polars(df, rule_obj, predicate, n).to_dicts()
1308
+ except Exception as e:
1309
+ raise RuntimeError(f"Failed to query failing rows: {e}") from e
1310
+ else:
1311
+ # SQL pushdown already applied filter and added row index
1312
+ failing = df.head(n).to_dicts()
1313
+
1314
+ return FailureSamples(failing, rule_id)
1315
+
1316
+ def _load_data_for_sampling(
1317
+ self, rule: Any = None, n: int = 5
1318
+ ) -> Tuple["pl.DataFrame", str]:
1319
+ """
1320
+ Load data from the stored data source for sample_failures().
1321
+
1322
+ For database sources with rules that support SQL filters,
1323
+ pushes the filter to SQL for performance.
1324
+
1325
+ Returns:
1326
+ Tuple of (DataFrame, source) where source is "sql" or "polars"
1327
+ """
1328
+ import polars as pl
1329
+
1330
+ source = self._data_source
1331
+
1332
+ if source is None:
1333
+ raise RuntimeError("No data source available")
1334
+
1335
+ # String path/URI
1336
+ if isinstance(source, str):
1337
+ # Try to load as file with predicate pushdown for Parquet
1338
+ if source.lower().endswith(".parquet") or source.startswith("s3://"):
1339
+ return self._load_parquet_with_filter(source, rule, n)
1340
+ elif source.lower().endswith(".csv"):
1341
+ return pl.read_csv(source), "polars"
1342
+ else:
1343
+ # Try parquet first, then CSV
1344
+ try:
1345
+ return self._load_parquet_with_filter(source, rule, n)
1346
+ except Exception:
1347
+ try:
1348
+ return pl.read_csv(source), "polars"
1349
+ except Exception:
1350
+ raise RuntimeError(f"Cannot load data from: {source}")
1351
+
1352
+ # Polars DataFrame (was passed directly)
1353
+ if isinstance(source, pl.DataFrame):
1354
+ return source, "polars"
1355
+
1356
+ # DatasetHandle (BYOC or parsed URI)
1357
+ if hasattr(source, "scheme") and hasattr(source, "uri"):
1358
+ # It's a DatasetHandle
1359
+ handle = source
1360
+
1361
+ # Check for BYOC (external connection)
1362
+ if handle.scheme == "byoc" or hasattr(handle, "external_conn"):
1363
+ conn = getattr(handle, "external_conn", None)
1364
+ if conn is None:
1365
+ raise RuntimeError(
1366
+ "Database connection is closed. "
1367
+ "For BYOC, keep the connection open until done with sample_failures()."
1368
+ )
1369
+ table = getattr(handle, "table_ref", None) or handle.path
1370
+ return self._query_db_with_filter(conn, table, rule, n, "postgres"), "sql"
1371
+
1372
+ elif handle.scheme in ("postgres", "postgresql"):
1373
+ # PostgreSQL via URI
1374
+ if hasattr(handle, "external_conn") and handle.external_conn:
1375
+ conn = handle.external_conn
1376
+ else:
1377
+ raise RuntimeError(
1378
+ "Database connection is not available. "
1379
+ "For URI-based connections, sample_failures() requires re-connection."
1380
+ )
1381
+ table = getattr(handle, "table_ref", None) or handle.path
1382
+ return self._query_db_with_filter(conn, table, rule, n, "postgres"), "sql"
1383
+
1384
+ elif handle.scheme == "mssql":
1385
+ # SQL Server
1386
+ if hasattr(handle, "external_conn") and handle.external_conn:
1387
+ conn = handle.external_conn
1388
+ else:
1389
+ raise RuntimeError(
1390
+ "Database connection is not available."
1391
+ )
1392
+ table = getattr(handle, "table_ref", None) or handle.path
1393
+ return self._query_db_with_filter(conn, table, rule, n, "mssql"), "sql"
1394
+
1395
+ elif handle.scheme in ("file", None) or (handle.uri and not handle.scheme):
1396
+ # File-based
1397
+ uri = handle.uri
1398
+ if uri.lower().endswith(".parquet"):
1399
+ return self._load_parquet_with_filter(uri, rule, n)
1400
+ elif uri.lower().endswith(".csv"):
1401
+ return pl.read_csv(uri), "polars"
1402
+ else:
1403
+ return self._load_parquet_with_filter(uri, rule, n)
1404
+
1405
+ raise RuntimeError(f"Unsupported data source type: {type(source)}")
1406
+
1407
+ def _query_db_with_filter(
1408
+ self,
1409
+ conn: Any,
1410
+ table: str,
1411
+ rule: Any,
1412
+ n: int,
1413
+ dialect: str,
1414
+ ) -> "pl.DataFrame":
1415
+ """
1416
+ Query database with SQL filter if rule supports it.
1417
+
1418
+ Uses the rule's to_sql_filter() method to push the filter to SQL,
1419
+ avoiding loading the entire table.
1420
+ """
1421
+ import polars as pl
1422
+
1423
+ sql_filter = None
1424
+
1425
+ # Special case: unique rule needs subquery with table name
1426
+ if _is_unique_rule(rule):
1427
+ column = rule.params.get("column")
1428
+ if column:
1429
+ query = _build_unique_sample_query_sql(table, column, n, dialect)
1430
+ return pl.read_database(query, conn)
1431
+
1432
+ if rule is not None and hasattr(rule, "to_sql_filter"):
1433
+ sql_filter = rule.to_sql_filter(dialect)
1434
+
1435
+ if sql_filter:
1436
+ # Build query with filter and row number
1437
+ # ROW_NUMBER() gives us the original row index
1438
+ if dialect == "mssql":
1439
+ # SQL Server syntax
1440
+ query = f"""
1441
+ SELECT *, ROW_NUMBER() OVER (ORDER BY (SELECT NULL)) - 1 AS _row_index
1442
+ FROM {table}
1443
+ WHERE {sql_filter}
1444
+ ORDER BY (SELECT NULL)
1445
+ OFFSET 0 ROWS FETCH FIRST {n} ROWS ONLY
1446
+ """
1447
+ else:
1448
+ # PostgreSQL / DuckDB syntax
1449
+ query = f"""
1450
+ SELECT *, ROW_NUMBER() OVER () - 1 AS _row_index
1451
+ FROM {table}
1452
+ WHERE {sql_filter}
1453
+ LIMIT {n}
1454
+ """
1455
+ return pl.read_database(query, conn)
1456
+ else:
1457
+ # Fall back to loading all data (rule doesn't support SQL filter)
1458
+ return pl.read_database(f"SELECT * FROM {table}", conn)
1459
+
1460
+ def _load_parquet_with_filter(
1461
+ self,
1462
+ path: str,
1463
+ rule: Any,
1464
+ n: int,
1465
+ ) -> Tuple["pl.DataFrame", str]:
1466
+ """
1467
+ Load Parquet file with predicate pushdown for performance.
1468
+
1469
+ - S3/remote files: Uses DuckDB SQL pushdown (doesn't download whole file)
1470
+ - Local files: Uses Polars scan_parquet (efficient for local)
1471
+
1472
+ Returns:
1473
+ Tuple of (DataFrame, source) where source is "sql" or "polars"
1474
+ """
1475
+ import polars as pl
1476
+
1477
+ # Check if this is a remote file (S3, HTTP)
1478
+ is_remote = path.startswith("s3://") or path.startswith("http://") or path.startswith("https://")
1479
+
1480
+ # For remote files, use DuckDB SQL pushdown (much faster - doesn't download whole file)
1481
+ if is_remote and rule is not None and hasattr(rule, "to_sql_filter"):
1482
+ sql_filter = rule.to_sql_filter("duckdb")
1483
+ if sql_filter:
1484
+ try:
1485
+ return self._query_parquet_with_duckdb(path, sql_filter, n), "sql"
1486
+ except Exception:
1487
+ pass # Fall through to Polars
1488
+
1489
+ # For local files, just return the raw data - caller will filter
1490
+ # (Don't filter here to avoid double-filtering in _collect_samples_for_rule)
1491
+ return pl.read_parquet(path), "polars"
1492
+
1493
+ def _query_parquet_with_duckdb(
1494
+ self,
1495
+ path: str,
1496
+ sql_filter: str,
1497
+ n: int,
1498
+ columns: Optional[List[str]] = None,
1499
+ ) -> "pl.DataFrame":
1500
+ """
1501
+ Query Parquet file using DuckDB with SQL filter.
1502
+
1503
+ Much faster than Polars for S3 files because DuckDB pushes
1504
+ the filter and LIMIT to the row group level.
1505
+ """
1506
+ import duckdb
1507
+ import polars as pl
1508
+
1509
+ con = duckdb.connect()
1510
+
1511
+ # Configure S3 if needed
1512
+ if path.startswith("s3://"):
1513
+ import os
1514
+ con.execute("INSTALL httpfs; LOAD httpfs;")
1515
+ if os.environ.get("AWS_ACCESS_KEY_ID"):
1516
+ con.execute(f"SET s3_access_key_id='{os.environ['AWS_ACCESS_KEY_ID']}';")
1517
+ if os.environ.get("AWS_SECRET_ACCESS_KEY"):
1518
+ con.execute(f"SET s3_secret_access_key='{os.environ['AWS_SECRET_ACCESS_KEY']}';")
1519
+ if os.environ.get("AWS_ENDPOINT_URL"):
1520
+ endpoint = os.environ["AWS_ENDPOINT_URL"].replace("http://", "").replace("https://", "")
1521
+ con.execute(f"SET s3_endpoint='{endpoint}';")
1522
+ con.execute("SET s3_use_ssl=false;")
1523
+ con.execute("SET s3_url_style='path';")
1524
+ if os.environ.get("AWS_REGION"):
1525
+ con.execute(f"SET s3_region='{os.environ['AWS_REGION']}';")
1526
+
1527
+ # Escape path for SQL
1528
+ escaped_path = path.replace("'", "''")
1529
+
1530
+ # Build column list (with projection if specified)
1531
+ if columns:
1532
+ col_list = ", ".join(f'"{c}"' for c in columns)
1533
+ else:
1534
+ col_list = "*"
1535
+
1536
+ # Build query with filter and row number
1537
+ query = f"""
1538
+ SELECT {col_list}, ROW_NUMBER() OVER () - 1 AS _row_index
1539
+ FROM read_parquet('{escaped_path}')
1540
+ WHERE {sql_filter}
1541
+ LIMIT {n}
1542
+ """
1543
+
1544
+ result = con.execute(query).pl()
1545
+ con.close()
1546
+ return result
1547
+
1548
+ def _batch_sample_parquet_duckdb(
1549
+ self,
1550
+ path: str,
1551
+ rules_to_sample: List[Tuple[str, str, int, Optional[List[str]]]],
1552
+ ) -> Dict[str, List[Dict[str, Any]]]:
1553
+ """
1554
+ Batch sample multiple rules from Parquet using a single DuckDB query.
1555
+
1556
+ Args:
1557
+ path: Parquet file path or S3 URI
1558
+ rules_to_sample: List of (rule_id, sql_filter, limit, columns)
1559
+
1560
+ Returns:
1561
+ Dict mapping rule_id to list of sample dicts
1562
+ """
1563
+ import duckdb
1564
+ import polars as pl
1565
+
1566
+ if not rules_to_sample:
1567
+ return {}
1568
+
1569
+ con = duckdb.connect()
1570
+
1571
+ # Configure S3 if needed
1572
+ if path.startswith("s3://"):
1573
+ import os
1574
+ con.execute("INSTALL httpfs; LOAD httpfs;")
1575
+ if os.environ.get("AWS_ACCESS_KEY_ID"):
1576
+ con.execute(f"SET s3_access_key_id='{os.environ['AWS_ACCESS_KEY_ID']}';")
1577
+ if os.environ.get("AWS_SECRET_ACCESS_KEY"):
1578
+ con.execute(f"SET s3_secret_access_key='{os.environ['AWS_SECRET_ACCESS_KEY']}';")
1579
+ if os.environ.get("AWS_ENDPOINT_URL"):
1580
+ endpoint = os.environ["AWS_ENDPOINT_URL"].replace("http://", "").replace("https://", "")
1581
+ con.execute(f"SET s3_endpoint='{endpoint}';")
1582
+ con.execute("SET s3_use_ssl=false;")
1583
+ con.execute("SET s3_url_style='path';")
1584
+ if os.environ.get("AWS_REGION"):
1585
+ con.execute(f"SET s3_region='{os.environ['AWS_REGION']}';")
1586
+
1587
+ escaped_path = path.replace("'", "''")
1588
+
1589
+ # Collect all columns needed across all rules
1590
+ all_columns: set = set()
1591
+ for rule_id, sql_filter, limit, columns in rules_to_sample:
1592
+ if columns:
1593
+ all_columns.update(columns)
1594
+
1595
+ # If any rule needs all columns, use *
1596
+ needs_all = any(cols is None for _, _, _, cols in rules_to_sample)
1597
+
1598
+ if needs_all or not all_columns:
1599
+ col_list = "*"
1600
+ else:
1601
+ col_list = ", ".join(f'"{c}"' for c in sorted(all_columns))
1602
+
1603
+ # Build UNION ALL query - one subquery per rule (wrapped in parens for DuckDB)
1604
+ subqueries = []
1605
+ for rule_id, sql_filter, limit, columns in rules_to_sample:
1606
+ escaped_rule_id = rule_id.replace("'", "''")
1607
+ subquery = f"""(
1608
+ SELECT '{escaped_rule_id}' AS _rule_id, {col_list}, ROW_NUMBER() OVER () - 1 AS _row_index
1609
+ FROM read_parquet('{escaped_path}')
1610
+ WHERE {sql_filter}
1611
+ LIMIT {limit}
1612
+ )"""
1613
+ subqueries.append(subquery)
1614
+
1615
+ query = " UNION ALL ".join(subqueries)
1616
+
1617
+ try:
1618
+ result_df = con.execute(query).pl()
1619
+ finally:
1620
+ con.close()
1621
+
1622
+ # Distribute results to each rule
1623
+ results: Dict[str, List[Dict[str, Any]]] = {rule_id: [] for rule_id, _, _, _ in rules_to_sample}
1624
+
1625
+ for row in result_df.to_dicts():
1626
+ rule_id = row.pop("_rule_id")
1627
+ if rule_id in results:
1628
+ results[rule_id].append(row)
1629
+
1630
+ return results
1631
+
1632
+ def _batch_sample_db(
1633
+ self,
1634
+ conn: Any,
1635
+ table: str,
1636
+ rules_to_sample: List[Tuple[str, str, int, Optional[List[str]]]],
1637
+ dialect: str,
1638
+ ) -> Dict[str, List[Dict[str, Any]]]:
1639
+ """
1640
+ Batch sample multiple rules from database using a single query.
1641
+
1642
+ Args:
1643
+ conn: Database connection
1644
+ table: Table name (with schema if needed)
1645
+ rules_to_sample: List of (rule_id, sql_filter, limit, columns)
1646
+ dialect: "postgres" or "mssql"
1647
+
1648
+ Returns:
1649
+ Dict mapping rule_id to list of sample dicts
1650
+ """
1651
+ import polars as pl
1652
+
1653
+ if not rules_to_sample:
1654
+ return {}
1655
+
1656
+ # Collect all columns needed across all rules
1657
+ all_columns: set = set()
1658
+ for rule_id, sql_filter, limit, columns in rules_to_sample:
1659
+ if columns:
1660
+ all_columns.update(columns)
1661
+
1662
+ needs_all = any(cols is None for _, _, _, cols in rules_to_sample)
1663
+
1664
+ if dialect == "mssql":
1665
+ # SQL Server syntax
1666
+ if needs_all or not all_columns:
1667
+ col_list = "*"
1668
+ else:
1669
+ col_list = ", ".join(f"[{c}]" for c in sorted(all_columns))
1670
+
1671
+ subqueries = []
1672
+ for rule_id, sql_filter, limit, columns in rules_to_sample:
1673
+ escaped_rule_id = rule_id.replace("'", "''")
1674
+ subquery = f"""(
1675
+ SELECT TOP {limit} '{escaped_rule_id}' AS _rule_id, {col_list},
1676
+ ROW_NUMBER() OVER (ORDER BY (SELECT NULL)) - 1 AS _row_index
1677
+ FROM {table}
1678
+ WHERE {sql_filter}
1679
+ )"""
1680
+ subqueries.append(subquery)
1681
+ else:
1682
+ # PostgreSQL syntax
1683
+ if needs_all or not all_columns:
1684
+ col_list = "*"
1685
+ else:
1686
+ col_list = ", ".join(f'"{c}"' for c in sorted(all_columns))
1687
+
1688
+ subqueries = []
1689
+ for rule_id, sql_filter, limit, columns in rules_to_sample:
1690
+ escaped_rule_id = rule_id.replace("'", "''")
1691
+ subquery = f"""(
1692
+ SELECT '{escaped_rule_id}' AS _rule_id, {col_list},
1693
+ ROW_NUMBER() OVER () - 1 AS _row_index
1694
+ FROM {table}
1695
+ WHERE {sql_filter}
1696
+ LIMIT {limit}
1697
+ )"""
1698
+ subqueries.append(subquery)
1699
+
1700
+ query = " UNION ALL ".join(subqueries)
1701
+
1702
+ result_df = pl.read_database(query, conn)
1703
+
1704
+ # Distribute results to each rule
1705
+ results: Dict[str, List[Dict[str, Any]]] = {rule_id: [] for rule_id, _, _, _ in rules_to_sample}
1706
+
1707
+ for row in result_df.to_dicts():
1708
+ rule_id = row.pop("_rule_id")
1709
+ if rule_id in results:
1710
+ results[rule_id].append(row)
1711
+
1712
+ return results
1713
+
1714
+
1715
+ @dataclass
1716
+ class DryRunResult:
1717
+ """
1718
+ Result of a dry run (contract validation without execution).
1719
+
1720
+ Properties:
1721
+ valid: Whether the contract is syntactically valid
1722
+ rules_count: Number of rules that would run
1723
+ columns_needed: Columns the contract requires
1724
+ contract_name: Name of the contract (if any)
1725
+ errors: List of errors found during validation
1726
+ datasource: Datasource from contract
1727
+ """
1728
+
1729
+ valid: bool
1730
+ rules_count: int
1731
+ columns_needed: List[str]
1732
+ contract_name: Optional[str] = None
1733
+ datasource: Optional[str] = None
1734
+ errors: List[str] = field(default_factory=list)
1735
+
1736
+ def __repr__(self) -> str:
1737
+ status = "VALID" if self.valid else "INVALID"
1738
+ parts = [f"DryRunResult({self.contract_name or 'inline'}) {status}"]
1739
+ if self.valid:
1740
+ parts.append(f" Rules: {self.rules_count}, Columns: {len(self.columns_needed)}")
1741
+ if self.columns_needed:
1742
+ cols = ", ".join(self.columns_needed[:5])
1743
+ if len(self.columns_needed) > 5:
1744
+ cols += f" ... +{len(self.columns_needed) - 5} more"
1745
+ parts.append(f" Needs: {cols}")
1746
+ else:
1747
+ for err in self.errors[:3]:
1748
+ parts.append(f" ERROR: {err}")
1749
+ if len(self.errors) > 3:
1750
+ parts.append(f" ... +{len(self.errors) - 3} more errors")
1751
+ return "\n".join(parts)
1752
+
1753
+ def to_dict(self) -> Dict[str, Any]:
1754
+ """Convert to dictionary."""
1755
+ return {
1756
+ "valid": self.valid,
1757
+ "rules_count": self.rules_count,
1758
+ "columns_needed": self.columns_needed,
1759
+ "contract_name": self.contract_name,
1760
+ "datasource": self.datasource,
1761
+ "errors": self.errors,
1762
+ }
1763
+
1764
+ def to_json(self, indent: Optional[int] = None) -> str:
1765
+ """Convert to JSON string."""
1766
+ return json.dumps(self.to_dict(), indent=indent, default=str)
1767
+
1768
+ def to_llm(self) -> str:
1769
+ """Token-optimized format for LLM context."""
1770
+ if self.valid:
1771
+ cols = ",".join(self.columns_needed[:10])
1772
+ if len(self.columns_needed) > 10:
1773
+ cols += f"...+{len(self.columns_needed) - 10}"
1774
+ return f"DRYRUN: {self.contract_name or 'inline'} VALID rules={self.rules_count} cols=[{cols}]"
1775
+ else:
1776
+ errs = "; ".join(self.errors[:3])
1777
+ return f"DRYRUN: {self.contract_name or 'inline'} INVALID errors=[{errs}]"
1778
+
1779
+
1780
+ @dataclass
1781
+ class Diff:
1782
+ """
1783
+ Diff between two validation runs.
1784
+
1785
+ Properties:
1786
+ has_changes: Whether there are any changes
1787
+ improved: Fewer failures than before
1788
+ regressed: More failures than before
1789
+ before: Summary of before run
1790
+ after: Summary of after run
1791
+ new_failures: Rules that started failing
1792
+ resolved: Rules that stopped failing
1793
+ count_changes: Rules where failure count changed
1794
+ """
1795
+
1796
+ has_changes: bool
1797
+ improved: bool
1798
+ regressed: bool
1799
+ before: Dict[str, Any]
1800
+ after: Dict[str, Any]
1801
+ new_failures: List[Dict[str, Any]]
1802
+ resolved: List[Dict[str, Any]]
1803
+ regressions: List[Dict[str, Any]]
1804
+ improvements: List[Dict[str, Any]]
1805
+ _state_diff: Optional[Any] = field(default=None, repr=False)
1806
+
1807
+ def __repr__(self) -> str:
1808
+ if self.regressed:
1809
+ status = "REGRESSED"
1810
+ elif self.improved:
1811
+ status = "IMPROVED"
1812
+ else:
1813
+ status = "NO CHANGE"
1814
+
1815
+ contract = self.after.get("contract_name", "unknown")
1816
+ before_date = self.before.get("run_at", "")[:10]
1817
+ after_date = self.after.get("run_at", "")[:10]
1818
+
1819
+ parts = [f"Diff({contract}) {status}"]
1820
+ parts.append(f" {before_date} -> {after_date}")
1821
+ if self.new_failures:
1822
+ parts.append(f" New failures: {len(self.new_failures)}")
1823
+ if self.resolved:
1824
+ parts.append(f" Resolved: {len(self.resolved)}")
1825
+ return "\n".join(parts)
1826
+
1827
+ @property
1828
+ def count_changes(self) -> List[Dict[str, Any]]:
1829
+ """Rules where failure count changed (both regressions and improvements)."""
1830
+ return self.regressions + self.improvements
1831
+
1832
+ @classmethod
1833
+ def from_state_diff(cls, state_diff: "StateDiff") -> "Diff":
1834
+ """Create from internal StateDiff object."""
1835
+ return cls(
1836
+ has_changes=state_diff.has_regressions or state_diff.has_improvements,
1837
+ improved=state_diff.has_improvements and not state_diff.has_regressions,
1838
+ regressed=state_diff.has_regressions,
1839
+ before={
1840
+ "run_at": state_diff.before.run_at.isoformat(),
1841
+ "passed": state_diff.before.summary.passed,
1842
+ "total_rules": state_diff.before.summary.total_rules,
1843
+ "failed_count": state_diff.before.summary.failed_rules,
1844
+ "contract_name": state_diff.before.contract_name,
1845
+ },
1846
+ after={
1847
+ "run_at": state_diff.after.run_at.isoformat(),
1848
+ "passed": state_diff.after.summary.passed,
1849
+ "total_rules": state_diff.after.summary.total_rules,
1850
+ "failed_count": state_diff.after.summary.failed_rules,
1851
+ "contract_name": state_diff.after.contract_name,
1852
+ },
1853
+ new_failures=[rd.to_dict() for rd in state_diff.new_failures],
1854
+ resolved=[rd.to_dict() for rd in state_diff.resolved],
1855
+ regressions=[rd.to_dict() for rd in state_diff.regressions],
1856
+ improvements=[rd.to_dict() for rd in state_diff.improvements],
1857
+ _state_diff=state_diff,
1858
+ )
1859
+
1860
+ def to_dict(self) -> Dict[str, Any]:
1861
+ """Convert to dictionary."""
1862
+ return {
1863
+ "has_changes": self.has_changes,
1864
+ "improved": self.improved,
1865
+ "regressed": self.regressed,
1866
+ "before": self.before,
1867
+ "after": self.after,
1868
+ "new_failures": self.new_failures,
1869
+ "resolved": self.resolved,
1870
+ "regressions": self.regressions,
1871
+ "improvements": self.improvements,
1872
+ }
1873
+
1874
+ def to_json(self, indent: Optional[int] = None) -> str:
1875
+ """Convert to JSON string."""
1876
+ return json.dumps(self.to_dict(), indent=indent, default=str)
1877
+
1878
+ def to_llm(self) -> str:
1879
+ """Token-optimized format for LLM context."""
1880
+ if self._state_diff is not None:
1881
+ return self._state_diff.to_llm()
1882
+
1883
+ # Fallback if no state_diff
1884
+ lines = []
1885
+ contract = self.after.get("contract_name", "unknown")
1886
+
1887
+ if self.regressed:
1888
+ status = "REGRESSION"
1889
+ elif self.improved:
1890
+ status = "IMPROVED"
1891
+ else:
1892
+ status = "NO_CHANGE"
1893
+
1894
+ lines.append(f"DIFF: {contract} {status}")
1895
+ lines.append(f"{self.before.get('run_at', '')[:10]} -> {self.after.get('run_at', '')[:10]}")
1896
+
1897
+ if self.new_failures:
1898
+ lines.append(f"NEW_FAILURES: {len(self.new_failures)}")
1899
+ for nf in self.new_failures[:3]:
1900
+ lines.append(f" - {nf.get('rule_id', '')}")
1901
+
1902
+ if self.resolved:
1903
+ lines.append(f"RESOLVED: {len(self.resolved)}")
1904
+
1905
+ return "\n".join(lines)
1906
+
1907
+
1908
+ @dataclass
1909
+ class SuggestedRule:
1910
+ """A suggested validation rule from profile analysis."""
1911
+
1912
+ name: str
1913
+ params: Dict[str, Any]
1914
+ confidence: float
1915
+ reason: str
1916
+
1917
+ def __repr__(self) -> str:
1918
+ return f"SuggestedRule({self.name}, confidence={self.confidence:.2f})"
1919
+
1920
+ def to_dict(self) -> Dict[str, Any]:
1921
+ """Convert to rule dict format (for inline rules)."""
1922
+ return {
1923
+ "name": self.name,
1924
+ "params": self.params,
1925
+ }
1926
+
1927
+ def to_full_dict(self) -> Dict[str, Any]:
1928
+ """Convert to dict including metadata."""
1929
+ return {
1930
+ "name": self.name,
1931
+ "params": self.params,
1932
+ "confidence": self.confidence,
1933
+ "reason": self.reason,
1934
+ }
1935
+
1936
+
1937
+ class Suggestions:
1938
+ """
1939
+ Collection of suggested validation rules from profile analysis.
1940
+
1941
+ Methods:
1942
+ to_yaml(): Export as YAML contract
1943
+ to_json(): Export as JSON
1944
+ to_dict(): Export as list of rule dicts (for inline rules)
1945
+ save(path): Save to file
1946
+ filter(min_confidence=None, name=None): Filter suggestions
1947
+ """
1948
+
1949
+ def __init__(
1950
+ self,
1951
+ rules: List[SuggestedRule],
1952
+ source: str = "unknown",
1953
+ ):
1954
+ self._rules = rules
1955
+ self.source = source
1956
+
1957
+ def __repr__(self) -> str:
1958
+ return f"Suggestions({len(self._rules)} rules from {self.source})"
1959
+
1960
+ def __len__(self) -> int:
1961
+ return len(self._rules)
1962
+
1963
+ def __iter__(self) -> Iterator[SuggestedRule]:
1964
+ return iter(self._rules)
1965
+
1966
+ def __getitem__(self, index: int) -> SuggestedRule:
1967
+ return self._rules[index]
1968
+
1969
+ def filter(
1970
+ self,
1971
+ min_confidence: Optional[float] = None,
1972
+ name: Optional[str] = None,
1973
+ ) -> "Suggestions":
1974
+ """
1975
+ Filter suggestions by criteria.
1976
+
1977
+ Args:
1978
+ min_confidence: Minimum confidence score (0.0-1.0)
1979
+ name: Filter by rule name
1980
+
1981
+ Returns:
1982
+ New Suggestions with filtered rules
1983
+ """
1984
+ filtered = self._rules
1985
+
1986
+ if min_confidence is not None:
1987
+ filtered = [r for r in filtered if r.confidence >= min_confidence]
1988
+
1989
+ if name is not None:
1990
+ filtered = [r for r in filtered if r.name == name]
1991
+
1992
+ return Suggestions(filtered, self.source)
1993
+
1994
+ def to_dict(self) -> List[Dict[str, Any]]:
1995
+ """Convert to list of rule dicts (usable with kontra.validate(rules=...))."""
1996
+ return [r.to_dict() for r in self._rules]
1997
+
1998
+ def to_json(self, indent: Optional[int] = None) -> str:
1999
+ """Convert to JSON string."""
2000
+ return json.dumps(self.to_dict(), indent=indent, default=str)
2001
+
2002
+ def to_yaml(self, contract_name: str = "suggested_contract") -> str:
2003
+ """
2004
+ Convert to YAML contract format.
2005
+
2006
+ Args:
2007
+ contract_name: Name for the contract
2008
+
2009
+ Returns:
2010
+ YAML string
2011
+ """
2012
+ contract = {
2013
+ "name": contract_name,
2014
+ "dataset": self.source,
2015
+ "rules": self.to_dict(),
2016
+ }
2017
+ return yaml.dump(contract, default_flow_style=False, sort_keys=False)
2018
+
2019
+ def save(self, path: Union[str, Path]) -> None:
2020
+ """
2021
+ Save suggestions to file.
2022
+
2023
+ Args:
2024
+ path: Output path (YAML format)
2025
+ """
2026
+ path = Path(path)
2027
+ path.parent.mkdir(parents=True, exist_ok=True)
2028
+ path.write_text(self.to_yaml(contract_name=path.stem))
2029
+
2030
+ @classmethod
2031
+ def from_profile(
2032
+ cls,
2033
+ profile: "DatasetProfile",
2034
+ min_confidence: float = 0.5,
2035
+ ) -> "Suggestions":
2036
+ """
2037
+ Generate rule suggestions from a profile.
2038
+
2039
+ This is a basic implementation. More sophisticated analysis
2040
+ could be added based on profile depth/preset.
2041
+ """
2042
+ rules: List[SuggestedRule] = []
2043
+
2044
+ for col in profile.columns:
2045
+ # not_null suggestion
2046
+ if col.null_rate == 0:
2047
+ rules.append(SuggestedRule(
2048
+ name="not_null",
2049
+ params={"column": col.name},
2050
+ confidence=1.0,
2051
+ reason=f"Column {col.name} has no nulls",
2052
+ ))
2053
+ elif col.null_rate < 0.01: # < 1% nulls
2054
+ rules.append(SuggestedRule(
2055
+ name="not_null",
2056
+ params={"column": col.name},
2057
+ confidence=0.8,
2058
+ reason=f"Column {col.name} has very few nulls ({col.null_rate:.1%})",
2059
+ ))
2060
+
2061
+ # unique suggestion
2062
+ if col.uniqueness_ratio == 1.0 and col.distinct_count > 1:
2063
+ rules.append(SuggestedRule(
2064
+ name="unique",
2065
+ params={"column": col.name},
2066
+ confidence=1.0,
2067
+ reason=f"Column {col.name} has all unique values",
2068
+ ))
2069
+ elif col.uniqueness_ratio > 0.99:
2070
+ rules.append(SuggestedRule(
2071
+ name="unique",
2072
+ params={"column": col.name},
2073
+ confidence=0.7,
2074
+ reason=f"Column {col.name} is nearly unique ({col.uniqueness_ratio:.1%})",
2075
+ ))
2076
+
2077
+ # dtype suggestion
2078
+ rules.append(SuggestedRule(
2079
+ name="dtype",
2080
+ params={"column": col.name, "type": col.dtype},
2081
+ confidence=1.0,
2082
+ reason=f"Column {col.name} is {col.dtype}",
2083
+ ))
2084
+
2085
+ # allowed_values for low cardinality
2086
+ if col.is_low_cardinality and col.values:
2087
+ rules.append(SuggestedRule(
2088
+ name="allowed_values",
2089
+ params={"column": col.name, "values": col.values},
2090
+ confidence=0.9,
2091
+ reason=f"Column {col.name} has {len(col.values)} distinct values",
2092
+ ))
2093
+
2094
+ # range for numeric
2095
+ if col.numeric and col.numeric.min is not None and col.numeric.max is not None:
2096
+ rules.append(SuggestedRule(
2097
+ name="range",
2098
+ params={
2099
+ "column": col.name,
2100
+ "min": col.numeric.min,
2101
+ "max": col.numeric.max,
2102
+ },
2103
+ confidence=0.7,
2104
+ reason=f"Column {col.name} ranges from {col.numeric.min} to {col.numeric.max}",
2105
+ ))
2106
+
2107
+ # min_rows suggestion
2108
+ if profile.row_count > 0:
2109
+ # Suggest minimum as 80% of current count (or 1 if small dataset)
2110
+ min_rows = max(1, int(profile.row_count * 0.8))
2111
+ rules.append(SuggestedRule(
2112
+ name="min_rows",
2113
+ params={"threshold": min_rows},
2114
+ confidence=0.6,
2115
+ reason=f"Dataset has {profile.row_count:,} rows",
2116
+ ))
2117
+
2118
+ # Filter by confidence
2119
+ filtered = [r for r in rules if r.confidence >= min_confidence]
2120
+
2121
+ return cls(filtered, source=profile.source_uri)