kontra 0.5.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- kontra/__init__.py +1871 -0
- kontra/api/__init__.py +22 -0
- kontra/api/compare.py +340 -0
- kontra/api/decorators.py +153 -0
- kontra/api/results.py +2121 -0
- kontra/api/rules.py +681 -0
- kontra/cli/__init__.py +0 -0
- kontra/cli/commands/__init__.py +1 -0
- kontra/cli/commands/config.py +153 -0
- kontra/cli/commands/diff.py +450 -0
- kontra/cli/commands/history.py +196 -0
- kontra/cli/commands/profile.py +289 -0
- kontra/cli/commands/validate.py +468 -0
- kontra/cli/constants.py +6 -0
- kontra/cli/main.py +48 -0
- kontra/cli/renderers.py +304 -0
- kontra/cli/utils.py +28 -0
- kontra/config/__init__.py +34 -0
- kontra/config/loader.py +127 -0
- kontra/config/models.py +49 -0
- kontra/config/settings.py +797 -0
- kontra/connectors/__init__.py +0 -0
- kontra/connectors/db_utils.py +251 -0
- kontra/connectors/detection.py +323 -0
- kontra/connectors/handle.py +368 -0
- kontra/connectors/postgres.py +127 -0
- kontra/connectors/sqlserver.py +226 -0
- kontra/engine/__init__.py +0 -0
- kontra/engine/backends/duckdb_session.py +227 -0
- kontra/engine/backends/duckdb_utils.py +18 -0
- kontra/engine/backends/polars_backend.py +47 -0
- kontra/engine/engine.py +1205 -0
- kontra/engine/executors/__init__.py +15 -0
- kontra/engine/executors/base.py +50 -0
- kontra/engine/executors/database_base.py +528 -0
- kontra/engine/executors/duckdb_sql.py +607 -0
- kontra/engine/executors/postgres_sql.py +162 -0
- kontra/engine/executors/registry.py +69 -0
- kontra/engine/executors/sqlserver_sql.py +163 -0
- kontra/engine/materializers/__init__.py +14 -0
- kontra/engine/materializers/base.py +42 -0
- kontra/engine/materializers/duckdb.py +110 -0
- kontra/engine/materializers/factory.py +22 -0
- kontra/engine/materializers/polars_connector.py +131 -0
- kontra/engine/materializers/postgres.py +157 -0
- kontra/engine/materializers/registry.py +138 -0
- kontra/engine/materializers/sqlserver.py +160 -0
- kontra/engine/result.py +15 -0
- kontra/engine/sql_utils.py +611 -0
- kontra/engine/sql_validator.py +609 -0
- kontra/engine/stats.py +194 -0
- kontra/engine/types.py +138 -0
- kontra/errors.py +533 -0
- kontra/logging.py +85 -0
- kontra/preplan/__init__.py +5 -0
- kontra/preplan/planner.py +253 -0
- kontra/preplan/postgres.py +179 -0
- kontra/preplan/sqlserver.py +191 -0
- kontra/preplan/types.py +24 -0
- kontra/probes/__init__.py +20 -0
- kontra/probes/compare.py +400 -0
- kontra/probes/relationship.py +283 -0
- kontra/reporters/__init__.py +0 -0
- kontra/reporters/json_reporter.py +190 -0
- kontra/reporters/rich_reporter.py +11 -0
- kontra/rules/__init__.py +35 -0
- kontra/rules/base.py +186 -0
- kontra/rules/builtin/__init__.py +40 -0
- kontra/rules/builtin/allowed_values.py +156 -0
- kontra/rules/builtin/compare.py +188 -0
- kontra/rules/builtin/conditional_not_null.py +213 -0
- kontra/rules/builtin/conditional_range.py +310 -0
- kontra/rules/builtin/contains.py +138 -0
- kontra/rules/builtin/custom_sql_check.py +182 -0
- kontra/rules/builtin/disallowed_values.py +140 -0
- kontra/rules/builtin/dtype.py +203 -0
- kontra/rules/builtin/ends_with.py +129 -0
- kontra/rules/builtin/freshness.py +240 -0
- kontra/rules/builtin/length.py +193 -0
- kontra/rules/builtin/max_rows.py +35 -0
- kontra/rules/builtin/min_rows.py +46 -0
- kontra/rules/builtin/not_null.py +121 -0
- kontra/rules/builtin/range.py +222 -0
- kontra/rules/builtin/regex.py +143 -0
- kontra/rules/builtin/starts_with.py +129 -0
- kontra/rules/builtin/unique.py +124 -0
- kontra/rules/condition_parser.py +203 -0
- kontra/rules/execution_plan.py +455 -0
- kontra/rules/factory.py +103 -0
- kontra/rules/predicates.py +25 -0
- kontra/rules/registry.py +24 -0
- kontra/rules/static_predicates.py +120 -0
- kontra/scout/__init__.py +9 -0
- kontra/scout/backends/__init__.py +17 -0
- kontra/scout/backends/base.py +111 -0
- kontra/scout/backends/duckdb_backend.py +359 -0
- kontra/scout/backends/postgres_backend.py +519 -0
- kontra/scout/backends/sqlserver_backend.py +577 -0
- kontra/scout/dtype_mapping.py +150 -0
- kontra/scout/patterns.py +69 -0
- kontra/scout/profiler.py +801 -0
- kontra/scout/reporters/__init__.py +39 -0
- kontra/scout/reporters/json_reporter.py +165 -0
- kontra/scout/reporters/markdown_reporter.py +152 -0
- kontra/scout/reporters/rich_reporter.py +144 -0
- kontra/scout/store.py +208 -0
- kontra/scout/suggest.py +200 -0
- kontra/scout/types.py +652 -0
- kontra/state/__init__.py +29 -0
- kontra/state/backends/__init__.py +79 -0
- kontra/state/backends/base.py +348 -0
- kontra/state/backends/local.py +480 -0
- kontra/state/backends/postgres.py +1010 -0
- kontra/state/backends/s3.py +543 -0
- kontra/state/backends/sqlserver.py +969 -0
- kontra/state/fingerprint.py +166 -0
- kontra/state/types.py +1061 -0
- kontra/version.py +1 -0
- kontra-0.5.2.dist-info/METADATA +122 -0
- kontra-0.5.2.dist-info/RECORD +124 -0
- kontra-0.5.2.dist-info/WHEEL +5 -0
- kontra-0.5.2.dist-info/entry_points.txt +2 -0
- kontra-0.5.2.dist-info/licenses/LICENSE +17 -0
- kontra-0.5.2.dist-info/top_level.txt +1 -0
kontra/probes/compare.py
ADDED
|
@@ -0,0 +1,400 @@
|
|
|
1
|
+
# src/kontra/probes/compare.py
|
|
2
|
+
"""
|
|
3
|
+
Compare probe: Measure transformation effects between before/after datasets.
|
|
4
|
+
|
|
5
|
+
This probe answers: "Did my transformation preserve rows and keys as expected?"
|
|
6
|
+
|
|
7
|
+
It does NOT answer: whether the transformation is "correct".
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
from __future__ import annotations
|
|
11
|
+
|
|
12
|
+
from typing import Any, Dict, List, Optional, Union
|
|
13
|
+
|
|
14
|
+
import polars as pl
|
|
15
|
+
|
|
16
|
+
from kontra.api.compare import CompareResult
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def compare(
|
|
20
|
+
before: Union[pl.DataFrame, str],
|
|
21
|
+
after: Union[pl.DataFrame, str],
|
|
22
|
+
key: Union[str, List[str]],
|
|
23
|
+
*,
|
|
24
|
+
sample_limit: int = 5,
|
|
25
|
+
save: bool = False,
|
|
26
|
+
) -> CompareResult:
|
|
27
|
+
"""
|
|
28
|
+
Compare two datasets to measure transformation effects.
|
|
29
|
+
|
|
30
|
+
Answers: "Did my transformation preserve rows and keys as expected?"
|
|
31
|
+
|
|
32
|
+
Does NOT answer: whether the transformation is "correct".
|
|
33
|
+
|
|
34
|
+
This probe provides deterministic, structured measurements that allow
|
|
35
|
+
agents (and humans) to reason about transformation effects with confidence.
|
|
36
|
+
|
|
37
|
+
Args:
|
|
38
|
+
before: Dataset before transformation (DataFrame or path/URI)
|
|
39
|
+
after: Dataset after transformation (DataFrame or path/URI)
|
|
40
|
+
key: Column(s) to use as row identifier
|
|
41
|
+
sample_limit: Max samples per category (default 5)
|
|
42
|
+
save: Persist result to state backend (not yet implemented)
|
|
43
|
+
|
|
44
|
+
Returns:
|
|
45
|
+
CompareResult with row_stats, key_stats, change_stats,
|
|
46
|
+
column_stats, and bounded samples.
|
|
47
|
+
|
|
48
|
+
Example:
|
|
49
|
+
# Basic comparison
|
|
50
|
+
result = compare(raw_df, transformed_df, key="order_id")
|
|
51
|
+
|
|
52
|
+
# With composite key
|
|
53
|
+
result = compare(before, after, key=["customer_id", "date"])
|
|
54
|
+
|
|
55
|
+
# Check for issues
|
|
56
|
+
if result.duplicated_after > 0:
|
|
57
|
+
print(f"Warning: {result.duplicated_after} keys are duplicated")
|
|
58
|
+
print(f"Sample: {result.samples_duplicated_keys}")
|
|
59
|
+
|
|
60
|
+
# Get structured output for LLM
|
|
61
|
+
print(result.to_llm())
|
|
62
|
+
"""
|
|
63
|
+
# Normalize key to list
|
|
64
|
+
if isinstance(key, str):
|
|
65
|
+
key = [key]
|
|
66
|
+
|
|
67
|
+
# Load data if paths provided
|
|
68
|
+
before_df = _load_data(before)
|
|
69
|
+
after_df = _load_data(after)
|
|
70
|
+
|
|
71
|
+
# Compute the comparison
|
|
72
|
+
result = _compute_compare(before_df, after_df, key, sample_limit)
|
|
73
|
+
|
|
74
|
+
# TODO: Implement save functionality
|
|
75
|
+
if save:
|
|
76
|
+
pass
|
|
77
|
+
|
|
78
|
+
return result
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
def _load_data(data: Union[pl.DataFrame, str]) -> pl.DataFrame:
|
|
82
|
+
"""
|
|
83
|
+
Load data from DataFrame or path/URI.
|
|
84
|
+
|
|
85
|
+
For MVP, only Polars DataFrames are fully supported.
|
|
86
|
+
File paths are loaded via Polars read functions.
|
|
87
|
+
"""
|
|
88
|
+
if isinstance(data, pl.DataFrame):
|
|
89
|
+
return data
|
|
90
|
+
|
|
91
|
+
if isinstance(data, str):
|
|
92
|
+
# Simple file loading for MVP
|
|
93
|
+
if data.lower().endswith(".parquet"):
|
|
94
|
+
return pl.read_parquet(data)
|
|
95
|
+
elif data.lower().endswith(".csv"):
|
|
96
|
+
return pl.read_csv(data)
|
|
97
|
+
elif data.startswith("s3://"):
|
|
98
|
+
return pl.read_parquet(data)
|
|
99
|
+
else:
|
|
100
|
+
# Try parquet first, then CSV
|
|
101
|
+
try:
|
|
102
|
+
return pl.read_parquet(data)
|
|
103
|
+
except Exception:
|
|
104
|
+
return pl.read_csv(data)
|
|
105
|
+
|
|
106
|
+
raise ValueError(f"Unsupported data type: {type(data)}")
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
def _compute_compare(
|
|
110
|
+
before: pl.DataFrame,
|
|
111
|
+
after: pl.DataFrame,
|
|
112
|
+
key: List[str],
|
|
113
|
+
sample_limit: int,
|
|
114
|
+
) -> CompareResult:
|
|
115
|
+
"""
|
|
116
|
+
Compute all comparison metrics between before and after datasets.
|
|
117
|
+
|
|
118
|
+
This is the core algorithm implementing the MVP schema.
|
|
119
|
+
"""
|
|
120
|
+
# Validate key columns exist
|
|
121
|
+
for k in key:
|
|
122
|
+
if k not in before.columns:
|
|
123
|
+
raise ValueError(f"Key column '{k}' not found in before dataset")
|
|
124
|
+
if k not in after.columns:
|
|
125
|
+
raise ValueError(f"Key column '{k}' not found in after dataset")
|
|
126
|
+
|
|
127
|
+
# ==========================================================================
|
|
128
|
+
# 1. Row stats
|
|
129
|
+
# ==========================================================================
|
|
130
|
+
before_rows = len(before)
|
|
131
|
+
after_rows = len(after)
|
|
132
|
+
row_delta = after_rows - before_rows
|
|
133
|
+
row_ratio = after_rows / before_rows if before_rows > 0 else float('inf')
|
|
134
|
+
|
|
135
|
+
# ==========================================================================
|
|
136
|
+
# 2. Key stats
|
|
137
|
+
# ==========================================================================
|
|
138
|
+
before_keys = before.select(key).unique()
|
|
139
|
+
after_keys = after.select(key).unique()
|
|
140
|
+
|
|
141
|
+
unique_before = len(before_keys)
|
|
142
|
+
unique_after = len(after_keys)
|
|
143
|
+
|
|
144
|
+
# Keys in both (preserved)
|
|
145
|
+
preserved_keys = before_keys.join(after_keys, on=key, how="inner")
|
|
146
|
+
preserved = len(preserved_keys)
|
|
147
|
+
|
|
148
|
+
# Keys dropped (in before but not in after)
|
|
149
|
+
dropped = unique_before - preserved
|
|
150
|
+
|
|
151
|
+
# Keys added (in after but not in before)
|
|
152
|
+
added = unique_after - preserved
|
|
153
|
+
|
|
154
|
+
# Duplicated after: count of keys appearing >1x in after
|
|
155
|
+
# (This is key count, not row count)
|
|
156
|
+
after_key_counts = after.group_by(key).agg(pl.len().alias("_count"))
|
|
157
|
+
duplicated_keys_df = after_key_counts.filter(pl.col("_count") > 1)
|
|
158
|
+
duplicated_after = len(duplicated_keys_df)
|
|
159
|
+
|
|
160
|
+
# ==========================================================================
|
|
161
|
+
# 3. Change stats (for preserved keys only)
|
|
162
|
+
# ==========================================================================
|
|
163
|
+
# Join before and after on key to find matching rows
|
|
164
|
+
# Use suffix to disambiguate columns
|
|
165
|
+
non_key_cols_before = [c for c in before.columns if c not in key]
|
|
166
|
+
non_key_cols_after = [c for c in after.columns if c not in key]
|
|
167
|
+
common_non_key_cols = set(non_key_cols_before) & set(non_key_cols_after)
|
|
168
|
+
|
|
169
|
+
unchanged_rows = 0
|
|
170
|
+
changed_rows = 0
|
|
171
|
+
|
|
172
|
+
if preserved > 0 and common_non_key_cols:
|
|
173
|
+
# For each preserved key, compare values
|
|
174
|
+
# Join on key, suffix the after columns
|
|
175
|
+
merged = before.join(after, on=key, how="inner", suffix="_after")
|
|
176
|
+
|
|
177
|
+
# Build a change mask: True if any common non-key column differs
|
|
178
|
+
# Handle NULL comparison: NULL != value should be True
|
|
179
|
+
change_exprs = []
|
|
180
|
+
for col in common_non_key_cols:
|
|
181
|
+
after_col = f"{col}_after"
|
|
182
|
+
if after_col in merged.columns:
|
|
183
|
+
# Use ne_missing to treat NULLs as different from values
|
|
184
|
+
# but NULL == NULL as same
|
|
185
|
+
change_exprs.append(
|
|
186
|
+
(pl.col(col).ne(pl.col(after_col))) |
|
|
187
|
+
(pl.col(col).is_null() != pl.col(after_col).is_null())
|
|
188
|
+
)
|
|
189
|
+
|
|
190
|
+
if change_exprs:
|
|
191
|
+
# Combine all expressions with OR
|
|
192
|
+
combined_mask = change_exprs[0]
|
|
193
|
+
for expr in change_exprs[1:]:
|
|
194
|
+
combined_mask = combined_mask | expr
|
|
195
|
+
|
|
196
|
+
# Count changed and unchanged
|
|
197
|
+
changed_df = merged.filter(combined_mask)
|
|
198
|
+
changed_rows = len(changed_df)
|
|
199
|
+
unchanged_rows = len(merged) - changed_rows
|
|
200
|
+
else:
|
|
201
|
+
# No common columns to compare
|
|
202
|
+
unchanged_rows = len(merged)
|
|
203
|
+
changed_rows = 0
|
|
204
|
+
elif preserved > 0:
|
|
205
|
+
# No common non-key columns, so no changes possible
|
|
206
|
+
unchanged_rows = preserved
|
|
207
|
+
changed_rows = 0
|
|
208
|
+
|
|
209
|
+
# ==========================================================================
|
|
210
|
+
# 4. Column stats
|
|
211
|
+
# ==========================================================================
|
|
212
|
+
before_cols = set(before.columns)
|
|
213
|
+
after_cols = set(after.columns)
|
|
214
|
+
|
|
215
|
+
columns_added = sorted(after_cols - before_cols)
|
|
216
|
+
columns_removed = sorted(before_cols - after_cols)
|
|
217
|
+
|
|
218
|
+
# Modified columns: columns in both where at least one value differs
|
|
219
|
+
# Also compute modified_fraction
|
|
220
|
+
columns_modified = []
|
|
221
|
+
modified_fraction: Dict[str, float] = {}
|
|
222
|
+
|
|
223
|
+
if preserved > 0:
|
|
224
|
+
merged = before.join(after, on=key, how="inner", suffix="_after")
|
|
225
|
+
preserved_count = len(merged)
|
|
226
|
+
|
|
227
|
+
for col in sorted(common_non_key_cols):
|
|
228
|
+
after_col = f"{col}_after"
|
|
229
|
+
if after_col in merged.columns and preserved_count > 0:
|
|
230
|
+
# Count rows where this column changed
|
|
231
|
+
changed_count = len(merged.filter(
|
|
232
|
+
(pl.col(col).ne(pl.col(after_col))) |
|
|
233
|
+
(pl.col(col).is_null() != pl.col(after_col).is_null())
|
|
234
|
+
))
|
|
235
|
+
if changed_count > 0:
|
|
236
|
+
columns_modified.append(col)
|
|
237
|
+
modified_fraction[col] = changed_count / preserved_count
|
|
238
|
+
|
|
239
|
+
# Nullability delta
|
|
240
|
+
nullability_delta: Dict[str, Dict[str, Optional[float]]] = {}
|
|
241
|
+
|
|
242
|
+
# For modified columns, compute before and after null rates
|
|
243
|
+
for col in columns_modified:
|
|
244
|
+
before_null = before[col].null_count() / before_rows if before_rows > 0 else 0.0
|
|
245
|
+
after_null = after[col].null_count() / after_rows if after_rows > 0 else 0.0
|
|
246
|
+
nullability_delta[col] = {"before": before_null, "after": after_null}
|
|
247
|
+
|
|
248
|
+
# For added columns, only after rate
|
|
249
|
+
for col in columns_added:
|
|
250
|
+
after_null = after[col].null_count() / after_rows if after_rows > 0 else 0.0
|
|
251
|
+
nullability_delta[col] = {"before": None, "after": after_null}
|
|
252
|
+
|
|
253
|
+
# ==========================================================================
|
|
254
|
+
# 5. Samples
|
|
255
|
+
# ==========================================================================
|
|
256
|
+
|
|
257
|
+
# Sample duplicated keys
|
|
258
|
+
samples_duplicated_keys = _extract_key_samples(
|
|
259
|
+
duplicated_keys_df.select(key),
|
|
260
|
+
key,
|
|
261
|
+
sample_limit
|
|
262
|
+
)
|
|
263
|
+
|
|
264
|
+
# Sample dropped keys (in before but not in after)
|
|
265
|
+
dropped_keys_df = before_keys.join(after_keys, on=key, how="anti")
|
|
266
|
+
samples_dropped_keys = _extract_key_samples(dropped_keys_df, key, sample_limit)
|
|
267
|
+
|
|
268
|
+
# Sample changed rows (with before/after values)
|
|
269
|
+
samples_changed_rows = _extract_changed_row_samples(
|
|
270
|
+
before, after, key, common_non_key_cols, sample_limit
|
|
271
|
+
)
|
|
272
|
+
|
|
273
|
+
# ==========================================================================
|
|
274
|
+
# Build result
|
|
275
|
+
# ==========================================================================
|
|
276
|
+
return CompareResult(
|
|
277
|
+
# Meta
|
|
278
|
+
before_rows=before_rows,
|
|
279
|
+
after_rows=after_rows,
|
|
280
|
+
key=key,
|
|
281
|
+
execution_tier="polars",
|
|
282
|
+
|
|
283
|
+
# Row stats
|
|
284
|
+
row_delta=row_delta,
|
|
285
|
+
row_ratio=row_ratio,
|
|
286
|
+
|
|
287
|
+
# Key stats
|
|
288
|
+
unique_before=unique_before,
|
|
289
|
+
unique_after=unique_after,
|
|
290
|
+
preserved=preserved,
|
|
291
|
+
dropped=dropped,
|
|
292
|
+
added=added,
|
|
293
|
+
duplicated_after=duplicated_after,
|
|
294
|
+
|
|
295
|
+
# Change stats
|
|
296
|
+
unchanged_rows=unchanged_rows,
|
|
297
|
+
changed_rows=changed_rows,
|
|
298
|
+
|
|
299
|
+
# Column stats
|
|
300
|
+
columns_added=columns_added,
|
|
301
|
+
columns_removed=columns_removed,
|
|
302
|
+
columns_modified=columns_modified,
|
|
303
|
+
modified_fraction=modified_fraction,
|
|
304
|
+
nullability_delta=nullability_delta,
|
|
305
|
+
|
|
306
|
+
# Samples
|
|
307
|
+
samples_duplicated_keys=samples_duplicated_keys,
|
|
308
|
+
samples_dropped_keys=samples_dropped_keys,
|
|
309
|
+
samples_changed_rows=samples_changed_rows,
|
|
310
|
+
|
|
311
|
+
# Config
|
|
312
|
+
sample_limit=sample_limit,
|
|
313
|
+
)
|
|
314
|
+
|
|
315
|
+
|
|
316
|
+
def _extract_key_samples(
|
|
317
|
+
keys_df: pl.DataFrame,
|
|
318
|
+
key: List[str],
|
|
319
|
+
limit: int,
|
|
320
|
+
) -> List[Any]:
|
|
321
|
+
"""
|
|
322
|
+
Extract sample key values from a DataFrame.
|
|
323
|
+
|
|
324
|
+
Returns list of key values (single value if single key, tuple if composite).
|
|
325
|
+
"""
|
|
326
|
+
if len(keys_df) == 0:
|
|
327
|
+
return []
|
|
328
|
+
|
|
329
|
+
samples = keys_df.head(limit)
|
|
330
|
+
|
|
331
|
+
if len(key) == 1:
|
|
332
|
+
# Single key - return list of values
|
|
333
|
+
return samples[key[0]].to_list()
|
|
334
|
+
else:
|
|
335
|
+
# Composite key - return list of dicts
|
|
336
|
+
return samples.to_dicts()
|
|
337
|
+
|
|
338
|
+
|
|
339
|
+
def _extract_changed_row_samples(
|
|
340
|
+
before: pl.DataFrame,
|
|
341
|
+
after: pl.DataFrame,
|
|
342
|
+
key: List[str],
|
|
343
|
+
common_cols: set,
|
|
344
|
+
limit: int,
|
|
345
|
+
) -> List[Dict[str, Any]]:
|
|
346
|
+
"""
|
|
347
|
+
Extract sample changed rows with before/after values.
|
|
348
|
+
|
|
349
|
+
Returns list of dicts with key, before values, and after values
|
|
350
|
+
for columns that changed.
|
|
351
|
+
"""
|
|
352
|
+
if not common_cols:
|
|
353
|
+
return []
|
|
354
|
+
|
|
355
|
+
# Join on key
|
|
356
|
+
merged = before.join(after, on=key, how="inner", suffix="_after")
|
|
357
|
+
|
|
358
|
+
if len(merged) == 0:
|
|
359
|
+
return []
|
|
360
|
+
|
|
361
|
+
samples = []
|
|
362
|
+
for row in merged.head(limit * 2).iter_rows(named=True):
|
|
363
|
+
# Check if any column changed
|
|
364
|
+
changes_before = {}
|
|
365
|
+
changes_after = {}
|
|
366
|
+
has_change = False
|
|
367
|
+
|
|
368
|
+
for col in common_cols:
|
|
369
|
+
after_col = f"{col}_after"
|
|
370
|
+
if after_col in row:
|
|
371
|
+
before_val = row[col]
|
|
372
|
+
after_val = row[after_col]
|
|
373
|
+
|
|
374
|
+
# Check for change (handle NULL)
|
|
375
|
+
is_changed = (before_val != after_val) or (
|
|
376
|
+
(before_val is None) != (after_val is None)
|
|
377
|
+
)
|
|
378
|
+
|
|
379
|
+
if is_changed:
|
|
380
|
+
has_change = True
|
|
381
|
+
changes_before[col] = before_val
|
|
382
|
+
changes_after[col] = after_val
|
|
383
|
+
|
|
384
|
+
if has_change:
|
|
385
|
+
# Extract key value(s)
|
|
386
|
+
if len(key) == 1:
|
|
387
|
+
key_val = row[key[0]]
|
|
388
|
+
else:
|
|
389
|
+
key_val = {k: row[k] for k in key}
|
|
390
|
+
|
|
391
|
+
samples.append({
|
|
392
|
+
"key": key_val,
|
|
393
|
+
"before": changes_before,
|
|
394
|
+
"after": changes_after,
|
|
395
|
+
})
|
|
396
|
+
|
|
397
|
+
if len(samples) >= limit:
|
|
398
|
+
break
|
|
399
|
+
|
|
400
|
+
return samples
|
|
@@ -0,0 +1,283 @@
|
|
|
1
|
+
# src/kontra/probes/relationship.py
|
|
2
|
+
"""
|
|
3
|
+
Relationship probe: Measure JOIN viability between two datasets.
|
|
4
|
+
|
|
5
|
+
This probe answers: "What is the shape of this join?"
|
|
6
|
+
|
|
7
|
+
It does NOT answer: which join type to use, or whether the join is correct.
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
from __future__ import annotations
|
|
11
|
+
|
|
12
|
+
from typing import Any, List, Union
|
|
13
|
+
|
|
14
|
+
import polars as pl
|
|
15
|
+
|
|
16
|
+
from kontra.api.compare import RelationshipProfile
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def profile_relationship(
|
|
20
|
+
left: Union[pl.DataFrame, str],
|
|
21
|
+
right: Union[pl.DataFrame, str],
|
|
22
|
+
on: Union[str, List[str]],
|
|
23
|
+
*,
|
|
24
|
+
sample_limit: int = 5,
|
|
25
|
+
save: bool = False,
|
|
26
|
+
) -> RelationshipProfile:
|
|
27
|
+
"""
|
|
28
|
+
Profile the structural relationship between two datasets.
|
|
29
|
+
|
|
30
|
+
Answers: "What is the shape of this join?"
|
|
31
|
+
|
|
32
|
+
Does NOT answer: which join type to use, or whether the join is correct.
|
|
33
|
+
|
|
34
|
+
This probe provides deterministic, structured measurements that allow
|
|
35
|
+
agents (and humans) to understand JOIN viability before writing SQL.
|
|
36
|
+
|
|
37
|
+
Args:
|
|
38
|
+
left: Left dataset (DataFrame or path/URI)
|
|
39
|
+
right: Right dataset (DataFrame or path/URI)
|
|
40
|
+
on: Column(s) to join on
|
|
41
|
+
sample_limit: Max samples per category (default 5)
|
|
42
|
+
save: Persist result to state backend (not yet implemented)
|
|
43
|
+
|
|
44
|
+
Returns:
|
|
45
|
+
RelationshipProfile with key_stats, cardinality, coverage,
|
|
46
|
+
and bounded samples.
|
|
47
|
+
|
|
48
|
+
Example:
|
|
49
|
+
# Profile before writing JOIN
|
|
50
|
+
profile = profile_relationship(orders, customers, on="customer_id")
|
|
51
|
+
|
|
52
|
+
# Check for issues
|
|
53
|
+
if profile.right_key_multiplicity_max > 1:
|
|
54
|
+
print("Warning: right side has duplicates, JOIN may explode rows")
|
|
55
|
+
print(f"Sample duplicated keys: {profile.samples_right_duplicates}")
|
|
56
|
+
|
|
57
|
+
if profile.left_keys_without_match > 0:
|
|
58
|
+
print(f"Warning: {profile.left_keys_without_match} keys won't match")
|
|
59
|
+
|
|
60
|
+
# Get structured output for LLM
|
|
61
|
+
print(profile.to_llm())
|
|
62
|
+
"""
|
|
63
|
+
# Normalize on to list
|
|
64
|
+
if isinstance(on, str):
|
|
65
|
+
on = [on]
|
|
66
|
+
|
|
67
|
+
# Load data if paths provided
|
|
68
|
+
left_df = _load_data(left)
|
|
69
|
+
right_df = _load_data(right)
|
|
70
|
+
|
|
71
|
+
# Compute the profile
|
|
72
|
+
result = _compute_relationship(left_df, right_df, on, sample_limit)
|
|
73
|
+
|
|
74
|
+
# TODO: Implement save functionality
|
|
75
|
+
if save:
|
|
76
|
+
pass
|
|
77
|
+
|
|
78
|
+
return result
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
def _load_data(data: Union[pl.DataFrame, str]) -> pl.DataFrame:
|
|
82
|
+
"""
|
|
83
|
+
Load data from DataFrame or path/URI.
|
|
84
|
+
|
|
85
|
+
For MVP, only Polars DataFrames are fully supported.
|
|
86
|
+
File paths are loaded via Polars read functions.
|
|
87
|
+
"""
|
|
88
|
+
if isinstance(data, pl.DataFrame):
|
|
89
|
+
return data
|
|
90
|
+
|
|
91
|
+
if isinstance(data, str):
|
|
92
|
+
# Simple file loading for MVP
|
|
93
|
+
if data.lower().endswith(".parquet"):
|
|
94
|
+
return pl.read_parquet(data)
|
|
95
|
+
elif data.lower().endswith(".csv"):
|
|
96
|
+
return pl.read_csv(data)
|
|
97
|
+
elif data.startswith("s3://"):
|
|
98
|
+
return pl.read_parquet(data)
|
|
99
|
+
else:
|
|
100
|
+
# Try parquet first, then CSV
|
|
101
|
+
try:
|
|
102
|
+
return pl.read_parquet(data)
|
|
103
|
+
except Exception:
|
|
104
|
+
return pl.read_csv(data)
|
|
105
|
+
|
|
106
|
+
raise ValueError(f"Unsupported data type: {type(data)}")
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
def _compute_relationship(
|
|
110
|
+
left: pl.DataFrame,
|
|
111
|
+
right: pl.DataFrame,
|
|
112
|
+
on: List[str],
|
|
113
|
+
sample_limit: int,
|
|
114
|
+
) -> RelationshipProfile:
|
|
115
|
+
"""
|
|
116
|
+
Compute all relationship metrics between left and right datasets.
|
|
117
|
+
|
|
118
|
+
This is the core algorithm implementing the MVP schema.
|
|
119
|
+
"""
|
|
120
|
+
# Validate key columns exist
|
|
121
|
+
for col in on:
|
|
122
|
+
if col not in left.columns:
|
|
123
|
+
raise ValueError(f"Key column '{col}' not found in left dataset")
|
|
124
|
+
if col not in right.columns:
|
|
125
|
+
raise ValueError(f"Key column '{col}' not found in right dataset")
|
|
126
|
+
|
|
127
|
+
# ==========================================================================
|
|
128
|
+
# 1. Basic counts
|
|
129
|
+
# ==========================================================================
|
|
130
|
+
left_rows = len(left)
|
|
131
|
+
right_rows = len(right)
|
|
132
|
+
|
|
133
|
+
# ==========================================================================
|
|
134
|
+
# 2. Key stats - left
|
|
135
|
+
# ==========================================================================
|
|
136
|
+
# Null rate: fraction of rows with any NULL in join key columns
|
|
137
|
+
if left_rows > 0:
|
|
138
|
+
null_mask = pl.lit(False)
|
|
139
|
+
for col in on:
|
|
140
|
+
null_mask = null_mask | pl.col(col).is_null()
|
|
141
|
+
left_null_count = len(left.filter(null_mask))
|
|
142
|
+
left_null_rate = left_null_count / left_rows
|
|
143
|
+
else:
|
|
144
|
+
left_null_rate = 0.0
|
|
145
|
+
|
|
146
|
+
# Unique keys (excluding NULLs)
|
|
147
|
+
left_keys = left.select(on).drop_nulls().unique()
|
|
148
|
+
left_unique_keys = len(left_keys)
|
|
149
|
+
|
|
150
|
+
# Duplicate keys: count of keys appearing >1x
|
|
151
|
+
left_key_counts = left.drop_nulls(subset=on).group_by(on).agg(pl.len().alias("_count"))
|
|
152
|
+
left_duplicate_keys = len(left_key_counts.filter(pl.col("_count") > 1))
|
|
153
|
+
|
|
154
|
+
# ==========================================================================
|
|
155
|
+
# 3. Key stats - right
|
|
156
|
+
# ==========================================================================
|
|
157
|
+
if right_rows > 0:
|
|
158
|
+
null_mask = pl.lit(False)
|
|
159
|
+
for col in on:
|
|
160
|
+
null_mask = null_mask | pl.col(col).is_null()
|
|
161
|
+
right_null_count = len(right.filter(null_mask))
|
|
162
|
+
right_null_rate = right_null_count / right_rows
|
|
163
|
+
else:
|
|
164
|
+
right_null_rate = 0.0
|
|
165
|
+
|
|
166
|
+
# Unique keys (excluding NULLs)
|
|
167
|
+
right_keys = right.select(on).drop_nulls().unique()
|
|
168
|
+
right_unique_keys = len(right_keys)
|
|
169
|
+
|
|
170
|
+
# Duplicate keys: count of keys appearing >1x
|
|
171
|
+
right_key_counts = right.drop_nulls(subset=on).group_by(on).agg(pl.len().alias("_count"))
|
|
172
|
+
right_duplicate_keys = len(right_key_counts.filter(pl.col("_count") > 1))
|
|
173
|
+
|
|
174
|
+
# ==========================================================================
|
|
175
|
+
# 4. Cardinality (rows per key)
|
|
176
|
+
# ==========================================================================
|
|
177
|
+
# Left multiplicity
|
|
178
|
+
if len(left_key_counts) > 0:
|
|
179
|
+
left_key_multiplicity_min = left_key_counts["_count"].min()
|
|
180
|
+
left_key_multiplicity_max = left_key_counts["_count"].max()
|
|
181
|
+
else:
|
|
182
|
+
left_key_multiplicity_min = 0
|
|
183
|
+
left_key_multiplicity_max = 0
|
|
184
|
+
|
|
185
|
+
# Right multiplicity
|
|
186
|
+
if len(right_key_counts) > 0:
|
|
187
|
+
right_key_multiplicity_min = right_key_counts["_count"].min()
|
|
188
|
+
right_key_multiplicity_max = right_key_counts["_count"].max()
|
|
189
|
+
else:
|
|
190
|
+
right_key_multiplicity_min = 0
|
|
191
|
+
right_key_multiplicity_max = 0
|
|
192
|
+
|
|
193
|
+
# ==========================================================================
|
|
194
|
+
# 5. Coverage
|
|
195
|
+
# ==========================================================================
|
|
196
|
+
# Left keys with match in right
|
|
197
|
+
left_matched = left_keys.join(right_keys, on=on, how="inner")
|
|
198
|
+
left_keys_with_match = len(left_matched)
|
|
199
|
+
left_keys_without_match = left_unique_keys - left_keys_with_match
|
|
200
|
+
|
|
201
|
+
# Right keys with match in left
|
|
202
|
+
right_matched = right_keys.join(left_keys, on=on, how="inner")
|
|
203
|
+
right_keys_with_match = len(right_matched)
|
|
204
|
+
right_keys_without_match = right_unique_keys - right_keys_with_match
|
|
205
|
+
|
|
206
|
+
# ==========================================================================
|
|
207
|
+
# 6. Samples
|
|
208
|
+
# ==========================================================================
|
|
209
|
+
# Sample left unmatched keys
|
|
210
|
+
left_unmatched = left_keys.join(right_keys, on=on, how="anti")
|
|
211
|
+
samples_left_unmatched = _extract_key_samples(left_unmatched, on, sample_limit)
|
|
212
|
+
|
|
213
|
+
# Sample right unmatched keys
|
|
214
|
+
right_unmatched = right_keys.join(left_keys, on=on, how="anti")
|
|
215
|
+
samples_right_unmatched = _extract_key_samples(right_unmatched, on, sample_limit)
|
|
216
|
+
|
|
217
|
+
# Sample right duplicate keys
|
|
218
|
+
right_duplicates = right_key_counts.filter(pl.col("_count") > 1).select(on)
|
|
219
|
+
samples_right_duplicates = _extract_key_samples(right_duplicates, on, sample_limit)
|
|
220
|
+
|
|
221
|
+
# ==========================================================================
|
|
222
|
+
# Build result
|
|
223
|
+
# ==========================================================================
|
|
224
|
+
return RelationshipProfile(
|
|
225
|
+
# Meta
|
|
226
|
+
on=on,
|
|
227
|
+
left_rows=left_rows,
|
|
228
|
+
right_rows=right_rows,
|
|
229
|
+
execution_tier="polars",
|
|
230
|
+
|
|
231
|
+
# Key stats - left
|
|
232
|
+
left_null_rate=left_null_rate,
|
|
233
|
+
left_unique_keys=left_unique_keys,
|
|
234
|
+
left_duplicate_keys=left_duplicate_keys,
|
|
235
|
+
|
|
236
|
+
# Key stats - right
|
|
237
|
+
right_null_rate=right_null_rate,
|
|
238
|
+
right_unique_keys=right_unique_keys,
|
|
239
|
+
right_duplicate_keys=right_duplicate_keys,
|
|
240
|
+
|
|
241
|
+
# Cardinality
|
|
242
|
+
left_key_multiplicity_min=left_key_multiplicity_min,
|
|
243
|
+
left_key_multiplicity_max=left_key_multiplicity_max,
|
|
244
|
+
right_key_multiplicity_min=right_key_multiplicity_min,
|
|
245
|
+
right_key_multiplicity_max=right_key_multiplicity_max,
|
|
246
|
+
|
|
247
|
+
# Coverage
|
|
248
|
+
left_keys_with_match=left_keys_with_match,
|
|
249
|
+
left_keys_without_match=left_keys_without_match,
|
|
250
|
+
right_keys_with_match=right_keys_with_match,
|
|
251
|
+
right_keys_without_match=right_keys_without_match,
|
|
252
|
+
|
|
253
|
+
# Samples
|
|
254
|
+
samples_left_unmatched=samples_left_unmatched,
|
|
255
|
+
samples_right_unmatched=samples_right_unmatched,
|
|
256
|
+
samples_right_duplicates=samples_right_duplicates,
|
|
257
|
+
|
|
258
|
+
# Config
|
|
259
|
+
sample_limit=sample_limit,
|
|
260
|
+
)
|
|
261
|
+
|
|
262
|
+
|
|
263
|
+
def _extract_key_samples(
|
|
264
|
+
keys_df: pl.DataFrame,
|
|
265
|
+
on: List[str],
|
|
266
|
+
limit: int,
|
|
267
|
+
) -> List[Any]:
|
|
268
|
+
"""
|
|
269
|
+
Extract sample key values from a DataFrame.
|
|
270
|
+
|
|
271
|
+
Returns list of key values (single value if single key, dict if composite).
|
|
272
|
+
"""
|
|
273
|
+
if len(keys_df) == 0:
|
|
274
|
+
return []
|
|
275
|
+
|
|
276
|
+
samples = keys_df.head(limit)
|
|
277
|
+
|
|
278
|
+
if len(on) == 1:
|
|
279
|
+
# Single key - return list of values
|
|
280
|
+
return samples[on[0]].to_list()
|
|
281
|
+
else:
|
|
282
|
+
# Composite key - return list of dicts
|
|
283
|
+
return samples.to_dicts()
|
|
File without changes
|