kontra 0.5.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (124) hide show
  1. kontra/__init__.py +1871 -0
  2. kontra/api/__init__.py +22 -0
  3. kontra/api/compare.py +340 -0
  4. kontra/api/decorators.py +153 -0
  5. kontra/api/results.py +2121 -0
  6. kontra/api/rules.py +681 -0
  7. kontra/cli/__init__.py +0 -0
  8. kontra/cli/commands/__init__.py +1 -0
  9. kontra/cli/commands/config.py +153 -0
  10. kontra/cli/commands/diff.py +450 -0
  11. kontra/cli/commands/history.py +196 -0
  12. kontra/cli/commands/profile.py +289 -0
  13. kontra/cli/commands/validate.py +468 -0
  14. kontra/cli/constants.py +6 -0
  15. kontra/cli/main.py +48 -0
  16. kontra/cli/renderers.py +304 -0
  17. kontra/cli/utils.py +28 -0
  18. kontra/config/__init__.py +34 -0
  19. kontra/config/loader.py +127 -0
  20. kontra/config/models.py +49 -0
  21. kontra/config/settings.py +797 -0
  22. kontra/connectors/__init__.py +0 -0
  23. kontra/connectors/db_utils.py +251 -0
  24. kontra/connectors/detection.py +323 -0
  25. kontra/connectors/handle.py +368 -0
  26. kontra/connectors/postgres.py +127 -0
  27. kontra/connectors/sqlserver.py +226 -0
  28. kontra/engine/__init__.py +0 -0
  29. kontra/engine/backends/duckdb_session.py +227 -0
  30. kontra/engine/backends/duckdb_utils.py +18 -0
  31. kontra/engine/backends/polars_backend.py +47 -0
  32. kontra/engine/engine.py +1205 -0
  33. kontra/engine/executors/__init__.py +15 -0
  34. kontra/engine/executors/base.py +50 -0
  35. kontra/engine/executors/database_base.py +528 -0
  36. kontra/engine/executors/duckdb_sql.py +607 -0
  37. kontra/engine/executors/postgres_sql.py +162 -0
  38. kontra/engine/executors/registry.py +69 -0
  39. kontra/engine/executors/sqlserver_sql.py +163 -0
  40. kontra/engine/materializers/__init__.py +14 -0
  41. kontra/engine/materializers/base.py +42 -0
  42. kontra/engine/materializers/duckdb.py +110 -0
  43. kontra/engine/materializers/factory.py +22 -0
  44. kontra/engine/materializers/polars_connector.py +131 -0
  45. kontra/engine/materializers/postgres.py +157 -0
  46. kontra/engine/materializers/registry.py +138 -0
  47. kontra/engine/materializers/sqlserver.py +160 -0
  48. kontra/engine/result.py +15 -0
  49. kontra/engine/sql_utils.py +611 -0
  50. kontra/engine/sql_validator.py +609 -0
  51. kontra/engine/stats.py +194 -0
  52. kontra/engine/types.py +138 -0
  53. kontra/errors.py +533 -0
  54. kontra/logging.py +85 -0
  55. kontra/preplan/__init__.py +5 -0
  56. kontra/preplan/planner.py +253 -0
  57. kontra/preplan/postgres.py +179 -0
  58. kontra/preplan/sqlserver.py +191 -0
  59. kontra/preplan/types.py +24 -0
  60. kontra/probes/__init__.py +20 -0
  61. kontra/probes/compare.py +400 -0
  62. kontra/probes/relationship.py +283 -0
  63. kontra/reporters/__init__.py +0 -0
  64. kontra/reporters/json_reporter.py +190 -0
  65. kontra/reporters/rich_reporter.py +11 -0
  66. kontra/rules/__init__.py +35 -0
  67. kontra/rules/base.py +186 -0
  68. kontra/rules/builtin/__init__.py +40 -0
  69. kontra/rules/builtin/allowed_values.py +156 -0
  70. kontra/rules/builtin/compare.py +188 -0
  71. kontra/rules/builtin/conditional_not_null.py +213 -0
  72. kontra/rules/builtin/conditional_range.py +310 -0
  73. kontra/rules/builtin/contains.py +138 -0
  74. kontra/rules/builtin/custom_sql_check.py +182 -0
  75. kontra/rules/builtin/disallowed_values.py +140 -0
  76. kontra/rules/builtin/dtype.py +203 -0
  77. kontra/rules/builtin/ends_with.py +129 -0
  78. kontra/rules/builtin/freshness.py +240 -0
  79. kontra/rules/builtin/length.py +193 -0
  80. kontra/rules/builtin/max_rows.py +35 -0
  81. kontra/rules/builtin/min_rows.py +46 -0
  82. kontra/rules/builtin/not_null.py +121 -0
  83. kontra/rules/builtin/range.py +222 -0
  84. kontra/rules/builtin/regex.py +143 -0
  85. kontra/rules/builtin/starts_with.py +129 -0
  86. kontra/rules/builtin/unique.py +124 -0
  87. kontra/rules/condition_parser.py +203 -0
  88. kontra/rules/execution_plan.py +455 -0
  89. kontra/rules/factory.py +103 -0
  90. kontra/rules/predicates.py +25 -0
  91. kontra/rules/registry.py +24 -0
  92. kontra/rules/static_predicates.py +120 -0
  93. kontra/scout/__init__.py +9 -0
  94. kontra/scout/backends/__init__.py +17 -0
  95. kontra/scout/backends/base.py +111 -0
  96. kontra/scout/backends/duckdb_backend.py +359 -0
  97. kontra/scout/backends/postgres_backend.py +519 -0
  98. kontra/scout/backends/sqlserver_backend.py +577 -0
  99. kontra/scout/dtype_mapping.py +150 -0
  100. kontra/scout/patterns.py +69 -0
  101. kontra/scout/profiler.py +801 -0
  102. kontra/scout/reporters/__init__.py +39 -0
  103. kontra/scout/reporters/json_reporter.py +165 -0
  104. kontra/scout/reporters/markdown_reporter.py +152 -0
  105. kontra/scout/reporters/rich_reporter.py +144 -0
  106. kontra/scout/store.py +208 -0
  107. kontra/scout/suggest.py +200 -0
  108. kontra/scout/types.py +652 -0
  109. kontra/state/__init__.py +29 -0
  110. kontra/state/backends/__init__.py +79 -0
  111. kontra/state/backends/base.py +348 -0
  112. kontra/state/backends/local.py +480 -0
  113. kontra/state/backends/postgres.py +1010 -0
  114. kontra/state/backends/s3.py +543 -0
  115. kontra/state/backends/sqlserver.py +969 -0
  116. kontra/state/fingerprint.py +166 -0
  117. kontra/state/types.py +1061 -0
  118. kontra/version.py +1 -0
  119. kontra-0.5.2.dist-info/METADATA +122 -0
  120. kontra-0.5.2.dist-info/RECORD +124 -0
  121. kontra-0.5.2.dist-info/WHEEL +5 -0
  122. kontra-0.5.2.dist-info/entry_points.txt +2 -0
  123. kontra-0.5.2.dist-info/licenses/LICENSE +17 -0
  124. kontra-0.5.2.dist-info/top_level.txt +1 -0
@@ -0,0 +1,310 @@
1
+ # src/kontra/rules/builtin/conditional_range.py
2
+ """
3
+ Conditional range rule - Column must be within range when a condition is met.
4
+
5
+ Usage:
6
+ - name: conditional_range
7
+ params:
8
+ column: discount_percent
9
+ when: "customer_type == 'premium'"
10
+ min: 10
11
+ max: 50
12
+
13
+ Fails when:
14
+ - The `when` condition is TRUE AND (column is NULL OR column < min OR column > max)
15
+
16
+ Passes when:
17
+ - The `when` condition is FALSE (regardless of column value)
18
+ - The `when` condition is TRUE AND column is within [min, max]
19
+ """
20
+ from __future__ import annotations
21
+
22
+ from typing import Any, Dict, List, Optional, Set, Union
23
+
24
+ import polars as pl
25
+
26
+ from kontra.rules.base import BaseRule
27
+ from kontra.rules.registry import register_rule
28
+ from kontra.rules.predicates import Predicate
29
+ from kontra.rules.condition_parser import parse_condition, ConditionParseError
30
+ from kontra.state.types import FailureMode
31
+
32
+
33
+ # Map operators to Polars comparison methods
34
+ POLARS_OP_MAP = {
35
+ "==": pl.Expr.__eq__,
36
+ "!=": pl.Expr.__ne__,
37
+ ">": pl.Expr.__gt__,
38
+ ">=": pl.Expr.__ge__,
39
+ "<": pl.Expr.__lt__,
40
+ "<=": pl.Expr.__le__,
41
+ }
42
+
43
+
44
+ @register_rule("conditional_range")
45
+ class ConditionalRangeRule(BaseRule):
46
+ """
47
+ Fails where column is outside range when a condition is met.
48
+
49
+ params:
50
+ - column: str (required) - Column to check range
51
+ - when: str (required) - Condition expression (e.g., "status == 'active'")
52
+ - min: numeric (optional) - Minimum allowed value (inclusive)
53
+ - max: numeric (optional) - Maximum allowed value (inclusive)
54
+
55
+ At least one of `min` or `max` must be provided.
56
+
57
+ Condition syntax:
58
+ column_name operator value
59
+
60
+ Supported operators: ==, !=, >, >=, <, <=
61
+ Supported values: 'string', 123, 123.45, true, false, null
62
+
63
+ When the condition is TRUE:
64
+ - NULL in column = failure (can't compare NULL)
65
+ - Value outside [min, max] = failure
66
+
67
+ Examples:
68
+ - name: conditional_range
69
+ params:
70
+ column: discount_percent
71
+ when: "customer_type == 'premium'"
72
+ min: 10
73
+ max: 50
74
+ """
75
+
76
+ def __init__(self, name: str, params: Dict[str, Any]):
77
+ super().__init__(name, params)
78
+ # Validate parameters at construction time
79
+ self._column = self._get_required_param("column", str)
80
+ self._when_expr = self._get_required_param("when", str)
81
+ self._min_val: Optional[Union[int, float]] = params.get("min")
82
+ self._max_val: Optional[Union[int, float]] = params.get("max")
83
+
84
+ # At least one bound must be provided
85
+ if self._min_val is None and self._max_val is None:
86
+ raise ValueError(
87
+ "Rule 'conditional_range' requires at least one of 'min' or 'max'"
88
+ )
89
+
90
+ # Validate min <= max
91
+ if self._min_val is not None and self._max_val is not None:
92
+ if self._min_val > self._max_val:
93
+ from kontra.errors import RuleParameterError
94
+ raise RuleParameterError(
95
+ "conditional_range", "min/max",
96
+ f"min ({self._min_val}) must be <= max ({self._max_val})"
97
+ )
98
+
99
+ # Parse the when expression at init time to fail early
100
+ try:
101
+ self._when_column, self._when_op, self._when_value = parse_condition(self._when_expr)
102
+ except ConditionParseError as e:
103
+ raise ValueError(f"Rule 'conditional_range' invalid 'when' expression: {e}") from e
104
+
105
+ def required_columns(self) -> Set[str]:
106
+ return {self._column, self._when_column}
107
+
108
+ def _build_condition_expr(self) -> pl.Expr:
109
+ """Build the Polars expression for the when condition."""
110
+ when_col = pl.col(self._when_column)
111
+ compare_fn = POLARS_OP_MAP[self._when_op]
112
+
113
+ # Handle NULL value in condition
114
+ if self._when_value is None:
115
+ if self._when_op == "==":
116
+ return when_col.is_null()
117
+ elif self._when_op == "!=":
118
+ return when_col.is_not_null()
119
+ else:
120
+ # Other operators with NULL don't make sense; treat as always false
121
+ return pl.lit(False)
122
+
123
+ # Build comparison expression
124
+ return compare_fn(when_col, self._when_value)
125
+
126
+ def _build_range_violation_expr(self) -> pl.Expr:
127
+ """Build expression for range violation (NULL or out of range)."""
128
+ col = pl.col(self._column)
129
+
130
+ # NULL is a violation
131
+ null_expr = col.is_null()
132
+
133
+ # Out of range conditions
134
+ if self._min_val is not None and self._max_val is not None:
135
+ out_of_range = (col < self._min_val) | (col > self._max_val)
136
+ elif self._min_val is not None:
137
+ out_of_range = col < self._min_val
138
+ else:
139
+ out_of_range = col > self._max_val
140
+
141
+ return null_expr | out_of_range
142
+
143
+ def validate(self, df: pl.DataFrame) -> Dict[str, Any]:
144
+ # Check columns exist before accessing
145
+ col_check = self._check_columns(df, {self._column, self._when_column})
146
+ if col_check is not None:
147
+ return col_check
148
+
149
+ # Build condition and range violation expressions
150
+ condition_expr = self._build_condition_expr()
151
+ range_violation_expr = self._build_range_violation_expr()
152
+
153
+ # Mask: True = failure
154
+ # Failure = condition is TRUE AND (column is NULL OR outside range)
155
+ mask_expr = condition_expr & range_violation_expr
156
+
157
+ # Evaluate the expression to get a Series
158
+ mask = df.select(mask_expr.alias("_mask"))["_mask"]
159
+
160
+ message = self._build_message()
161
+
162
+ res = super()._failures(df, mask, message)
163
+ res["rule_id"] = self.rule_id
164
+
165
+ if res["failed_count"] > 0:
166
+ res["failure_mode"] = str(FailureMode.CONDITIONAL_RANGE_VIOLATION)
167
+ res["details"] = self._explain_failure(df, mask, res["failed_count"])
168
+
169
+ return res
170
+
171
+ def _build_message(self) -> str:
172
+ """Build the failure message."""
173
+ if self._min_val is not None and self._max_val is not None:
174
+ return f"{self._column} outside range [{self._min_val}, {self._max_val}] when {self._when_expr}"
175
+ elif self._min_val is not None:
176
+ return f"{self._column} below {self._min_val} when {self._when_expr}"
177
+ else:
178
+ return f"{self._column} above {self._max_val} when {self._when_expr}"
179
+
180
+ def _explain_failure(
181
+ self, df: pl.DataFrame, mask: pl.Series, failed_count: int
182
+ ) -> Dict[str, Any]:
183
+ """Generate detailed failure explanation."""
184
+ total_rows = df.height
185
+ failure_rate = failed_count / total_rows if total_rows > 0 else 0
186
+
187
+ # Count rows matching the condition
188
+ condition_expr = self._build_condition_expr()
189
+ condition_matches = df.select(condition_expr.sum())[0, 0]
190
+
191
+ details: Dict[str, Any] = {
192
+ "failed_count": failed_count,
193
+ "failure_rate": round(failure_rate, 4),
194
+ "total_rows": total_rows,
195
+ "column": self._column,
196
+ "when_condition": self._when_expr,
197
+ "rows_matching_condition": int(condition_matches) if condition_matches else 0,
198
+ }
199
+
200
+ # Add expected bounds
201
+ if self._min_val is not None:
202
+ details["expected_min"] = self._min_val
203
+ if self._max_val is not None:
204
+ details["expected_max"] = self._max_val
205
+
206
+ # Filter to rows where condition is true
207
+ condition_mask = df.select(condition_expr.alias("_cond"))["_cond"]
208
+ conditional_df = df.filter(condition_mask)
209
+
210
+ if conditional_df.height > 0:
211
+ col = conditional_df[self._column]
212
+
213
+ # Count violations by type
214
+ null_count = col.null_count()
215
+ if null_count > 0:
216
+ details["null_count_when_condition"] = int(null_count)
217
+
218
+ if self._min_val is not None:
219
+ below_min = (col < self._min_val).sum()
220
+ if below_min > 0:
221
+ details["below_min_count"] = int(below_min)
222
+
223
+ if self._max_val is not None:
224
+ above_max = (col > self._max_val).sum()
225
+ if above_max > 0:
226
+ details["above_max_count"] = int(above_max)
227
+
228
+ # Sample failing row positions (first 5)
229
+ if failed_count > 0 and failed_count <= 1000:
230
+ positions: List[int] = []
231
+ for i, val in enumerate(mask):
232
+ if val:
233
+ positions.append(i)
234
+ if len(positions) >= 5:
235
+ break
236
+ if positions:
237
+ details["sample_positions"] = positions
238
+
239
+ return details
240
+
241
+ def compile_predicate(self) -> Optional[Predicate]:
242
+ # Build condition and range violation expressions
243
+ condition_expr = self._build_condition_expr()
244
+ range_violation_expr = self._build_range_violation_expr()
245
+
246
+ # Mask: condition is TRUE AND (column is NULL OR outside range)
247
+ expr = condition_expr & range_violation_expr
248
+
249
+ message = self._build_message()
250
+
251
+ return Predicate(
252
+ rule_id=self.rule_id,
253
+ expr=expr,
254
+ message=message,
255
+ columns={self._column, self._when_column},
256
+ )
257
+
258
+ def to_sql_spec(self) -> Optional[Dict[str, Any]]:
259
+ """Return SQL spec for SQL pushdown executors."""
260
+ return {
261
+ "kind": "conditional_range",
262
+ "rule_id": self.rule_id,
263
+ "column": self._column,
264
+ "when_column": self._when_column,
265
+ "when_op": self._when_op,
266
+ "when_value": self._when_value,
267
+ "min": self._min_val,
268
+ "max": self._max_val,
269
+ }
270
+
271
+ def to_sql_filter(self, dialect: str = "postgres") -> str | None:
272
+ col = f'"{self._column}"'
273
+ when_col = f'"{self._when_column}"'
274
+
275
+ # Map operators
276
+ sql_op = self._when_op
277
+ if sql_op == "==":
278
+ sql_op = "="
279
+ elif sql_op == "!=":
280
+ sql_op = "<>"
281
+
282
+ # Format the when value
283
+ if self._when_value is None:
284
+ # Special handling for NULL comparison
285
+ if sql_op == "=":
286
+ condition = f"{when_col} IS NULL"
287
+ elif sql_op == "<>":
288
+ condition = f"{when_col} IS NOT NULL"
289
+ else:
290
+ return None # Can't compare with NULL using < > etc.
291
+ elif isinstance(self._when_value, str):
292
+ escaped = self._when_value.replace("'", "''")
293
+ condition = f"{when_col} {sql_op} '{escaped}'"
294
+ elif isinstance(self._when_value, bool):
295
+ val = "TRUE" if self._when_value else "FALSE"
296
+ condition = f"{when_col} {sql_op} {val}"
297
+ else:
298
+ condition = f"{when_col} {sql_op} {self._when_value}"
299
+
300
+ # Build range violation part
301
+ range_parts = [f"{col} IS NULL"]
302
+ if self._min_val is not None:
303
+ range_parts.append(f"{col} < {self._min_val}")
304
+ if self._max_val is not None:
305
+ range_parts.append(f"{col} > {self._max_val}")
306
+
307
+ range_violation = " OR ".join(range_parts)
308
+
309
+ # Failure = condition is TRUE AND (column is NULL OR outside range)
310
+ return f"({condition}) AND ({range_violation})"
@@ -0,0 +1,138 @@
1
+ # src/kontra/rules/builtin/contains.py
2
+ """
3
+ Contains rule - Column must contain the specified substring.
4
+
5
+ Uses literal substring matching (not regex) for maximum efficiency.
6
+ For regex patterns, use the `regex` rule instead.
7
+
8
+ Usage:
9
+ - name: contains
10
+ params:
11
+ column: email
12
+ substring: "@"
13
+
14
+ Fails when:
15
+ - Value does NOT contain the substring
16
+ - Value is NULL (can't search in NULL)
17
+ """
18
+ from __future__ import annotations
19
+
20
+ from typing import Any, Dict, List, Optional, Set
21
+
22
+ import polars as pl
23
+
24
+ from kontra.rules.base import BaseRule
25
+ from kontra.rules.registry import register_rule
26
+ from kontra.rules.predicates import Predicate
27
+ from kontra.state.types import FailureMode
28
+
29
+
30
+ def _escape_like_pattern(value: str, escape_char: str = "\\") -> str:
31
+ """Escape LIKE special characters: %, _, and the escape char."""
32
+ for c in (escape_char, "%", "_"):
33
+ value = value.replace(c, escape_char + c)
34
+ return value
35
+
36
+
37
+ @register_rule("contains")
38
+ class ContainsRule(BaseRule):
39
+ """
40
+ Fails where column value does NOT contain the substring.
41
+
42
+ params:
43
+ - column: str (required) - Column to check
44
+ - substring: str (required) - Substring that must be present
45
+
46
+ This rule uses literal matching, not regex. For regex patterns,
47
+ use the `regex` rule instead.
48
+
49
+ NULL handling:
50
+ - NULL values are failures (can't search in NULL)
51
+ """
52
+
53
+ def __init__(self, name: str, params: Dict[str, Any]):
54
+ super().__init__(name, params)
55
+ self._column = self._get_required_param("column", str)
56
+ self._substring = self._get_required_param("substring", str)
57
+
58
+ if not self._substring:
59
+ raise ValueError("Rule 'contains' substring cannot be empty")
60
+
61
+ def required_columns(self) -> Set[str]:
62
+ return {self._column}
63
+
64
+ def validate(self, df: pl.DataFrame) -> Dict[str, Any]:
65
+ # Check column exists before accessing
66
+ col_check = self._check_columns(df, {self._column})
67
+ if col_check is not None:
68
+ return col_check
69
+
70
+ # Use literal=True for efficiency (not regex)
71
+ contains_result = df[self._column].cast(pl.Utf8).str.contains(
72
+ self._substring, literal=True
73
+ )
74
+
75
+ # Failure = does NOT contain OR is NULL
76
+ mask = (~contains_result).fill_null(True)
77
+
78
+ msg = f"{self._column} does not contain '{self._substring}'"
79
+ res = super()._failures(df, mask, msg)
80
+ res["rule_id"] = self.rule_id
81
+
82
+ if res["failed_count"] > 0:
83
+ res["failure_mode"] = str(FailureMode.PATTERN_MISMATCH)
84
+ res["details"] = self._explain_failure(df, mask)
85
+
86
+ return res
87
+
88
+ def _explain_failure(self, df: pl.DataFrame, mask: pl.Series) -> Dict[str, Any]:
89
+ """Generate detailed failure explanation."""
90
+ details: Dict[str, Any] = {
91
+ "column": self._column,
92
+ "expected_substring": self._substring,
93
+ }
94
+
95
+ # Sample failing values
96
+ failed_df = df.filter(mask).head(5)
97
+ samples: List[Any] = []
98
+ for val in failed_df[self._column]:
99
+ samples.append(val)
100
+
101
+ if samples:
102
+ details["sample_failures"] = samples
103
+
104
+ return details
105
+
106
+ def compile_predicate(self) -> Optional[Predicate]:
107
+ # Use literal=True for efficiency
108
+ contains_expr = pl.col(self._column).cast(pl.Utf8).str.contains(
109
+ self._substring, literal=True
110
+ )
111
+ expr = (~contains_expr).fill_null(True)
112
+
113
+ return Predicate(
114
+ rule_id=self.rule_id,
115
+ expr=expr,
116
+ message=f"{self._column} does not contain '{self._substring}'",
117
+ columns={self._column},
118
+ )
119
+
120
+ def to_sql_spec(self) -> Optional[Dict[str, Any]]:
121
+ """Generate SQL pushdown specification."""
122
+ return {
123
+ "kind": "contains",
124
+ "rule_id": self.rule_id,
125
+ "column": self._column,
126
+ "substring": self._substring,
127
+ }
128
+
129
+ def to_sql_filter(self, dialect: str = "postgres") -> str | None:
130
+ """Generate SQL filter for sampling failing rows."""
131
+ col = f'"{self._column}"'
132
+
133
+ # Escape LIKE special characters
134
+ escaped = _escape_like_pattern(self._substring)
135
+ pattern = f"%{escaped}%"
136
+
137
+ # Failure = does NOT contain OR is NULL
138
+ return f"{col} IS NULL OR {col} NOT LIKE '{pattern}' ESCAPE '\\'"
@@ -0,0 +1,182 @@
1
+ from __future__ import annotations
2
+ from typing import Dict, Any, Optional, Tuple
3
+ import polars as pl
4
+ import duckdb
5
+
6
+ from kontra.rules.base import BaseRule
7
+ from kontra.rules.registry import register_rule
8
+ from kontra.state.types import FailureMode
9
+
10
+
11
+ @register_rule("custom_sql_check")
12
+ class CustomSQLCheck(BaseRule):
13
+ """
14
+ Custom SQL check rule for flexible validation logic.
15
+
16
+ Executes user-provided SQL and counts violations.
17
+ Supports remote execution on PostgreSQL/SQL Server when safe.
18
+
19
+ Parameters:
20
+ sql: SQL query that returns rows representing violations.
21
+ Use {table} as placeholder for the table reference.
22
+
23
+ Example:
24
+ - name: custom_sql_check
25
+ params:
26
+ sql: "SELECT * FROM {table} WHERE balance < 0 AND status = 'active'"
27
+
28
+ Remote Execution:
29
+ When the data source is PostgreSQL or SQL Server, the SQL is validated
30
+ using sqlglot to ensure it's safe (SELECT-only, no dangerous functions).
31
+ If safe, it executes directly on the database. Otherwise, falls back to
32
+ loading data into DuckDB.
33
+ """
34
+
35
+ def __init__(self, *args, **kwargs):
36
+ super().__init__(*args, **kwargs)
37
+ self._validation_result: Optional[Any] = None # Cache validation result
38
+
39
+ def validate(self, df: pl.DataFrame) -> Dict[str, Any]:
40
+ """Execute SQL check via DuckDB (fallback path)."""
41
+ from kontra.engine.sql_validator import to_count_query, validate_sql
42
+
43
+ # Accept both 'sql' (documented) and 'query' (legacy) parameter names
44
+ query = self.params.get("sql") or self.params.get("query")
45
+ if not query:
46
+ return {
47
+ "rule_id": self.rule_id,
48
+ "passed": False,
49
+ "failed_count": int(df.height),
50
+ "message": "Missing 'sql' parameter",
51
+ }
52
+
53
+ # Substitute {table} placeholder with the registered table name
54
+ query = query.replace("{table}", "data")
55
+
56
+ try:
57
+ # Validate SQL is safe before execution (blocks read_csv, read_parquet, etc.)
58
+ validation = validate_sql(query, dialect="duckdb")
59
+ if not validation.is_safe:
60
+ raise ValueError(f"SQL validation failed: {validation.reason}")
61
+
62
+ # Transform to COUNT(*) query for efficiency
63
+ success, count_query = to_count_query(query, dialect="duckdb")
64
+ if not success:
65
+ raise ValueError(f"Failed to transform SQL: {count_query}")
66
+
67
+ # Use DuckDB's native Polars support (zero-copy)
68
+ con = duckdb.connect()
69
+ con.register("data", df)
70
+ result = con.execute(count_query).fetchone()
71
+
72
+ if result is None or len(result) < 1:
73
+ raise ValueError("Query returned no result")
74
+
75
+ failed_count = int(result[0]) if result[0] is not None else 0
76
+
77
+ res: Dict[str, Any] = {
78
+ "rule_id": self.rule_id,
79
+ "passed": failed_count == 0,
80
+ "failed_count": failed_count,
81
+ "message": "Passed" if failed_count == 0 else f"Custom SQL check failed for {failed_count} rows",
82
+ }
83
+
84
+ if failed_count > 0:
85
+ res["failure_mode"] = str(FailureMode.CUSTOM_CHECK_FAILED)
86
+ res["details"] = {
87
+ "query": query,
88
+ "failed_row_count": failed_count,
89
+ }
90
+
91
+ return res
92
+ except Exception as e:
93
+ return {
94
+ "rule_id": self.rule_id,
95
+ "passed": False,
96
+ "failed_count": int(df.height),
97
+ "message": f"Rule execution failed: {e}",
98
+ }
99
+
100
+ def compile_predicate(self):
101
+ return None # fallback-only
102
+
103
+ # -------------------------------------------------------------------------
104
+ # Remote Execution Support
105
+ # -------------------------------------------------------------------------
106
+
107
+ def supports_remote_execution(self, dialect: str) -> Tuple[bool, str]:
108
+ """
109
+ Check if this rule can be executed directly on a remote database.
110
+
111
+ Uses sqlglot to validate the SQL is safe (SELECT-only, no side effects).
112
+
113
+ Args:
114
+ dialect: Database dialect ("postgres", "sqlserver")
115
+
116
+ Returns:
117
+ Tuple of (is_supported, reason)
118
+ """
119
+ from kontra.engine.sql_validator import validate_sql
120
+
121
+ query = self.params.get("sql") or self.params.get("query")
122
+ if not query:
123
+ return False, "Missing SQL parameter"
124
+
125
+ # Remove {table} placeholder for validation (it will be replaced later)
126
+ # Use a dummy table name for parsing
127
+ test_sql = query.replace("{table}", "dummy_table")
128
+
129
+ result = validate_sql(test_sql, dialect=dialect)
130
+
131
+ if result.is_safe:
132
+ self._validation_result = result
133
+ return True, "SQL validated as safe for remote execution"
134
+ else:
135
+ return False, result.reason or "SQL validation failed"
136
+
137
+ def get_remote_sql(
138
+ self,
139
+ schema: str,
140
+ table: str,
141
+ dialect: str,
142
+ ) -> str:
143
+ """
144
+ Get the SQL query formatted for remote execution.
145
+
146
+ Args:
147
+ schema: Database schema (e.g., "public", "dbo")
148
+ table: Table name
149
+ dialect: Database dialect ("postgres", "sqlserver")
150
+
151
+ Returns:
152
+ SQL query with {table} replaced with proper table reference
153
+ """
154
+ from kontra.engine.sql_validator import replace_table_placeholder
155
+
156
+ query = self.params.get("sql") or self.params.get("query")
157
+ if not query:
158
+ raise ValueError("Missing SQL parameter")
159
+
160
+ return replace_table_placeholder(
161
+ sql=query,
162
+ schema=schema,
163
+ table=table,
164
+ dialect=dialect,
165
+ )
166
+
167
+ def to_sql_spec(self) -> Optional[Dict[str, Any]]:
168
+ """
169
+ Return SQL specification for the executor.
170
+
171
+ This is used by SQL executors to determine if/how to execute remotely.
172
+ """
173
+ query = self.params.get("sql") or self.params.get("query")
174
+ if not query:
175
+ return None
176
+
177
+ return {
178
+ "kind": "custom_sql_check",
179
+ "rule_id": self.rule_id,
180
+ "sql": query,
181
+ "params": self.params,
182
+ }