kontra 0.5.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (124) hide show
  1. kontra/__init__.py +1871 -0
  2. kontra/api/__init__.py +22 -0
  3. kontra/api/compare.py +340 -0
  4. kontra/api/decorators.py +153 -0
  5. kontra/api/results.py +2121 -0
  6. kontra/api/rules.py +681 -0
  7. kontra/cli/__init__.py +0 -0
  8. kontra/cli/commands/__init__.py +1 -0
  9. kontra/cli/commands/config.py +153 -0
  10. kontra/cli/commands/diff.py +450 -0
  11. kontra/cli/commands/history.py +196 -0
  12. kontra/cli/commands/profile.py +289 -0
  13. kontra/cli/commands/validate.py +468 -0
  14. kontra/cli/constants.py +6 -0
  15. kontra/cli/main.py +48 -0
  16. kontra/cli/renderers.py +304 -0
  17. kontra/cli/utils.py +28 -0
  18. kontra/config/__init__.py +34 -0
  19. kontra/config/loader.py +127 -0
  20. kontra/config/models.py +49 -0
  21. kontra/config/settings.py +797 -0
  22. kontra/connectors/__init__.py +0 -0
  23. kontra/connectors/db_utils.py +251 -0
  24. kontra/connectors/detection.py +323 -0
  25. kontra/connectors/handle.py +368 -0
  26. kontra/connectors/postgres.py +127 -0
  27. kontra/connectors/sqlserver.py +226 -0
  28. kontra/engine/__init__.py +0 -0
  29. kontra/engine/backends/duckdb_session.py +227 -0
  30. kontra/engine/backends/duckdb_utils.py +18 -0
  31. kontra/engine/backends/polars_backend.py +47 -0
  32. kontra/engine/engine.py +1205 -0
  33. kontra/engine/executors/__init__.py +15 -0
  34. kontra/engine/executors/base.py +50 -0
  35. kontra/engine/executors/database_base.py +528 -0
  36. kontra/engine/executors/duckdb_sql.py +607 -0
  37. kontra/engine/executors/postgres_sql.py +162 -0
  38. kontra/engine/executors/registry.py +69 -0
  39. kontra/engine/executors/sqlserver_sql.py +163 -0
  40. kontra/engine/materializers/__init__.py +14 -0
  41. kontra/engine/materializers/base.py +42 -0
  42. kontra/engine/materializers/duckdb.py +110 -0
  43. kontra/engine/materializers/factory.py +22 -0
  44. kontra/engine/materializers/polars_connector.py +131 -0
  45. kontra/engine/materializers/postgres.py +157 -0
  46. kontra/engine/materializers/registry.py +138 -0
  47. kontra/engine/materializers/sqlserver.py +160 -0
  48. kontra/engine/result.py +15 -0
  49. kontra/engine/sql_utils.py +611 -0
  50. kontra/engine/sql_validator.py +609 -0
  51. kontra/engine/stats.py +194 -0
  52. kontra/engine/types.py +138 -0
  53. kontra/errors.py +533 -0
  54. kontra/logging.py +85 -0
  55. kontra/preplan/__init__.py +5 -0
  56. kontra/preplan/planner.py +253 -0
  57. kontra/preplan/postgres.py +179 -0
  58. kontra/preplan/sqlserver.py +191 -0
  59. kontra/preplan/types.py +24 -0
  60. kontra/probes/__init__.py +20 -0
  61. kontra/probes/compare.py +400 -0
  62. kontra/probes/relationship.py +283 -0
  63. kontra/reporters/__init__.py +0 -0
  64. kontra/reporters/json_reporter.py +190 -0
  65. kontra/reporters/rich_reporter.py +11 -0
  66. kontra/rules/__init__.py +35 -0
  67. kontra/rules/base.py +186 -0
  68. kontra/rules/builtin/__init__.py +40 -0
  69. kontra/rules/builtin/allowed_values.py +156 -0
  70. kontra/rules/builtin/compare.py +188 -0
  71. kontra/rules/builtin/conditional_not_null.py +213 -0
  72. kontra/rules/builtin/conditional_range.py +310 -0
  73. kontra/rules/builtin/contains.py +138 -0
  74. kontra/rules/builtin/custom_sql_check.py +182 -0
  75. kontra/rules/builtin/disallowed_values.py +140 -0
  76. kontra/rules/builtin/dtype.py +203 -0
  77. kontra/rules/builtin/ends_with.py +129 -0
  78. kontra/rules/builtin/freshness.py +240 -0
  79. kontra/rules/builtin/length.py +193 -0
  80. kontra/rules/builtin/max_rows.py +35 -0
  81. kontra/rules/builtin/min_rows.py +46 -0
  82. kontra/rules/builtin/not_null.py +121 -0
  83. kontra/rules/builtin/range.py +222 -0
  84. kontra/rules/builtin/regex.py +143 -0
  85. kontra/rules/builtin/starts_with.py +129 -0
  86. kontra/rules/builtin/unique.py +124 -0
  87. kontra/rules/condition_parser.py +203 -0
  88. kontra/rules/execution_plan.py +455 -0
  89. kontra/rules/factory.py +103 -0
  90. kontra/rules/predicates.py +25 -0
  91. kontra/rules/registry.py +24 -0
  92. kontra/rules/static_predicates.py +120 -0
  93. kontra/scout/__init__.py +9 -0
  94. kontra/scout/backends/__init__.py +17 -0
  95. kontra/scout/backends/base.py +111 -0
  96. kontra/scout/backends/duckdb_backend.py +359 -0
  97. kontra/scout/backends/postgres_backend.py +519 -0
  98. kontra/scout/backends/sqlserver_backend.py +577 -0
  99. kontra/scout/dtype_mapping.py +150 -0
  100. kontra/scout/patterns.py +69 -0
  101. kontra/scout/profiler.py +801 -0
  102. kontra/scout/reporters/__init__.py +39 -0
  103. kontra/scout/reporters/json_reporter.py +165 -0
  104. kontra/scout/reporters/markdown_reporter.py +152 -0
  105. kontra/scout/reporters/rich_reporter.py +144 -0
  106. kontra/scout/store.py +208 -0
  107. kontra/scout/suggest.py +200 -0
  108. kontra/scout/types.py +652 -0
  109. kontra/state/__init__.py +29 -0
  110. kontra/state/backends/__init__.py +79 -0
  111. kontra/state/backends/base.py +348 -0
  112. kontra/state/backends/local.py +480 -0
  113. kontra/state/backends/postgres.py +1010 -0
  114. kontra/state/backends/s3.py +543 -0
  115. kontra/state/backends/sqlserver.py +969 -0
  116. kontra/state/fingerprint.py +166 -0
  117. kontra/state/types.py +1061 -0
  118. kontra/version.py +1 -0
  119. kontra-0.5.2.dist-info/METADATA +122 -0
  120. kontra-0.5.2.dist-info/RECORD +124 -0
  121. kontra-0.5.2.dist-info/WHEEL +5 -0
  122. kontra-0.5.2.dist-info/entry_points.txt +2 -0
  123. kontra-0.5.2.dist-info/licenses/LICENSE +17 -0
  124. kontra-0.5.2.dist-info/top_level.txt +1 -0
@@ -0,0 +1,222 @@
1
+ from __future__ import annotations
2
+ from typing import Dict, Any, Optional, Union
3
+ import polars as pl
4
+
5
+ from kontra.rules.base import BaseRule
6
+ from kontra.rules.registry import register_rule
7
+ from kontra.rules.predicates import Predicate
8
+ from kontra.state.types import FailureMode
9
+
10
+
11
+ @register_rule("range")
12
+ class RangeRule(BaseRule):
13
+ """
14
+ Fails where `column` is outside the specified range [min, max].
15
+ At least one of `min` or `max` must be provided.
16
+
17
+ params:
18
+ - column: str (required)
19
+ - min: numeric (optional) - minimum allowed value (inclusive)
20
+ - max: numeric (optional) - maximum allowed value (inclusive)
21
+
22
+ NULLs are treated as failures (out of range).
23
+
24
+ Examples:
25
+ - name: range
26
+ params:
27
+ column: age
28
+ min: 0
29
+ max: 120
30
+
31
+ - name: range
32
+ params:
33
+ column: price
34
+ min: 0 # Only minimum, no upper bound
35
+ """
36
+
37
+ def __init__(self, *args, **kwargs):
38
+ super().__init__(*args, **kwargs)
39
+ from kontra.errors import RuleParameterError
40
+
41
+ # Validate required column param
42
+ self._get_required_param("column", str)
43
+
44
+ min_val = self.params.get("min")
45
+ max_val = self.params.get("max")
46
+
47
+ # Validate at least one bound is provided
48
+ if min_val is None and max_val is None:
49
+ raise RuleParameterError(
50
+ "range", "min/max",
51
+ "at least one of 'min' or 'max' must be provided"
52
+ )
53
+
54
+ # Validate min <= max at construction time
55
+ if min_val is not None and max_val is not None:
56
+ if min_val > max_val:
57
+ raise RuleParameterError(
58
+ "range", "min/max",
59
+ f"min ({min_val}) must be <= max ({max_val})"
60
+ )
61
+
62
+ def validate(self, df: pl.DataFrame) -> Dict[str, Any]:
63
+ column = self.params["column"]
64
+ min_val = self.params.get("min")
65
+ max_val = self.params.get("max")
66
+
67
+ # Check column exists before accessing
68
+ col_check = self._check_columns(df, {column})
69
+ if col_check is not None:
70
+ return col_check
71
+
72
+ # Note: min/max validation is done in __init__, so we know at least one is set
73
+ try:
74
+ col = df[column]
75
+
76
+ # Build condition for out-of-range values
77
+ if min_val is not None and max_val is not None:
78
+ mask = (col < min_val) | (col > max_val)
79
+ elif min_val is not None:
80
+ mask = col < min_val
81
+ else:
82
+ mask = col > max_val
83
+
84
+ # NULLs are also failures
85
+ mask = mask.fill_null(True)
86
+
87
+ res = super()._failures(df, mask, self._build_message(column, min_val, max_val))
88
+ res["rule_id"] = self.rule_id
89
+
90
+ # Add failure details
91
+ if res["failed_count"] > 0:
92
+ res["failure_mode"] = str(FailureMode.RANGE_VIOLATION)
93
+ res["details"] = self._explain_failure(df, column, min_val, max_val)
94
+
95
+ return res
96
+ except Exception as e:
97
+ return {
98
+ "rule_id": self.rule_id,
99
+ "passed": False,
100
+ "failed_count": int(df.height),
101
+ "message": f"Rule execution failed: {e}",
102
+ }
103
+
104
+ def _explain_failure(
105
+ self,
106
+ df: pl.DataFrame,
107
+ column: str,
108
+ min_val: Optional[Union[int, float]],
109
+ max_val: Optional[Union[int, float]],
110
+ ) -> Dict[str, Any]:
111
+ """Generate detailed failure explanation."""
112
+ col = df[column]
113
+ details: Dict[str, Any] = {}
114
+
115
+ # Get actual min/max
116
+ actual_min = col.min()
117
+ actual_max = col.max()
118
+ if actual_min is not None:
119
+ details["actual_min"] = actual_min
120
+ if actual_max is not None:
121
+ details["actual_max"] = actual_max
122
+
123
+ # Expected bounds
124
+ if min_val is not None:
125
+ details["expected_min"] = min_val
126
+ if max_val is not None:
127
+ details["expected_max"] = max_val
128
+
129
+ # Count below min
130
+ if min_val is not None:
131
+ below_min = (col < min_val).sum()
132
+ if below_min > 0:
133
+ details["below_min_count"] = int(below_min)
134
+
135
+ # Count above max
136
+ if max_val is not None:
137
+ above_max = (col > max_val).sum()
138
+ if above_max > 0:
139
+ details["above_max_count"] = int(above_max)
140
+
141
+ # Count nulls
142
+ null_count = col.null_count()
143
+ if null_count > 0:
144
+ details["null_count"] = int(null_count)
145
+
146
+ return details
147
+
148
+ def compile_predicate(self) -> Optional[Predicate]:
149
+ column = self.params["column"]
150
+ min_val = self.params.get("min")
151
+ max_val = self.params.get("max")
152
+
153
+ if min_val is None and max_val is None:
154
+ return None
155
+
156
+ col = pl.col(column)
157
+
158
+ # Build expression for out-of-range values
159
+ if min_val is not None and max_val is not None:
160
+ expr = (col < min_val) | (col > max_val)
161
+ elif min_val is not None:
162
+ expr = col < min_val
163
+ else:
164
+ expr = col > max_val
165
+
166
+ # NULLs are also failures
167
+ expr = expr.fill_null(True)
168
+
169
+ return Predicate(
170
+ rule_id=self.rule_id,
171
+ expr=expr,
172
+ message=self._build_message(column, min_val, max_val),
173
+ columns={column},
174
+ )
175
+
176
+ def to_sql_spec(self) -> Optional[Dict[str, Any]]:
177
+ """Generate SQL pushdown specification."""
178
+ column = self.params.get("column")
179
+ min_val = self.params.get("min")
180
+ max_val = self.params.get("max")
181
+
182
+ if not column or (min_val is None and max_val is None):
183
+ return None
184
+
185
+ return {
186
+ "kind": "range",
187
+ "rule_id": self.rule_id,
188
+ "column": column,
189
+ "min": min_val,
190
+ "max": max_val,
191
+ }
192
+
193
+ def _build_message(
194
+ self, column: str, min_val: Optional[Union[int, float]], max_val: Optional[Union[int, float]]
195
+ ) -> str:
196
+ if min_val is not None and max_val is not None:
197
+ return f"{column} values outside range [{min_val}, {max_val}]"
198
+ elif min_val is not None:
199
+ return f"{column} values below minimum {min_val}"
200
+ else:
201
+ return f"{column} values above maximum {max_val}"
202
+
203
+ def to_sql_filter(self, dialect: str = "postgres") -> str | None:
204
+ column = self.params.get("column")
205
+ min_val = self.params.get("min")
206
+ max_val = self.params.get("max")
207
+
208
+ if not column or (min_val is None and max_val is None):
209
+ return None
210
+
211
+ col = f'"{column}"'
212
+ conditions = []
213
+
214
+ if min_val is not None:
215
+ conditions.append(f"{col} < {min_val}")
216
+ if max_val is not None:
217
+ conditions.append(f"{col} > {max_val}")
218
+
219
+ # NULL is also a failure
220
+ conditions.append(f"{col} IS NULL")
221
+
222
+ return " OR ".join(conditions)
@@ -0,0 +1,143 @@
1
+ from __future__ import annotations
2
+ from typing import Dict, Any, List, Optional
3
+ import re
4
+ import polars as pl
5
+
6
+ from kontra.rules.base import BaseRule
7
+ from kontra.rules.registry import register_rule
8
+ from kontra.rules.predicates import Predicate
9
+ from kontra.state.types import FailureMode
10
+ from kontra.errors import RuleParameterError
11
+
12
+
13
+ @register_rule("regex")
14
+ class RegexRule(BaseRule):
15
+ """
16
+ Fails where `column` does not match the regex `pattern`. NULLs are failures.
17
+
18
+ params:
19
+ - column: str (required)
20
+ - pattern: str (required)
21
+
22
+ Notes:
23
+ - Uses vectorized `str.contains` (regex by default in this Polars version).
24
+ - No `regex=`/`strict=` kwargs are passed to maintain compatibility.
25
+ """
26
+
27
+ def __init__(self, name: str, params: Dict[str, Any]):
28
+ super().__init__(name, params)
29
+ # Validate regex pattern early to provide helpful error message
30
+ pattern = params.get("pattern", "")
31
+ try:
32
+ re.compile(pattern)
33
+ except re.error as e:
34
+ pos_info = f" at position {e.pos}" if e.pos is not None else ""
35
+ raise RuleParameterError(
36
+ "regex",
37
+ "pattern",
38
+ f"Invalid regex pattern{pos_info}: {e.msg}\n Pattern: {pattern}"
39
+ )
40
+
41
+ def validate(self, df: pl.DataFrame) -> Dict[str, Any]:
42
+ column = self.params["column"]
43
+ pattern = self.params["pattern"]
44
+
45
+ # Check column exists before accessing
46
+ col_check = self._check_columns(df, {column})
47
+ if col_check is not None:
48
+ return col_check
49
+
50
+ try:
51
+ mask = (
52
+ ~df[column]
53
+ .cast(pl.Utf8)
54
+ .str.contains(pattern) # regex by default
55
+ ).fill_null(True)
56
+ res = super()._failures(df, mask, f"{column} failed regex pattern {pattern}")
57
+ res["rule_id"] = self.rule_id
58
+
59
+ # Add failure details
60
+ if res["failed_count"] > 0:
61
+ res["failure_mode"] = str(FailureMode.PATTERN_MISMATCH)
62
+ res["details"] = self._explain_failure(df, column, pattern, mask)
63
+
64
+ return res
65
+ except Exception as e:
66
+ return {
67
+ "rule_id": self.rule_id,
68
+ "passed": False,
69
+ "failed_count": int(df.height),
70
+ "message": f"Rule execution failed: {e}",
71
+ }
72
+
73
+ def _explain_failure(
74
+ self, df: pl.DataFrame, column: str, pattern: str, mask: pl.Series
75
+ ) -> Dict[str, Any]:
76
+ """Generate detailed failure explanation."""
77
+ details: Dict[str, Any] = {
78
+ "pattern": pattern,
79
+ }
80
+
81
+ # Sample non-matching values (first 5)
82
+ failed_df = df.filter(mask)
83
+ if failed_df.height > 0:
84
+ sample_values: List[Any] = []
85
+ for val in failed_df[column].head(5):
86
+ sample_values.append(val)
87
+ if sample_values:
88
+ details["sample_mismatches"] = sample_values
89
+
90
+ return details
91
+
92
+ def compile_predicate(self) -> Optional[Predicate]:
93
+ column = self.params["column"]
94
+ pattern = self.params["pattern"]
95
+ expr = (
96
+ ~pl.col(column)
97
+ .cast(pl.Utf8)
98
+ .str.contains(pattern) # regex by default
99
+ ).fill_null(True)
100
+ return Predicate(
101
+ rule_id=self.rule_id,
102
+ expr=expr,
103
+ message=f"{column} failed regex pattern {pattern}",
104
+ columns={column},
105
+ )
106
+
107
+ def to_sql_spec(self) -> Optional[Dict[str, Any]]:
108
+ """Generate SQL pushdown specification for regex rule."""
109
+ column = self.params.get("column")
110
+ pattern = self.params.get("pattern")
111
+
112
+ if not column or not pattern:
113
+ return None
114
+
115
+ return {
116
+ "kind": "regex",
117
+ "rule_id": self.rule_id,
118
+ "column": column,
119
+ "pattern": pattern,
120
+ }
121
+
122
+ def to_sql_filter(self, dialect: str = "postgres") -> str | None:
123
+ column = self.params.get("column")
124
+ pattern = self.params.get("pattern")
125
+
126
+ if not column or not pattern:
127
+ return None
128
+
129
+ col = f'"{column}"'
130
+ # Escape single quotes in pattern
131
+ escaped_pattern = pattern.replace("'", "''")
132
+
133
+ if dialect in ("postgres", "postgresql"):
134
+ # PostgreSQL uses ~ for regex match, !~ for non-match
135
+ return f"{col} !~ '{escaped_pattern}' OR {col} IS NULL"
136
+ elif dialect == "duckdb":
137
+ # DuckDB uses regexp_matches
138
+ return f"NOT regexp_matches({col}, '{escaped_pattern}') OR {col} IS NULL"
139
+ elif dialect == "mssql":
140
+ # SQL Server doesn't have native regex - skip SQL filter
141
+ return None
142
+ else:
143
+ return None
@@ -0,0 +1,129 @@
1
+ # src/kontra/rules/builtin/starts_with.py
2
+ """
3
+ Starts with rule - Column must start with the specified prefix.
4
+
5
+ Uses LIKE pattern matching for maximum efficiency (faster than regex).
6
+
7
+ Usage:
8
+ - name: starts_with
9
+ params:
10
+ column: url
11
+ prefix: "https://"
12
+
13
+ Fails when:
14
+ - Value does NOT start with the prefix
15
+ - Value is NULL (can't check NULL)
16
+ """
17
+ from __future__ import annotations
18
+
19
+ from typing import Any, Dict, List, Optional, Set
20
+
21
+ import polars as pl
22
+
23
+ from kontra.rules.base import BaseRule
24
+ from kontra.rules.registry import register_rule
25
+ from kontra.rules.predicates import Predicate
26
+ from kontra.state.types import FailureMode
27
+
28
+
29
+ def _escape_like_pattern(value: str, escape_char: str = "\\") -> str:
30
+ """Escape LIKE special characters: %, _, and the escape char."""
31
+ for c in (escape_char, "%", "_"):
32
+ value = value.replace(c, escape_char + c)
33
+ return value
34
+
35
+
36
+ @register_rule("starts_with")
37
+ class StartsWithRule(BaseRule):
38
+ """
39
+ Fails where column value does NOT start with the prefix.
40
+
41
+ params:
42
+ - column: str (required) - Column to check
43
+ - prefix: str (required) - Prefix that must be present
44
+
45
+ NULL handling:
46
+ - NULL values are failures (can't check NULL)
47
+ """
48
+
49
+ def __init__(self, name: str, params: Dict[str, Any]):
50
+ super().__init__(name, params)
51
+ self._column = self._get_required_param("column", str)
52
+ self._prefix = self._get_required_param("prefix", str)
53
+
54
+ if not self._prefix:
55
+ raise ValueError("Rule 'starts_with' prefix cannot be empty")
56
+
57
+ def required_columns(self) -> Set[str]:
58
+ return {self._column}
59
+
60
+ def validate(self, df: pl.DataFrame) -> Dict[str, Any]:
61
+ # Check column exists before accessing
62
+ col_check = self._check_columns(df, {self._column})
63
+ if col_check is not None:
64
+ return col_check
65
+
66
+ # Use Polars str.starts_with for efficiency
67
+ starts_result = df[self._column].cast(pl.Utf8).str.starts_with(self._prefix)
68
+
69
+ # Failure = does NOT start with OR is NULL
70
+ mask = (~starts_result).fill_null(True)
71
+
72
+ msg = f"{self._column} does not start with '{self._prefix}'"
73
+ res = super()._failures(df, mask, msg)
74
+ res["rule_id"] = self.rule_id
75
+
76
+ if res["failed_count"] > 0:
77
+ res["failure_mode"] = str(FailureMode.PATTERN_MISMATCH)
78
+ res["details"] = self._explain_failure(df, mask)
79
+
80
+ return res
81
+
82
+ def _explain_failure(self, df: pl.DataFrame, mask: pl.Series) -> Dict[str, Any]:
83
+ """Generate detailed failure explanation."""
84
+ details: Dict[str, Any] = {
85
+ "column": self._column,
86
+ "expected_prefix": self._prefix,
87
+ }
88
+
89
+ # Sample failing values
90
+ failed_df = df.filter(mask).head(5)
91
+ samples: List[Any] = []
92
+ for val in failed_df[self._column]:
93
+ samples.append(val)
94
+
95
+ if samples:
96
+ details["sample_failures"] = samples
97
+
98
+ return details
99
+
100
+ def compile_predicate(self) -> Optional[Predicate]:
101
+ starts_expr = pl.col(self._column).cast(pl.Utf8).str.starts_with(self._prefix)
102
+ expr = (~starts_expr).fill_null(True)
103
+
104
+ return Predicate(
105
+ rule_id=self.rule_id,
106
+ expr=expr,
107
+ message=f"{self._column} does not start with '{self._prefix}'",
108
+ columns={self._column},
109
+ )
110
+
111
+ def to_sql_spec(self) -> Optional[Dict[str, Any]]:
112
+ """Generate SQL pushdown specification."""
113
+ return {
114
+ "kind": "starts_with",
115
+ "rule_id": self.rule_id,
116
+ "column": self._column,
117
+ "prefix": self._prefix,
118
+ }
119
+
120
+ def to_sql_filter(self, dialect: str = "postgres") -> str | None:
121
+ """Generate SQL filter for sampling failing rows."""
122
+ col = f'"{self._column}"'
123
+
124
+ # Escape LIKE special characters
125
+ escaped = _escape_like_pattern(self._prefix)
126
+ pattern = f"{escaped}%"
127
+
128
+ # Failure = does NOT start with OR is NULL
129
+ return f"{col} IS NULL OR {col} NOT LIKE '{pattern}' ESCAPE '\\'"
@@ -0,0 +1,124 @@
1
+ from __future__ import annotations
2
+ from typing import Dict, Any, List, Optional
3
+ import polars as pl
4
+
5
+ from kontra.rules.base import BaseRule
6
+ from kontra.rules.registry import register_rule
7
+ from kontra.rules.predicates import Predicate
8
+ from kontra.state.types import FailureMode
9
+
10
+ @register_rule("unique")
11
+ class UniqueRule(BaseRule):
12
+ def __init__(self, name: str, params: Dict[str, Any]):
13
+ super().__init__(name, params)
14
+ self._get_required_param("column", str)
15
+
16
+ def validate(self, df: pl.DataFrame) -> Dict[str, Any]:
17
+ column = self.params["column"]
18
+
19
+ # Check column exists before accessing
20
+ col_check = self._check_columns(df, {column})
21
+ if col_check is not None:
22
+ return col_check
23
+
24
+ col = df[column]
25
+
26
+ # SQL semantics: COUNT(*) - COUNT(DISTINCT col)
27
+ # This counts "extra" rows beyond one unique occurrence
28
+ # NULLs are excluded from DISTINCT count but included in total
29
+ total_count = len(df)
30
+ distinct_count = col.n_unique() # includes NULL as one value if present
31
+
32
+ # Adjust for NULL handling: SQL COUNT(DISTINCT) excludes NULLs
33
+ # but n_unique() counts NULL as a distinct value
34
+ null_count = col.null_count()
35
+ if null_count > 0:
36
+ distinct_count -= 1 # Remove NULL from distinct count
37
+
38
+ failed_count = total_count - distinct_count - null_count
39
+
40
+ # For sampling, still identify duplicated rows (non-null)
41
+ non_null_mask = col.is_not_null()
42
+ duplicates = col.is_duplicated() & non_null_mask
43
+
44
+ # Build result manually to use SQL-semantics count
45
+ res = {
46
+ "rule_id": self.rule_id,
47
+ "name": self.name,
48
+ "passed": failed_count == 0,
49
+ "failed_count": failed_count,
50
+ "message": f"{column} has duplicate values" if failed_count > 0 else f"{column} values are unique",
51
+ "severity": self.params.get("severity", "blocking"),
52
+ }
53
+
54
+ # Add failure details and samples
55
+ if failed_count > 0:
56
+ res["failure_mode"] = str(FailureMode.DUPLICATE_VALUES)
57
+ res["details"] = self._explain_failure(df, column)
58
+ # Store mask for sampling (still shows all duplicate rows)
59
+ res["_failure_mask"] = duplicates
60
+
61
+ return res
62
+
63
+ def _explain_failure(self, df: pl.DataFrame, column: str) -> Dict[str, Any]:
64
+ """Generate detailed failure explanation."""
65
+ # Find duplicated values and their counts
66
+ duplicates_df = (
67
+ df.group_by(column)
68
+ .agg(pl.len().alias("count"))
69
+ .filter(pl.col("count") > 1)
70
+ .sort("count", descending=True)
71
+ .head(10) # Top 10 duplicates
72
+ )
73
+
74
+ top_duplicates: List[Dict[str, Any]] = []
75
+ for row in duplicates_df.iter_rows(named=True):
76
+ val = row[column]
77
+ count = row["count"]
78
+ top_duplicates.append({
79
+ "value": val,
80
+ "count": count,
81
+ })
82
+
83
+ total_duplicates = (
84
+ df.group_by(column)
85
+ .agg(pl.len().alias("count"))
86
+ .filter(pl.col("count") > 1)
87
+ .height
88
+ )
89
+
90
+ return {
91
+ "duplicate_value_count": total_duplicates,
92
+ "top_duplicates": top_duplicates,
93
+ }
94
+
95
+ def compile_predicate(self) -> Optional[Predicate]:
96
+ # Return None to force fallback to validate() for COUNTING
97
+ # validate() uses SQL semantics: total - distinct (counts "extra" rows)
98
+ # Use sample_predicate() for identifying which rows are duplicates
99
+ return None
100
+
101
+ def sample_predicate(self) -> Optional[Predicate]:
102
+ """Return predicate for sampling duplicate rows (not for counting)."""
103
+ column = self.params["column"]
104
+ col = pl.col(column)
105
+ # Identifies all rows participating in duplicates for sampling
106
+ # NULLs are not considered duplicates (NULL != NULL in SQL)
107
+ expr = col.is_duplicated() & col.is_not_null()
108
+ return Predicate(
109
+ rule_id=self.rule_id,
110
+ expr=expr,
111
+ message=f"{column} has duplicate values",
112
+ columns={column},
113
+ )
114
+
115
+ def to_sql_filter(self, dialect: str = "postgres") -> str | None:
116
+ # Unique requires a subquery to find duplicated values
117
+ # This is more complex but still much faster than loading 1M rows
118
+ column = self.params["column"]
119
+ col = f'"{column}"'
120
+
121
+ # Find values that appear more than once, then select rows with those values
122
+ # Note: This requires knowing the table name, which we don't have here
123
+ # Return None to fall back to Polars for this rule
124
+ return None