duckguard 2.2.0__py3-none-any.whl → 3.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- duckguard/__init__.py +1 -1
- duckguard/anomaly/__init__.py +28 -0
- duckguard/anomaly/baselines.py +294 -0
- duckguard/anomaly/methods.py +16 -2
- duckguard/anomaly/ml_methods.py +724 -0
- duckguard/checks/__init__.py +26 -0
- duckguard/checks/conditional.py +796 -0
- duckguard/checks/distributional.py +524 -0
- duckguard/checks/multicolumn.py +726 -0
- duckguard/checks/query_based.py +643 -0
- duckguard/cli/main.py +257 -2
- duckguard/connectors/factory.py +30 -2
- duckguard/connectors/files.py +7 -3
- duckguard/core/column.py +851 -1
- duckguard/core/dataset.py +1035 -0
- duckguard/core/result.py +236 -0
- duckguard/freshness/__init__.py +33 -0
- duckguard/freshness/monitor.py +429 -0
- duckguard/history/schema.py +119 -1
- duckguard/notifications/__init__.py +20 -2
- duckguard/notifications/email.py +508 -0
- duckguard/profiler/distribution_analyzer.py +384 -0
- duckguard/profiler/outlier_detector.py +497 -0
- duckguard/profiler/pattern_matcher.py +301 -0
- duckguard/profiler/quality_scorer.py +445 -0
- duckguard/reports/html_reporter.py +1 -2
- duckguard/rules/executor.py +642 -0
- duckguard/rules/generator.py +4 -1
- duckguard/rules/schema.py +54 -0
- duckguard/schema_history/__init__.py +40 -0
- duckguard/schema_history/analyzer.py +414 -0
- duckguard/schema_history/tracker.py +288 -0
- duckguard/semantic/detector.py +17 -1
- duckguard-3.0.0.dist-info/METADATA +1072 -0
- {duckguard-2.2.0.dist-info → duckguard-3.0.0.dist-info}/RECORD +38 -21
- duckguard-2.2.0.dist-info/METADATA +0 -351
- {duckguard-2.2.0.dist-info → duckguard-3.0.0.dist-info}/WHEEL +0 -0
- {duckguard-2.2.0.dist-info → duckguard-3.0.0.dist-info}/entry_points.txt +0 -0
- {duckguard-2.2.0.dist-info → duckguard-3.0.0.dist-info}/licenses/LICENSE +0 -0
duckguard/core/dataset.py
CHANGED
|
@@ -6,9 +6,13 @@ from typing import TYPE_CHECKING, Any
|
|
|
6
6
|
|
|
7
7
|
from duckguard.core.column import Column
|
|
8
8
|
from duckguard.core.engine import DuckGuardEngine
|
|
9
|
+
from duckguard.core.result import GroupByResult, ReconciliationResult, ValidationResult
|
|
9
10
|
|
|
10
11
|
if TYPE_CHECKING:
|
|
12
|
+
from datetime import timedelta
|
|
13
|
+
|
|
11
14
|
from duckguard.core.scoring import QualityScore
|
|
15
|
+
from duckguard.freshness import FreshnessResult
|
|
12
16
|
|
|
13
17
|
|
|
14
18
|
class Dataset:
|
|
@@ -227,6 +231,129 @@ class Dataset:
|
|
|
227
231
|
"""Iterate over column names."""
|
|
228
232
|
return iter(self.columns)
|
|
229
233
|
|
|
234
|
+
@property
|
|
235
|
+
def freshness(self) -> FreshnessResult:
|
|
236
|
+
"""
|
|
237
|
+
Get freshness information for this dataset.
|
|
238
|
+
|
|
239
|
+
Returns:
|
|
240
|
+
FreshnessResult with freshness information including:
|
|
241
|
+
- last_modified: When data was last updated
|
|
242
|
+
- age_seconds: Age in seconds
|
|
243
|
+
- age_human: Human-readable age string
|
|
244
|
+
- is_fresh: Whether data meets default 24h threshold
|
|
245
|
+
|
|
246
|
+
Example:
|
|
247
|
+
data = connect("data.csv")
|
|
248
|
+
print(data.freshness.age_human) # "2 hours ago"
|
|
249
|
+
print(data.freshness.is_fresh) # True
|
|
250
|
+
"""
|
|
251
|
+
from duckguard.freshness import FreshnessMonitor
|
|
252
|
+
|
|
253
|
+
monitor = FreshnessMonitor()
|
|
254
|
+
return monitor.check(self)
|
|
255
|
+
|
|
256
|
+
def is_fresh(self, max_age: timedelta) -> bool:
|
|
257
|
+
"""
|
|
258
|
+
Check if data is fresher than the specified maximum age.
|
|
259
|
+
|
|
260
|
+
Args:
|
|
261
|
+
max_age: Maximum acceptable age for the data
|
|
262
|
+
|
|
263
|
+
Returns:
|
|
264
|
+
True if data is fresher than max_age
|
|
265
|
+
|
|
266
|
+
Example:
|
|
267
|
+
from datetime import timedelta
|
|
268
|
+
data = connect("data.csv")
|
|
269
|
+
|
|
270
|
+
if not data.is_fresh(timedelta(hours=6)):
|
|
271
|
+
print("Data is stale!")
|
|
272
|
+
"""
|
|
273
|
+
from duckguard.freshness import FreshnessMonitor
|
|
274
|
+
|
|
275
|
+
monitor = FreshnessMonitor(threshold=max_age)
|
|
276
|
+
result = monitor.check(self)
|
|
277
|
+
return result.is_fresh
|
|
278
|
+
|
|
279
|
+
# =========================================================================
|
|
280
|
+
# Cross-Dataset Validation Methods
|
|
281
|
+
# =========================================================================
|
|
282
|
+
|
|
283
|
+
def row_count_matches(
|
|
284
|
+
self,
|
|
285
|
+
other_dataset: Dataset,
|
|
286
|
+
tolerance: int = 0,
|
|
287
|
+
) -> ValidationResult:
|
|
288
|
+
"""
|
|
289
|
+
Check that row count matches another dataset within tolerance.
|
|
290
|
+
|
|
291
|
+
Useful for comparing backup data, validating migrations,
|
|
292
|
+
or ensuring parallel pipelines produce consistent results.
|
|
293
|
+
|
|
294
|
+
Args:
|
|
295
|
+
other_dataset: Dataset to compare against
|
|
296
|
+
tolerance: Allowed difference in row counts (default: 0 = exact match)
|
|
297
|
+
|
|
298
|
+
Returns:
|
|
299
|
+
ValidationResult indicating if row counts match
|
|
300
|
+
|
|
301
|
+
Example:
|
|
302
|
+
orders = connect("orders.parquet")
|
|
303
|
+
backup = connect("orders_backup.parquet")
|
|
304
|
+
|
|
305
|
+
# Exact match
|
|
306
|
+
result = orders.row_count_matches(backup)
|
|
307
|
+
|
|
308
|
+
# Allow small difference
|
|
309
|
+
result = orders.row_count_matches(backup, tolerance=10)
|
|
310
|
+
"""
|
|
311
|
+
source_count = self.row_count
|
|
312
|
+
other_count = other_dataset.row_count
|
|
313
|
+
diff = abs(source_count - other_count)
|
|
314
|
+
passed = diff <= tolerance
|
|
315
|
+
|
|
316
|
+
other_name = other_dataset.name or other_dataset.source
|
|
317
|
+
if tolerance == 0:
|
|
318
|
+
message = f"Row counts {'match' if passed else 'differ'}: {self._name}={source_count}, {other_name}={other_count}"
|
|
319
|
+
else:
|
|
320
|
+
message = f"Row count difference is {diff} (tolerance: {tolerance}): {self._name}={source_count}, {other_name}={other_count}"
|
|
321
|
+
|
|
322
|
+
return ValidationResult(
|
|
323
|
+
passed=passed,
|
|
324
|
+
actual_value=diff,
|
|
325
|
+
expected_value=f"<= {tolerance}",
|
|
326
|
+
message=message,
|
|
327
|
+
details={
|
|
328
|
+
"source_count": source_count,
|
|
329
|
+
"other_count": other_count,
|
|
330
|
+
"difference": diff,
|
|
331
|
+
"tolerance": tolerance,
|
|
332
|
+
"source_dataset": self._name,
|
|
333
|
+
"other_dataset": other_name,
|
|
334
|
+
},
|
|
335
|
+
)
|
|
336
|
+
|
|
337
|
+
def row_count_equals(
|
|
338
|
+
self,
|
|
339
|
+
other_dataset: Dataset,
|
|
340
|
+
) -> ValidationResult:
|
|
341
|
+
"""
|
|
342
|
+
Check that row count exactly equals another dataset.
|
|
343
|
+
|
|
344
|
+
This is a convenience alias for row_count_matches(other, tolerance=0).
|
|
345
|
+
|
|
346
|
+
Args:
|
|
347
|
+
other_dataset: Dataset to compare against
|
|
348
|
+
|
|
349
|
+
Returns:
|
|
350
|
+
ValidationResult indicating if row counts are equal
|
|
351
|
+
|
|
352
|
+
Example:
|
|
353
|
+
result = orders.row_count_equals(backup_orders)
|
|
354
|
+
"""
|
|
355
|
+
return self.row_count_matches(other_dataset, tolerance=0)
|
|
356
|
+
|
|
230
357
|
def score(
|
|
231
358
|
self,
|
|
232
359
|
weights: dict | None = None,
|
|
@@ -282,3 +409,911 @@ class Dataset:
|
|
|
282
409
|
|
|
283
410
|
scorer = QualityScorer(weights=scorer_weights)
|
|
284
411
|
return scorer.score(self)
|
|
412
|
+
|
|
413
|
+
# =========================================================================
|
|
414
|
+
# Reconciliation Methods
|
|
415
|
+
# =========================================================================
|
|
416
|
+
|
|
417
|
+
def reconcile(
|
|
418
|
+
self,
|
|
419
|
+
target_dataset: Dataset,
|
|
420
|
+
key_columns: list[str],
|
|
421
|
+
compare_columns: list[str] | None = None,
|
|
422
|
+
tolerance: float = 0.0,
|
|
423
|
+
sample_mismatches: int = 10,
|
|
424
|
+
) -> ReconciliationResult:
|
|
425
|
+
"""
|
|
426
|
+
Reconcile this dataset with a target dataset.
|
|
427
|
+
|
|
428
|
+
Performs comprehensive comparison including row matching, missing/extra
|
|
429
|
+
detection, and column-by-column value comparison. Essential for
|
|
430
|
+
migration validation and data synchronization checks.
|
|
431
|
+
|
|
432
|
+
Args:
|
|
433
|
+
target_dataset: Dataset to compare against
|
|
434
|
+
key_columns: Columns to use for matching rows (like a primary key)
|
|
435
|
+
compare_columns: Columns to compare values (default: all non-key columns)
|
|
436
|
+
tolerance: Numeric tolerance for value comparison (default: 0 = exact match)
|
|
437
|
+
sample_mismatches: Number of sample mismatches to capture (default: 10)
|
|
438
|
+
|
|
439
|
+
Returns:
|
|
440
|
+
ReconciliationResult with detailed comparison metrics
|
|
441
|
+
|
|
442
|
+
Example:
|
|
443
|
+
source = connect("orders_source.parquet")
|
|
444
|
+
target = connect("orders_target.parquet")
|
|
445
|
+
|
|
446
|
+
result = source.reconcile(
|
|
447
|
+
target,
|
|
448
|
+
key_columns=["order_id"],
|
|
449
|
+
compare_columns=["amount", "status", "customer_id"]
|
|
450
|
+
)
|
|
451
|
+
|
|
452
|
+
if not result.passed:
|
|
453
|
+
print(f"Missing in target: {result.missing_in_target}")
|
|
454
|
+
print(f"Extra in target: {result.extra_in_target}")
|
|
455
|
+
print(result.summary())
|
|
456
|
+
"""
|
|
457
|
+
from duckguard.core.result import ReconciliationMismatch, ReconciliationResult
|
|
458
|
+
|
|
459
|
+
source_ref = self._engine.get_source_reference(self._source)
|
|
460
|
+
target_ref = target_dataset.engine.get_source_reference(target_dataset.source)
|
|
461
|
+
target_name = target_dataset.name or target_dataset.source
|
|
462
|
+
|
|
463
|
+
# Determine columns to compare
|
|
464
|
+
if compare_columns is None:
|
|
465
|
+
compare_columns = [c for c in self.columns if c not in key_columns]
|
|
466
|
+
|
|
467
|
+
# Build key column references
|
|
468
|
+
key_join_condition = " AND ".join(
|
|
469
|
+
f's."{k}" = t."{k}"' for k in key_columns
|
|
470
|
+
)
|
|
471
|
+
|
|
472
|
+
# Count rows in source not in target (missing)
|
|
473
|
+
sql_missing = f"""
|
|
474
|
+
SELECT COUNT(*) FROM {source_ref} s
|
|
475
|
+
WHERE NOT EXISTS (
|
|
476
|
+
SELECT 1 FROM {target_ref} t
|
|
477
|
+
WHERE {key_join_condition}
|
|
478
|
+
)
|
|
479
|
+
"""
|
|
480
|
+
missing_count = self._engine.fetch_value(sql_missing) or 0
|
|
481
|
+
|
|
482
|
+
# Count rows in target not in source (extra)
|
|
483
|
+
sql_extra = f"""
|
|
484
|
+
SELECT COUNT(*) FROM {target_ref} t
|
|
485
|
+
WHERE NOT EXISTS (
|
|
486
|
+
SELECT 1 FROM {source_ref} s
|
|
487
|
+
WHERE {key_join_condition}
|
|
488
|
+
)
|
|
489
|
+
"""
|
|
490
|
+
extra_count = self._engine.fetch_value(sql_extra) or 0
|
|
491
|
+
|
|
492
|
+
# Count value mismatches per column for matching rows
|
|
493
|
+
value_mismatches: dict[str, int] = {}
|
|
494
|
+
for col in compare_columns:
|
|
495
|
+
if tolerance > 0:
|
|
496
|
+
# Numeric tolerance comparison
|
|
497
|
+
sql_mismatch = f"""
|
|
498
|
+
SELECT COUNT(*) FROM {source_ref} s
|
|
499
|
+
INNER JOIN {target_ref} t ON {key_join_condition}
|
|
500
|
+
WHERE ABS(COALESCE(CAST(s."{col}" AS DOUBLE), 0) -
|
|
501
|
+
COALESCE(CAST(t."{col}" AS DOUBLE), 0)) > {tolerance}
|
|
502
|
+
OR (s."{col}" IS NULL AND t."{col}" IS NOT NULL)
|
|
503
|
+
OR (s."{col}" IS NOT NULL AND t."{col}" IS NULL)
|
|
504
|
+
"""
|
|
505
|
+
else:
|
|
506
|
+
# Exact match comparison
|
|
507
|
+
sql_mismatch = f"""
|
|
508
|
+
SELECT COUNT(*) FROM {source_ref} s
|
|
509
|
+
INNER JOIN {target_ref} t ON {key_join_condition}
|
|
510
|
+
WHERE s."{col}" IS DISTINCT FROM t."{col}"
|
|
511
|
+
"""
|
|
512
|
+
mismatch_count = self._engine.fetch_value(sql_mismatch) or 0
|
|
513
|
+
if mismatch_count > 0:
|
|
514
|
+
value_mismatches[col] = mismatch_count
|
|
515
|
+
|
|
516
|
+
# Calculate match percentage
|
|
517
|
+
source_count = self.row_count
|
|
518
|
+
target_count = target_dataset.row_count
|
|
519
|
+
matched_rows = source_count - missing_count
|
|
520
|
+
total_comparisons = matched_rows * len(compare_columns) if compare_columns else matched_rows
|
|
521
|
+
total_mismatches = sum(value_mismatches.values())
|
|
522
|
+
|
|
523
|
+
if total_comparisons > 0:
|
|
524
|
+
match_percentage = ((total_comparisons - total_mismatches) / total_comparisons) * 100
|
|
525
|
+
else:
|
|
526
|
+
match_percentage = 100.0 if missing_count == 0 and extra_count == 0 else 0.0
|
|
527
|
+
|
|
528
|
+
# Capture sample mismatches
|
|
529
|
+
mismatches: list[ReconciliationMismatch] = []
|
|
530
|
+
if sample_mismatches > 0 and (missing_count > 0 or extra_count > 0 or value_mismatches):
|
|
531
|
+
mismatches = self._get_sample_mismatches(
|
|
532
|
+
target_dataset,
|
|
533
|
+
key_columns,
|
|
534
|
+
compare_columns,
|
|
535
|
+
tolerance,
|
|
536
|
+
sample_mismatches,
|
|
537
|
+
)
|
|
538
|
+
|
|
539
|
+
# Determine if passed (no mismatches at all)
|
|
540
|
+
passed = missing_count == 0 and extra_count == 0 and len(value_mismatches) == 0
|
|
541
|
+
|
|
542
|
+
return ReconciliationResult(
|
|
543
|
+
passed=passed,
|
|
544
|
+
source_row_count=source_count,
|
|
545
|
+
target_row_count=target_count,
|
|
546
|
+
missing_in_target=missing_count,
|
|
547
|
+
extra_in_target=extra_count,
|
|
548
|
+
value_mismatches=value_mismatches,
|
|
549
|
+
match_percentage=match_percentage,
|
|
550
|
+
key_columns=key_columns,
|
|
551
|
+
compared_columns=compare_columns,
|
|
552
|
+
mismatches=mismatches,
|
|
553
|
+
details={
|
|
554
|
+
"source_dataset": self._name,
|
|
555
|
+
"target_dataset": target_name,
|
|
556
|
+
"tolerance": tolerance,
|
|
557
|
+
},
|
|
558
|
+
)
|
|
559
|
+
|
|
560
|
+
def _get_sample_mismatches(
|
|
561
|
+
self,
|
|
562
|
+
target_dataset: Dataset,
|
|
563
|
+
key_columns: list[str],
|
|
564
|
+
compare_columns: list[str],
|
|
565
|
+
tolerance: float,
|
|
566
|
+
limit: int,
|
|
567
|
+
) -> list:
|
|
568
|
+
"""Get sample of reconciliation mismatches."""
|
|
569
|
+
from duckguard.core.result import ReconciliationMismatch
|
|
570
|
+
|
|
571
|
+
source_ref = self._engine.get_source_reference(self._source)
|
|
572
|
+
target_ref = target_dataset.engine.get_source_reference(target_dataset.source)
|
|
573
|
+
|
|
574
|
+
key_cols_sql = ", ".join(f's."{k}"' for k in key_columns)
|
|
575
|
+
key_join_condition = " AND ".join(
|
|
576
|
+
f's."{k}" = t."{k}"' for k in key_columns
|
|
577
|
+
)
|
|
578
|
+
|
|
579
|
+
mismatches: list[ReconciliationMismatch] = []
|
|
580
|
+
|
|
581
|
+
# Sample missing in target
|
|
582
|
+
sql_missing = f"""
|
|
583
|
+
SELECT {key_cols_sql} FROM {source_ref} s
|
|
584
|
+
WHERE NOT EXISTS (
|
|
585
|
+
SELECT 1 FROM {target_ref} t
|
|
586
|
+
WHERE {key_join_condition}
|
|
587
|
+
)
|
|
588
|
+
LIMIT {limit}
|
|
589
|
+
"""
|
|
590
|
+
missing_rows = self._engine.fetch_all(sql_missing)
|
|
591
|
+
for row in missing_rows:
|
|
592
|
+
key_values = dict(zip(key_columns, row))
|
|
593
|
+
mismatches.append(ReconciliationMismatch(
|
|
594
|
+
key_values=key_values,
|
|
595
|
+
column="(row)",
|
|
596
|
+
source_value="exists",
|
|
597
|
+
target_value="missing",
|
|
598
|
+
mismatch_type="missing_in_target",
|
|
599
|
+
))
|
|
600
|
+
|
|
601
|
+
# Sample value mismatches
|
|
602
|
+
remaining = limit - len(mismatches)
|
|
603
|
+
if remaining > 0 and compare_columns:
|
|
604
|
+
for col in compare_columns[:3]: # Limit columns for sampling
|
|
605
|
+
if tolerance > 0:
|
|
606
|
+
sql_diff = f"""
|
|
607
|
+
SELECT {key_cols_sql}, s."{col}" as source_val, t."{col}" as target_val
|
|
608
|
+
FROM {source_ref} s
|
|
609
|
+
INNER JOIN {target_ref} t ON {key_join_condition}
|
|
610
|
+
WHERE ABS(COALESCE(CAST(s."{col}" AS DOUBLE), 0) -
|
|
611
|
+
COALESCE(CAST(t."{col}" AS DOUBLE), 0)) > {tolerance}
|
|
612
|
+
OR (s."{col}" IS NULL AND t."{col}" IS NOT NULL)
|
|
613
|
+
OR (s."{col}" IS NOT NULL AND t."{col}" IS NULL)
|
|
614
|
+
LIMIT {remaining}
|
|
615
|
+
"""
|
|
616
|
+
else:
|
|
617
|
+
sql_diff = f"""
|
|
618
|
+
SELECT {key_cols_sql}, s."{col}" as source_val, t."{col}" as target_val
|
|
619
|
+
FROM {source_ref} s
|
|
620
|
+
INNER JOIN {target_ref} t ON {key_join_condition}
|
|
621
|
+
WHERE s."{col}" IS DISTINCT FROM t."{col}"
|
|
622
|
+
LIMIT {remaining}
|
|
623
|
+
"""
|
|
624
|
+
diff_rows = self._engine.fetch_all(sql_diff)
|
|
625
|
+
for row in diff_rows:
|
|
626
|
+
key_values = dict(zip(key_columns, row[:len(key_columns)]))
|
|
627
|
+
mismatches.append(ReconciliationMismatch(
|
|
628
|
+
key_values=key_values,
|
|
629
|
+
column=col,
|
|
630
|
+
source_value=row[len(key_columns)],
|
|
631
|
+
target_value=row[len(key_columns) + 1],
|
|
632
|
+
mismatch_type="value_diff",
|
|
633
|
+
))
|
|
634
|
+
|
|
635
|
+
return mismatches[:limit]
|
|
636
|
+
|
|
637
|
+
# =========================================================================
|
|
638
|
+
# Group By Methods
|
|
639
|
+
# =========================================================================
|
|
640
|
+
|
|
641
|
+
def group_by(self, columns: list[str] | str) -> GroupedDataset:
|
|
642
|
+
"""
|
|
643
|
+
Group the dataset by one or more columns for segmented validation.
|
|
644
|
+
|
|
645
|
+
Returns a GroupedDataset that allows running validation checks
|
|
646
|
+
on each group separately. Useful for partition-level data quality
|
|
647
|
+
checks and segmented analysis.
|
|
648
|
+
|
|
649
|
+
Args:
|
|
650
|
+
columns: Column name(s) to group by
|
|
651
|
+
|
|
652
|
+
Returns:
|
|
653
|
+
GroupedDataset for running group-level validations
|
|
654
|
+
|
|
655
|
+
Example:
|
|
656
|
+
# Validate each region has data
|
|
657
|
+
result = orders.group_by("region").row_count_greater_than(0)
|
|
658
|
+
|
|
659
|
+
# Validate per-date quality
|
|
660
|
+
result = orders.group_by(["date", "region"]).validate(
|
|
661
|
+
lambda g: g["amount"].null_percent < 5
|
|
662
|
+
)
|
|
663
|
+
|
|
664
|
+
# Get group statistics
|
|
665
|
+
stats = orders.group_by("status").stats()
|
|
666
|
+
"""
|
|
667
|
+
if isinstance(columns, str):
|
|
668
|
+
columns = [columns]
|
|
669
|
+
|
|
670
|
+
# Validate columns exist
|
|
671
|
+
for col in columns:
|
|
672
|
+
if col not in self.columns:
|
|
673
|
+
raise KeyError(
|
|
674
|
+
f"Column '{col}' not found. Available columns: {', '.join(self.columns)}"
|
|
675
|
+
)
|
|
676
|
+
|
|
677
|
+
return GroupedDataset(self, columns)
|
|
678
|
+
|
|
679
|
+
# =================================================================
|
|
680
|
+
# Multi-Column Validation Methods (DuckGuard 3.0)
|
|
681
|
+
# =================================================================
|
|
682
|
+
|
|
683
|
+
def expect_column_pair_satisfy(
|
|
684
|
+
self,
|
|
685
|
+
column_a: str,
|
|
686
|
+
column_b: str,
|
|
687
|
+
expression: str,
|
|
688
|
+
threshold: float = 1.0
|
|
689
|
+
) -> ValidationResult:
|
|
690
|
+
"""Check that column pair satisfies expression.
|
|
691
|
+
|
|
692
|
+
Args:
|
|
693
|
+
column_a: First column name
|
|
694
|
+
column_b: Second column name
|
|
695
|
+
expression: Expression to evaluate (e.g., "A > B", "A + B = 100")
|
|
696
|
+
threshold: Maximum allowed failure rate (0.0-1.0)
|
|
697
|
+
|
|
698
|
+
Returns:
|
|
699
|
+
ValidationResult with pass/fail status
|
|
700
|
+
|
|
701
|
+
Examples:
|
|
702
|
+
>>> data = connect("orders.csv")
|
|
703
|
+
>>> # Date range validation
|
|
704
|
+
>>> result = data.expect_column_pair_satisfy(
|
|
705
|
+
... column_a="end_date",
|
|
706
|
+
... column_b="start_date",
|
|
707
|
+
... expression="end_date >= start_date"
|
|
708
|
+
... )
|
|
709
|
+
>>> assert result.passed
|
|
710
|
+
|
|
711
|
+
>>> # Arithmetic validation
|
|
712
|
+
>>> result = data.expect_column_pair_satisfy(
|
|
713
|
+
... column_a="total",
|
|
714
|
+
... column_b="subtotal",
|
|
715
|
+
... expression="total = subtotal * 1.1"
|
|
716
|
+
... )
|
|
717
|
+
"""
|
|
718
|
+
from duckguard.checks.multicolumn import MultiColumnCheckHandler
|
|
719
|
+
|
|
720
|
+
handler = MultiColumnCheckHandler()
|
|
721
|
+
return handler.execute_column_pair_satisfy(
|
|
722
|
+
dataset=self,
|
|
723
|
+
column_a=column_a,
|
|
724
|
+
column_b=column_b,
|
|
725
|
+
expression=expression,
|
|
726
|
+
threshold=threshold
|
|
727
|
+
)
|
|
728
|
+
|
|
729
|
+
def expect_columns_unique(
|
|
730
|
+
self,
|
|
731
|
+
columns: list[str],
|
|
732
|
+
threshold: float = 1.0
|
|
733
|
+
) -> ValidationResult:
|
|
734
|
+
"""Check that combination of columns is unique (composite key).
|
|
735
|
+
|
|
736
|
+
Args:
|
|
737
|
+
columns: List of column names forming composite key
|
|
738
|
+
threshold: Minimum required uniqueness rate (0.0-1.0)
|
|
739
|
+
|
|
740
|
+
Returns:
|
|
741
|
+
ValidationResult with pass/fail status
|
|
742
|
+
|
|
743
|
+
Examples:
|
|
744
|
+
>>> # Two-column composite key
|
|
745
|
+
>>> result = data.expect_columns_unique(
|
|
746
|
+
... columns=["user_id", "session_id"]
|
|
747
|
+
... )
|
|
748
|
+
>>> assert result.passed
|
|
749
|
+
|
|
750
|
+
>>> # Three-column composite key
|
|
751
|
+
>>> result = data.expect_columns_unique(
|
|
752
|
+
... columns=["year", "month", "product_id"]
|
|
753
|
+
... )
|
|
754
|
+
"""
|
|
755
|
+
from duckguard.checks.multicolumn import MultiColumnCheckHandler
|
|
756
|
+
|
|
757
|
+
handler = MultiColumnCheckHandler()
|
|
758
|
+
return handler.execute_columns_unique(
|
|
759
|
+
dataset=self,
|
|
760
|
+
columns=columns,
|
|
761
|
+
threshold=threshold
|
|
762
|
+
)
|
|
763
|
+
|
|
764
|
+
def expect_multicolumn_sum_to_equal(
|
|
765
|
+
self,
|
|
766
|
+
columns: list[str],
|
|
767
|
+
expected_sum: float,
|
|
768
|
+
threshold: float = 0.01
|
|
769
|
+
) -> ValidationResult:
|
|
770
|
+
"""Check that sum of columns equals expected value.
|
|
771
|
+
|
|
772
|
+
Args:
|
|
773
|
+
columns: List of columns to sum
|
|
774
|
+
expected_sum: Expected sum value
|
|
775
|
+
threshold: Maximum allowed deviation
|
|
776
|
+
|
|
777
|
+
Returns:
|
|
778
|
+
ValidationResult with pass/fail status
|
|
779
|
+
|
|
780
|
+
Examples:
|
|
781
|
+
>>> # Components must sum to 100%
|
|
782
|
+
>>> result = data.expect_multicolumn_sum_to_equal(
|
|
783
|
+
... columns=["q1_pct", "q2_pct", "q3_pct", "q4_pct"],
|
|
784
|
+
... expected_sum=100.0
|
|
785
|
+
... )
|
|
786
|
+
>>> assert result.passed
|
|
787
|
+
|
|
788
|
+
>>> # Budget allocation check
|
|
789
|
+
>>> result = data.expect_multicolumn_sum_to_equal(
|
|
790
|
+
... columns=["marketing", "sales", "r_and_d"],
|
|
791
|
+
... expected_sum=data.total_budget
|
|
792
|
+
... )
|
|
793
|
+
"""
|
|
794
|
+
from duckguard.checks.multicolumn import MultiColumnCheckHandler
|
|
795
|
+
|
|
796
|
+
handler = MultiColumnCheckHandler()
|
|
797
|
+
return handler.execute_multicolumn_sum_equal(
|
|
798
|
+
dataset=self,
|
|
799
|
+
columns=columns,
|
|
800
|
+
expected_sum=expected_sum,
|
|
801
|
+
threshold=threshold
|
|
802
|
+
)
|
|
803
|
+
|
|
804
|
+
# =================================================================
|
|
805
|
+
# Query-Based Validation Methods (DuckGuard 3.0)
|
|
806
|
+
# =================================================================
|
|
807
|
+
|
|
808
|
+
def expect_query_to_return_no_rows(
|
|
809
|
+
self,
|
|
810
|
+
query: str,
|
|
811
|
+
message: str | None = None
|
|
812
|
+
) -> ValidationResult:
|
|
813
|
+
"""Check that custom SQL query returns no rows (finds no violations).
|
|
814
|
+
|
|
815
|
+
Use case: Write a query that finds violations. The check passes if
|
|
816
|
+
the query returns no rows (no violations found).
|
|
817
|
+
|
|
818
|
+
Args:
|
|
819
|
+
query: SQL SELECT query (use 'table' to reference the dataset)
|
|
820
|
+
message: Optional custom message
|
|
821
|
+
|
|
822
|
+
Returns:
|
|
823
|
+
ValidationResult (passed if query returns 0 rows)
|
|
824
|
+
|
|
825
|
+
Examples:
|
|
826
|
+
>>> data = connect("orders.csv")
|
|
827
|
+
>>> # Find invalid totals (total < subtotal)
|
|
828
|
+
>>> result = data.expect_query_to_return_no_rows(
|
|
829
|
+
... query="SELECT * FROM table WHERE total < subtotal"
|
|
830
|
+
... )
|
|
831
|
+
>>> assert result.passed
|
|
832
|
+
|
|
833
|
+
>>> # Find future dates
|
|
834
|
+
>>> result = data.expect_query_to_return_no_rows(
|
|
835
|
+
... query="SELECT * FROM table WHERE order_date > CURRENT_DATE"
|
|
836
|
+
... )
|
|
837
|
+
|
|
838
|
+
Security:
|
|
839
|
+
- Query is validated to prevent SQL injection
|
|
840
|
+
- Only SELECT queries allowed
|
|
841
|
+
- READ-ONLY mode enforced
|
|
842
|
+
- 30-second timeout
|
|
843
|
+
- 10,000 row result limit
|
|
844
|
+
"""
|
|
845
|
+
from duckguard.checks.query_based import QueryCheckHandler
|
|
846
|
+
|
|
847
|
+
handler = QueryCheckHandler()
|
|
848
|
+
return handler.execute_query_no_rows(
|
|
849
|
+
dataset=self,
|
|
850
|
+
query=query,
|
|
851
|
+
message=message
|
|
852
|
+
)
|
|
853
|
+
|
|
854
|
+
def expect_query_to_return_rows(
|
|
855
|
+
self,
|
|
856
|
+
query: str,
|
|
857
|
+
message: str | None = None
|
|
858
|
+
) -> ValidationResult:
|
|
859
|
+
"""Check that custom SQL query returns at least one row.
|
|
860
|
+
|
|
861
|
+
Use case: Ensure expected data exists in the dataset.
|
|
862
|
+
|
|
863
|
+
Args:
|
|
864
|
+
query: SQL SELECT query (use 'table' to reference the dataset)
|
|
865
|
+
message: Optional custom message
|
|
866
|
+
|
|
867
|
+
Returns:
|
|
868
|
+
ValidationResult (passed if query returns > 0 rows)
|
|
869
|
+
|
|
870
|
+
Examples:
|
|
871
|
+
>>> data = connect("products.csv")
|
|
872
|
+
>>> # Ensure we have active products
|
|
873
|
+
>>> result = data.expect_query_to_return_rows(
|
|
874
|
+
... query="SELECT * FROM table WHERE status = 'active'"
|
|
875
|
+
... )
|
|
876
|
+
>>> assert result.passed
|
|
877
|
+
|
|
878
|
+
>>> # Ensure we have recent data
|
|
879
|
+
>>> result = data.expect_query_to_return_rows(
|
|
880
|
+
... query="SELECT * FROM table WHERE created_at >= CURRENT_DATE - 7"
|
|
881
|
+
... )
|
|
882
|
+
|
|
883
|
+
Security:
|
|
884
|
+
- Query is validated to prevent SQL injection
|
|
885
|
+
- Only SELECT queries allowed
|
|
886
|
+
- READ-ONLY mode enforced
|
|
887
|
+
"""
|
|
888
|
+
from duckguard.checks.query_based import QueryCheckHandler
|
|
889
|
+
|
|
890
|
+
handler = QueryCheckHandler()
|
|
891
|
+
return handler.execute_query_returns_rows(
|
|
892
|
+
dataset=self,
|
|
893
|
+
query=query,
|
|
894
|
+
message=message
|
|
895
|
+
)
|
|
896
|
+
|
|
897
|
+
def expect_query_result_to_equal(
|
|
898
|
+
self,
|
|
899
|
+
query: str,
|
|
900
|
+
expected: Any,
|
|
901
|
+
tolerance: float | None = None,
|
|
902
|
+
message: str | None = None
|
|
903
|
+
) -> ValidationResult:
|
|
904
|
+
"""Check that custom SQL query returns a specific value.
|
|
905
|
+
|
|
906
|
+
Use case: Aggregate validation (COUNT, SUM, AVG, etc.)
|
|
907
|
+
|
|
908
|
+
Args:
|
|
909
|
+
query: SQL query returning single value (use 'table' to reference dataset)
|
|
910
|
+
expected: Expected value
|
|
911
|
+
tolerance: Optional tolerance for numeric comparisons
|
|
912
|
+
message: Optional custom message
|
|
913
|
+
|
|
914
|
+
Returns:
|
|
915
|
+
ValidationResult (passed if query result equals expected)
|
|
916
|
+
|
|
917
|
+
Examples:
|
|
918
|
+
>>> data = connect("orders.csv")
|
|
919
|
+
>>> # Check pending order count
|
|
920
|
+
>>> result = data.expect_query_result_to_equal(
|
|
921
|
+
... query="SELECT COUNT(*) FROM table WHERE status = 'pending'",
|
|
922
|
+
... expected=0
|
|
923
|
+
... )
|
|
924
|
+
>>> assert result.passed
|
|
925
|
+
|
|
926
|
+
>>> # Check average with tolerance
|
|
927
|
+
>>> result = data.expect_query_result_to_equal(
|
|
928
|
+
... query="SELECT AVG(price) FROM table",
|
|
929
|
+
... expected=100.0,
|
|
930
|
+
... tolerance=5.0
|
|
931
|
+
... )
|
|
932
|
+
|
|
933
|
+
>>> # Check sum constraint
|
|
934
|
+
>>> result = data.expect_query_result_to_equal(
|
|
935
|
+
... query="SELECT SUM(quantity) FROM table WHERE category = 'electronics'",
|
|
936
|
+
... expected=1000
|
|
937
|
+
... )
|
|
938
|
+
|
|
939
|
+
Security:
|
|
940
|
+
- Query must return exactly 1 row with 1 column
|
|
941
|
+
- Query is validated to prevent SQL injection
|
|
942
|
+
"""
|
|
943
|
+
from duckguard.checks.query_based import QueryCheckHandler
|
|
944
|
+
|
|
945
|
+
handler = QueryCheckHandler()
|
|
946
|
+
return handler.execute_query_result_equals(
|
|
947
|
+
dataset=self,
|
|
948
|
+
query=query,
|
|
949
|
+
expected=expected,
|
|
950
|
+
tolerance=tolerance,
|
|
951
|
+
message=message
|
|
952
|
+
)
|
|
953
|
+
|
|
954
|
+
def expect_query_result_to_be_between(
|
|
955
|
+
self,
|
|
956
|
+
query: str,
|
|
957
|
+
min_value: float,
|
|
958
|
+
max_value: float,
|
|
959
|
+
message: str | None = None
|
|
960
|
+
) -> ValidationResult:
|
|
961
|
+
"""Check that custom SQL query result is within a range.
|
|
962
|
+
|
|
963
|
+
Use case: Metric validation (e.g., average must be between X and Y)
|
|
964
|
+
|
|
965
|
+
Args:
|
|
966
|
+
query: SQL query returning single numeric value
|
|
967
|
+
min_value: Minimum allowed value (inclusive)
|
|
968
|
+
max_value: Maximum allowed value (inclusive)
|
|
969
|
+
message: Optional custom message
|
|
970
|
+
|
|
971
|
+
Returns:
|
|
972
|
+
ValidationResult (passed if min_value <= result <= max_value)
|
|
973
|
+
|
|
974
|
+
Examples:
|
|
975
|
+
>>> data = connect("metrics.csv")
|
|
976
|
+
>>> # Average price in range
|
|
977
|
+
>>> result = data.expect_query_result_to_be_between(
|
|
978
|
+
... query="SELECT AVG(price) FROM table",
|
|
979
|
+
... min_value=10.0,
|
|
980
|
+
... max_value=1000.0
|
|
981
|
+
... )
|
|
982
|
+
>>> assert result.passed
|
|
983
|
+
|
|
984
|
+
>>> # Null rate validation
|
|
985
|
+
>>> result = data.expect_query_result_to_be_between(
|
|
986
|
+
... query='''
|
|
987
|
+
... SELECT (COUNT(*) FILTER (WHERE price IS NULL)) * 100.0 / COUNT(*)
|
|
988
|
+
... FROM table
|
|
989
|
+
... ''',
|
|
990
|
+
... min_value=0.0,
|
|
991
|
+
... max_value=5.0 # Max 5% nulls
|
|
992
|
+
... )
|
|
993
|
+
|
|
994
|
+
Security:
|
|
995
|
+
- Query must return exactly 1 row with 1 numeric column
|
|
996
|
+
- Query is validated to prevent SQL injection
|
|
997
|
+
"""
|
|
998
|
+
from duckguard.checks.query_based import QueryCheckHandler
|
|
999
|
+
|
|
1000
|
+
handler = QueryCheckHandler()
|
|
1001
|
+
return handler.execute_query_result_between(
|
|
1002
|
+
dataset=self,
|
|
1003
|
+
query=query,
|
|
1004
|
+
min_value=min_value,
|
|
1005
|
+
max_value=max_value,
|
|
1006
|
+
message=message
|
|
1007
|
+
)
|
|
1008
|
+
|
|
1009
|
+
|
|
1010
|
+
class GroupedDataset:
|
|
1011
|
+
"""
|
|
1012
|
+
Represents a dataset grouped by one or more columns.
|
|
1013
|
+
|
|
1014
|
+
Provides methods for running validation checks and getting statistics
|
|
1015
|
+
at the group level. Created via Dataset.group_by().
|
|
1016
|
+
|
|
1017
|
+
Example:
|
|
1018
|
+
grouped = orders.group_by("region")
|
|
1019
|
+
result = grouped.row_count_greater_than(100)
|
|
1020
|
+
stats = grouped.stats()
|
|
1021
|
+
"""
|
|
1022
|
+
|
|
1023
|
+
def __init__(self, dataset: Dataset, group_columns: list[str]):
|
|
1024
|
+
"""Initialize a grouped dataset.
|
|
1025
|
+
|
|
1026
|
+
Args:
|
|
1027
|
+
dataset: The source dataset
|
|
1028
|
+
group_columns: Columns to group by
|
|
1029
|
+
"""
|
|
1030
|
+
self._dataset = dataset
|
|
1031
|
+
self._group_columns = group_columns
|
|
1032
|
+
|
|
1033
|
+
@property
|
|
1034
|
+
def groups(self) -> list[dict[str, Any]]:
|
|
1035
|
+
"""Get all distinct group key combinations."""
|
|
1036
|
+
ref = self._dataset.engine.get_source_reference(self._dataset.source)
|
|
1037
|
+
cols_sql = ", ".join(f'"{c}"' for c in self._group_columns)
|
|
1038
|
+
|
|
1039
|
+
sql = f"""
|
|
1040
|
+
SELECT DISTINCT {cols_sql}
|
|
1041
|
+
FROM {ref}
|
|
1042
|
+
ORDER BY {cols_sql}
|
|
1043
|
+
"""
|
|
1044
|
+
|
|
1045
|
+
rows = self._dataset.engine.fetch_all(sql)
|
|
1046
|
+
return [dict(zip(self._group_columns, row)) for row in rows]
|
|
1047
|
+
|
|
1048
|
+
@property
|
|
1049
|
+
def group_count(self) -> int:
|
|
1050
|
+
"""Get the number of distinct groups."""
|
|
1051
|
+
ref = self._dataset.engine.get_source_reference(self._dataset.source)
|
|
1052
|
+
cols_sql = ", ".join(f'"{c}"' for c in self._group_columns)
|
|
1053
|
+
|
|
1054
|
+
sql = f"""
|
|
1055
|
+
SELECT COUNT(DISTINCT ({cols_sql}))
|
|
1056
|
+
FROM {ref}
|
|
1057
|
+
"""
|
|
1058
|
+
|
|
1059
|
+
return self._dataset.engine.fetch_value(sql) or 0
|
|
1060
|
+
|
|
1061
|
+
def stats(self) -> list[dict[str, Any]]:
|
|
1062
|
+
"""
|
|
1063
|
+
Get statistics for each group.
|
|
1064
|
+
|
|
1065
|
+
Returns a list of dictionaries with group key values and statistics
|
|
1066
|
+
including row count, null counts, and basic aggregations.
|
|
1067
|
+
|
|
1068
|
+
Returns:
|
|
1069
|
+
List of group statistics dictionaries
|
|
1070
|
+
|
|
1071
|
+
Example:
|
|
1072
|
+
stats = orders.group_by("status").stats()
|
|
1073
|
+
for g in stats:
|
|
1074
|
+
print(f"{g['status']}: {g['row_count']} rows")
|
|
1075
|
+
"""
|
|
1076
|
+
ref = self._dataset.engine.get_source_reference(self._dataset.source)
|
|
1077
|
+
cols_sql = ", ".join(f'"{c}"' for c in self._group_columns)
|
|
1078
|
+
|
|
1079
|
+
sql = f"""
|
|
1080
|
+
SELECT {cols_sql}, COUNT(*) as row_count
|
|
1081
|
+
FROM {ref}
|
|
1082
|
+
GROUP BY {cols_sql}
|
|
1083
|
+
ORDER BY row_count DESC
|
|
1084
|
+
"""
|
|
1085
|
+
|
|
1086
|
+
rows = self._dataset.engine.fetch_all(sql)
|
|
1087
|
+
return [
|
|
1088
|
+
{**dict(zip(self._group_columns, row[:-1])), "row_count": row[-1]}
|
|
1089
|
+
for row in rows
|
|
1090
|
+
]
|
|
1091
|
+
|
|
1092
|
+
def row_count_greater_than(self, min_count: int) -> GroupByResult:
|
|
1093
|
+
"""
|
|
1094
|
+
Validate that each group has more than min_count rows.
|
|
1095
|
+
|
|
1096
|
+
Args:
|
|
1097
|
+
min_count: Minimum required rows per group
|
|
1098
|
+
|
|
1099
|
+
Returns:
|
|
1100
|
+
GroupByResult with per-group validation results
|
|
1101
|
+
|
|
1102
|
+
Example:
|
|
1103
|
+
result = orders.group_by("region").row_count_greater_than(100)
|
|
1104
|
+
if not result.passed:
|
|
1105
|
+
for g in result.get_failed_groups():
|
|
1106
|
+
print(f"Region {g.group_key} has only {g.row_count} rows")
|
|
1107
|
+
"""
|
|
1108
|
+
from duckguard.core.result import GroupByResult, GroupResult, ValidationResult
|
|
1109
|
+
|
|
1110
|
+
group_results: list[GroupResult] = []
|
|
1111
|
+
passed_count = 0
|
|
1112
|
+
|
|
1113
|
+
for group_stats in self.stats():
|
|
1114
|
+
group_key = {k: group_stats[k] for k in self._group_columns}
|
|
1115
|
+
row_count = group_stats["row_count"]
|
|
1116
|
+
passed = row_count > min_count
|
|
1117
|
+
|
|
1118
|
+
check_result = ValidationResult(
|
|
1119
|
+
passed=passed,
|
|
1120
|
+
actual_value=row_count,
|
|
1121
|
+
expected_value=f"> {min_count}",
|
|
1122
|
+
message=f"row_count = {row_count} {'>' if passed else '<='} {min_count}",
|
|
1123
|
+
)
|
|
1124
|
+
|
|
1125
|
+
group_results.append(GroupResult(
|
|
1126
|
+
group_key=group_key,
|
|
1127
|
+
row_count=row_count,
|
|
1128
|
+
passed=passed,
|
|
1129
|
+
check_results=[check_result],
|
|
1130
|
+
))
|
|
1131
|
+
|
|
1132
|
+
if passed:
|
|
1133
|
+
passed_count += 1
|
|
1134
|
+
|
|
1135
|
+
total_groups = len(group_results)
|
|
1136
|
+
all_passed = passed_count == total_groups
|
|
1137
|
+
|
|
1138
|
+
return GroupByResult(
|
|
1139
|
+
passed=all_passed,
|
|
1140
|
+
total_groups=total_groups,
|
|
1141
|
+
passed_groups=passed_count,
|
|
1142
|
+
failed_groups=total_groups - passed_count,
|
|
1143
|
+
group_results=group_results,
|
|
1144
|
+
group_columns=self._group_columns,
|
|
1145
|
+
)
|
|
1146
|
+
|
|
1147
|
+
def validate(
|
|
1148
|
+
self,
|
|
1149
|
+
check_fn,
|
|
1150
|
+
column: str | None = None,
|
|
1151
|
+
) -> GroupByResult:
|
|
1152
|
+
"""
|
|
1153
|
+
Run a custom validation function on each group.
|
|
1154
|
+
|
|
1155
|
+
Args:
|
|
1156
|
+
check_fn: Function that takes a group's column and returns ValidationResult
|
|
1157
|
+
column: Column to validate (required for column-level checks)
|
|
1158
|
+
|
|
1159
|
+
Returns:
|
|
1160
|
+
GroupByResult with per-group validation results
|
|
1161
|
+
|
|
1162
|
+
Example:
|
|
1163
|
+
# Check null percent per group
|
|
1164
|
+
result = orders.group_by("region").validate(
|
|
1165
|
+
lambda col: col.null_percent < 5,
|
|
1166
|
+
column="customer_id"
|
|
1167
|
+
)
|
|
1168
|
+
|
|
1169
|
+
# Check amount range per group
|
|
1170
|
+
result = orders.group_by("date").validate(
|
|
1171
|
+
lambda col: col.between(0, 10000),
|
|
1172
|
+
column="amount"
|
|
1173
|
+
)
|
|
1174
|
+
"""
|
|
1175
|
+
from duckguard.core.result import GroupByResult, GroupResult, ValidationResult
|
|
1176
|
+
|
|
1177
|
+
group_results: list[GroupResult] = []
|
|
1178
|
+
passed_count = 0
|
|
1179
|
+
|
|
1180
|
+
for group_key in self.groups:
|
|
1181
|
+
# Build WHERE clause for this group
|
|
1182
|
+
conditions = " AND ".join(
|
|
1183
|
+
f'"{k}" = {self._format_value(v)}'
|
|
1184
|
+
for k, v in group_key.items()
|
|
1185
|
+
)
|
|
1186
|
+
|
|
1187
|
+
# Create a filtered view of the data for this group
|
|
1188
|
+
ref = self._dataset.engine.get_source_reference(self._dataset.source)
|
|
1189
|
+
|
|
1190
|
+
# Get row count for this group
|
|
1191
|
+
sql_count = f"SELECT COUNT(*) FROM {ref} WHERE {conditions}"
|
|
1192
|
+
row_count = self._dataset.engine.fetch_value(sql_count) or 0
|
|
1193
|
+
|
|
1194
|
+
# Create a temporary filtered column for validation
|
|
1195
|
+
if column:
|
|
1196
|
+
group_col = _GroupColumn(
|
|
1197
|
+
name=column,
|
|
1198
|
+
dataset=self._dataset,
|
|
1199
|
+
filter_condition=conditions,
|
|
1200
|
+
)
|
|
1201
|
+
try:
|
|
1202
|
+
result = check_fn(group_col)
|
|
1203
|
+
if not isinstance(result, ValidationResult):
|
|
1204
|
+
# If check_fn returns a boolean (e.g., col.null_percent < 5)
|
|
1205
|
+
result = ValidationResult(
|
|
1206
|
+
passed=bool(result),
|
|
1207
|
+
actual_value=result,
|
|
1208
|
+
message=f"Custom check {'passed' if result else 'failed'}",
|
|
1209
|
+
)
|
|
1210
|
+
except Exception as e:
|
|
1211
|
+
result = ValidationResult(
|
|
1212
|
+
passed=False,
|
|
1213
|
+
actual_value=None,
|
|
1214
|
+
message=f"Check error: {e}",
|
|
1215
|
+
)
|
|
1216
|
+
else:
|
|
1217
|
+
result = ValidationResult(
|
|
1218
|
+
passed=True,
|
|
1219
|
+
actual_value=row_count,
|
|
1220
|
+
message="No column check specified",
|
|
1221
|
+
)
|
|
1222
|
+
|
|
1223
|
+
group_results.append(GroupResult(
|
|
1224
|
+
group_key=group_key,
|
|
1225
|
+
row_count=row_count,
|
|
1226
|
+
passed=result.passed,
|
|
1227
|
+
check_results=[result],
|
|
1228
|
+
))
|
|
1229
|
+
|
|
1230
|
+
if result.passed:
|
|
1231
|
+
passed_count += 1
|
|
1232
|
+
|
|
1233
|
+
total_groups = len(group_results)
|
|
1234
|
+
all_passed = passed_count == total_groups
|
|
1235
|
+
|
|
1236
|
+
return GroupByResult(
|
|
1237
|
+
passed=all_passed,
|
|
1238
|
+
total_groups=total_groups,
|
|
1239
|
+
passed_groups=passed_count,
|
|
1240
|
+
failed_groups=total_groups - passed_count,
|
|
1241
|
+
group_results=group_results,
|
|
1242
|
+
group_columns=self._group_columns,
|
|
1243
|
+
)
|
|
1244
|
+
|
|
1245
|
+
def _format_value(self, value: Any) -> str:
|
|
1246
|
+
"""Format a value for SQL WHERE clause."""
|
|
1247
|
+
if value is None:
|
|
1248
|
+
return "NULL"
|
|
1249
|
+
elif isinstance(value, str):
|
|
1250
|
+
return f"'{value}'"
|
|
1251
|
+
elif isinstance(value, bool):
|
|
1252
|
+
return "TRUE" if value else "FALSE"
|
|
1253
|
+
else:
|
|
1254
|
+
return str(value)
|
|
1255
|
+
|
|
1256
|
+
|
|
1257
|
+
class _GroupColumn:
|
|
1258
|
+
"""
|
|
1259
|
+
A column wrapper that applies a filter condition.
|
|
1260
|
+
|
|
1261
|
+
Used internally by GroupedDataset to validate columns within groups.
|
|
1262
|
+
"""
|
|
1263
|
+
|
|
1264
|
+
def __init__(self, name: str, dataset: Dataset, filter_condition: str):
|
|
1265
|
+
self._name = name
|
|
1266
|
+
self._dataset = dataset
|
|
1267
|
+
self._filter_condition = filter_condition
|
|
1268
|
+
|
|
1269
|
+
@property
|
|
1270
|
+
def null_percent(self) -> float:
|
|
1271
|
+
"""Get null percentage for this column within the group."""
|
|
1272
|
+
ref = self._dataset.engine.get_source_reference(self._dataset.source)
|
|
1273
|
+
col = f'"{self._name}"'
|
|
1274
|
+
|
|
1275
|
+
sql = f"""
|
|
1276
|
+
SELECT
|
|
1277
|
+
(COUNT(*) - COUNT({col})) * 100.0 / NULLIF(COUNT(*), 0) as null_pct
|
|
1278
|
+
FROM {ref}
|
|
1279
|
+
WHERE {self._filter_condition}
|
|
1280
|
+
"""
|
|
1281
|
+
|
|
1282
|
+
result = self._dataset.engine.fetch_value(sql)
|
|
1283
|
+
return float(result) if result is not None else 0.0
|
|
1284
|
+
|
|
1285
|
+
@property
|
|
1286
|
+
def count(self) -> int:
|
|
1287
|
+
"""Get non-null count for this column within the group."""
|
|
1288
|
+
ref = self._dataset.engine.get_source_reference(self._dataset.source)
|
|
1289
|
+
col = f'"{self._name}"'
|
|
1290
|
+
|
|
1291
|
+
sql = f"""
|
|
1292
|
+
SELECT COUNT({col})
|
|
1293
|
+
FROM {ref}
|
|
1294
|
+
WHERE {self._filter_condition}
|
|
1295
|
+
"""
|
|
1296
|
+
|
|
1297
|
+
return self._dataset.engine.fetch_value(sql) or 0
|
|
1298
|
+
|
|
1299
|
+
def between(self, min_val: Any, max_val: Any) -> ValidationResult:
|
|
1300
|
+
"""Check if all values are between min and max within the group."""
|
|
1301
|
+
ref = self._dataset.engine.get_source_reference(self._dataset.source)
|
|
1302
|
+
col = f'"{self._name}"'
|
|
1303
|
+
|
|
1304
|
+
sql = f"""
|
|
1305
|
+
SELECT COUNT(*) FROM {ref}
|
|
1306
|
+
WHERE {self._filter_condition}
|
|
1307
|
+
AND {col} IS NOT NULL
|
|
1308
|
+
AND ({col} < {min_val} OR {col} > {max_val})
|
|
1309
|
+
"""
|
|
1310
|
+
|
|
1311
|
+
out_of_range = self._dataset.engine.fetch_value(sql) or 0
|
|
1312
|
+
passed = out_of_range == 0
|
|
1313
|
+
|
|
1314
|
+
return ValidationResult(
|
|
1315
|
+
passed=passed,
|
|
1316
|
+
actual_value=out_of_range,
|
|
1317
|
+
expected_value=0,
|
|
1318
|
+
message=f"{out_of_range} values outside [{min_val}, {max_val}]",
|
|
1319
|
+
)
|