pointblank 0.17.0__py3-none-any.whl → 0.19.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
pointblank/__init__.py CHANGED
@@ -20,6 +20,7 @@ from pointblank.column import (
20
20
  first_n,
21
21
  last_n,
22
22
  matches,
23
+ ref,
23
24
  starts_with,
24
25
  )
25
26
  from pointblank.datascan import DataScan, col_summary_tbl
@@ -59,6 +60,7 @@ __all__ = [
59
60
  "DataScan",
60
61
  "DraftValidation",
61
62
  "col",
63
+ "ref",
62
64
  "expr_col",
63
65
  "col_summary_tbl",
64
66
  "starts_with",
pointblank/_agg.py ADDED
@@ -0,0 +1,120 @@
1
+ from __future__ import annotations
2
+
3
+ import itertools
4
+ from collections.abc import Callable
5
+ from typing import Any
6
+
7
+ import narwhals as nw
8
+
9
+ # TODO: Should take any frame type
10
+ Aggregator = Callable[[nw.DataFrame], float | int]
11
+ Comparator = Callable[[Any, Any, Any], bool]
12
+
13
+ AGGREGATOR_REGISTRY: dict[str, Aggregator] = {}
14
+
15
+ COMPARATOR_REGISTRY: dict[str, Comparator] = {}
16
+
17
+
18
+ def register(fn):
19
+ """Register an aggregator or comparator function."""
20
+ name: str = fn.__name__
21
+ if name.startswith("comp_"):
22
+ COMPARATOR_REGISTRY[name.removeprefix("comp_")] = fn
23
+ elif name.startswith("agg_"):
24
+ AGGREGATOR_REGISTRY[name.removeprefix("agg_")] = fn
25
+ else:
26
+ raise NotImplementedError # pragma: no cover
27
+ return fn
28
+
29
+
30
+ ## Aggregator Functions
31
+ @register
32
+ def agg_sum(column: nw.DataFrame) -> float:
33
+ return column.select(nw.all().sum()).item()
34
+
35
+
36
+ @register
37
+ def agg_avg(column: nw.DataFrame) -> float:
38
+ return column.select(nw.all().mean()).item()
39
+
40
+
41
+ @register
42
+ def agg_sd(column: nw.DataFrame) -> float:
43
+ return column.select(nw.all().std()).item()
44
+
45
+
46
+ ## Comparator functions:
47
+ @register
48
+ def comp_eq(real: float, lower: float, upper: float) -> bool:
49
+ if lower == upper:
50
+ return bool(real == lower)
51
+ return _generic_between(real, lower, upper)
52
+
53
+
54
+ @register
55
+ def comp_gt(real: float, lower: float, upper: float) -> bool:
56
+ return bool(real > lower)
57
+
58
+
59
+ @register
60
+ def comp_ge(real: Any, lower: float, upper: float) -> bool:
61
+ return bool(real >= lower)
62
+
63
+
64
+ @register
65
+ def comp_lt(real: float, lower: float, upper: float) -> bool:
66
+ return bool(real < upper)
67
+
68
+
69
+ @register
70
+ def comp_le(real: float, lower: float, upper: float) -> bool:
71
+ return bool(real <= upper)
72
+
73
+
74
+ def _generic_between(real: Any, lower: Any, upper: Any) -> bool:
75
+ """Call if comparator needs to check between two values."""
76
+ return bool(lower <= real <= upper)
77
+
78
+
79
+ def resolve_agg_registries(name: str) -> tuple[Aggregator, Comparator]:
80
+ """Resolve the assertion name to a valid aggregator
81
+
82
+ Args:
83
+ name (str): The name of the assertion.
84
+
85
+ Returns:
86
+ tuple[Aggregator, Comparator]: The aggregator and comparator functions.
87
+ """
88
+ name = name.removeprefix("col_")
89
+ agg_name, comp_name = name.split("_")[-2:]
90
+
91
+ aggregator = AGGREGATOR_REGISTRY.get(agg_name)
92
+ comparator = COMPARATOR_REGISTRY.get(comp_name)
93
+
94
+ if aggregator is None: # pragma: no cover
95
+ raise ValueError(f"Aggregator '{agg_name}' not found in registry.")
96
+
97
+ if comparator is None: # pragma: no cover
98
+ raise ValueError(f"Comparator '{comp_name}' not found in registry.")
99
+
100
+ return aggregator, comparator
101
+
102
+
103
+ def is_valid_agg(name: str) -> bool:
104
+ try:
105
+ resolve_agg_registries(name)
106
+ return True
107
+ except ValueError:
108
+ return False
109
+
110
+
111
+ def load_validation_method_grid() -> tuple[str, ...]:
112
+ """Generate all possible validation methods."""
113
+ methods = []
114
+ for agg_name, comp_name in itertools.product(
115
+ AGGREGATOR_REGISTRY.keys(), COMPARATOR_REGISTRY.keys()
116
+ ):
117
+ method = f"col_{agg_name}_{comp_name}"
118
+ methods.append(method)
119
+
120
+ return tuple(methods)