duckguard 2.3.0__py3-none-any.whl → 3.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,726 @@
1
+ """Multi-column check implementations for DuckGuard 3.0.
2
+
3
+ This module provides cross-column validation checks that evaluate relationships
4
+ between multiple columns, enabling sophisticated business logic like:
5
+ - "End date must be after start date"
6
+ - "Total must equal sum of parts"
7
+ - "Composite primary key uniqueness"
8
+ - "Cross-column arithmetic constraints"
9
+
10
+ Examples:
11
+ >>> data = connect("orders.csv")
12
+ >>> # Validate date range
13
+ >>> result = data.expect_column_pair_satisfy(
14
+ ... column_a="end_date",
15
+ ... column_b="start_date",
16
+ ... expression="end_date >= start_date"
17
+ ... )
18
+ >>> assert result.passed
19
+
20
+ >>> # Composite uniqueness
21
+ >>> result = data.expect_columns_unique(
22
+ ... columns=["user_id", "session_id"]
23
+ ... )
24
+ >>> assert result.passed
25
+ """
26
+
27
+ from __future__ import annotations
28
+
29
+ import re
30
+ from dataclasses import dataclass, field
31
+
32
+ from duckguard.core.result import ValidationResult
33
+ from duckguard.errors import ValidationError
34
+
35
+
36
+ @dataclass
37
+ class ExpressionValidationResult:
38
+ """Result of expression validation."""
39
+
40
+ is_valid: bool
41
+ error_message: str | None = None
42
+ parsed_columns: list[str] = field(default_factory=list)
43
+ operators: list[str] = field(default_factory=list)
44
+ complexity_score: int = 0
45
+
46
+ @property
47
+ def columns(self) -> list[str]:
48
+ """Alias for parsed_columns for test compatibility."""
49
+ return self.parsed_columns
50
+
51
+
52
+ class ExpressionParser:
53
+ """Parses and validates multi-column expressions.
54
+
55
+ Supports:
56
+ - Comparison operators: >, <, >=, <=, =, !=
57
+ - Arithmetic operators: +, -, *, /
58
+ - Logical operators: AND, OR
59
+ - Parentheses for grouping
60
+ - Column references by name
61
+
62
+ Examples:
63
+ >>> parser = ExpressionParser()
64
+ >>> result = parser.parse("end_date >= start_date")
65
+ >>> assert result.is_valid
66
+
67
+ >>> result = parser.parse("A + B = C")
68
+ >>> assert result.is_valid
69
+ >>> assert set(result.parsed_columns) == {'A', 'B', 'C'}
70
+ """
71
+
72
+ # Supported operators
73
+ COMPARISON_OPS = ['>=', '<=', '!=', '<>', '>', '<', '=']
74
+ ARITHMETIC_OPS = ['+', '-', '*', '/']
75
+ LOGICAL_OPS = ['AND', 'OR', 'NOT']
76
+
77
+ FORBIDDEN_KEYWORDS = [
78
+ 'DROP', 'INSERT', 'UPDATE', 'DELETE', 'CREATE', 'ALTER', 'TRUNCATE',
79
+ 'GRANT', 'REVOKE', 'EXECUTE', 'EXEC', 'CALL', 'ATTACH', 'DETACH'
80
+ ]
81
+ def __init__(self, max_complexity: int = 50):
82
+ """Initialize expression parser.
83
+
84
+ Args:
85
+ max_complexity: Maximum allowed expression complexity (0-100)
86
+ """
87
+ self.max_complexity = max_complexity
88
+
89
+ def parse(self, expression: str) -> ExpressionValidationResult:
90
+ """Parse and validate a multi-column expression.
91
+
92
+ Args:
93
+ expression: Expression string to parse
94
+
95
+ Returns:
96
+ ExpressionValidationResult with validation status
97
+
98
+ Examples:
99
+ >>> parser = ExpressionParser()
100
+ >>> result = parser.parse("amount > min_amount")
101
+ >>> assert result.is_valid
102
+ """
103
+ if not expression or not expression.strip():
104
+ return ExpressionValidationResult(
105
+ is_valid=False,
106
+ error_message="Expression cannot be empty"
107
+ )
108
+
109
+ expression = expression.strip()
110
+
111
+ # Check balanced parentheses
112
+ if expression.count('(') != expression.count(')'):
113
+ return ExpressionValidationResult(
114
+ is_valid=False,
115
+ error_message="Unbalanced parentheses in expression"
116
+ )
117
+
118
+ # Check for forbidden SQL keywords
119
+ expression_upper = expression.upper()
120
+ for keyword in self.FORBIDDEN_KEYWORDS:
121
+ if re.search(r'\b' + keyword + r'\b', expression_upper):
122
+ return ExpressionValidationResult(
123
+ is_valid=False,
124
+ error_message=f"Forbidden keyword detected: {keyword}"
125
+ )
126
+
127
+ # Extract column names
128
+ columns = self._extract_columns(expression)
129
+ if not columns:
130
+ return ExpressionValidationResult(
131
+ is_valid=False,
132
+ error_message="No column references found in expression"
133
+ )
134
+
135
+ # Extract operators
136
+ operators = self._extract_operators(expression)
137
+
138
+ # Calculate complexity
139
+ complexity = self._calculate_complexity(expression, columns, operators)
140
+ if complexity > self.max_complexity:
141
+ return ExpressionValidationResult(
142
+ is_valid=False,
143
+ error_message=f"Expression too complex (score: {complexity}, "
144
+ f"max: {self.max_complexity})"
145
+ )
146
+
147
+ return ExpressionValidationResult(
148
+ is_valid=True,
149
+ parsed_columns=columns,
150
+ operators=operators,
151
+ complexity_score=complexity
152
+ )
153
+
154
+ def _extract_columns(self, expression: str) -> list[str]:
155
+ """Extract column names from expression.
156
+
157
+ Column names are identifiers that aren't SQL keywords or numbers.
158
+
159
+ Args:
160
+ expression: Expression string
161
+
162
+ Returns:
163
+ List of unique column names
164
+ """
165
+ # Pattern to match identifiers (column names)
166
+ identifier_pattern = r'\b[a-zA-Z_][a-zA-Z0-9_]*\b'
167
+
168
+ # SQL keywords to exclude
169
+ keywords = {
170
+ 'AND', 'OR', 'NOT', 'NULL', 'TRUE', 'FALSE',
171
+ 'UPPER', 'LOWER', 'LENGTH', 'COALESCE', 'CAST',
172
+ 'DATE', 'TIME', 'TIMESTAMP', 'INT', 'FLOAT',
173
+ 'VARCHAR', 'TEXT', 'BOOLEAN'
174
+ }
175
+
176
+ # Find all identifiers
177
+ identifiers = re.findall(identifier_pattern, expression)
178
+
179
+ # Filter out keywords and numbers
180
+ columns = []
181
+ for ident in identifiers:
182
+ if ident.upper() not in keywords and not ident.isdigit():
183
+ if ident not in columns:
184
+ columns.append(ident)
185
+
186
+ return columns
187
+
188
+ def _extract_operators(self, expression: str) -> list[str]:
189
+ """Extract operators from expression.
190
+
191
+ Args:
192
+ expression: Expression string
193
+
194
+ Returns:
195
+ List of operators found
196
+ """
197
+ operators = []
198
+ expression_upper = expression.upper()
199
+
200
+ # Check for comparison operators
201
+ for op in self.COMPARISON_OPS:
202
+ if op in expression:
203
+ operators.append(op)
204
+
205
+ # Check for arithmetic operators
206
+ for op in self.ARITHMETIC_OPS:
207
+ if op in expression:
208
+ operators.append(op)
209
+
210
+ # Check for logical operators
211
+ for op in self.LOGICAL_OPS:
212
+ if re.search(rf'\b{op}\b', expression_upper):
213
+ operators.append(op)
214
+
215
+ return operators
216
+
217
+ def _calculate_complexity(
218
+ self,
219
+ expression: str,
220
+ columns: list[str],
221
+ operators: list[str]
222
+ ) -> int:
223
+ """Calculate expression complexity score (0-100).
224
+
225
+ Factors:
226
+ - Length of expression
227
+ - Number of columns
228
+ - Number of operators
229
+ - Nesting depth
230
+
231
+ Args:
232
+ expression: Expression string
233
+ columns: List of columns
234
+ operators: List of operators
235
+
236
+ Returns:
237
+ Complexity score (0-100)
238
+ """
239
+ score = 0
240
+
241
+ # Length factor (0-20 points)
242
+ score += min(20, len(expression) // 10)
243
+
244
+ # Column count (1 point each)
245
+ score += len(columns) * 1
246
+
247
+ # Operator count (1 point each)
248
+ score += len(operators) * 1
249
+
250
+ # Nesting depth (10 points per level)
251
+ max_depth = 0
252
+ current_depth = 0
253
+ for char in expression:
254
+ if char == '(':
255
+ current_depth += 1
256
+ max_depth = max(max_depth, current_depth)
257
+ elif char == ')':
258
+ current_depth -= 1
259
+ score += max_depth * 10
260
+
261
+ return min(100, score)
262
+
263
+
264
+ class MultiColumnCheckHandler:
265
+ """Executes multi-column validation checks.
266
+
267
+ This handler performs cross-column validations including:
268
+ - Column pair comparisons
269
+ - Composite uniqueness
270
+ - Multi-column sum constraints
271
+ - Expression-based validations
272
+
273
+ Attributes:
274
+ parser: ExpressionParser instance for validating expressions
275
+
276
+ Examples:
277
+ >>> handler = MultiColumnCheckHandler()
278
+ >>> result = handler.execute_column_pair_satisfy(
279
+ ... dataset=data,
280
+ ... column_a="end_date",
281
+ ... column_b="start_date",
282
+ ... expression="end_date >= start_date"
283
+ ... )
284
+ """
285
+
286
+ def __init__(self, parser: ExpressionParser | None = None):
287
+ """Initialize multi-column check handler.
288
+
289
+ Args:
290
+ parser: Expression parser (creates default if None)
291
+ """
292
+ self.parser = parser or ExpressionParser()
293
+
294
+ def execute_column_pair_satisfy(
295
+ self,
296
+ dataset,
297
+ column_a: str,
298
+ column_b: str,
299
+ expression: str,
300
+ threshold: float = 1.0
301
+ ) -> ValidationResult:
302
+ """Check that column pair satisfies expression.
303
+
304
+ Args:
305
+ dataset: Dataset to validate
306
+ column_a: First column name
307
+ column_b: Second column name
308
+ expression: Expression to evaluate (e.g., "A > B", "A + B = 100")
309
+ threshold: Maximum allowed failure rate (0.0-1.0)
310
+
311
+ Returns:
312
+ ValidationResult with pass/fail status
313
+
314
+ Raises:
315
+ ValidationError: If expression is invalid
316
+
317
+ Examples:
318
+ >>> # Date range validation
319
+ >>> result = handler.execute_column_pair_satisfy(
320
+ ... dataset=data,
321
+ ... column_a="end_date",
322
+ ... column_b="start_date",
323
+ ... expression="end_date >= start_date"
324
+ ... )
325
+
326
+ >>> # Arithmetic validation
327
+ >>> result = handler.execute_column_pair_satisfy(
328
+ ... dataset=data,
329
+ ... column_a="total",
330
+ ... column_b="subtotal",
331
+ ... expression="total = subtotal * 1.1" # 10% markup
332
+ ... )
333
+ """
334
+ # Validate expression
335
+ validation = self.parser.parse(expression)
336
+ if not validation.is_valid:
337
+ raise ValidationError(
338
+ f"Invalid expression: {validation.error_message}"
339
+ )
340
+
341
+ # Validate columns exist
342
+ available_columns = dataset.columns
343
+ if column_a not in available_columns:
344
+ raise ValueError(f"Column '{column_a}' does not exist. Available columns: {available_columns}")
345
+ if column_b not in available_columns:
346
+ raise ValueError(f"Column '{column_b}' does not exist. Available columns: {available_columns}")
347
+
348
+ # Normalize path for DuckDB (forward slashes work on all platforms)
349
+ source_path = dataset._source.replace('\\', '/')
350
+
351
+ # Replace column placeholders with actual column names
352
+ sql_expression = self._build_sql_expression(
353
+ expression, column_a, column_b
354
+ )
355
+
356
+ # Build SQL query to find violations
357
+ sql = f"""
358
+ SELECT COUNT(*) as violations
359
+ FROM '{source_path}'
360
+ WHERE NOT ({sql_expression})
361
+ """
362
+
363
+ try:
364
+ violations = dataset._engine.fetch_value(sql)
365
+ total_rows = dataset.row_count
366
+
367
+ if total_rows == 0:
368
+ return ValidationResult(
369
+ passed=True,
370
+ actual_value=0,
371
+ expected_value=0,
372
+ message="Dataset is empty",
373
+ details={}
374
+ )
375
+
376
+ violation_rate = violations / total_rows
377
+ passed = violation_rate <= (1.0 - threshold)
378
+
379
+ return ValidationResult(
380
+ passed=passed,
381
+ actual_value=violations,
382
+ expected_value=0,
383
+ message=self._format_pair_message(
384
+ passed=passed,
385
+ column_a=column_a,
386
+ column_b=column_b,
387
+ expression=expression,
388
+ violations=violations,
389
+ total=total_rows,
390
+ violation_rate=violation_rate
391
+ ),
392
+ details={
393
+ 'column_a': column_a,
394
+ 'column_b': column_b,
395
+ 'expression': expression,
396
+ 'violations': violations,
397
+ 'total_rows': total_rows,
398
+ 'violation_rate': violation_rate,
399
+ 'threshold': threshold
400
+ }
401
+ )
402
+
403
+ except Exception as e:
404
+ # Handle type mismatch errors gracefully (e.g., comparing VARCHAR with DOUBLE)
405
+ error_msg = str(e).lower()
406
+ if "cannot compare" in error_msg or "type" in error_msg:
407
+ return ValidationResult(
408
+ passed=False,
409
+ actual_value=None,
410
+ expected_value=None,
411
+ message=f"Type mismatch in column comparison: {str(e)}",
412
+ details={
413
+ 'column_a': column_a,
414
+ 'column_b': column_b,
415
+ 'expression': expression,
416
+ 'error': str(e)
417
+ }
418
+ )
419
+ # For other errors, raise ValidationError
420
+ raise ValidationError(
421
+ f"Error executing column pair check: {str(e)}"
422
+ ) from e
423
+
424
+ def execute_columns_unique(
425
+ self,
426
+ dataset,
427
+ columns: list[str],
428
+ threshold: float = 1.0
429
+ ) -> ValidationResult:
430
+ """Check that combination of columns is unique (composite key).
431
+
432
+ Args:
433
+ dataset: Dataset to validate
434
+ columns: List of column names forming composite key
435
+ threshold: Minimum required uniqueness rate (0.0-1.0)
436
+
437
+ Returns:
438
+ ValidationResult with pass/fail status
439
+
440
+ Examples:
441
+ >>> # Two-column composite key
442
+ >>> result = handler.execute_columns_unique(
443
+ ... dataset=data,
444
+ ... columns=["user_id", "session_id"]
445
+ ... )
446
+
447
+ >>> # Three-column composite key
448
+ >>> result = handler.execute_columns_unique(
449
+ ... dataset=data,
450
+ ... columns=["year", "month", "day"]
451
+ ... )
452
+ """
453
+ if not columns or len(columns) < 2:
454
+ raise ValidationError(
455
+ "At least 2 columns required for composite uniqueness check"
456
+ )
457
+
458
+ # Normalize path for DuckDB (forward slashes work on all platforms)
459
+ source_path = dataset._source.replace('\\', '/')
460
+
461
+ # Build SQL to find duplicate combinations
462
+ column_list = ", ".join(columns)
463
+ sql = f"""
464
+ SELECT COUNT(*) as duplicate_combinations
465
+ FROM (
466
+ SELECT {column_list}, COUNT(*) as cnt
467
+ FROM '{source_path}'
468
+ GROUP BY {column_list}
469
+ HAVING cnt > 1
470
+ ) as dups
471
+ """
472
+
473
+ try:
474
+ duplicate_combinations = dataset._engine.fetch_value(sql)
475
+
476
+ # Count total distinct combinations
477
+ distinct_sql = f"""
478
+ SELECT COUNT(DISTINCT ({column_list})) as distinct_count
479
+ FROM '{source_path}'
480
+ """
481
+ distinct_count = dataset._engine.fetch_value(distinct_sql)
482
+
483
+ total_rows = dataset.row_count
484
+
485
+ if total_rows == 0:
486
+ return ValidationResult(
487
+ passed=True,
488
+ actual_value=0,
489
+ expected_value=0,
490
+ message="Dataset is empty",
491
+ details={'columns': columns}
492
+ )
493
+
494
+ uniqueness_rate = distinct_count / total_rows
495
+ passed = uniqueness_rate >= threshold
496
+
497
+ return ValidationResult(
498
+ passed=passed,
499
+ actual_value=duplicate_combinations,
500
+ expected_value=0,
501
+ message=self._format_unique_message(
502
+ passed=passed,
503
+ columns=columns,
504
+ duplicates=duplicate_combinations,
505
+ total=total_rows,
506
+ uniqueness_rate=uniqueness_rate
507
+ ),
508
+ details={
509
+ 'columns': columns,
510
+ 'duplicate_combinations': duplicate_combinations,
511
+ 'distinct_combinations': distinct_count,
512
+ 'total_rows': total_rows,
513
+ 'uniqueness_rate': uniqueness_rate,
514
+ 'threshold': threshold
515
+ }
516
+ )
517
+
518
+ except Exception as e:
519
+ raise ValidationError(
520
+ f"Error executing composite uniqueness check: {str(e)}"
521
+ ) from e
522
+
523
+ def execute_multicolumn_sum_equal(
524
+ self,
525
+ dataset,
526
+ columns: list[str],
527
+ expected_sum: float,
528
+ threshold: float = 0.01
529
+ ) -> ValidationResult:
530
+ """Check that sum of columns equals expected value.
531
+
532
+ Args:
533
+ dataset: Dataset to validate
534
+ columns: List of columns to sum
535
+ expected_sum: Expected sum value
536
+ threshold: Maximum allowed deviation
537
+
538
+ Returns:
539
+ ValidationResult with pass/fail status
540
+
541
+ Examples:
542
+ >>> # Components must sum to 100%
543
+ >>> result = handler.execute_multicolumn_sum_equal(
544
+ ... dataset=data,
545
+ ... columns=["component_a", "component_b", "component_c"],
546
+ ... expected_sum=100.0
547
+ ... )
548
+
549
+ >>> # Budget allocation check
550
+ >>> result = handler.execute_multicolumn_sum_equal(
551
+ ... dataset=data,
552
+ ... columns=["q1", "q2", "q3", "q4"],
553
+ ... expected_sum=data.annual_total
554
+ ... )
555
+ """
556
+ if not columns:
557
+ raise ValidationError("At least 1 column required for sum check")
558
+
559
+ # Normalize path for DuckDB (forward slashes work on all platforms)
560
+ source_path = dataset._source.replace('\\', '/')
561
+
562
+ # Build SQL to check sum
563
+ column_sum = " + ".join([f"COALESCE({col}, 0)" for col in columns])
564
+
565
+ # Handle None expected_sum (just compute sum without comparison)
566
+ if expected_sum is None:
567
+ # Just compute the sum for testing purposes
568
+ sql = f"""
569
+ SELECT ({column_sum}) as total_sum
570
+ FROM '{source_path}'
571
+ LIMIT 1
572
+ """
573
+ else:
574
+ sql = f"""
575
+ SELECT COUNT(*) as violations
576
+ FROM '{source_path}'
577
+ WHERE ABS(({column_sum}) - {expected_sum}) > {threshold}
578
+ """
579
+
580
+ try:
581
+ result_value = dataset._engine.fetch_value(sql)
582
+ total_rows = dataset.row_count
583
+
584
+ if total_rows == 0:
585
+ return ValidationResult(
586
+ passed=True,
587
+ actual_value=0,
588
+ expected_value=expected_sum,
589
+ message="Dataset is empty",
590
+ details={'columns': columns}
591
+ )
592
+
593
+ # Handle None expected_sum (just testing mechanics)
594
+ if expected_sum is None:
595
+ return ValidationResult(
596
+ passed=True,
597
+ actual_value=result_value,
598
+ expected_value=None,
599
+ message=f"Sum computed: {result_value}",
600
+ details={'columns': columns, 'sum': result_value}
601
+ )
602
+
603
+ violations = result_value
604
+ violation_rate = violations / total_rows
605
+ passed = violations == 0
606
+
607
+ return ValidationResult(
608
+ passed=passed,
609
+ actual_value=violations,
610
+ expected_value=0,
611
+ message=self._format_sum_message(
612
+ passed=passed,
613
+ columns=columns,
614
+ expected_sum=expected_sum,
615
+ violations=violations,
616
+ total=total_rows
617
+ ),
618
+ details={
619
+ 'columns': columns,
620
+ 'expected_sum': expected_sum,
621
+ 'violations': violations,
622
+ 'total_rows': total_rows,
623
+ 'violation_rate': violation_rate,
624
+ 'threshold': threshold
625
+ }
626
+ )
627
+
628
+ except Exception as e:
629
+ raise ValidationError(
630
+ f"Error executing multicolumn sum check: {str(e)}"
631
+ ) from e
632
+
633
+ def _build_sql_expression(
634
+ self,
635
+ expression: str,
636
+ column_a: str,
637
+ column_b: str
638
+ ) -> str:
639
+ """Build SQL expression from template.
640
+
641
+ Replaces placeholders like 'A', 'B' or column_a, column_b
642
+ with actual column names.
643
+
644
+ Args:
645
+ expression: Expression template
646
+ column_a: First column name
647
+ column_b: Second column name
648
+
649
+ Returns:
650
+ SQL-ready expression string
651
+ """
652
+ # Replace common placeholders
653
+ sql_expr = expression
654
+ sql_expr = re.sub(r'\bA\b', column_a, sql_expr)
655
+ sql_expr = re.sub(r'\bB\b', column_b, sql_expr)
656
+
657
+ # Also replace if actual column names are used
658
+ # (no-op if already using A/B pattern)
659
+
660
+ return sql_expr
661
+
662
+ def _format_pair_message(
663
+ self,
664
+ passed: bool,
665
+ column_a: str,
666
+ column_b: str,
667
+ expression: str,
668
+ violations: int,
669
+ total: int,
670
+ violation_rate: float
671
+ ) -> str:
672
+ """Format message for column pair check."""
673
+ if passed:
674
+ return (
675
+ f"Column pair ({column_a}, {column_b}) satisfies '{expression}': "
676
+ f"PASSED ({violations}/{total} violations, {violation_rate:.1%} rate)"
677
+ )
678
+ else:
679
+ return (
680
+ f"Column pair ({column_a}, {column_b}) fails '{expression}': "
681
+ f"FAILED ({violations}/{total} violations, {violation_rate:.1%} rate)"
682
+ )
683
+
684
+ def _format_unique_message(
685
+ self,
686
+ passed: bool,
687
+ columns: list[str],
688
+ duplicates: int,
689
+ total: int,
690
+ uniqueness_rate: float
691
+ ) -> str:
692
+ """Format message for composite uniqueness check."""
693
+ column_str = ", ".join(columns)
694
+ if passed:
695
+ return (
696
+ f"Columns ({column_str}) form unique composite key: "
697
+ f"PASSED ({duplicates} duplicate combinations, "
698
+ f"{uniqueness_rate:.1%} uniqueness)"
699
+ )
700
+ else:
701
+ return (
702
+ f"Columns ({column_str}) not unique: "
703
+ f"FAILED ({duplicates} duplicate combinations, "
704
+ f"{uniqueness_rate:.1%} uniqueness)"
705
+ )
706
+
707
+ def _format_sum_message(
708
+ self,
709
+ passed: bool,
710
+ columns: list[str],
711
+ expected_sum: float,
712
+ violations: int,
713
+ total: int
714
+ ) -> str:
715
+ """Format message for multicolumn sum check."""
716
+ column_str = ", ".join(columns)
717
+ if passed:
718
+ return (
719
+ f"Sum of ({column_str}) equals {expected_sum}: "
720
+ f"PASSED (0 violations)"
721
+ )
722
+ else:
723
+ return (
724
+ f"Sum of ({column_str}) does not equal {expected_sum}: "
725
+ f"FAILED ({violations}/{total} rows)"
726
+ )