duckguard 2.0.0__py3-none-any.whl → 2.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (71) hide show
  1. duckguard/__init__.py +55 -28
  2. duckguard/anomaly/__init__.py +29 -1
  3. duckguard/anomaly/baselines.py +294 -0
  4. duckguard/anomaly/detector.py +1 -5
  5. duckguard/anomaly/methods.py +17 -5
  6. duckguard/anomaly/ml_methods.py +724 -0
  7. duckguard/cli/main.py +561 -56
  8. duckguard/connectors/__init__.py +2 -2
  9. duckguard/connectors/bigquery.py +1 -1
  10. duckguard/connectors/databricks.py +1 -1
  11. duckguard/connectors/factory.py +2 -3
  12. duckguard/connectors/files.py +1 -1
  13. duckguard/connectors/kafka.py +2 -2
  14. duckguard/connectors/mongodb.py +1 -1
  15. duckguard/connectors/mysql.py +1 -1
  16. duckguard/connectors/oracle.py +1 -1
  17. duckguard/connectors/postgres.py +1 -2
  18. duckguard/connectors/redshift.py +1 -1
  19. duckguard/connectors/snowflake.py +1 -2
  20. duckguard/connectors/sqlite.py +1 -1
  21. duckguard/connectors/sqlserver.py +10 -13
  22. duckguard/contracts/__init__.py +6 -6
  23. duckguard/contracts/diff.py +1 -1
  24. duckguard/contracts/generator.py +5 -6
  25. duckguard/contracts/loader.py +4 -4
  26. duckguard/contracts/validator.py +3 -4
  27. duckguard/core/__init__.py +3 -3
  28. duckguard/core/column.py +588 -5
  29. duckguard/core/dataset.py +708 -3
  30. duckguard/core/result.py +328 -1
  31. duckguard/core/scoring.py +1 -2
  32. duckguard/errors.py +362 -0
  33. duckguard/freshness/__init__.py +33 -0
  34. duckguard/freshness/monitor.py +429 -0
  35. duckguard/history/__init__.py +44 -0
  36. duckguard/history/schema.py +301 -0
  37. duckguard/history/storage.py +479 -0
  38. duckguard/history/trends.py +348 -0
  39. duckguard/integrations/__init__.py +31 -0
  40. duckguard/integrations/airflow.py +387 -0
  41. duckguard/integrations/dbt.py +458 -0
  42. duckguard/notifications/__init__.py +61 -0
  43. duckguard/notifications/email.py +508 -0
  44. duckguard/notifications/formatter.py +118 -0
  45. duckguard/notifications/notifiers.py +357 -0
  46. duckguard/profiler/auto_profile.py +3 -3
  47. duckguard/pytest_plugin/__init__.py +1 -1
  48. duckguard/pytest_plugin/plugin.py +1 -1
  49. duckguard/reporting/console.py +2 -2
  50. duckguard/reports/__init__.py +42 -0
  51. duckguard/reports/html_reporter.py +514 -0
  52. duckguard/reports/pdf_reporter.py +114 -0
  53. duckguard/rules/__init__.py +3 -3
  54. duckguard/rules/executor.py +3 -4
  55. duckguard/rules/generator.py +8 -5
  56. duckguard/rules/loader.py +5 -5
  57. duckguard/rules/schema.py +23 -0
  58. duckguard/schema_history/__init__.py +40 -0
  59. duckguard/schema_history/analyzer.py +414 -0
  60. duckguard/schema_history/tracker.py +288 -0
  61. duckguard/semantic/__init__.py +1 -1
  62. duckguard/semantic/analyzer.py +0 -2
  63. duckguard/semantic/detector.py +17 -1
  64. duckguard/semantic/validators.py +2 -1
  65. duckguard-2.3.0.dist-info/METADATA +953 -0
  66. duckguard-2.3.0.dist-info/RECORD +77 -0
  67. duckguard-2.0.0.dist-info/METADATA +0 -221
  68. duckguard-2.0.0.dist-info/RECORD +0 -55
  69. {duckguard-2.0.0.dist-info → duckguard-2.3.0.dist-info}/WHEEL +0 -0
  70. {duckguard-2.0.0.dist-info → duckguard-2.3.0.dist-info}/entry_points.txt +0 -0
  71. {duckguard-2.0.0.dist-info → duckguard-2.3.0.dist-info}/licenses/LICENSE +0 -0
duckguard/core/dataset.py CHANGED
@@ -4,11 +4,15 @@ from __future__ import annotations
4
4
 
5
5
  from typing import TYPE_CHECKING, Any
6
6
 
7
- from duckguard.core.engine import DuckGuardEngine
8
7
  from duckguard.core.column import Column
8
+ from duckguard.core.engine import DuckGuardEngine
9
+ from duckguard.core.result import GroupByResult, ReconciliationResult, ValidationResult
9
10
 
10
11
  if TYPE_CHECKING:
12
+ from datetime import timedelta
13
+
11
14
  from duckguard.core.scoring import QualityScore
15
+ from duckguard.freshness import FreshnessResult
12
16
 
13
17
 
14
18
  class Dataset:
@@ -227,10 +231,133 @@ class Dataset:
227
231
  """Iterate over column names."""
228
232
  return iter(self.columns)
229
233
 
234
+ @property
235
+ def freshness(self) -> FreshnessResult:
236
+ """
237
+ Get freshness information for this dataset.
238
+
239
+ Returns:
240
+ FreshnessResult with freshness information including:
241
+ - last_modified: When data was last updated
242
+ - age_seconds: Age in seconds
243
+ - age_human: Human-readable age string
244
+ - is_fresh: Whether data meets default 24h threshold
245
+
246
+ Example:
247
+ data = connect("data.csv")
248
+ print(data.freshness.age_human) # "2 hours ago"
249
+ print(data.freshness.is_fresh) # True
250
+ """
251
+ from duckguard.freshness import FreshnessMonitor
252
+
253
+ monitor = FreshnessMonitor()
254
+ return monitor.check(self)
255
+
256
+ def is_fresh(self, max_age: timedelta) -> bool:
257
+ """
258
+ Check if data is fresher than the specified maximum age.
259
+
260
+ Args:
261
+ max_age: Maximum acceptable age for the data
262
+
263
+ Returns:
264
+ True if data is fresher than max_age
265
+
266
+ Example:
267
+ from datetime import timedelta
268
+ data = connect("data.csv")
269
+
270
+ if not data.is_fresh(timedelta(hours=6)):
271
+ print("Data is stale!")
272
+ """
273
+ from duckguard.freshness import FreshnessMonitor
274
+
275
+ monitor = FreshnessMonitor(threshold=max_age)
276
+ result = monitor.check(self)
277
+ return result.is_fresh
278
+
279
+ # =========================================================================
280
+ # Cross-Dataset Validation Methods
281
+ # =========================================================================
282
+
283
+ def row_count_matches(
284
+ self,
285
+ other_dataset: Dataset,
286
+ tolerance: int = 0,
287
+ ) -> ValidationResult:
288
+ """
289
+ Check that row count matches another dataset within tolerance.
290
+
291
+ Useful for comparing backup data, validating migrations,
292
+ or ensuring parallel pipelines produce consistent results.
293
+
294
+ Args:
295
+ other_dataset: Dataset to compare against
296
+ tolerance: Allowed difference in row counts (default: 0 = exact match)
297
+
298
+ Returns:
299
+ ValidationResult indicating if row counts match
300
+
301
+ Example:
302
+ orders = connect("orders.parquet")
303
+ backup = connect("orders_backup.parquet")
304
+
305
+ # Exact match
306
+ result = orders.row_count_matches(backup)
307
+
308
+ # Allow small difference
309
+ result = orders.row_count_matches(backup, tolerance=10)
310
+ """
311
+ source_count = self.row_count
312
+ other_count = other_dataset.row_count
313
+ diff = abs(source_count - other_count)
314
+ passed = diff <= tolerance
315
+
316
+ other_name = other_dataset.name or other_dataset.source
317
+ if tolerance == 0:
318
+ message = f"Row counts {'match' if passed else 'differ'}: {self._name}={source_count}, {other_name}={other_count}"
319
+ else:
320
+ message = f"Row count difference is {diff} (tolerance: {tolerance}): {self._name}={source_count}, {other_name}={other_count}"
321
+
322
+ return ValidationResult(
323
+ passed=passed,
324
+ actual_value=diff,
325
+ expected_value=f"<= {tolerance}",
326
+ message=message,
327
+ details={
328
+ "source_count": source_count,
329
+ "other_count": other_count,
330
+ "difference": diff,
331
+ "tolerance": tolerance,
332
+ "source_dataset": self._name,
333
+ "other_dataset": other_name,
334
+ },
335
+ )
336
+
337
+ def row_count_equals(
338
+ self,
339
+ other_dataset: Dataset,
340
+ ) -> ValidationResult:
341
+ """
342
+ Check that row count exactly equals another dataset.
343
+
344
+ This is a convenience alias for row_count_matches(other, tolerance=0).
345
+
346
+ Args:
347
+ other_dataset: Dataset to compare against
348
+
349
+ Returns:
350
+ ValidationResult indicating if row counts are equal
351
+
352
+ Example:
353
+ result = orders.row_count_equals(backup_orders)
354
+ """
355
+ return self.row_count_matches(other_dataset, tolerance=0)
356
+
230
357
  def score(
231
358
  self,
232
359
  weights: dict | None = None,
233
- ) -> "QualityScore":
360
+ ) -> QualityScore:
234
361
  """
235
362
  Calculate data quality score for this dataset.
236
363
 
@@ -262,7 +389,7 @@ class Dataset:
262
389
  'consistency': 0.1,
263
390
  })
264
391
  """
265
- from duckguard.core.scoring import QualityScorer, QualityDimension
392
+ from duckguard.core.scoring import QualityDimension, QualityScorer
266
393
 
267
394
  # Convert string keys to QualityDimension enums if needed
268
395
  scorer_weights = None
@@ -282,3 +409,581 @@ class Dataset:
282
409
 
283
410
  scorer = QualityScorer(weights=scorer_weights)
284
411
  return scorer.score(self)
412
+
413
+ # =========================================================================
414
+ # Reconciliation Methods
415
+ # =========================================================================
416
+
417
+ def reconcile(
418
+ self,
419
+ target_dataset: Dataset,
420
+ key_columns: list[str],
421
+ compare_columns: list[str] | None = None,
422
+ tolerance: float = 0.0,
423
+ sample_mismatches: int = 10,
424
+ ) -> ReconciliationResult:
425
+ """
426
+ Reconcile this dataset with a target dataset.
427
+
428
+ Performs comprehensive comparison including row matching, missing/extra
429
+ detection, and column-by-column value comparison. Essential for
430
+ migration validation and data synchronization checks.
431
+
432
+ Args:
433
+ target_dataset: Dataset to compare against
434
+ key_columns: Columns to use for matching rows (like a primary key)
435
+ compare_columns: Columns to compare values (default: all non-key columns)
436
+ tolerance: Numeric tolerance for value comparison (default: 0 = exact match)
437
+ sample_mismatches: Number of sample mismatches to capture (default: 10)
438
+
439
+ Returns:
440
+ ReconciliationResult with detailed comparison metrics
441
+
442
+ Example:
443
+ source = connect("orders_source.parquet")
444
+ target = connect("orders_target.parquet")
445
+
446
+ result = source.reconcile(
447
+ target,
448
+ key_columns=["order_id"],
449
+ compare_columns=["amount", "status", "customer_id"]
450
+ )
451
+
452
+ if not result.passed:
453
+ print(f"Missing in target: {result.missing_in_target}")
454
+ print(f"Extra in target: {result.extra_in_target}")
455
+ print(result.summary())
456
+ """
457
+ from duckguard.core.result import ReconciliationMismatch, ReconciliationResult
458
+
459
+ source_ref = self._engine.get_source_reference(self._source)
460
+ target_ref = target_dataset.engine.get_source_reference(target_dataset.source)
461
+ target_name = target_dataset.name or target_dataset.source
462
+
463
+ # Determine columns to compare
464
+ if compare_columns is None:
465
+ compare_columns = [c for c in self.columns if c not in key_columns]
466
+
467
+ # Build key column references
468
+ key_join_condition = " AND ".join(
469
+ f's."{k}" = t."{k}"' for k in key_columns
470
+ )
471
+
472
+ # Count rows in source not in target (missing)
473
+ sql_missing = f"""
474
+ SELECT COUNT(*) FROM {source_ref} s
475
+ WHERE NOT EXISTS (
476
+ SELECT 1 FROM {target_ref} t
477
+ WHERE {key_join_condition}
478
+ )
479
+ """
480
+ missing_count = self._engine.fetch_value(sql_missing) or 0
481
+
482
+ # Count rows in target not in source (extra)
483
+ sql_extra = f"""
484
+ SELECT COUNT(*) FROM {target_ref} t
485
+ WHERE NOT EXISTS (
486
+ SELECT 1 FROM {source_ref} s
487
+ WHERE {key_join_condition}
488
+ )
489
+ """
490
+ extra_count = self._engine.fetch_value(sql_extra) or 0
491
+
492
+ # Count value mismatches per column for matching rows
493
+ value_mismatches: dict[str, int] = {}
494
+ for col in compare_columns:
495
+ if tolerance > 0:
496
+ # Numeric tolerance comparison
497
+ sql_mismatch = f"""
498
+ SELECT COUNT(*) FROM {source_ref} s
499
+ INNER JOIN {target_ref} t ON {key_join_condition}
500
+ WHERE ABS(COALESCE(CAST(s."{col}" AS DOUBLE), 0) -
501
+ COALESCE(CAST(t."{col}" AS DOUBLE), 0)) > {tolerance}
502
+ OR (s."{col}" IS NULL AND t."{col}" IS NOT NULL)
503
+ OR (s."{col}" IS NOT NULL AND t."{col}" IS NULL)
504
+ """
505
+ else:
506
+ # Exact match comparison
507
+ sql_mismatch = f"""
508
+ SELECT COUNT(*) FROM {source_ref} s
509
+ INNER JOIN {target_ref} t ON {key_join_condition}
510
+ WHERE s."{col}" IS DISTINCT FROM t."{col}"
511
+ """
512
+ mismatch_count = self._engine.fetch_value(sql_mismatch) or 0
513
+ if mismatch_count > 0:
514
+ value_mismatches[col] = mismatch_count
515
+
516
+ # Calculate match percentage
517
+ source_count = self.row_count
518
+ target_count = target_dataset.row_count
519
+ matched_rows = source_count - missing_count
520
+ total_comparisons = matched_rows * len(compare_columns) if compare_columns else matched_rows
521
+ total_mismatches = sum(value_mismatches.values())
522
+
523
+ if total_comparisons > 0:
524
+ match_percentage = ((total_comparisons - total_mismatches) / total_comparisons) * 100
525
+ else:
526
+ match_percentage = 100.0 if missing_count == 0 and extra_count == 0 else 0.0
527
+
528
+ # Capture sample mismatches
529
+ mismatches: list[ReconciliationMismatch] = []
530
+ if sample_mismatches > 0 and (missing_count > 0 or extra_count > 0 or value_mismatches):
531
+ mismatches = self._get_sample_mismatches(
532
+ target_dataset,
533
+ key_columns,
534
+ compare_columns,
535
+ tolerance,
536
+ sample_mismatches,
537
+ )
538
+
539
+ # Determine if passed (no mismatches at all)
540
+ passed = missing_count == 0 and extra_count == 0 and len(value_mismatches) == 0
541
+
542
+ return ReconciliationResult(
543
+ passed=passed,
544
+ source_row_count=source_count,
545
+ target_row_count=target_count,
546
+ missing_in_target=missing_count,
547
+ extra_in_target=extra_count,
548
+ value_mismatches=value_mismatches,
549
+ match_percentage=match_percentage,
550
+ key_columns=key_columns,
551
+ compared_columns=compare_columns,
552
+ mismatches=mismatches,
553
+ details={
554
+ "source_dataset": self._name,
555
+ "target_dataset": target_name,
556
+ "tolerance": tolerance,
557
+ },
558
+ )
559
+
560
+ def _get_sample_mismatches(
561
+ self,
562
+ target_dataset: Dataset,
563
+ key_columns: list[str],
564
+ compare_columns: list[str],
565
+ tolerance: float,
566
+ limit: int,
567
+ ) -> list:
568
+ """Get sample of reconciliation mismatches."""
569
+ from duckguard.core.result import ReconciliationMismatch
570
+
571
+ source_ref = self._engine.get_source_reference(self._source)
572
+ target_ref = target_dataset.engine.get_source_reference(target_dataset.source)
573
+
574
+ key_cols_sql = ", ".join(f's."{k}"' for k in key_columns)
575
+ key_join_condition = " AND ".join(
576
+ f's."{k}" = t."{k}"' for k in key_columns
577
+ )
578
+
579
+ mismatches: list[ReconciliationMismatch] = []
580
+
581
+ # Sample missing in target
582
+ sql_missing = f"""
583
+ SELECT {key_cols_sql} FROM {source_ref} s
584
+ WHERE NOT EXISTS (
585
+ SELECT 1 FROM {target_ref} t
586
+ WHERE {key_join_condition}
587
+ )
588
+ LIMIT {limit}
589
+ """
590
+ missing_rows = self._engine.fetch_all(sql_missing)
591
+ for row in missing_rows:
592
+ key_values = dict(zip(key_columns, row))
593
+ mismatches.append(ReconciliationMismatch(
594
+ key_values=key_values,
595
+ column="(row)",
596
+ source_value="exists",
597
+ target_value="missing",
598
+ mismatch_type="missing_in_target",
599
+ ))
600
+
601
+ # Sample value mismatches
602
+ remaining = limit - len(mismatches)
603
+ if remaining > 0 and compare_columns:
604
+ for col in compare_columns[:3]: # Limit columns for sampling
605
+ if tolerance > 0:
606
+ sql_diff = f"""
607
+ SELECT {key_cols_sql}, s."{col}" as source_val, t."{col}" as target_val
608
+ FROM {source_ref} s
609
+ INNER JOIN {target_ref} t ON {key_join_condition}
610
+ WHERE ABS(COALESCE(CAST(s."{col}" AS DOUBLE), 0) -
611
+ COALESCE(CAST(t."{col}" AS DOUBLE), 0)) > {tolerance}
612
+ OR (s."{col}" IS NULL AND t."{col}" IS NOT NULL)
613
+ OR (s."{col}" IS NOT NULL AND t."{col}" IS NULL)
614
+ LIMIT {remaining}
615
+ """
616
+ else:
617
+ sql_diff = f"""
618
+ SELECT {key_cols_sql}, s."{col}" as source_val, t."{col}" as target_val
619
+ FROM {source_ref} s
620
+ INNER JOIN {target_ref} t ON {key_join_condition}
621
+ WHERE s."{col}" IS DISTINCT FROM t."{col}"
622
+ LIMIT {remaining}
623
+ """
624
+ diff_rows = self._engine.fetch_all(sql_diff)
625
+ for row in diff_rows:
626
+ key_values = dict(zip(key_columns, row[:len(key_columns)]))
627
+ mismatches.append(ReconciliationMismatch(
628
+ key_values=key_values,
629
+ column=col,
630
+ source_value=row[len(key_columns)],
631
+ target_value=row[len(key_columns) + 1],
632
+ mismatch_type="value_diff",
633
+ ))
634
+
635
+ return mismatches[:limit]
636
+
637
+ # =========================================================================
638
+ # Group By Methods
639
+ # =========================================================================
640
+
641
+ def group_by(self, columns: list[str] | str) -> GroupedDataset:
642
+ """
643
+ Group the dataset by one or more columns for segmented validation.
644
+
645
+ Returns a GroupedDataset that allows running validation checks
646
+ on each group separately. Useful for partition-level data quality
647
+ checks and segmented analysis.
648
+
649
+ Args:
650
+ columns: Column name(s) to group by
651
+
652
+ Returns:
653
+ GroupedDataset for running group-level validations
654
+
655
+ Example:
656
+ # Validate each region has data
657
+ result = orders.group_by("region").row_count_greater_than(0)
658
+
659
+ # Validate per-date quality
660
+ result = orders.group_by(["date", "region"]).validate(
661
+ lambda g: g["amount"].null_percent < 5
662
+ )
663
+
664
+ # Get group statistics
665
+ stats = orders.group_by("status").stats()
666
+ """
667
+ if isinstance(columns, str):
668
+ columns = [columns]
669
+
670
+ # Validate columns exist
671
+ for col in columns:
672
+ if col not in self.columns:
673
+ raise KeyError(
674
+ f"Column '{col}' not found. Available columns: {', '.join(self.columns)}"
675
+ )
676
+
677
+ return GroupedDataset(self, columns)
678
+
679
+
680
+ class GroupedDataset:
681
+ """
682
+ Represents a dataset grouped by one or more columns.
683
+
684
+ Provides methods for running validation checks and getting statistics
685
+ at the group level. Created via Dataset.group_by().
686
+
687
+ Example:
688
+ grouped = orders.group_by("region")
689
+ result = grouped.row_count_greater_than(100)
690
+ stats = grouped.stats()
691
+ """
692
+
693
+ def __init__(self, dataset: Dataset, group_columns: list[str]):
694
+ """Initialize a grouped dataset.
695
+
696
+ Args:
697
+ dataset: The source dataset
698
+ group_columns: Columns to group by
699
+ """
700
+ self._dataset = dataset
701
+ self._group_columns = group_columns
702
+
703
+ @property
704
+ def groups(self) -> list[dict[str, Any]]:
705
+ """Get all distinct group key combinations."""
706
+ ref = self._dataset.engine.get_source_reference(self._dataset.source)
707
+ cols_sql = ", ".join(f'"{c}"' for c in self._group_columns)
708
+
709
+ sql = f"""
710
+ SELECT DISTINCT {cols_sql}
711
+ FROM {ref}
712
+ ORDER BY {cols_sql}
713
+ """
714
+
715
+ rows = self._dataset.engine.fetch_all(sql)
716
+ return [dict(zip(self._group_columns, row)) for row in rows]
717
+
718
+ @property
719
+ def group_count(self) -> int:
720
+ """Get the number of distinct groups."""
721
+ ref = self._dataset.engine.get_source_reference(self._dataset.source)
722
+ cols_sql = ", ".join(f'"{c}"' for c in self._group_columns)
723
+
724
+ sql = f"""
725
+ SELECT COUNT(DISTINCT ({cols_sql}))
726
+ FROM {ref}
727
+ """
728
+
729
+ return self._dataset.engine.fetch_value(sql) or 0
730
+
731
+ def stats(self) -> list[dict[str, Any]]:
732
+ """
733
+ Get statistics for each group.
734
+
735
+ Returns a list of dictionaries with group key values and statistics
736
+ including row count, null counts, and basic aggregations.
737
+
738
+ Returns:
739
+ List of group statistics dictionaries
740
+
741
+ Example:
742
+ stats = orders.group_by("status").stats()
743
+ for g in stats:
744
+ print(f"{g['status']}: {g['row_count']} rows")
745
+ """
746
+ ref = self._dataset.engine.get_source_reference(self._dataset.source)
747
+ cols_sql = ", ".join(f'"{c}"' for c in self._group_columns)
748
+
749
+ sql = f"""
750
+ SELECT {cols_sql}, COUNT(*) as row_count
751
+ FROM {ref}
752
+ GROUP BY {cols_sql}
753
+ ORDER BY row_count DESC
754
+ """
755
+
756
+ rows = self._dataset.engine.fetch_all(sql)
757
+ return [
758
+ {**dict(zip(self._group_columns, row[:-1])), "row_count": row[-1]}
759
+ for row in rows
760
+ ]
761
+
762
+ def row_count_greater_than(self, min_count: int) -> GroupByResult:
763
+ """
764
+ Validate that each group has more than min_count rows.
765
+
766
+ Args:
767
+ min_count: Minimum required rows per group
768
+
769
+ Returns:
770
+ GroupByResult with per-group validation results
771
+
772
+ Example:
773
+ result = orders.group_by("region").row_count_greater_than(100)
774
+ if not result.passed:
775
+ for g in result.get_failed_groups():
776
+ print(f"Region {g.group_key} has only {g.row_count} rows")
777
+ """
778
+ from duckguard.core.result import GroupByResult, GroupResult, ValidationResult
779
+
780
+ group_results: list[GroupResult] = []
781
+ passed_count = 0
782
+
783
+ for group_stats in self.stats():
784
+ group_key = {k: group_stats[k] for k in self._group_columns}
785
+ row_count = group_stats["row_count"]
786
+ passed = row_count > min_count
787
+
788
+ check_result = ValidationResult(
789
+ passed=passed,
790
+ actual_value=row_count,
791
+ expected_value=f"> {min_count}",
792
+ message=f"row_count = {row_count} {'>' if passed else '<='} {min_count}",
793
+ )
794
+
795
+ group_results.append(GroupResult(
796
+ group_key=group_key,
797
+ row_count=row_count,
798
+ passed=passed,
799
+ check_results=[check_result],
800
+ ))
801
+
802
+ if passed:
803
+ passed_count += 1
804
+
805
+ total_groups = len(group_results)
806
+ all_passed = passed_count == total_groups
807
+
808
+ return GroupByResult(
809
+ passed=all_passed,
810
+ total_groups=total_groups,
811
+ passed_groups=passed_count,
812
+ failed_groups=total_groups - passed_count,
813
+ group_results=group_results,
814
+ group_columns=self._group_columns,
815
+ )
816
+
817
+ def validate(
818
+ self,
819
+ check_fn,
820
+ column: str | None = None,
821
+ ) -> GroupByResult:
822
+ """
823
+ Run a custom validation function on each group.
824
+
825
+ Args:
826
+ check_fn: Function that takes a group's column and returns ValidationResult
827
+ column: Column to validate (required for column-level checks)
828
+
829
+ Returns:
830
+ GroupByResult with per-group validation results
831
+
832
+ Example:
833
+ # Check null percent per group
834
+ result = orders.group_by("region").validate(
835
+ lambda col: col.null_percent < 5,
836
+ column="customer_id"
837
+ )
838
+
839
+ # Check amount range per group
840
+ result = orders.group_by("date").validate(
841
+ lambda col: col.between(0, 10000),
842
+ column="amount"
843
+ )
844
+ """
845
+ from duckguard.core.result import GroupByResult, GroupResult, ValidationResult
846
+
847
+ group_results: list[GroupResult] = []
848
+ passed_count = 0
849
+
850
+ for group_key in self.groups:
851
+ # Build WHERE clause for this group
852
+ conditions = " AND ".join(
853
+ f'"{k}" = {self._format_value(v)}'
854
+ for k, v in group_key.items()
855
+ )
856
+
857
+ # Create a filtered view of the data for this group
858
+ ref = self._dataset.engine.get_source_reference(self._dataset.source)
859
+
860
+ # Get row count for this group
861
+ sql_count = f"SELECT COUNT(*) FROM {ref} WHERE {conditions}"
862
+ row_count = self._dataset.engine.fetch_value(sql_count) or 0
863
+
864
+ # Create a temporary filtered column for validation
865
+ if column:
866
+ group_col = _GroupColumn(
867
+ name=column,
868
+ dataset=self._dataset,
869
+ filter_condition=conditions,
870
+ )
871
+ try:
872
+ result = check_fn(group_col)
873
+ if not isinstance(result, ValidationResult):
874
+ # If check_fn returns a boolean (e.g., col.null_percent < 5)
875
+ result = ValidationResult(
876
+ passed=bool(result),
877
+ actual_value=result,
878
+ message=f"Custom check {'passed' if result else 'failed'}",
879
+ )
880
+ except Exception as e:
881
+ result = ValidationResult(
882
+ passed=False,
883
+ actual_value=None,
884
+ message=f"Check error: {e}",
885
+ )
886
+ else:
887
+ result = ValidationResult(
888
+ passed=True,
889
+ actual_value=row_count,
890
+ message="No column check specified",
891
+ )
892
+
893
+ group_results.append(GroupResult(
894
+ group_key=group_key,
895
+ row_count=row_count,
896
+ passed=result.passed,
897
+ check_results=[result],
898
+ ))
899
+
900
+ if result.passed:
901
+ passed_count += 1
902
+
903
+ total_groups = len(group_results)
904
+ all_passed = passed_count == total_groups
905
+
906
+ return GroupByResult(
907
+ passed=all_passed,
908
+ total_groups=total_groups,
909
+ passed_groups=passed_count,
910
+ failed_groups=total_groups - passed_count,
911
+ group_results=group_results,
912
+ group_columns=self._group_columns,
913
+ )
914
+
915
+ def _format_value(self, value: Any) -> str:
916
+ """Format a value for SQL WHERE clause."""
917
+ if value is None:
918
+ return "NULL"
919
+ elif isinstance(value, str):
920
+ return f"'{value}'"
921
+ elif isinstance(value, bool):
922
+ return "TRUE" if value else "FALSE"
923
+ else:
924
+ return str(value)
925
+
926
+
927
+ class _GroupColumn:
928
+ """
929
+ A column wrapper that applies a filter condition.
930
+
931
+ Used internally by GroupedDataset to validate columns within groups.
932
+ """
933
+
934
+ def __init__(self, name: str, dataset: Dataset, filter_condition: str):
935
+ self._name = name
936
+ self._dataset = dataset
937
+ self._filter_condition = filter_condition
938
+
939
+ @property
940
+ def null_percent(self) -> float:
941
+ """Get null percentage for this column within the group."""
942
+ ref = self._dataset.engine.get_source_reference(self._dataset.source)
943
+ col = f'"{self._name}"'
944
+
945
+ sql = f"""
946
+ SELECT
947
+ (COUNT(*) - COUNT({col})) * 100.0 / NULLIF(COUNT(*), 0) as null_pct
948
+ FROM {ref}
949
+ WHERE {self._filter_condition}
950
+ """
951
+
952
+ result = self._dataset.engine.fetch_value(sql)
953
+ return float(result) if result is not None else 0.0
954
+
955
+ @property
956
+ def count(self) -> int:
957
+ """Get non-null count for this column within the group."""
958
+ ref = self._dataset.engine.get_source_reference(self._dataset.source)
959
+ col = f'"{self._name}"'
960
+
961
+ sql = f"""
962
+ SELECT COUNT({col})
963
+ FROM {ref}
964
+ WHERE {self._filter_condition}
965
+ """
966
+
967
+ return self._dataset.engine.fetch_value(sql) or 0
968
+
969
+ def between(self, min_val: Any, max_val: Any) -> ValidationResult:
970
+ """Check if all values are between min and max within the group."""
971
+ ref = self._dataset.engine.get_source_reference(self._dataset.source)
972
+ col = f'"{self._name}"'
973
+
974
+ sql = f"""
975
+ SELECT COUNT(*) FROM {ref}
976
+ WHERE {self._filter_condition}
977
+ AND {col} IS NOT NULL
978
+ AND ({col} < {min_val} OR {col} > {max_val})
979
+ """
980
+
981
+ out_of_range = self._dataset.engine.fetch_value(sql) or 0
982
+ passed = out_of_range == 0
983
+
984
+ return ValidationResult(
985
+ passed=passed,
986
+ actual_value=out_of_range,
987
+ expected_value=0,
988
+ message=f"{out_of_range} values outside [{min_val}, {max_val}]",
989
+ )