duckguard 2.2.0__py3-none-any.whl → 3.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (39) hide show
  1. duckguard/__init__.py +1 -1
  2. duckguard/anomaly/__init__.py +28 -0
  3. duckguard/anomaly/baselines.py +294 -0
  4. duckguard/anomaly/methods.py +16 -2
  5. duckguard/anomaly/ml_methods.py +724 -0
  6. duckguard/checks/__init__.py +26 -0
  7. duckguard/checks/conditional.py +796 -0
  8. duckguard/checks/distributional.py +524 -0
  9. duckguard/checks/multicolumn.py +726 -0
  10. duckguard/checks/query_based.py +643 -0
  11. duckguard/cli/main.py +257 -2
  12. duckguard/connectors/factory.py +30 -2
  13. duckguard/connectors/files.py +7 -3
  14. duckguard/core/column.py +851 -1
  15. duckguard/core/dataset.py +1035 -0
  16. duckguard/core/result.py +236 -0
  17. duckguard/freshness/__init__.py +33 -0
  18. duckguard/freshness/monitor.py +429 -0
  19. duckguard/history/schema.py +119 -1
  20. duckguard/notifications/__init__.py +20 -2
  21. duckguard/notifications/email.py +508 -0
  22. duckguard/profiler/distribution_analyzer.py +384 -0
  23. duckguard/profiler/outlier_detector.py +497 -0
  24. duckguard/profiler/pattern_matcher.py +301 -0
  25. duckguard/profiler/quality_scorer.py +445 -0
  26. duckguard/reports/html_reporter.py +1 -2
  27. duckguard/rules/executor.py +642 -0
  28. duckguard/rules/generator.py +4 -1
  29. duckguard/rules/schema.py +54 -0
  30. duckguard/schema_history/__init__.py +40 -0
  31. duckguard/schema_history/analyzer.py +414 -0
  32. duckguard/schema_history/tracker.py +288 -0
  33. duckguard/semantic/detector.py +17 -1
  34. duckguard-3.0.0.dist-info/METADATA +1072 -0
  35. {duckguard-2.2.0.dist-info → duckguard-3.0.0.dist-info}/RECORD +38 -21
  36. duckguard-2.2.0.dist-info/METADATA +0 -351
  37. {duckguard-2.2.0.dist-info → duckguard-3.0.0.dist-info}/WHEEL +0 -0
  38. {duckguard-2.2.0.dist-info → duckguard-3.0.0.dist-info}/entry_points.txt +0 -0
  39. {duckguard-2.2.0.dist-info → duckguard-3.0.0.dist-info}/licenses/LICENSE +0 -0
duckguard/core/dataset.py CHANGED
@@ -6,9 +6,13 @@ from typing import TYPE_CHECKING, Any
6
6
 
7
7
  from duckguard.core.column import Column
8
8
  from duckguard.core.engine import DuckGuardEngine
9
+ from duckguard.core.result import GroupByResult, ReconciliationResult, ValidationResult
9
10
 
10
11
  if TYPE_CHECKING:
12
+ from datetime import timedelta
13
+
11
14
  from duckguard.core.scoring import QualityScore
15
+ from duckguard.freshness import FreshnessResult
12
16
 
13
17
 
14
18
  class Dataset:
@@ -227,6 +231,129 @@ class Dataset:
227
231
  """Iterate over column names."""
228
232
  return iter(self.columns)
229
233
 
234
+ @property
235
+ def freshness(self) -> FreshnessResult:
236
+ """
237
+ Get freshness information for this dataset.
238
+
239
+ Returns:
240
+ FreshnessResult with freshness information including:
241
+ - last_modified: When data was last updated
242
+ - age_seconds: Age in seconds
243
+ - age_human: Human-readable age string
244
+ - is_fresh: Whether data meets default 24h threshold
245
+
246
+ Example:
247
+ data = connect("data.csv")
248
+ print(data.freshness.age_human) # "2 hours ago"
249
+ print(data.freshness.is_fresh) # True
250
+ """
251
+ from duckguard.freshness import FreshnessMonitor
252
+
253
+ monitor = FreshnessMonitor()
254
+ return monitor.check(self)
255
+
256
+ def is_fresh(self, max_age: timedelta) -> bool:
257
+ """
258
+ Check if data is fresher than the specified maximum age.
259
+
260
+ Args:
261
+ max_age: Maximum acceptable age for the data
262
+
263
+ Returns:
264
+ True if data is fresher than max_age
265
+
266
+ Example:
267
+ from datetime import timedelta
268
+ data = connect("data.csv")
269
+
270
+ if not data.is_fresh(timedelta(hours=6)):
271
+ print("Data is stale!")
272
+ """
273
+ from duckguard.freshness import FreshnessMonitor
274
+
275
+ monitor = FreshnessMonitor(threshold=max_age)
276
+ result = monitor.check(self)
277
+ return result.is_fresh
278
+
279
+ # =========================================================================
280
+ # Cross-Dataset Validation Methods
281
+ # =========================================================================
282
+
283
+ def row_count_matches(
284
+ self,
285
+ other_dataset: Dataset,
286
+ tolerance: int = 0,
287
+ ) -> ValidationResult:
288
+ """
289
+ Check that row count matches another dataset within tolerance.
290
+
291
+ Useful for comparing backup data, validating migrations,
292
+ or ensuring parallel pipelines produce consistent results.
293
+
294
+ Args:
295
+ other_dataset: Dataset to compare against
296
+ tolerance: Allowed difference in row counts (default: 0 = exact match)
297
+
298
+ Returns:
299
+ ValidationResult indicating if row counts match
300
+
301
+ Example:
302
+ orders = connect("orders.parquet")
303
+ backup = connect("orders_backup.parquet")
304
+
305
+ # Exact match
306
+ result = orders.row_count_matches(backup)
307
+
308
+ # Allow small difference
309
+ result = orders.row_count_matches(backup, tolerance=10)
310
+ """
311
+ source_count = self.row_count
312
+ other_count = other_dataset.row_count
313
+ diff = abs(source_count - other_count)
314
+ passed = diff <= tolerance
315
+
316
+ other_name = other_dataset.name or other_dataset.source
317
+ if tolerance == 0:
318
+ message = f"Row counts {'match' if passed else 'differ'}: {self._name}={source_count}, {other_name}={other_count}"
319
+ else:
320
+ message = f"Row count difference is {diff} (tolerance: {tolerance}): {self._name}={source_count}, {other_name}={other_count}"
321
+
322
+ return ValidationResult(
323
+ passed=passed,
324
+ actual_value=diff,
325
+ expected_value=f"<= {tolerance}",
326
+ message=message,
327
+ details={
328
+ "source_count": source_count,
329
+ "other_count": other_count,
330
+ "difference": diff,
331
+ "tolerance": tolerance,
332
+ "source_dataset": self._name,
333
+ "other_dataset": other_name,
334
+ },
335
+ )
336
+
337
+ def row_count_equals(
338
+ self,
339
+ other_dataset: Dataset,
340
+ ) -> ValidationResult:
341
+ """
342
+ Check that row count exactly equals another dataset.
343
+
344
+ This is a convenience alias for row_count_matches(other, tolerance=0).
345
+
346
+ Args:
347
+ other_dataset: Dataset to compare against
348
+
349
+ Returns:
350
+ ValidationResult indicating if row counts are equal
351
+
352
+ Example:
353
+ result = orders.row_count_equals(backup_orders)
354
+ """
355
+ return self.row_count_matches(other_dataset, tolerance=0)
356
+
230
357
  def score(
231
358
  self,
232
359
  weights: dict | None = None,
@@ -282,3 +409,911 @@ class Dataset:
282
409
 
283
410
  scorer = QualityScorer(weights=scorer_weights)
284
411
  return scorer.score(self)
412
+
413
+ # =========================================================================
414
+ # Reconciliation Methods
415
+ # =========================================================================
416
+
417
+ def reconcile(
418
+ self,
419
+ target_dataset: Dataset,
420
+ key_columns: list[str],
421
+ compare_columns: list[str] | None = None,
422
+ tolerance: float = 0.0,
423
+ sample_mismatches: int = 10,
424
+ ) -> ReconciliationResult:
425
+ """
426
+ Reconcile this dataset with a target dataset.
427
+
428
+ Performs comprehensive comparison including row matching, missing/extra
429
+ detection, and column-by-column value comparison. Essential for
430
+ migration validation and data synchronization checks.
431
+
432
+ Args:
433
+ target_dataset: Dataset to compare against
434
+ key_columns: Columns to use for matching rows (like a primary key)
435
+ compare_columns: Columns to compare values (default: all non-key columns)
436
+ tolerance: Numeric tolerance for value comparison (default: 0 = exact match)
437
+ sample_mismatches: Number of sample mismatches to capture (default: 10)
438
+
439
+ Returns:
440
+ ReconciliationResult with detailed comparison metrics
441
+
442
+ Example:
443
+ source = connect("orders_source.parquet")
444
+ target = connect("orders_target.parquet")
445
+
446
+ result = source.reconcile(
447
+ target,
448
+ key_columns=["order_id"],
449
+ compare_columns=["amount", "status", "customer_id"]
450
+ )
451
+
452
+ if not result.passed:
453
+ print(f"Missing in target: {result.missing_in_target}")
454
+ print(f"Extra in target: {result.extra_in_target}")
455
+ print(result.summary())
456
+ """
457
+ from duckguard.core.result import ReconciliationMismatch, ReconciliationResult
458
+
459
+ source_ref = self._engine.get_source_reference(self._source)
460
+ target_ref = target_dataset.engine.get_source_reference(target_dataset.source)
461
+ target_name = target_dataset.name or target_dataset.source
462
+
463
+ # Determine columns to compare
464
+ if compare_columns is None:
465
+ compare_columns = [c for c in self.columns if c not in key_columns]
466
+
467
+ # Build key column references
468
+ key_join_condition = " AND ".join(
469
+ f's."{k}" = t."{k}"' for k in key_columns
470
+ )
471
+
472
+ # Count rows in source not in target (missing)
473
+ sql_missing = f"""
474
+ SELECT COUNT(*) FROM {source_ref} s
475
+ WHERE NOT EXISTS (
476
+ SELECT 1 FROM {target_ref} t
477
+ WHERE {key_join_condition}
478
+ )
479
+ """
480
+ missing_count = self._engine.fetch_value(sql_missing) or 0
481
+
482
+ # Count rows in target not in source (extra)
483
+ sql_extra = f"""
484
+ SELECT COUNT(*) FROM {target_ref} t
485
+ WHERE NOT EXISTS (
486
+ SELECT 1 FROM {source_ref} s
487
+ WHERE {key_join_condition}
488
+ )
489
+ """
490
+ extra_count = self._engine.fetch_value(sql_extra) or 0
491
+
492
+ # Count value mismatches per column for matching rows
493
+ value_mismatches: dict[str, int] = {}
494
+ for col in compare_columns:
495
+ if tolerance > 0:
496
+ # Numeric tolerance comparison
497
+ sql_mismatch = f"""
498
+ SELECT COUNT(*) FROM {source_ref} s
499
+ INNER JOIN {target_ref} t ON {key_join_condition}
500
+ WHERE ABS(COALESCE(CAST(s."{col}" AS DOUBLE), 0) -
501
+ COALESCE(CAST(t."{col}" AS DOUBLE), 0)) > {tolerance}
502
+ OR (s."{col}" IS NULL AND t."{col}" IS NOT NULL)
503
+ OR (s."{col}" IS NOT NULL AND t."{col}" IS NULL)
504
+ """
505
+ else:
506
+ # Exact match comparison
507
+ sql_mismatch = f"""
508
+ SELECT COUNT(*) FROM {source_ref} s
509
+ INNER JOIN {target_ref} t ON {key_join_condition}
510
+ WHERE s."{col}" IS DISTINCT FROM t."{col}"
511
+ """
512
+ mismatch_count = self._engine.fetch_value(sql_mismatch) or 0
513
+ if mismatch_count > 0:
514
+ value_mismatches[col] = mismatch_count
515
+
516
+ # Calculate match percentage
517
+ source_count = self.row_count
518
+ target_count = target_dataset.row_count
519
+ matched_rows = source_count - missing_count
520
+ total_comparisons = matched_rows * len(compare_columns) if compare_columns else matched_rows
521
+ total_mismatches = sum(value_mismatches.values())
522
+
523
+ if total_comparisons > 0:
524
+ match_percentage = ((total_comparisons - total_mismatches) / total_comparisons) * 100
525
+ else:
526
+ match_percentage = 100.0 if missing_count == 0 and extra_count == 0 else 0.0
527
+
528
+ # Capture sample mismatches
529
+ mismatches: list[ReconciliationMismatch] = []
530
+ if sample_mismatches > 0 and (missing_count > 0 or extra_count > 0 or value_mismatches):
531
+ mismatches = self._get_sample_mismatches(
532
+ target_dataset,
533
+ key_columns,
534
+ compare_columns,
535
+ tolerance,
536
+ sample_mismatches,
537
+ )
538
+
539
+ # Determine if passed (no mismatches at all)
540
+ passed = missing_count == 0 and extra_count == 0 and len(value_mismatches) == 0
541
+
542
+ return ReconciliationResult(
543
+ passed=passed,
544
+ source_row_count=source_count,
545
+ target_row_count=target_count,
546
+ missing_in_target=missing_count,
547
+ extra_in_target=extra_count,
548
+ value_mismatches=value_mismatches,
549
+ match_percentage=match_percentage,
550
+ key_columns=key_columns,
551
+ compared_columns=compare_columns,
552
+ mismatches=mismatches,
553
+ details={
554
+ "source_dataset": self._name,
555
+ "target_dataset": target_name,
556
+ "tolerance": tolerance,
557
+ },
558
+ )
559
+
560
+ def _get_sample_mismatches(
561
+ self,
562
+ target_dataset: Dataset,
563
+ key_columns: list[str],
564
+ compare_columns: list[str],
565
+ tolerance: float,
566
+ limit: int,
567
+ ) -> list:
568
+ """Get sample of reconciliation mismatches."""
569
+ from duckguard.core.result import ReconciliationMismatch
570
+
571
+ source_ref = self._engine.get_source_reference(self._source)
572
+ target_ref = target_dataset.engine.get_source_reference(target_dataset.source)
573
+
574
+ key_cols_sql = ", ".join(f's."{k}"' for k in key_columns)
575
+ key_join_condition = " AND ".join(
576
+ f's."{k}" = t."{k}"' for k in key_columns
577
+ )
578
+
579
+ mismatches: list[ReconciliationMismatch] = []
580
+
581
+ # Sample missing in target
582
+ sql_missing = f"""
583
+ SELECT {key_cols_sql} FROM {source_ref} s
584
+ WHERE NOT EXISTS (
585
+ SELECT 1 FROM {target_ref} t
586
+ WHERE {key_join_condition}
587
+ )
588
+ LIMIT {limit}
589
+ """
590
+ missing_rows = self._engine.fetch_all(sql_missing)
591
+ for row in missing_rows:
592
+ key_values = dict(zip(key_columns, row))
593
+ mismatches.append(ReconciliationMismatch(
594
+ key_values=key_values,
595
+ column="(row)",
596
+ source_value="exists",
597
+ target_value="missing",
598
+ mismatch_type="missing_in_target",
599
+ ))
600
+
601
+ # Sample value mismatches
602
+ remaining = limit - len(mismatches)
603
+ if remaining > 0 and compare_columns:
604
+ for col in compare_columns[:3]: # Limit columns for sampling
605
+ if tolerance > 0:
606
+ sql_diff = f"""
607
+ SELECT {key_cols_sql}, s."{col}" as source_val, t."{col}" as target_val
608
+ FROM {source_ref} s
609
+ INNER JOIN {target_ref} t ON {key_join_condition}
610
+ WHERE ABS(COALESCE(CAST(s."{col}" AS DOUBLE), 0) -
611
+ COALESCE(CAST(t."{col}" AS DOUBLE), 0)) > {tolerance}
612
+ OR (s."{col}" IS NULL AND t."{col}" IS NOT NULL)
613
+ OR (s."{col}" IS NOT NULL AND t."{col}" IS NULL)
614
+ LIMIT {remaining}
615
+ """
616
+ else:
617
+ sql_diff = f"""
618
+ SELECT {key_cols_sql}, s."{col}" as source_val, t."{col}" as target_val
619
+ FROM {source_ref} s
620
+ INNER JOIN {target_ref} t ON {key_join_condition}
621
+ WHERE s."{col}" IS DISTINCT FROM t."{col}"
622
+ LIMIT {remaining}
623
+ """
624
+ diff_rows = self._engine.fetch_all(sql_diff)
625
+ for row in diff_rows:
626
+ key_values = dict(zip(key_columns, row[:len(key_columns)]))
627
+ mismatches.append(ReconciliationMismatch(
628
+ key_values=key_values,
629
+ column=col,
630
+ source_value=row[len(key_columns)],
631
+ target_value=row[len(key_columns) + 1],
632
+ mismatch_type="value_diff",
633
+ ))
634
+
635
+ return mismatches[:limit]
636
+
637
+ # =========================================================================
638
+ # Group By Methods
639
+ # =========================================================================
640
+
641
+ def group_by(self, columns: list[str] | str) -> GroupedDataset:
642
+ """
643
+ Group the dataset by one or more columns for segmented validation.
644
+
645
+ Returns a GroupedDataset that allows running validation checks
646
+ on each group separately. Useful for partition-level data quality
647
+ checks and segmented analysis.
648
+
649
+ Args:
650
+ columns: Column name(s) to group by
651
+
652
+ Returns:
653
+ GroupedDataset for running group-level validations
654
+
655
+ Example:
656
+ # Validate each region has data
657
+ result = orders.group_by("region").row_count_greater_than(0)
658
+
659
+ # Validate per-date quality
660
+ result = orders.group_by(["date", "region"]).validate(
661
+ lambda g: g["amount"].null_percent < 5
662
+ )
663
+
664
+ # Get group statistics
665
+ stats = orders.group_by("status").stats()
666
+ """
667
+ if isinstance(columns, str):
668
+ columns = [columns]
669
+
670
+ # Validate columns exist
671
+ for col in columns:
672
+ if col not in self.columns:
673
+ raise KeyError(
674
+ f"Column '{col}' not found. Available columns: {', '.join(self.columns)}"
675
+ )
676
+
677
+ return GroupedDataset(self, columns)
678
+
679
+ # =================================================================
680
+ # Multi-Column Validation Methods (DuckGuard 3.0)
681
+ # =================================================================
682
+
683
+ def expect_column_pair_satisfy(
684
+ self,
685
+ column_a: str,
686
+ column_b: str,
687
+ expression: str,
688
+ threshold: float = 1.0
689
+ ) -> ValidationResult:
690
+ """Check that column pair satisfies expression.
691
+
692
+ Args:
693
+ column_a: First column name
694
+ column_b: Second column name
695
+ expression: Expression to evaluate (e.g., "A > B", "A + B = 100")
696
+ threshold: Maximum allowed failure rate (0.0-1.0)
697
+
698
+ Returns:
699
+ ValidationResult with pass/fail status
700
+
701
+ Examples:
702
+ >>> data = connect("orders.csv")
703
+ >>> # Date range validation
704
+ >>> result = data.expect_column_pair_satisfy(
705
+ ... column_a="end_date",
706
+ ... column_b="start_date",
707
+ ... expression="end_date >= start_date"
708
+ ... )
709
+ >>> assert result.passed
710
+
711
+ >>> # Arithmetic validation
712
+ >>> result = data.expect_column_pair_satisfy(
713
+ ... column_a="total",
714
+ ... column_b="subtotal",
715
+ ... expression="total = subtotal * 1.1"
716
+ ... )
717
+ """
718
+ from duckguard.checks.multicolumn import MultiColumnCheckHandler
719
+
720
+ handler = MultiColumnCheckHandler()
721
+ return handler.execute_column_pair_satisfy(
722
+ dataset=self,
723
+ column_a=column_a,
724
+ column_b=column_b,
725
+ expression=expression,
726
+ threshold=threshold
727
+ )
728
+
729
+ def expect_columns_unique(
730
+ self,
731
+ columns: list[str],
732
+ threshold: float = 1.0
733
+ ) -> ValidationResult:
734
+ """Check that combination of columns is unique (composite key).
735
+
736
+ Args:
737
+ columns: List of column names forming composite key
738
+ threshold: Minimum required uniqueness rate (0.0-1.0)
739
+
740
+ Returns:
741
+ ValidationResult with pass/fail status
742
+
743
+ Examples:
744
+ >>> # Two-column composite key
745
+ >>> result = data.expect_columns_unique(
746
+ ... columns=["user_id", "session_id"]
747
+ ... )
748
+ >>> assert result.passed
749
+
750
+ >>> # Three-column composite key
751
+ >>> result = data.expect_columns_unique(
752
+ ... columns=["year", "month", "product_id"]
753
+ ... )
754
+ """
755
+ from duckguard.checks.multicolumn import MultiColumnCheckHandler
756
+
757
+ handler = MultiColumnCheckHandler()
758
+ return handler.execute_columns_unique(
759
+ dataset=self,
760
+ columns=columns,
761
+ threshold=threshold
762
+ )
763
+
764
+ def expect_multicolumn_sum_to_equal(
765
+ self,
766
+ columns: list[str],
767
+ expected_sum: float,
768
+ threshold: float = 0.01
769
+ ) -> ValidationResult:
770
+ """Check that sum of columns equals expected value.
771
+
772
+ Args:
773
+ columns: List of columns to sum
774
+ expected_sum: Expected sum value
775
+ threshold: Maximum allowed deviation
776
+
777
+ Returns:
778
+ ValidationResult with pass/fail status
779
+
780
+ Examples:
781
+ >>> # Components must sum to 100%
782
+ >>> result = data.expect_multicolumn_sum_to_equal(
783
+ ... columns=["q1_pct", "q2_pct", "q3_pct", "q4_pct"],
784
+ ... expected_sum=100.0
785
+ ... )
786
+ >>> assert result.passed
787
+
788
+ >>> # Budget allocation check
789
+ >>> result = data.expect_multicolumn_sum_to_equal(
790
+ ... columns=["marketing", "sales", "r_and_d"],
791
+ ... expected_sum=data.total_budget
792
+ ... )
793
+ """
794
+ from duckguard.checks.multicolumn import MultiColumnCheckHandler
795
+
796
+ handler = MultiColumnCheckHandler()
797
+ return handler.execute_multicolumn_sum_equal(
798
+ dataset=self,
799
+ columns=columns,
800
+ expected_sum=expected_sum,
801
+ threshold=threshold
802
+ )
803
+
804
+ # =================================================================
805
+ # Query-Based Validation Methods (DuckGuard 3.0)
806
+ # =================================================================
807
+
808
+ def expect_query_to_return_no_rows(
809
+ self,
810
+ query: str,
811
+ message: str | None = None
812
+ ) -> ValidationResult:
813
+ """Check that custom SQL query returns no rows (finds no violations).
814
+
815
+ Use case: Write a query that finds violations. The check passes if
816
+ the query returns no rows (no violations found).
817
+
818
+ Args:
819
+ query: SQL SELECT query (use 'table' to reference the dataset)
820
+ message: Optional custom message
821
+
822
+ Returns:
823
+ ValidationResult (passed if query returns 0 rows)
824
+
825
+ Examples:
826
+ >>> data = connect("orders.csv")
827
+ >>> # Find invalid totals (total < subtotal)
828
+ >>> result = data.expect_query_to_return_no_rows(
829
+ ... query="SELECT * FROM table WHERE total < subtotal"
830
+ ... )
831
+ >>> assert result.passed
832
+
833
+ >>> # Find future dates
834
+ >>> result = data.expect_query_to_return_no_rows(
835
+ ... query="SELECT * FROM table WHERE order_date > CURRENT_DATE"
836
+ ... )
837
+
838
+ Security:
839
+ - Query is validated to prevent SQL injection
840
+ - Only SELECT queries allowed
841
+ - READ-ONLY mode enforced
842
+ - 30-second timeout
843
+ - 10,000 row result limit
844
+ """
845
+ from duckguard.checks.query_based import QueryCheckHandler
846
+
847
+ handler = QueryCheckHandler()
848
+ return handler.execute_query_no_rows(
849
+ dataset=self,
850
+ query=query,
851
+ message=message
852
+ )
853
+
854
+ def expect_query_to_return_rows(
855
+ self,
856
+ query: str,
857
+ message: str | None = None
858
+ ) -> ValidationResult:
859
+ """Check that custom SQL query returns at least one row.
860
+
861
+ Use case: Ensure expected data exists in the dataset.
862
+
863
+ Args:
864
+ query: SQL SELECT query (use 'table' to reference the dataset)
865
+ message: Optional custom message
866
+
867
+ Returns:
868
+ ValidationResult (passed if query returns > 0 rows)
869
+
870
+ Examples:
871
+ >>> data = connect("products.csv")
872
+ >>> # Ensure we have active products
873
+ >>> result = data.expect_query_to_return_rows(
874
+ ... query="SELECT * FROM table WHERE status = 'active'"
875
+ ... )
876
+ >>> assert result.passed
877
+
878
+ >>> # Ensure we have recent data
879
+ >>> result = data.expect_query_to_return_rows(
880
+ ... query="SELECT * FROM table WHERE created_at >= CURRENT_DATE - 7"
881
+ ... )
882
+
883
+ Security:
884
+ - Query is validated to prevent SQL injection
885
+ - Only SELECT queries allowed
886
+ - READ-ONLY mode enforced
887
+ """
888
+ from duckguard.checks.query_based import QueryCheckHandler
889
+
890
+ handler = QueryCheckHandler()
891
+ return handler.execute_query_returns_rows(
892
+ dataset=self,
893
+ query=query,
894
+ message=message
895
+ )
896
+
897
+ def expect_query_result_to_equal(
898
+ self,
899
+ query: str,
900
+ expected: Any,
901
+ tolerance: float | None = None,
902
+ message: str | None = None
903
+ ) -> ValidationResult:
904
+ """Check that custom SQL query returns a specific value.
905
+
906
+ Use case: Aggregate validation (COUNT, SUM, AVG, etc.)
907
+
908
+ Args:
909
+ query: SQL query returning single value (use 'table' to reference dataset)
910
+ expected: Expected value
911
+ tolerance: Optional tolerance for numeric comparisons
912
+ message: Optional custom message
913
+
914
+ Returns:
915
+ ValidationResult (passed if query result equals expected)
916
+
917
+ Examples:
918
+ >>> data = connect("orders.csv")
919
+ >>> # Check pending order count
920
+ >>> result = data.expect_query_result_to_equal(
921
+ ... query="SELECT COUNT(*) FROM table WHERE status = 'pending'",
922
+ ... expected=0
923
+ ... )
924
+ >>> assert result.passed
925
+
926
+ >>> # Check average with tolerance
927
+ >>> result = data.expect_query_result_to_equal(
928
+ ... query="SELECT AVG(price) FROM table",
929
+ ... expected=100.0,
930
+ ... tolerance=5.0
931
+ ... )
932
+
933
+ >>> # Check sum constraint
934
+ >>> result = data.expect_query_result_to_equal(
935
+ ... query="SELECT SUM(quantity) FROM table WHERE category = 'electronics'",
936
+ ... expected=1000
937
+ ... )
938
+
939
+ Security:
940
+ - Query must return exactly 1 row with 1 column
941
+ - Query is validated to prevent SQL injection
942
+ """
943
+ from duckguard.checks.query_based import QueryCheckHandler
944
+
945
+ handler = QueryCheckHandler()
946
+ return handler.execute_query_result_equals(
947
+ dataset=self,
948
+ query=query,
949
+ expected=expected,
950
+ tolerance=tolerance,
951
+ message=message
952
+ )
953
+
954
+ def expect_query_result_to_be_between(
955
+ self,
956
+ query: str,
957
+ min_value: float,
958
+ max_value: float,
959
+ message: str | None = None
960
+ ) -> ValidationResult:
961
+ """Check that custom SQL query result is within a range.
962
+
963
+ Use case: Metric validation (e.g., average must be between X and Y)
964
+
965
+ Args:
966
+ query: SQL query returning single numeric value
967
+ min_value: Minimum allowed value (inclusive)
968
+ max_value: Maximum allowed value (inclusive)
969
+ message: Optional custom message
970
+
971
+ Returns:
972
+ ValidationResult (passed if min_value <= result <= max_value)
973
+
974
+ Examples:
975
+ >>> data = connect("metrics.csv")
976
+ >>> # Average price in range
977
+ >>> result = data.expect_query_result_to_be_between(
978
+ ... query="SELECT AVG(price) FROM table",
979
+ ... min_value=10.0,
980
+ ... max_value=1000.0
981
+ ... )
982
+ >>> assert result.passed
983
+
984
+ >>> # Null rate validation
985
+ >>> result = data.expect_query_result_to_be_between(
986
+ ... query='''
987
+ ... SELECT (COUNT(*) FILTER (WHERE price IS NULL)) * 100.0 / COUNT(*)
988
+ ... FROM table
989
+ ... ''',
990
+ ... min_value=0.0,
991
+ ... max_value=5.0 # Max 5% nulls
992
+ ... )
993
+
994
+ Security:
995
+ - Query must return exactly 1 row with 1 numeric column
996
+ - Query is validated to prevent SQL injection
997
+ """
998
+ from duckguard.checks.query_based import QueryCheckHandler
999
+
1000
+ handler = QueryCheckHandler()
1001
+ return handler.execute_query_result_between(
1002
+ dataset=self,
1003
+ query=query,
1004
+ min_value=min_value,
1005
+ max_value=max_value,
1006
+ message=message
1007
+ )
1008
+
1009
+
1010
+ class GroupedDataset:
1011
+ """
1012
+ Represents a dataset grouped by one or more columns.
1013
+
1014
+ Provides methods for running validation checks and getting statistics
1015
+ at the group level. Created via Dataset.group_by().
1016
+
1017
+ Example:
1018
+ grouped = orders.group_by("region")
1019
+ result = grouped.row_count_greater_than(100)
1020
+ stats = grouped.stats()
1021
+ """
1022
+
1023
+ def __init__(self, dataset: Dataset, group_columns: list[str]):
1024
+ """Initialize a grouped dataset.
1025
+
1026
+ Args:
1027
+ dataset: The source dataset
1028
+ group_columns: Columns to group by
1029
+ """
1030
+ self._dataset = dataset
1031
+ self._group_columns = group_columns
1032
+
1033
+ @property
1034
+ def groups(self) -> list[dict[str, Any]]:
1035
+ """Get all distinct group key combinations."""
1036
+ ref = self._dataset.engine.get_source_reference(self._dataset.source)
1037
+ cols_sql = ", ".join(f'"{c}"' for c in self._group_columns)
1038
+
1039
+ sql = f"""
1040
+ SELECT DISTINCT {cols_sql}
1041
+ FROM {ref}
1042
+ ORDER BY {cols_sql}
1043
+ """
1044
+
1045
+ rows = self._dataset.engine.fetch_all(sql)
1046
+ return [dict(zip(self._group_columns, row)) for row in rows]
1047
+
1048
+ @property
1049
+ def group_count(self) -> int:
1050
+ """Get the number of distinct groups."""
1051
+ ref = self._dataset.engine.get_source_reference(self._dataset.source)
1052
+ cols_sql = ", ".join(f'"{c}"' for c in self._group_columns)
1053
+
1054
+ sql = f"""
1055
+ SELECT COUNT(DISTINCT ({cols_sql}))
1056
+ FROM {ref}
1057
+ """
1058
+
1059
+ return self._dataset.engine.fetch_value(sql) or 0
1060
+
1061
+ def stats(self) -> list[dict[str, Any]]:
1062
+ """
1063
+ Get statistics for each group.
1064
+
1065
+ Returns a list of dictionaries with group key values and statistics
1066
+ including row count, null counts, and basic aggregations.
1067
+
1068
+ Returns:
1069
+ List of group statistics dictionaries
1070
+
1071
+ Example:
1072
+ stats = orders.group_by("status").stats()
1073
+ for g in stats:
1074
+ print(f"{g['status']}: {g['row_count']} rows")
1075
+ """
1076
+ ref = self._dataset.engine.get_source_reference(self._dataset.source)
1077
+ cols_sql = ", ".join(f'"{c}"' for c in self._group_columns)
1078
+
1079
+ sql = f"""
1080
+ SELECT {cols_sql}, COUNT(*) as row_count
1081
+ FROM {ref}
1082
+ GROUP BY {cols_sql}
1083
+ ORDER BY row_count DESC
1084
+ """
1085
+
1086
+ rows = self._dataset.engine.fetch_all(sql)
1087
+ return [
1088
+ {**dict(zip(self._group_columns, row[:-1])), "row_count": row[-1]}
1089
+ for row in rows
1090
+ ]
1091
+
1092
+ def row_count_greater_than(self, min_count: int) -> GroupByResult:
1093
+ """
1094
+ Validate that each group has more than min_count rows.
1095
+
1096
+ Args:
1097
+ min_count: Minimum required rows per group
1098
+
1099
+ Returns:
1100
+ GroupByResult with per-group validation results
1101
+
1102
+ Example:
1103
+ result = orders.group_by("region").row_count_greater_than(100)
1104
+ if not result.passed:
1105
+ for g in result.get_failed_groups():
1106
+ print(f"Region {g.group_key} has only {g.row_count} rows")
1107
+ """
1108
+ from duckguard.core.result import GroupByResult, GroupResult, ValidationResult
1109
+
1110
+ group_results: list[GroupResult] = []
1111
+ passed_count = 0
1112
+
1113
+ for group_stats in self.stats():
1114
+ group_key = {k: group_stats[k] for k in self._group_columns}
1115
+ row_count = group_stats["row_count"]
1116
+ passed = row_count > min_count
1117
+
1118
+ check_result = ValidationResult(
1119
+ passed=passed,
1120
+ actual_value=row_count,
1121
+ expected_value=f"> {min_count}",
1122
+ message=f"row_count = {row_count} {'>' if passed else '<='} {min_count}",
1123
+ )
1124
+
1125
+ group_results.append(GroupResult(
1126
+ group_key=group_key,
1127
+ row_count=row_count,
1128
+ passed=passed,
1129
+ check_results=[check_result],
1130
+ ))
1131
+
1132
+ if passed:
1133
+ passed_count += 1
1134
+
1135
+ total_groups = len(group_results)
1136
+ all_passed = passed_count == total_groups
1137
+
1138
+ return GroupByResult(
1139
+ passed=all_passed,
1140
+ total_groups=total_groups,
1141
+ passed_groups=passed_count,
1142
+ failed_groups=total_groups - passed_count,
1143
+ group_results=group_results,
1144
+ group_columns=self._group_columns,
1145
+ )
1146
+
1147
+ def validate(
1148
+ self,
1149
+ check_fn,
1150
+ column: str | None = None,
1151
+ ) -> GroupByResult:
1152
+ """
1153
+ Run a custom validation function on each group.
1154
+
1155
+ Args:
1156
+ check_fn: Function that takes a group's column and returns ValidationResult
1157
+ column: Column to validate (required for column-level checks)
1158
+
1159
+ Returns:
1160
+ GroupByResult with per-group validation results
1161
+
1162
+ Example:
1163
+ # Check null percent per group
1164
+ result = orders.group_by("region").validate(
1165
+ lambda col: col.null_percent < 5,
1166
+ column="customer_id"
1167
+ )
1168
+
1169
+ # Check amount range per group
1170
+ result = orders.group_by("date").validate(
1171
+ lambda col: col.between(0, 10000),
1172
+ column="amount"
1173
+ )
1174
+ """
1175
+ from duckguard.core.result import GroupByResult, GroupResult, ValidationResult
1176
+
1177
+ group_results: list[GroupResult] = []
1178
+ passed_count = 0
1179
+
1180
+ for group_key in self.groups:
1181
+ # Build WHERE clause for this group
1182
+ conditions = " AND ".join(
1183
+ f'"{k}" = {self._format_value(v)}'
1184
+ for k, v in group_key.items()
1185
+ )
1186
+
1187
+ # Create a filtered view of the data for this group
1188
+ ref = self._dataset.engine.get_source_reference(self._dataset.source)
1189
+
1190
+ # Get row count for this group
1191
+ sql_count = f"SELECT COUNT(*) FROM {ref} WHERE {conditions}"
1192
+ row_count = self._dataset.engine.fetch_value(sql_count) or 0
1193
+
1194
+ # Create a temporary filtered column for validation
1195
+ if column:
1196
+ group_col = _GroupColumn(
1197
+ name=column,
1198
+ dataset=self._dataset,
1199
+ filter_condition=conditions,
1200
+ )
1201
+ try:
1202
+ result = check_fn(group_col)
1203
+ if not isinstance(result, ValidationResult):
1204
+ # If check_fn returns a boolean (e.g., col.null_percent < 5)
1205
+ result = ValidationResult(
1206
+ passed=bool(result),
1207
+ actual_value=result,
1208
+ message=f"Custom check {'passed' if result else 'failed'}",
1209
+ )
1210
+ except Exception as e:
1211
+ result = ValidationResult(
1212
+ passed=False,
1213
+ actual_value=None,
1214
+ message=f"Check error: {e}",
1215
+ )
1216
+ else:
1217
+ result = ValidationResult(
1218
+ passed=True,
1219
+ actual_value=row_count,
1220
+ message="No column check specified",
1221
+ )
1222
+
1223
+ group_results.append(GroupResult(
1224
+ group_key=group_key,
1225
+ row_count=row_count,
1226
+ passed=result.passed,
1227
+ check_results=[result],
1228
+ ))
1229
+
1230
+ if result.passed:
1231
+ passed_count += 1
1232
+
1233
+ total_groups = len(group_results)
1234
+ all_passed = passed_count == total_groups
1235
+
1236
+ return GroupByResult(
1237
+ passed=all_passed,
1238
+ total_groups=total_groups,
1239
+ passed_groups=passed_count,
1240
+ failed_groups=total_groups - passed_count,
1241
+ group_results=group_results,
1242
+ group_columns=self._group_columns,
1243
+ )
1244
+
1245
+ def _format_value(self, value: Any) -> str:
1246
+ """Format a value for SQL WHERE clause."""
1247
+ if value is None:
1248
+ return "NULL"
1249
+ elif isinstance(value, str):
1250
+ return f"'{value}'"
1251
+ elif isinstance(value, bool):
1252
+ return "TRUE" if value else "FALSE"
1253
+ else:
1254
+ return str(value)
1255
+
1256
+
1257
+ class _GroupColumn:
1258
+ """
1259
+ A column wrapper that applies a filter condition.
1260
+
1261
+ Used internally by GroupedDataset to validate columns within groups.
1262
+ """
1263
+
1264
+ def __init__(self, name: str, dataset: Dataset, filter_condition: str):
1265
+ self._name = name
1266
+ self._dataset = dataset
1267
+ self._filter_condition = filter_condition
1268
+
1269
+ @property
1270
+ def null_percent(self) -> float:
1271
+ """Get null percentage for this column within the group."""
1272
+ ref = self._dataset.engine.get_source_reference(self._dataset.source)
1273
+ col = f'"{self._name}"'
1274
+
1275
+ sql = f"""
1276
+ SELECT
1277
+ (COUNT(*) - COUNT({col})) * 100.0 / NULLIF(COUNT(*), 0) as null_pct
1278
+ FROM {ref}
1279
+ WHERE {self._filter_condition}
1280
+ """
1281
+
1282
+ result = self._dataset.engine.fetch_value(sql)
1283
+ return float(result) if result is not None else 0.0
1284
+
1285
+ @property
1286
+ def count(self) -> int:
1287
+ """Get non-null count for this column within the group."""
1288
+ ref = self._dataset.engine.get_source_reference(self._dataset.source)
1289
+ col = f'"{self._name}"'
1290
+
1291
+ sql = f"""
1292
+ SELECT COUNT({col})
1293
+ FROM {ref}
1294
+ WHERE {self._filter_condition}
1295
+ """
1296
+
1297
+ return self._dataset.engine.fetch_value(sql) or 0
1298
+
1299
+ def between(self, min_val: Any, max_val: Any) -> ValidationResult:
1300
+ """Check if all values are between min and max within the group."""
1301
+ ref = self._dataset.engine.get_source_reference(self._dataset.source)
1302
+ col = f'"{self._name}"'
1303
+
1304
+ sql = f"""
1305
+ SELECT COUNT(*) FROM {ref}
1306
+ WHERE {self._filter_condition}
1307
+ AND {col} IS NOT NULL
1308
+ AND ({col} < {min_val} OR {col} > {max_val})
1309
+ """
1310
+
1311
+ out_of_range = self._dataset.engine.fetch_value(sql) or 0
1312
+ passed = out_of_range == 0
1313
+
1314
+ return ValidationResult(
1315
+ passed=passed,
1316
+ actual_value=out_of_range,
1317
+ expected_value=0,
1318
+ message=f"{out_of_range} values outside [{min_val}, {max_val}]",
1319
+ )