kontra 0.5.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (124) hide show
  1. kontra/__init__.py +1871 -0
  2. kontra/api/__init__.py +22 -0
  3. kontra/api/compare.py +340 -0
  4. kontra/api/decorators.py +153 -0
  5. kontra/api/results.py +2121 -0
  6. kontra/api/rules.py +681 -0
  7. kontra/cli/__init__.py +0 -0
  8. kontra/cli/commands/__init__.py +1 -0
  9. kontra/cli/commands/config.py +153 -0
  10. kontra/cli/commands/diff.py +450 -0
  11. kontra/cli/commands/history.py +196 -0
  12. kontra/cli/commands/profile.py +289 -0
  13. kontra/cli/commands/validate.py +468 -0
  14. kontra/cli/constants.py +6 -0
  15. kontra/cli/main.py +48 -0
  16. kontra/cli/renderers.py +304 -0
  17. kontra/cli/utils.py +28 -0
  18. kontra/config/__init__.py +34 -0
  19. kontra/config/loader.py +127 -0
  20. kontra/config/models.py +49 -0
  21. kontra/config/settings.py +797 -0
  22. kontra/connectors/__init__.py +0 -0
  23. kontra/connectors/db_utils.py +251 -0
  24. kontra/connectors/detection.py +323 -0
  25. kontra/connectors/handle.py +368 -0
  26. kontra/connectors/postgres.py +127 -0
  27. kontra/connectors/sqlserver.py +226 -0
  28. kontra/engine/__init__.py +0 -0
  29. kontra/engine/backends/duckdb_session.py +227 -0
  30. kontra/engine/backends/duckdb_utils.py +18 -0
  31. kontra/engine/backends/polars_backend.py +47 -0
  32. kontra/engine/engine.py +1205 -0
  33. kontra/engine/executors/__init__.py +15 -0
  34. kontra/engine/executors/base.py +50 -0
  35. kontra/engine/executors/database_base.py +528 -0
  36. kontra/engine/executors/duckdb_sql.py +607 -0
  37. kontra/engine/executors/postgres_sql.py +162 -0
  38. kontra/engine/executors/registry.py +69 -0
  39. kontra/engine/executors/sqlserver_sql.py +163 -0
  40. kontra/engine/materializers/__init__.py +14 -0
  41. kontra/engine/materializers/base.py +42 -0
  42. kontra/engine/materializers/duckdb.py +110 -0
  43. kontra/engine/materializers/factory.py +22 -0
  44. kontra/engine/materializers/polars_connector.py +131 -0
  45. kontra/engine/materializers/postgres.py +157 -0
  46. kontra/engine/materializers/registry.py +138 -0
  47. kontra/engine/materializers/sqlserver.py +160 -0
  48. kontra/engine/result.py +15 -0
  49. kontra/engine/sql_utils.py +611 -0
  50. kontra/engine/sql_validator.py +609 -0
  51. kontra/engine/stats.py +194 -0
  52. kontra/engine/types.py +138 -0
  53. kontra/errors.py +533 -0
  54. kontra/logging.py +85 -0
  55. kontra/preplan/__init__.py +5 -0
  56. kontra/preplan/planner.py +253 -0
  57. kontra/preplan/postgres.py +179 -0
  58. kontra/preplan/sqlserver.py +191 -0
  59. kontra/preplan/types.py +24 -0
  60. kontra/probes/__init__.py +20 -0
  61. kontra/probes/compare.py +400 -0
  62. kontra/probes/relationship.py +283 -0
  63. kontra/reporters/__init__.py +0 -0
  64. kontra/reporters/json_reporter.py +190 -0
  65. kontra/reporters/rich_reporter.py +11 -0
  66. kontra/rules/__init__.py +35 -0
  67. kontra/rules/base.py +186 -0
  68. kontra/rules/builtin/__init__.py +40 -0
  69. kontra/rules/builtin/allowed_values.py +156 -0
  70. kontra/rules/builtin/compare.py +188 -0
  71. kontra/rules/builtin/conditional_not_null.py +213 -0
  72. kontra/rules/builtin/conditional_range.py +310 -0
  73. kontra/rules/builtin/contains.py +138 -0
  74. kontra/rules/builtin/custom_sql_check.py +182 -0
  75. kontra/rules/builtin/disallowed_values.py +140 -0
  76. kontra/rules/builtin/dtype.py +203 -0
  77. kontra/rules/builtin/ends_with.py +129 -0
  78. kontra/rules/builtin/freshness.py +240 -0
  79. kontra/rules/builtin/length.py +193 -0
  80. kontra/rules/builtin/max_rows.py +35 -0
  81. kontra/rules/builtin/min_rows.py +46 -0
  82. kontra/rules/builtin/not_null.py +121 -0
  83. kontra/rules/builtin/range.py +222 -0
  84. kontra/rules/builtin/regex.py +143 -0
  85. kontra/rules/builtin/starts_with.py +129 -0
  86. kontra/rules/builtin/unique.py +124 -0
  87. kontra/rules/condition_parser.py +203 -0
  88. kontra/rules/execution_plan.py +455 -0
  89. kontra/rules/factory.py +103 -0
  90. kontra/rules/predicates.py +25 -0
  91. kontra/rules/registry.py +24 -0
  92. kontra/rules/static_predicates.py +120 -0
  93. kontra/scout/__init__.py +9 -0
  94. kontra/scout/backends/__init__.py +17 -0
  95. kontra/scout/backends/base.py +111 -0
  96. kontra/scout/backends/duckdb_backend.py +359 -0
  97. kontra/scout/backends/postgres_backend.py +519 -0
  98. kontra/scout/backends/sqlserver_backend.py +577 -0
  99. kontra/scout/dtype_mapping.py +150 -0
  100. kontra/scout/patterns.py +69 -0
  101. kontra/scout/profiler.py +801 -0
  102. kontra/scout/reporters/__init__.py +39 -0
  103. kontra/scout/reporters/json_reporter.py +165 -0
  104. kontra/scout/reporters/markdown_reporter.py +152 -0
  105. kontra/scout/reporters/rich_reporter.py +144 -0
  106. kontra/scout/store.py +208 -0
  107. kontra/scout/suggest.py +200 -0
  108. kontra/scout/types.py +652 -0
  109. kontra/state/__init__.py +29 -0
  110. kontra/state/backends/__init__.py +79 -0
  111. kontra/state/backends/base.py +348 -0
  112. kontra/state/backends/local.py +480 -0
  113. kontra/state/backends/postgres.py +1010 -0
  114. kontra/state/backends/s3.py +543 -0
  115. kontra/state/backends/sqlserver.py +969 -0
  116. kontra/state/fingerprint.py +166 -0
  117. kontra/state/types.py +1061 -0
  118. kontra/version.py +1 -0
  119. kontra-0.5.2.dist-info/METADATA +122 -0
  120. kontra-0.5.2.dist-info/RECORD +124 -0
  121. kontra-0.5.2.dist-info/WHEEL +5 -0
  122. kontra-0.5.2.dist-info/entry_points.txt +2 -0
  123. kontra-0.5.2.dist-info/licenses/LICENSE +17 -0
  124. kontra-0.5.2.dist-info/top_level.txt +1 -0
@@ -0,0 +1,203 @@
1
+ # src/kontra/rules/condition_parser.py
2
+ """
3
+ Simple condition parser for when expressions.
4
+
5
+ Parses expressions like:
6
+ - status == 'shipped'
7
+ - amount > 0
8
+ - is_active == true
9
+ - category != 'test'
10
+
11
+ Returns (column_name, operator, typed_value).
12
+
13
+ Safety:
14
+ - No eval() or exec()
15
+ - Regex-based parsing only
16
+ - Whitelist of supported operators
17
+ """
18
+ from __future__ import annotations
19
+
20
+ import re
21
+ from typing import Any, Tuple
22
+
23
+ # Regex pattern for condition expressions
24
+ # Matches: column_name operator value
25
+ # Examples: status == 'shipped', amount > 100, is_active == true
26
+ CONDITION_PATTERN = re.compile(
27
+ r"^\s*" # Leading whitespace
28
+ r"([a-zA-Z_][a-zA-Z0-9_]*)" # Column name (identifier)
29
+ r"\s*" # Whitespace
30
+ r"(==|!=|>=|<=|>|<)" # Operator
31
+ r"\s*" # Whitespace
32
+ r"(.+?)" # Value (non-greedy)
33
+ r"\s*$" # Trailing whitespace
34
+ )
35
+
36
+ SUPPORTED_OPERATORS = {"==", "!=", ">", ">=", "<", "<="}
37
+
38
+
39
+ class ConditionParseError(ValueError):
40
+ """Raised when a condition expression cannot be parsed."""
41
+
42
+ pass
43
+
44
+
45
+ def parse_condition(expr: str) -> Tuple[str, str, Any]:
46
+ """
47
+ Parse a condition expression into (column, operator, typed_value).
48
+
49
+ Args:
50
+ expr: Condition expression (e.g., "status == 'shipped'")
51
+
52
+ Returns:
53
+ Tuple of (column_name, operator, typed_value)
54
+
55
+ Raises:
56
+ ConditionParseError: If the expression cannot be parsed
57
+
58
+ Examples:
59
+ >>> parse_condition("status == 'shipped'")
60
+ ('status', '==', 'shipped')
61
+
62
+ >>> parse_condition("amount > 100")
63
+ ('amount', '>', 100)
64
+
65
+ >>> parse_condition("is_active == true")
66
+ ('is_active', '==', True)
67
+ """
68
+ if not expr or not isinstance(expr, str):
69
+ raise ConditionParseError(f"Invalid condition expression: {expr!r}")
70
+
71
+ # Check for unsupported compound expressions (AND, OR)
72
+ # Case-insensitive check, but avoid matching inside quoted strings
73
+ expr_upper = expr.upper()
74
+ if " AND " in expr_upper or " OR " in expr_upper:
75
+ raise ConditionParseError(
76
+ f"Compound expressions (AND/OR) are not supported: {expr!r}. "
77
+ f"Use multiple rules or custom_sql_check for complex conditions."
78
+ )
79
+
80
+ match = CONDITION_PATTERN.match(expr)
81
+ if not match:
82
+ raise ConditionParseError(
83
+ f"Cannot parse condition: {expr!r}. "
84
+ f"Expected format: column op value (e.g., status == 'shipped')"
85
+ )
86
+
87
+ column, operator, value_str = match.groups()
88
+
89
+ if operator not in SUPPORTED_OPERATORS:
90
+ raise ConditionParseError(
91
+ f"Unsupported operator '{operator}' in condition: {expr!r}. "
92
+ f"Supported: {', '.join(sorted(SUPPORTED_OPERATORS))}"
93
+ )
94
+
95
+ try:
96
+ typed_value = _parse_value(value_str)
97
+ except ValueError as e:
98
+ raise ConditionParseError(
99
+ f"Cannot parse value in condition: {expr!r}. {e}"
100
+ ) from e
101
+
102
+ return column, operator, typed_value
103
+
104
+
105
+ def _parse_value(value_str: str) -> Any:
106
+ """
107
+ Parse a value string into a typed Python value.
108
+
109
+ Supported value types:
110
+ - Strings: 'value' or "value"
111
+ - Booleans: true, false (case-insensitive)
112
+ - Null: null (case-insensitive)
113
+ - Numbers: 123, 123.45, -42
114
+
115
+ Args:
116
+ value_str: String representation of the value
117
+
118
+ Returns:
119
+ Typed Python value
120
+
121
+ Raises:
122
+ ValueError: If the value cannot be parsed
123
+ """
124
+ val = value_str.strip()
125
+
126
+ # Empty value
127
+ if not val:
128
+ raise ValueError("Empty value")
129
+
130
+ # String literals: 'value' or "value"
131
+ if (val.startswith("'") and val.endswith("'")) or \
132
+ (val.startswith('"') and val.endswith('"')):
133
+ # Handle escaped quotes inside strings
134
+ inner = val[1:-1]
135
+ return inner
136
+
137
+ # Boolean: true/false (case-insensitive)
138
+ val_lower = val.lower()
139
+ if val_lower == "true":
140
+ return True
141
+ if val_lower == "false":
142
+ return False
143
+
144
+ # Null: null (case-insensitive)
145
+ if val_lower == "null":
146
+ return None
147
+
148
+ # Try parsing as number
149
+ # Integer
150
+ try:
151
+ return int(val)
152
+ except ValueError:
153
+ pass
154
+
155
+ # Float
156
+ try:
157
+ return float(val)
158
+ except ValueError:
159
+ pass
160
+
161
+ raise ValueError(f"Cannot parse value: {val!r}")
162
+
163
+
164
+ def condition_to_sql(column: str, operator: str, value: Any, dialect: str = "duckdb") -> str:
165
+ """
166
+ Convert a parsed condition to SQL WHERE clause fragment.
167
+
168
+ Args:
169
+ column: Column name
170
+ operator: Comparison operator
171
+ value: Typed value
172
+ dialect: SQL dialect (duckdb, postgres, sqlserver)
173
+
174
+ Returns:
175
+ SQL WHERE clause fragment
176
+ """
177
+ from kontra.engine.sql_utils import esc_ident, lit_value
178
+
179
+ col_sql = esc_ident(column, dialect)
180
+
181
+ # Map Python operators to SQL operators
182
+ sql_op_map = {
183
+ "==": "=",
184
+ "!=": "<>",
185
+ ">": ">",
186
+ ">=": ">=",
187
+ "<": "<",
188
+ "<=": "<=",
189
+ }
190
+ sql_op = sql_op_map.get(operator, operator)
191
+
192
+ # Handle NULL comparison specially
193
+ if value is None:
194
+ if operator == "==":
195
+ return f"{col_sql} IS NULL"
196
+ elif operator == "!=":
197
+ return f"{col_sql} IS NOT NULL"
198
+ else:
199
+ # Other operators with NULL don't make sense
200
+ return f"{col_sql} IS NULL"
201
+
202
+ val_sql = lit_value(value, dialect)
203
+ return f"{col_sql} {sql_op} {val_sql}"
@@ -0,0 +1,455 @@
1
+ # src/kontra/rules/execution_plan.py
2
+ from __future__ import annotations
3
+
4
+ from dataclasses import dataclass
5
+ from typing import Iterable, List, Dict, Any, Optional, Set
6
+
7
+ import polars as pl
8
+
9
+ from kontra.rules.base import BaseRule
10
+ from kontra.rules.predicates import Predicate
11
+ from kontra.logging import get_logger, log_exception
12
+
13
+ _logger = get_logger(__name__)
14
+
15
+
16
+ # --------------------------------------------------------------------------- #
17
+ # Planning Artifact
18
+ # --------------------------------------------------------------------------- #
19
+
20
+ @dataclass
21
+ class CompiledPlan:
22
+ """
23
+ Output of planning/compilation.
24
+
25
+ Attributes
26
+ ----------
27
+ predicates
28
+ Vectorizable rule predicates (Polars expressions). These can be run in
29
+ a single, columnar pass (df.select([...])) and summarized cheaply.
30
+
31
+ fallback_rules
32
+ Rules that couldn't be vectorized. They will be executed individually
33
+ via rule.validate(df). We still include their required columns in
34
+ `required_cols` to enable projection.
35
+
36
+ required_cols
37
+ Union of all columns required by `predicates` and `fallback_rules`.
38
+ The engine can hand this list to the materializer for true projection.
39
+
40
+ sql_rules
41
+ Tiny, backend-agnostic specs for rules that can be evaluated as
42
+ single-row SQL aggregates (e.g., DuckDB). Polars ignores these; they
43
+ are consumed by a SQL executor if present.
44
+ """
45
+ predicates: List[Predicate]
46
+ fallback_rules: List[BaseRule]
47
+ required_cols: List[str]
48
+ sql_rules: List[Dict[str, Any]]
49
+
50
+
51
+ # --------------------------------------------------------------------------- #
52
+ # Planner
53
+ # --------------------------------------------------------------------------- #
54
+
55
+ class RuleExecutionPlan:
56
+ """
57
+ Builds and executes a plan for the given rules.
58
+
59
+ Design goals
60
+ ------------
61
+ - Deterministic: same inputs → same outputs
62
+ - Lean: compilation discovers vectorizable work + required columns
63
+ - Extensible: optional `sql_rules` for SQL backends (Polars behavior unchanged)
64
+ """
65
+
66
+ def __init__(self, rules: List[BaseRule]):
67
+ self.rules = rules
68
+
69
+ def __str__(self) -> str:
70
+ if not self.rules:
71
+ return "RuleExecutionPlan(rules=[])"
72
+ rules_list = [repr(r) for r in self.rules]
73
+ rules_str = ",\n ".join(rules_list)
74
+ return f"RuleExecutionPlan(rules=[\n {rules_str}\n])"
75
+
76
+ def __repr__(self) -> str:
77
+ return f"RuleExecutionPlan(rules={self.rules})"
78
+
79
+ # --------------------------- Public API -----------------------------------
80
+
81
+ def compile(self) -> CompiledPlan:
82
+ """
83
+ Compile rules into:
84
+ - vectorizable predicates (Polars)
85
+ - fallback rule list
86
+ - required column set (for projection)
87
+ - sql_rules (for optional SQL executor consumption)
88
+ """
89
+ predicates: List[Predicate] = []
90
+ fallbacks: List[BaseRule] = []
91
+ sql_rules: List[Dict[str, Any]] = []
92
+
93
+ for rule in self.rules:
94
+ # 1) Try vectorization (Polars)
95
+ pred = _try_compile_predicate(rule)
96
+ if pred is None:
97
+ fallbacks.append(rule)
98
+ else:
99
+ _validate_predicate(pred)
100
+ if pred.rule_id != rule.rule_id:
101
+ raise ValueError(
102
+ f"Predicate.rule_id '{pred.rule_id}' does not match "
103
+ f"rule.rule_id '{rule.rule_id}'."
104
+ )
105
+ predicates.append(pred)
106
+
107
+ # 2) Optionally generate a SQL spec (non-fatal if inapplicable)
108
+ spec = _maybe_rule_sql_spec(rule)
109
+ if spec:
110
+ sql_rules.append(spec)
111
+
112
+ # 3) Derive required columns for projection (predicates + fallbacks)
113
+ cols_pred = _collect_required_columns(predicates)
114
+ cols_fb = _extract_columns_from_rules(fallbacks)
115
+ required_cols = sorted(cols_pred | cols_fb)
116
+
117
+ return CompiledPlan(
118
+ predicates=predicates,
119
+ fallback_rules=fallbacks,
120
+ required_cols=required_cols,
121
+ sql_rules=sql_rules,
122
+ )
123
+
124
+ def execute_compiled(self, df: pl.DataFrame, compiled: CompiledPlan) -> List[Dict[str, Any]]:
125
+ """
126
+ Execute the compiled plan using Polars:
127
+ - vectorized pass for predicates
128
+ - individual validation for fallback rules
129
+ """
130
+ # Build rule_id -> severity mapping for predicates
131
+ rule_severity_map = self._build_severity_map()
132
+ available_cols = set(df.columns)
133
+
134
+ vec_results: List[Dict[str, Any]] = []
135
+ if compiled.predicates:
136
+ # Separate predicates into those with all columns present vs missing columns
137
+ valid_predicates: List[Predicate] = []
138
+ missing_col_results: List[Dict[str, Any]] = []
139
+
140
+ for p in compiled.predicates:
141
+ missing = p.columns - available_cols
142
+ if missing:
143
+ # Column(s) not found - generate failure result
144
+ missing_list = sorted(missing)
145
+ if len(missing_list) == 1:
146
+ msg = f"Column '{missing_list[0]}' not found"
147
+ else:
148
+ msg = f"Columns not found: {', '.join(missing_list)}"
149
+
150
+ # Hint if data might be nested (single column available, multiple expected)
151
+ if len(available_cols) == 1:
152
+ msg += ". Data may be nested - Kontra requires flat tabular data"
153
+
154
+ from kontra.state.types import FailureMode
155
+ missing_col_results.append({
156
+ "rule_id": p.rule_id,
157
+ "passed": False,
158
+ "failed_count": df.height,
159
+ "message": msg,
160
+ "execution_source": "polars",
161
+ "severity": rule_severity_map.get(p.rule_id, "blocking"),
162
+ "failure_mode": str(FailureMode.CONFIG_ERROR),
163
+ "details": {
164
+ "missing_columns": missing_list,
165
+ "available_columns": sorted(available_cols)[:20],
166
+ },
167
+ })
168
+ else:
169
+ valid_predicates.append(p)
170
+
171
+ # Execute valid predicates in vectorized pass
172
+ if valid_predicates:
173
+ counts_df = df.select([p.expr.sum().alias(p.rule_id) for p in valid_predicates])
174
+ counts = counts_df.row(0, named=True)
175
+ for p in valid_predicates:
176
+ failed_count = int(counts[p.rule_id])
177
+ passed = failed_count == 0
178
+ vec_results.append(
179
+ {
180
+ "rule_id": p.rule_id,
181
+ "passed": passed,
182
+ "failed_count": failed_count,
183
+ "message": "Passed" if passed else p.message,
184
+ "execution_source": "polars",
185
+ "severity": rule_severity_map.get(p.rule_id, "blocking"),
186
+ }
187
+ )
188
+
189
+ # Add missing column results
190
+ vec_results.extend(missing_col_results)
191
+
192
+ fb_results: List[Dict[str, Any]] = []
193
+ for r in compiled.fallback_rules:
194
+ try:
195
+ result = r.validate(df)
196
+ result["execution_source"] = "polars"
197
+ result["severity"] = getattr(r, "severity", "blocking")
198
+ fb_results.append(result)
199
+ except Exception as e:
200
+ fb_results.append(
201
+ {
202
+ "rule_id": getattr(r, "rule_id", r.name),
203
+ "passed": False,
204
+ "failed_count": int(df.height),
205
+ "message": f"Rule execution failed: {e}",
206
+ "execution_source": "polars",
207
+ "severity": getattr(r, "severity", "blocking"),
208
+ }
209
+ )
210
+
211
+ # Deterministic order: predicates first, then fallbacks
212
+ return vec_results + fb_results
213
+
214
+ def _build_severity_map(self) -> Dict[str, str]:
215
+ """Build a mapping from rule_id to severity for all rules."""
216
+ return {
217
+ getattr(r, "rule_id", r.name): getattr(r, "severity", "blocking")
218
+ for r in self.rules
219
+ }
220
+
221
+ def execute(self, df: pl.DataFrame) -> List[Dict[str, Any]]:
222
+ """Compile and execute in one step (Polars-only path)."""
223
+ compiled = self.compile()
224
+ return self.execute_compiled(df, compiled)
225
+
226
+ def summary(self, results: List[Dict[str, Any]]) -> Dict[str, Any]:
227
+ """Aggregate pass/fail counts for reporters."""
228
+ total = len(results)
229
+ failed = sum(1 for r in results if not r.get("passed", False))
230
+
231
+ # Count failures by severity
232
+ blocking_failures = 0
233
+ warning_failures = 0
234
+ info_failures = 0
235
+
236
+ for r in results:
237
+ if not r.get("passed", False):
238
+ severity = r.get("severity", "blocking")
239
+ if severity == "blocking":
240
+ blocking_failures += 1
241
+ elif severity == "warning":
242
+ warning_failures += 1
243
+ elif severity == "info":
244
+ info_failures += 1
245
+
246
+ # Validation passes if no blocking failures
247
+ # (warnings and info are reported but don't fail the pipeline)
248
+ passed = blocking_failures == 0
249
+
250
+ return {
251
+ "total_rules": total,
252
+ "rules_failed": failed,
253
+ "rules_passed": total - failed,
254
+ "passed": passed,
255
+ "blocking_failures": blocking_failures,
256
+ "warning_failures": warning_failures,
257
+ "info_failures": info_failures,
258
+ }
259
+
260
+ # ------------------------ Hybrid/Residual Helpers -------------------------
261
+
262
+ def without_ids(self, compiled: CompiledPlan, handled_ids: Set[str]) -> CompiledPlan:
263
+ """
264
+ Return a new CompiledPlan with any rules whose rule_id is in `handled_ids` removed.
265
+
266
+ Used by the hybrid path: a SQL executor handles a subset of rules; the
267
+ remainder (residual) still needs accurate `required_cols` so projection
268
+ works for Polars.
269
+ """
270
+ resid_preds = [p for p in compiled.predicates if p.rule_id not in handled_ids]
271
+ resid_fallbacks = [
272
+ r for r in compiled.fallback_rules
273
+ if getattr(r, "rule_id", r.name) not in handled_ids
274
+ ]
275
+
276
+ cols_pred = _collect_required_columns(resid_preds)
277
+ cols_fb = _extract_columns_from_rules(resid_fallbacks)
278
+ required_cols = sorted(cols_pred | cols_fb)
279
+
280
+ # sql_rules are irrelevant for the residual Polars pass
281
+ return CompiledPlan(
282
+ predicates=resid_preds,
283
+ fallback_rules=resid_fallbacks,
284
+ required_cols=required_cols,
285
+ sql_rules=[],
286
+ )
287
+
288
+ def required_cols_for(self, compiled: CompiledPlan) -> List[str]:
289
+ """Expose the computed required columns for a given compiled plan."""
290
+ return list(compiled.required_cols)
291
+
292
+
293
+ # --------------------------------------------------------------------------- #
294
+ # Helpers
295
+ # --------------------------------------------------------------------------- #
296
+
297
+ def _try_compile_predicate(rule: BaseRule) -> Optional[Predicate]:
298
+ """
299
+ Ask a rule for its vectorizable Predicate, if any.
300
+
301
+ Rules that don't implement `compile_predicate()` or cannot be compiled
302
+ (raise an error) return None and are treated as fallbacks.
303
+ """
304
+ fn = getattr(rule, "compile_predicate", None)
305
+ if fn is None:
306
+ return None
307
+ try:
308
+ return fn() or None
309
+ except Exception as e:
310
+ log_exception(_logger, f"compile_predicate failed for {getattr(rule, 'name', '?')}", e)
311
+ return None
312
+
313
+
314
+ def _collect_required_columns(preds: Iterable[Predicate]) -> Set[str]:
315
+ """Union the required columns declared by each predicate."""
316
+ cols: Set[str] = set()
317
+ for p in preds:
318
+ cols.update(p.columns)
319
+ return cols
320
+
321
+
322
+ def _extract_columns_from_rules(rules: Iterable[BaseRule]) -> Set[str]:
323
+ """
324
+ Extract required columns from fallback rules.
325
+
326
+ First tries rule.required_columns(), then falls back to inferring
327
+ from common param names ('column', 'columns').
328
+ """
329
+ cols: Set[str] = set()
330
+ for r in rules:
331
+ try:
332
+ # Prefer explicit declaration from the rule
333
+ rule_cols = r.required_columns() or set()
334
+ if not rule_cols:
335
+ # Heuristic: infer from common param names when not declared
336
+ p = getattr(r, "params", {}) or {}
337
+ col = p.get("column")
338
+ cols_list = p.get("columns")
339
+ if isinstance(col, str) and col:
340
+ rule_cols.add(col)
341
+ if isinstance(cols_list, (list, tuple)):
342
+ rule_cols.update(c for c in cols_list if isinstance(c, str))
343
+ cols.update(rule_cols)
344
+ except Exception as e:
345
+ # Be conservative: ignore here; rule will raise during validate() if broken.
346
+ log_exception(_logger, f"Could not extract columns for rule {getattr(r, 'name', '?')}", e)
347
+ return cols
348
+
349
+
350
+ def _validate_predicate(pred: Predicate) -> None:
351
+ """Type/shape checks for a Predicate returned by a rule."""
352
+ if not isinstance(pred, Predicate):
353
+ raise TypeError("compile_predicate() must return a Predicate instance")
354
+ if not isinstance(pred.expr, pl.Expr):
355
+ raise TypeError("Predicate.expr must be a Polars Expr")
356
+ if not pred.rule_id or not isinstance(pred.rule_id, str):
357
+ raise ValueError("Predicate.rule_id must be a non-empty string")
358
+ if not isinstance(pred.columns, set):
359
+ raise TypeError("Predicate.columns must be a set[str]")
360
+
361
+
362
+ def _maybe_rule_sql_spec(rule: BaseRule) -> Optional[Dict[str, Any]]:
363
+ """
364
+ Return a tiny, backend-agnostic spec for SQL-capable rules.
365
+
366
+ Supported rules:
367
+ - not_null(column)
368
+ - unique(column)
369
+ - min_rows(threshold)
370
+ - max_rows(threshold)
371
+ - allowed_values(column, values)
372
+ - Any custom rule implementing to_sql_agg()
373
+
374
+ Notes
375
+ -----
376
+ - If a rule provides `to_sql_spec()`, that takes precedence.
377
+ - If a rule provides `to_sql_agg()`, use it for custom SQL pushdown.
378
+ - We normalize namespaced rule names, e.g. "DATASET:not_null" → "not_null".
379
+ - For min/max rows, accept both `value` and `threshold` to match existing contracts.
380
+ - Not all executors support all rules (DuckDB: 3, PostgreSQL: 5).
381
+ """
382
+ rid = getattr(rule, "rule_id", None)
383
+ if not isinstance(rid, str):
384
+ return None
385
+
386
+ # Priority 1: Rule-provided spec (full control)
387
+ to_sql = getattr(rule, "to_sql_spec", None)
388
+ if callable(to_sql):
389
+ try:
390
+ spec = to_sql()
391
+ if spec:
392
+ return spec
393
+ except Exception as e:
394
+ log_exception(_logger, f"to_sql_spec failed for {getattr(rule, 'name', '?')}", e)
395
+
396
+ # Priority 2: Rule-provided SQL aggregate (custom rules)
397
+ # This allows custom rules to have SQL pushdown without modifying executors
398
+ to_sql_agg = getattr(rule, "to_sql_agg", None)
399
+ if callable(to_sql_agg):
400
+ try:
401
+ # Try each dialect - executors will use the one they need
402
+ # We include all dialects in the spec so any executor can use it
403
+ agg_duckdb = to_sql_agg("duckdb")
404
+ agg_postgres = to_sql_agg("postgres")
405
+ agg_mssql = to_sql_agg("mssql")
406
+
407
+ # If any dialect is supported, include the spec
408
+ if agg_duckdb or agg_postgres or agg_mssql:
409
+ return {
410
+ "kind": "custom_agg",
411
+ "rule_id": rid,
412
+ "sql_agg": {
413
+ "duckdb": agg_duckdb,
414
+ "postgres": agg_postgres,
415
+ "mssql": agg_mssql,
416
+ },
417
+ }
418
+ except Exception as e:
419
+ log_exception(_logger, f"to_sql_agg failed for {getattr(rule, 'name', '?')}", e)
420
+
421
+ # Priority 3: Built-in rule detection (fallback)
422
+ raw_name = getattr(rule, "name", None)
423
+ name = raw_name.split(":")[-1] if isinstance(raw_name, str) else raw_name
424
+ params: Dict[str, Any] = getattr(rule, "params", {}) or {}
425
+
426
+ if not (name and isinstance(params, dict)):
427
+ return None
428
+
429
+ if name == "not_null":
430
+ col = params.get("column")
431
+ if isinstance(col, str) and col:
432
+ return {"kind": "not_null", "rule_id": rid, "column": col}
433
+
434
+ if name == "unique":
435
+ col = params.get("column")
436
+ if isinstance(col, str) and col:
437
+ return {"kind": "unique", "rule_id": rid, "column": col}
438
+
439
+ if name == "min_rows":
440
+ thr = params.get("value", params.get("threshold"))
441
+ if isinstance(thr, int):
442
+ return {"kind": "min_rows", "rule_id": rid, "threshold": int(thr)}
443
+
444
+ if name == "max_rows":
445
+ thr = params.get("value", params.get("threshold"))
446
+ if isinstance(thr, int):
447
+ return {"kind": "max_rows", "rule_id": rid, "threshold": int(thr)}
448
+
449
+ if name == "allowed_values":
450
+ col = params.get("column")
451
+ values = params.get("values", [])
452
+ if isinstance(col, str) and col and values:
453
+ return {"kind": "allowed_values", "rule_id": rid, "column": col, "values": list(values)}
454
+
455
+ return None