kontra 0.5.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (124) hide show
  1. kontra/__init__.py +1871 -0
  2. kontra/api/__init__.py +22 -0
  3. kontra/api/compare.py +340 -0
  4. kontra/api/decorators.py +153 -0
  5. kontra/api/results.py +2121 -0
  6. kontra/api/rules.py +681 -0
  7. kontra/cli/__init__.py +0 -0
  8. kontra/cli/commands/__init__.py +1 -0
  9. kontra/cli/commands/config.py +153 -0
  10. kontra/cli/commands/diff.py +450 -0
  11. kontra/cli/commands/history.py +196 -0
  12. kontra/cli/commands/profile.py +289 -0
  13. kontra/cli/commands/validate.py +468 -0
  14. kontra/cli/constants.py +6 -0
  15. kontra/cli/main.py +48 -0
  16. kontra/cli/renderers.py +304 -0
  17. kontra/cli/utils.py +28 -0
  18. kontra/config/__init__.py +34 -0
  19. kontra/config/loader.py +127 -0
  20. kontra/config/models.py +49 -0
  21. kontra/config/settings.py +797 -0
  22. kontra/connectors/__init__.py +0 -0
  23. kontra/connectors/db_utils.py +251 -0
  24. kontra/connectors/detection.py +323 -0
  25. kontra/connectors/handle.py +368 -0
  26. kontra/connectors/postgres.py +127 -0
  27. kontra/connectors/sqlserver.py +226 -0
  28. kontra/engine/__init__.py +0 -0
  29. kontra/engine/backends/duckdb_session.py +227 -0
  30. kontra/engine/backends/duckdb_utils.py +18 -0
  31. kontra/engine/backends/polars_backend.py +47 -0
  32. kontra/engine/engine.py +1205 -0
  33. kontra/engine/executors/__init__.py +15 -0
  34. kontra/engine/executors/base.py +50 -0
  35. kontra/engine/executors/database_base.py +528 -0
  36. kontra/engine/executors/duckdb_sql.py +607 -0
  37. kontra/engine/executors/postgres_sql.py +162 -0
  38. kontra/engine/executors/registry.py +69 -0
  39. kontra/engine/executors/sqlserver_sql.py +163 -0
  40. kontra/engine/materializers/__init__.py +14 -0
  41. kontra/engine/materializers/base.py +42 -0
  42. kontra/engine/materializers/duckdb.py +110 -0
  43. kontra/engine/materializers/factory.py +22 -0
  44. kontra/engine/materializers/polars_connector.py +131 -0
  45. kontra/engine/materializers/postgres.py +157 -0
  46. kontra/engine/materializers/registry.py +138 -0
  47. kontra/engine/materializers/sqlserver.py +160 -0
  48. kontra/engine/result.py +15 -0
  49. kontra/engine/sql_utils.py +611 -0
  50. kontra/engine/sql_validator.py +609 -0
  51. kontra/engine/stats.py +194 -0
  52. kontra/engine/types.py +138 -0
  53. kontra/errors.py +533 -0
  54. kontra/logging.py +85 -0
  55. kontra/preplan/__init__.py +5 -0
  56. kontra/preplan/planner.py +253 -0
  57. kontra/preplan/postgres.py +179 -0
  58. kontra/preplan/sqlserver.py +191 -0
  59. kontra/preplan/types.py +24 -0
  60. kontra/probes/__init__.py +20 -0
  61. kontra/probes/compare.py +400 -0
  62. kontra/probes/relationship.py +283 -0
  63. kontra/reporters/__init__.py +0 -0
  64. kontra/reporters/json_reporter.py +190 -0
  65. kontra/reporters/rich_reporter.py +11 -0
  66. kontra/rules/__init__.py +35 -0
  67. kontra/rules/base.py +186 -0
  68. kontra/rules/builtin/__init__.py +40 -0
  69. kontra/rules/builtin/allowed_values.py +156 -0
  70. kontra/rules/builtin/compare.py +188 -0
  71. kontra/rules/builtin/conditional_not_null.py +213 -0
  72. kontra/rules/builtin/conditional_range.py +310 -0
  73. kontra/rules/builtin/contains.py +138 -0
  74. kontra/rules/builtin/custom_sql_check.py +182 -0
  75. kontra/rules/builtin/disallowed_values.py +140 -0
  76. kontra/rules/builtin/dtype.py +203 -0
  77. kontra/rules/builtin/ends_with.py +129 -0
  78. kontra/rules/builtin/freshness.py +240 -0
  79. kontra/rules/builtin/length.py +193 -0
  80. kontra/rules/builtin/max_rows.py +35 -0
  81. kontra/rules/builtin/min_rows.py +46 -0
  82. kontra/rules/builtin/not_null.py +121 -0
  83. kontra/rules/builtin/range.py +222 -0
  84. kontra/rules/builtin/regex.py +143 -0
  85. kontra/rules/builtin/starts_with.py +129 -0
  86. kontra/rules/builtin/unique.py +124 -0
  87. kontra/rules/condition_parser.py +203 -0
  88. kontra/rules/execution_plan.py +455 -0
  89. kontra/rules/factory.py +103 -0
  90. kontra/rules/predicates.py +25 -0
  91. kontra/rules/registry.py +24 -0
  92. kontra/rules/static_predicates.py +120 -0
  93. kontra/scout/__init__.py +9 -0
  94. kontra/scout/backends/__init__.py +17 -0
  95. kontra/scout/backends/base.py +111 -0
  96. kontra/scout/backends/duckdb_backend.py +359 -0
  97. kontra/scout/backends/postgres_backend.py +519 -0
  98. kontra/scout/backends/sqlserver_backend.py +577 -0
  99. kontra/scout/dtype_mapping.py +150 -0
  100. kontra/scout/patterns.py +69 -0
  101. kontra/scout/profiler.py +801 -0
  102. kontra/scout/reporters/__init__.py +39 -0
  103. kontra/scout/reporters/json_reporter.py +165 -0
  104. kontra/scout/reporters/markdown_reporter.py +152 -0
  105. kontra/scout/reporters/rich_reporter.py +144 -0
  106. kontra/scout/store.py +208 -0
  107. kontra/scout/suggest.py +200 -0
  108. kontra/scout/types.py +652 -0
  109. kontra/state/__init__.py +29 -0
  110. kontra/state/backends/__init__.py +79 -0
  111. kontra/state/backends/base.py +348 -0
  112. kontra/state/backends/local.py +480 -0
  113. kontra/state/backends/postgres.py +1010 -0
  114. kontra/state/backends/s3.py +543 -0
  115. kontra/state/backends/sqlserver.py +969 -0
  116. kontra/state/fingerprint.py +166 -0
  117. kontra/state/types.py +1061 -0
  118. kontra/version.py +1 -0
  119. kontra-0.5.2.dist-info/METADATA +122 -0
  120. kontra-0.5.2.dist-info/RECORD +124 -0
  121. kontra-0.5.2.dist-info/WHEEL +5 -0
  122. kontra-0.5.2.dist-info/entry_points.txt +2 -0
  123. kontra-0.5.2.dist-info/licenses/LICENSE +17 -0
  124. kontra-0.5.2.dist-info/top_level.txt +1 -0
@@ -0,0 +1,190 @@
1
+ from __future__ import annotations
2
+
3
+ import json
4
+ from datetime import datetime, timezone
5
+ from typing import Any, Dict, Iterable, List, Optional
6
+
7
+ # --- Version handling (robust to early-boot states) ---------------------------
8
+ try:
9
+ from kontra.version import VERSION as _VERSION
10
+ except Exception: # pragma: no cover
11
+ _VERSION = "0.0.0-dev"
12
+
13
+ SCHEMA_VERSION = "1.0"
14
+
15
+ # --- Optional JSON Schema validation (non-fatal if missing) -------------------
16
+ try:
17
+ import fastjsonschema # type: ignore
18
+ _HAVE_VALIDATOR = True
19
+ except Exception: # pragma: no cover
20
+ _HAVE_VALIDATOR = False
21
+
22
+ _VALIDATOR = None # lazy-compiled validator
23
+
24
+
25
+ def _utc_now_iso() -> str:
26
+ """UTC timestamp in stable ISO 8601 format with trailing Z."""
27
+ return datetime.now(timezone.utc).strftime("%Y-%m-%dT%H:%M:%SZ")
28
+
29
+
30
+ def _normalize_result(item: Dict[str, Any]) -> Dict[str, Any]:
31
+ passed = bool(item.get("passed", False))
32
+ msg = str(item.get("message", ""))
33
+
34
+ # Always compute from 'passed', ignore incoming 'severity' to keep outputs consistent.
35
+ severity = "INFO" if passed else "ERROR"
36
+
37
+ return {
38
+ "rule_id": str(item.get("rule_id", "")),
39
+ "passed": passed,
40
+ "message": msg,
41
+ "failed_count": int(item.get("failed_count", 0)),
42
+ "severity": severity,
43
+ "actions_executed": list(item.get("actions_executed", [])),
44
+ }
45
+
46
+
47
+
48
+ def _sorted_results(results: Iterable[Dict[str, Any]]) -> List[Dict[str, Any]]:
49
+ """Deterministic ordering by rule_id, then message for tie-breaks."""
50
+ normalized = [_normalize_result(r) for r in results]
51
+ return sorted(normalized, key=lambda r: (r["rule_id"], r["message"]))
52
+
53
+
54
+ def _derive_exec_seconds(summary: Dict[str, Any], stats: Optional[Dict[str, Any]]) -> float:
55
+ # Prefer summary if provided; else fall back to stats.run_meta.duration_ms_total
56
+ val = summary.get("execution_time_seconds")
57
+ if isinstance(val, (int, float)) and val:
58
+ return float(val)
59
+ if stats:
60
+ try:
61
+ ms = stats.get("run_meta", {}).get("duration_ms_total")
62
+ if isinstance(ms, (int, float)):
63
+ return float(ms) / 1000.0
64
+ except Exception:
65
+ pass
66
+ return 0.0
67
+
68
+
69
+ def _derive_rows_evaluated(summary: Dict[str, Any], stats: Optional[Dict[str, Any]]) -> int:
70
+ # Prefer summary if provided; else fall back to stats.dataset.nrows
71
+ val = summary.get("rows_evaluated")
72
+ if isinstance(val, int) and val >= 0:
73
+ return int(val)
74
+ if stats:
75
+ try:
76
+ n = stats.get("dataset", {}).get("nrows")
77
+ if isinstance(n, int) and n >= 0:
78
+ return int(n)
79
+ except Exception:
80
+ pass
81
+ return 0
82
+
83
+
84
+ def build_payload(
85
+ *,
86
+ dataset_name: str,
87
+ summary: Dict[str, Any],
88
+ results: List[Dict[str, Any]],
89
+ stats: Optional[Dict[str, Any]] = None,
90
+ quarantine: Optional[Dict[str, Any]] = None,
91
+ schema_version: str = SCHEMA_VERSION,
92
+ engine_version: Optional[str] = None,
93
+ ) -> Dict[str, Any]:
94
+ """
95
+ Construct the stable, versioned JSON document for CI/CD and machines.
96
+ This function is side-effect free and ideal for unit tests.
97
+ """
98
+ total = int(summary.get("total_rules", len(results)))
99
+ passed_count = int(summary.get("rules_passed", sum(1 for r in results if r.get("passed"))))
100
+ failed_count = int(summary.get("rules_failed", total - passed_count))
101
+
102
+ payload: Dict[str, Any] = {
103
+ "schema_version": str(schema_version),
104
+ "dataset_name": str(dataset_name),
105
+ "timestamp_utc": _utc_now_iso(),
106
+ "engine_version": str(engine_version or _VERSION),
107
+ "validation_passed": bool(summary.get("passed", failed_count == 0)),
108
+ "statistics": {
109
+ "execution_time_seconds": _derive_exec_seconds(summary, stats),
110
+ "rows_evaluated": _derive_rows_evaluated(summary, stats),
111
+ "rules_total": total,
112
+ "rules_passed": passed_count,
113
+ "rules_failed": failed_count,
114
+ },
115
+ "results": _sorted_results(results),
116
+ }
117
+
118
+ if quarantine:
119
+ payload["quarantine"] = {
120
+ "location": str(quarantine.get("location", "")),
121
+ "rows_quarantined": int(quarantine.get("rows_quarantined", 0)),
122
+ }
123
+
124
+ if stats is not None:
125
+ # Namespaced so the core schema remains stable as stats evolve
126
+ payload["stats"] = stats
127
+
128
+ return payload
129
+
130
+
131
+ def render_json(
132
+ *,
133
+ dataset_name: str,
134
+ summary: Dict[str, Any],
135
+ results: List[Dict[str, Any]],
136
+ stats: Optional[Dict[str, Any]] = None,
137
+ quarantine: Optional[Dict[str, Any]] = None,
138
+ validate: bool = False,
139
+ ) -> str:
140
+ """
141
+ Build (+ optionally validate) and dump as compact, deterministic JSON.
142
+ """
143
+ payload = build_payload(
144
+ dataset_name=dataset_name,
145
+ summary=summary,
146
+ results=results,
147
+ stats=stats,
148
+ quarantine=quarantine,
149
+ )
150
+
151
+ if validate and _HAVE_VALIDATOR:
152
+ _validate_against_local_schema(payload)
153
+
154
+ # Deterministic string: stable key order & separators
155
+ return json.dumps(payload, sort_keys=True, separators=(",", ":"), ensure_ascii=False)
156
+
157
+
158
+ # --- Optional local schema validation ----------------------------------------
159
+
160
+
161
+ def _load_local_schema() -> Optional[Dict[str, Any]]:
162
+ """
163
+ Load the local JSON Schema if bundled. Silently returns None if absent.
164
+ """
165
+ try:
166
+ from importlib import resources
167
+ from importlib.resources import files
168
+
169
+ schema_pkg = "schemas" # repository-level schema package
170
+ path = files(schema_pkg).joinpath("validation_output.schema.json")
171
+ with resources.as_file(path) as p:
172
+ with open(p, "r", encoding="utf-8") as f:
173
+ return json.load(f)
174
+ except Exception:
175
+ return None
176
+
177
+
178
+ def _validate_against_local_schema(payload: Dict[str, Any]) -> None:
179
+ global _VALIDATOR
180
+ if not _HAVE_VALIDATOR:
181
+ return
182
+ if _VALIDATOR is None:
183
+ schema = _load_local_schema()
184
+ if not schema:
185
+ return # schema not bundled; skip validation
186
+ _VALIDATOR = fastjsonschema.compile(schema) # type: ignore
187
+ _VALIDATOR(payload) # type: ignore
188
+
189
+
190
+ __all__ = ["build_payload", "render_json", "SCHEMA_VERSION"]
@@ -0,0 +1,11 @@
1
+ from __future__ import annotations
2
+
3
+ from rich.console import Console
4
+
5
+ _console = Console()
6
+
7
+ def report_success(msg: str) -> None:
8
+ _console.print(f"[bold green]✅ {msg}[/bold green]")
9
+
10
+ def report_failure(msg: str) -> None:
11
+ _console.print(f"[bold red]❌ {msg}[/bold red]")
@@ -0,0 +1,35 @@
1
+ # src/kontra/rules/__init__.py
2
+ """
3
+ Kontra rules module - Rule definitions and execution planning.
4
+
5
+ Public API:
6
+ - BaseRule: Abstract base class for custom rules
7
+ - RuleFactory: Creates rule instances from contract specs
8
+ - RuleExecutionPlan: Plans and executes rule validation
9
+
10
+ Built-in rules are auto-registered when kontra.engine is imported.
11
+ """
12
+
13
+ from kontra.rules.base import BaseRule
14
+ from kontra.rules.factory import RuleFactory
15
+ from kontra.rules.execution_plan import RuleExecutionPlan, CompiledPlan
16
+ from kontra.rules.predicates import Predicate
17
+ from kontra.rules.registry import (
18
+ register_rule,
19
+ get_rule,
20
+ get_all_rule_names,
21
+ )
22
+
23
+ __all__ = [
24
+ # Base classes
25
+ "BaseRule",
26
+ "Predicate",
27
+ # Factory and planning
28
+ "RuleFactory",
29
+ "RuleExecutionPlan",
30
+ "CompiledPlan",
31
+ # Registry
32
+ "register_rule",
33
+ "get_rule",
34
+ "get_all_rule_names",
35
+ ]
kontra/rules/base.py ADDED
@@ -0,0 +1,186 @@
1
+ # src/contra/rules/base.py
2
+ from abc import ABC, abstractmethod
3
+ from typing import Any, Dict, Set
4
+ import polars as pl
5
+
6
+ class BaseRule(ABC):
7
+ """
8
+ Abstract base class for all validation rules.
9
+ """
10
+
11
+ name: str
12
+ params: Dict[str, Any]
13
+
14
+ def __init__(self, name: str, params: Dict[str, Any]):
15
+ self.name = name
16
+ self.params = params
17
+ # rule_id is set by the factory (based on id/name/column)
18
+ self.rule_id: str = name
19
+ # severity is set by the factory (from contract spec)
20
+ self.severity: str = "blocking"
21
+ # context is set by the factory (from contract spec)
22
+ # Consumer-defined metadata, ignored by validation
23
+ self.context: Dict[str, Any] = {}
24
+
25
+ def __str__(self) -> str:
26
+ return f"{self.name}({self.params})"
27
+
28
+ def __repr__(self) -> str:
29
+ return str(self)
30
+
31
+ @abstractmethod
32
+ def validate(self, df: pl.DataFrame) -> Dict[str, Any]:
33
+ """Executes validation on a Polars DataFrame and returns a result dict."""
34
+ ...
35
+
36
+ # NEW: rules can declare columns they need even if not vectorizable
37
+ def required_columns(self) -> Set[str]:
38
+ """
39
+ Columns this rule requires to run `validate()`.
40
+ Default: none. Override in dataset/column rules that read specific columns.
41
+ """
42
+ return set()
43
+
44
+ def _get_required_param(self, key: str, param_type: type = str) -> Any:
45
+ """
46
+ Get a required parameter, raising a clear error if missing or wrong type.
47
+
48
+ Args:
49
+ key: Parameter name
50
+ param_type: Expected type (default: str)
51
+
52
+ Returns:
53
+ The parameter value
54
+
55
+ Raises:
56
+ ValueError: If parameter is missing or has wrong type
57
+ """
58
+ if key not in self.params:
59
+ raise ValueError(
60
+ f"Rule '{self.name}' requires parameter '{key}' but it was not provided"
61
+ )
62
+ value = self.params[key]
63
+ if not isinstance(value, param_type):
64
+ raise ValueError(
65
+ f"Rule '{self.name}' parameter '{key}' must be {param_type.__name__}, "
66
+ f"got {type(value).__name__}"
67
+ )
68
+ return value
69
+
70
+ def _get_optional_param(self, key: str, default: Any = None) -> Any:
71
+ """
72
+ Get an optional parameter with a default value.
73
+
74
+ Args:
75
+ key: Parameter name
76
+ default: Default value if not provided
77
+
78
+ Returns:
79
+ The parameter value or default
80
+ """
81
+ return self.params.get(key, default)
82
+
83
+ def _failures(self, df: pl.DataFrame, mask: pl.Series, message: str) -> Dict[str, Any]:
84
+ """Utility to summarize failing rows."""
85
+ failed_count = mask.sum()
86
+ return {
87
+ "rule_id": getattr(self, "rule_id", self.name),
88
+ "passed": failed_count == 0,
89
+ "failed_count": int(failed_count),
90
+ "message": message if failed_count > 0 else "Passed",
91
+ }
92
+
93
+ def _check_columns(self, df: pl.DataFrame, columns: Set[str]) -> Dict[str, Any] | None:
94
+ """
95
+ Check if required columns exist in the DataFrame.
96
+
97
+ Returns a failure result dict if any columns are missing, None if all exist.
98
+ This allows rules to fail gracefully instead of raising exceptions.
99
+
100
+ Args:
101
+ df: The DataFrame to check
102
+ columns: Set of required column names
103
+
104
+ Returns:
105
+ Failure result dict if columns missing, None if all present
106
+ """
107
+ if not columns:
108
+ return None
109
+
110
+ available = set(df.columns)
111
+ missing = columns - available
112
+
113
+ if not missing:
114
+ return None
115
+
116
+ # Build helpful error message
117
+ missing_list = sorted(missing)
118
+ available_list = sorted(available)
119
+
120
+ if len(missing_list) == 1:
121
+ msg = f"Column '{missing_list[0]}' not found"
122
+ else:
123
+ msg = f"Columns not found: {', '.join(missing_list)}"
124
+
125
+ # Check if data might be nested (single column that looks like a wrapper)
126
+ nested_hint = ""
127
+ if len(available) == 1 and len(missing) > 0:
128
+ nested_hint = ". Data may be nested - Kontra requires flat tabular data"
129
+
130
+ from kontra.state.types import FailureMode
131
+
132
+ return {
133
+ "rule_id": getattr(self, "rule_id", self.name),
134
+ "passed": False,
135
+ "failed_count": df.height, # All rows fail if column missing
136
+ "message": f"{msg}{nested_hint}",
137
+ "failure_mode": str(FailureMode.CONFIG_ERROR), # Mark as config issue, not data issue
138
+ "details": {
139
+ "missing_columns": missing_list,
140
+ "available_columns": available_list[:20], # Limit for readability
141
+ },
142
+ }
143
+
144
+ def to_sql_filter(self, dialect: str = "postgres") -> str | None:
145
+ """
146
+ Return a SQL WHERE clause that matches failing rows.
147
+
148
+ Used by sample_failures() to push filtering to the database instead of
149
+ loading the entire table. Returns None if the rule doesn't support SQL filters.
150
+
151
+ Args:
152
+ dialect: SQL dialect ("postgres", "mssql", "duckdb")
153
+
154
+ Returns:
155
+ SQL WHERE clause string (without "WHERE"), or None if not supported.
156
+
157
+ Example:
158
+ not_null rule returns: "email IS NULL"
159
+ range rule returns: "amount < 0 OR amount > 100 OR amount IS NULL"
160
+ """
161
+ return None
162
+
163
+ def to_sql_agg(self, dialect: str = "duckdb") -> str | None:
164
+ """
165
+ Return a SQL aggregate expression for counting violations.
166
+
167
+ This enables SQL pushdown for custom rules without modifying executors.
168
+ The executor wraps this as: {expr} AS "{rule_id}"
169
+
170
+ Args:
171
+ dialect: SQL dialect ("duckdb", "postgres", "mssql")
172
+
173
+ Returns:
174
+ SQL aggregate expression string, or None if not supported.
175
+
176
+ Example:
177
+ A "positive" rule checking col > 0:
178
+ return 'SUM(CASE WHEN "amount" IS NULL OR "amount" <= 0 THEN 1 ELSE 0 END)'
179
+
180
+ Note:
181
+ - Use double quotes for column names: "column"
182
+ - Return the full aggregate expression (SUM, COUNT, etc.)
183
+ - Handle NULL appropriately (usually NULL = violation)
184
+ - For dialect differences, check the dialect parameter
185
+ """
186
+ return None
@@ -0,0 +1,40 @@
1
+ # Import all builtin rules to register them
2
+ from kontra.rules.builtin.not_null import NotNullRule
3
+ from kontra.rules.builtin.unique import UniqueRule
4
+ from kontra.rules.builtin.dtype import DtypeRule
5
+ from kontra.rules.builtin.range import RangeRule
6
+ from kontra.rules.builtin.allowed_values import AllowedValuesRule
7
+ from kontra.rules.builtin.disallowed_values import DisallowedValuesRule
8
+ from kontra.rules.builtin.regex import RegexRule
9
+ from kontra.rules.builtin.length import LengthRule
10
+ from kontra.rules.builtin.contains import ContainsRule
11
+ from kontra.rules.builtin.starts_with import StartsWithRule
12
+ from kontra.rules.builtin.ends_with import EndsWithRule
13
+ from kontra.rules.builtin.min_rows import MinRowsRule
14
+ from kontra.rules.builtin.max_rows import MaxRowsRule
15
+ from kontra.rules.builtin.freshness import FreshnessRule
16
+ from kontra.rules.builtin.custom_sql_check import CustomSQLCheck
17
+ from kontra.rules.builtin.compare import CompareRule
18
+ from kontra.rules.builtin.conditional_not_null import ConditionalNotNullRule
19
+ from kontra.rules.builtin.conditional_range import ConditionalRangeRule
20
+
21
+ __all__ = [
22
+ "NotNullRule",
23
+ "UniqueRule",
24
+ "DtypeRule",
25
+ "RangeRule",
26
+ "AllowedValuesRule",
27
+ "DisallowedValuesRule",
28
+ "RegexRule",
29
+ "LengthRule",
30
+ "ContainsRule",
31
+ "StartsWithRule",
32
+ "EndsWithRule",
33
+ "MinRowsRule",
34
+ "MaxRowsRule",
35
+ "FreshnessRule",
36
+ "CustomSQLCheck",
37
+ "CompareRule",
38
+ "ConditionalNotNullRule",
39
+ "ConditionalRangeRule",
40
+ ]
@@ -0,0 +1,156 @@
1
+ from __future__ import annotations
2
+ from typing import Dict, Any, List, Optional, Sequence
3
+ import polars as pl
4
+
5
+ from kontra.rules.base import BaseRule
6
+ from kontra.rules.registry import register_rule
7
+ from kontra.rules.predicates import Predicate
8
+ from kontra.state.types import FailureMode
9
+
10
+ @register_rule("allowed_values")
11
+ class AllowedValuesRule(BaseRule):
12
+ def __init__(self, name: str, params: Dict[str, Any]):
13
+ super().__init__(name, params)
14
+ self._get_required_param("column", str)
15
+ if "values" not in self.params:
16
+ raise ValueError(
17
+ f"Rule '{self.name}' requires parameter 'values' but it was not provided"
18
+ )
19
+
20
+ def validate(self, df: pl.DataFrame) -> Dict[str, Any]:
21
+ column = self.params["column"]
22
+ values: Sequence[Any] = self.params["values"]
23
+
24
+ # Check column exists before accessing
25
+ col_check = self._check_columns(df, {column})
26
+ if col_check is not None:
27
+ return col_check
28
+
29
+ allowed_set = set(values)
30
+ # Check if NULL is explicitly allowed
31
+ null_allowed = None in allowed_set
32
+
33
+ # is_in returns NULL for NULL values, fill_null decides if NULL is violation
34
+ # If NULL is in allowed values, NULL should NOT be a violation (fill_null(False))
35
+ # If NULL is not allowed, NULL IS a violation (fill_null(True))
36
+ mask = (~df[column].is_in(list(values))).fill_null(not null_allowed)
37
+ res = super()._failures(df, mask, f"{column} contains disallowed values")
38
+ res["rule_id"] = self.rule_id
39
+
40
+ # Add detailed explanation for failures
41
+ if res["failed_count"] > 0:
42
+ res["failure_mode"] = str(FailureMode.NOVEL_CATEGORY)
43
+ res["details"] = self._explain_failure(df, column, allowed_set)
44
+
45
+ return res
46
+
47
+ def _explain_failure(self, df: pl.DataFrame, column: str, allowed: set) -> Dict[str, Any]:
48
+ """Generate detailed failure explanation."""
49
+ col = df[column]
50
+
51
+ # Find unexpected values and their counts
52
+ unexpected = (
53
+ df.filter(~col.is_in(list(allowed)) & col.is_not_null())
54
+ .group_by(column)
55
+ .agg(pl.len().alias("count"))
56
+ .sort("count", descending=True)
57
+ .head(10) # Top 10 unexpected values
58
+ )
59
+
60
+ unexpected_values: List[Dict[str, Any]] = []
61
+ for row in unexpected.iter_rows(named=True):
62
+ val = row[column]
63
+ count = row["count"]
64
+ unexpected_values.append({
65
+ "value": val,
66
+ "count": count,
67
+ })
68
+
69
+ return {
70
+ "expected": sorted([str(v) for v in allowed]),
71
+ "unexpected_values": unexpected_values,
72
+ "suggestion": self._suggest_fix(unexpected_values, allowed) if unexpected_values else None,
73
+ }
74
+
75
+ def _suggest_fix(self, unexpected: List[Dict[str, Any]], allowed: set) -> str:
76
+ """Suggest how to fix the validation failure."""
77
+ if not unexpected:
78
+ return ""
79
+
80
+ top_unexpected = unexpected[0]
81
+ val = top_unexpected["value"]
82
+ count = top_unexpected["count"]
83
+
84
+ # Simple suggestions
85
+ if count > 100:
86
+ return f"Consider adding '{val}' to allowed values (found in {count:,} rows)"
87
+
88
+ return f"Found {len(unexpected)} unexpected value(s)"
89
+
90
+ def compile_predicate(self) -> Optional[Predicate]:
91
+ column = self.params["column"]
92
+ values: Sequence[Any] = self.params["values"]
93
+ # Check if NULL is explicitly allowed
94
+ null_allowed = None in set(values)
95
+ # If NULL is allowed, don't treat NULL as violation
96
+ expr = (~pl.col(column).is_in(values)).fill_null(not null_allowed)
97
+ return Predicate(
98
+ rule_id=self.rule_id,
99
+ expr=expr,
100
+ message=f"{column} contains disallowed values",
101
+ columns={column},
102
+ )
103
+
104
+ def to_sql_spec(self) -> Optional[Dict[str, Any]]:
105
+ """Generate SQL pushdown specification."""
106
+ column = self.params.get("column")
107
+ values = self.params.get("values")
108
+
109
+ if not column or values is None:
110
+ return None
111
+
112
+ return {
113
+ "kind": "allowed_values",
114
+ "rule_id": self.rule_id,
115
+ "column": column,
116
+ "values": list(values),
117
+ }
118
+
119
+ def to_sql_filter(self, dialect: str = "postgres") -> str | None:
120
+ column = self.params["column"]
121
+ values: Sequence[Any] = self.params["values"]
122
+
123
+ col = f'"{column}"'
124
+
125
+ # Check if NULL is explicitly allowed
126
+ null_allowed = None in set(values)
127
+
128
+ # Build IN list with proper quoting (exclude None)
129
+ quoted_values = []
130
+ for v in values:
131
+ if v is None:
132
+ continue # NULL handled separately
133
+ elif isinstance(v, str):
134
+ # Escape single quotes
135
+ escaped = v.replace("'", "''")
136
+ quoted_values.append(f"'{escaped}'")
137
+ elif isinstance(v, bool):
138
+ quoted_values.append("TRUE" if v else "FALSE")
139
+ else:
140
+ quoted_values.append(str(v))
141
+
142
+ if quoted_values:
143
+ in_list = ", ".join(quoted_values)
144
+ if null_allowed:
145
+ # NULL is allowed, only non-null disallowed values are violations
146
+ return f"{col} NOT IN ({in_list}) AND {col} IS NOT NULL"
147
+ else:
148
+ # NULL is not allowed, both disallowed values AND NULL are violations
149
+ return f"{col} NOT IN ({in_list}) OR {col} IS NULL"
150
+ else:
151
+ # Only NULL in allowed values (no other values) - everything non-null fails
152
+ if null_allowed:
153
+ return f"{col} IS NOT NULL"
154
+ else:
155
+ # Empty allowed list, no NULL - everything fails (always true filter)
156
+ return "1=1"