pointblank 0.17.0__py3-none-any.whl → 0.18.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
pointblank/validate.py CHANGED
@@ -15,7 +15,7 @@ from enum import Enum
15
15
  from functools import partial
16
16
  from importlib.metadata import version
17
17
  from pathlib import Path
18
- from typing import TYPE_CHECKING, Any, Callable, Literal
18
+ from typing import TYPE_CHECKING, Any, Callable, Literal, NoReturn, ParamSpec, TypeVar
19
19
  from zipfile import ZipFile
20
20
 
21
21
  import commonmark
@@ -24,8 +24,8 @@ from great_tables import GT, from_column, google_font, html, loc, md, style, val
24
24
  from great_tables.gt import _get_column_of_values
25
25
  from great_tables.vals import fmt_integer, fmt_number
26
26
  from importlib_resources import files
27
- from narwhals.typing import FrameT
28
27
 
28
+ from pointblank._agg import is_valid_agg, load_validation_method_grid, resolve_agg_registries
29
29
  from pointblank._constants import (
30
30
  ASSERTION_TYPE_METHOD_MAP,
31
31
  CHECK_MARK_SPAN,
@@ -92,6 +92,8 @@ from pointblank._utils import (
92
92
  _is_lib_present,
93
93
  _is_narwhals_table,
94
94
  _is_value_a_df,
95
+ _PBUnresolvedColumn,
96
+ _resolve_columns,
95
97
  _select_df_lib,
96
98
  )
97
99
  from pointblank._utils_check_args import (
@@ -102,7 +104,14 @@ from pointblank._utils_check_args import (
102
104
  _check_thresholds,
103
105
  )
104
106
  from pointblank._utils_html import _create_table_dims_html, _create_table_type_html
105
- from pointblank.column import Column, ColumnLiteral, ColumnSelector, ColumnSelectorNarwhals, col
107
+ from pointblank.column import (
108
+ Column,
109
+ ColumnLiteral,
110
+ ColumnSelector,
111
+ ColumnSelectorNarwhals,
112
+ ReferenceColumn,
113
+ col,
114
+ )
106
115
  from pointblank.schema import Schema, _get_schema_validation_info
107
116
  from pointblank.segments import Segment
108
117
  from pointblank.thresholds import (
@@ -113,10 +122,18 @@ from pointblank.thresholds import (
113
122
  _normalize_thresholds_creation,
114
123
  )
115
124
 
125
+ P = ParamSpec("P")
126
+ R = TypeVar("R")
127
+
116
128
  if TYPE_CHECKING:
117
129
  from collections.abc import Collection
130
+ from typing import Any
131
+
132
+ import polars as pl
133
+ from narwhals.typing import IntoDataFrame, IntoFrame
134
+
135
+ from pointblank._typing import AbsoluteBounds, Tolerance, _CompliantValue, _CompliantValues
118
136
 
119
- from pointblank._typing import AbsoluteBounds, Tolerance
120
137
 
121
138
  __all__ = [
122
139
  "Validate",
@@ -135,6 +152,7 @@ __all__ = [
135
152
  "get_validation_summary",
136
153
  ]
137
154
 
155
+
138
156
  # Create a thread-local storage for the metadata
139
157
  _action_context = threading.local()
140
158
 
@@ -424,12 +442,13 @@ def config(
424
442
  global_config.report_incl_footer_timings = report_incl_footer_timings # pragma: no cover
425
443
  global_config.report_incl_footer_notes = report_incl_footer_notes # pragma: no cover
426
444
  global_config.preview_incl_header = preview_incl_header # pragma: no cover
445
+ return global_config # pragma: no cover
427
446
 
428
447
 
429
448
  def load_dataset(
430
449
  dataset: Literal["small_table", "game_revenue", "nycflights", "global_sales"] = "small_table",
431
450
  tbl_type: Literal["polars", "pandas", "duckdb"] = "polars",
432
- ) -> FrameT | Any:
451
+ ) -> Any:
433
452
  """
434
453
  Load a dataset hosted in the library as specified table type.
435
454
 
@@ -450,7 +469,7 @@ def load_dataset(
450
469
 
451
470
  Returns
452
471
  -------
453
- FrameT | Any
472
+ Any
454
473
  The dataset for the `Validate` object. This could be a Polars DataFrame, a Pandas DataFrame,
455
474
  or a DuckDB table as an Ibis table.
456
475
 
@@ -1523,7 +1542,7 @@ def get_data_path(
1523
1542
  return tmp_file.name
1524
1543
 
1525
1544
 
1526
- def _process_data(data: FrameT | Any) -> FrameT | Any:
1545
+ def _process_data(data: Any) -> Any:
1527
1546
  """
1528
1547
  Centralized data processing pipeline that handles all supported input types.
1529
1548
 
@@ -1540,7 +1559,7 @@ def _process_data(data: FrameT | Any) -> FrameT | Any:
1540
1559
 
1541
1560
  Parameters
1542
1561
  ----------
1543
- data : FrameT | Any
1562
+ data
1544
1563
  The input data which could be:
1545
1564
  - a DataFrame object (Polars, Pandas, Ibis, etc.)
1546
1565
  - a GitHub URL pointing to a CSV or Parquet file
@@ -1551,7 +1570,7 @@ def _process_data(data: FrameT | Any) -> FrameT | Any:
1551
1570
 
1552
1571
  Returns
1553
1572
  -------
1554
- FrameT | Any
1573
+ Any
1555
1574
  Processed data as a DataFrame if input was a supported data source type,
1556
1575
  otherwise the original data unchanged.
1557
1576
  """
@@ -1570,7 +1589,7 @@ def _process_data(data: FrameT | Any) -> FrameT | Any:
1570
1589
  return data
1571
1590
 
1572
1591
 
1573
- def _process_github_url(data: FrameT | Any) -> FrameT | Any:
1592
+ def _process_github_url(data: Any) -> Any:
1574
1593
  """
1575
1594
  Process data parameter to handle GitHub URLs pointing to CSV or Parquet files.
1576
1595
 
@@ -1585,12 +1604,12 @@ def _process_github_url(data: FrameT | Any) -> FrameT | Any:
1585
1604
 
1586
1605
  Parameters
1587
1606
  ----------
1588
- data : FrameT | Any
1607
+ data
1589
1608
  The data parameter which may be a GitHub URL string or any other data type.
1590
1609
 
1591
1610
  Returns
1592
1611
  -------
1593
- FrameT | Any
1612
+ Any
1594
1613
  If the input is a supported GitHub URL, returns a DataFrame loaded from the downloaded file.
1595
1614
  Otherwise, returns the original data unchanged.
1596
1615
 
@@ -1675,7 +1694,7 @@ def _process_github_url(data: FrameT | Any) -> FrameT | Any:
1675
1694
  return data
1676
1695
 
1677
1696
 
1678
- def _process_connection_string(data: FrameT | Any) -> FrameT | Any:
1697
+ def _process_connection_string(data: Any) -> Any:
1679
1698
  """
1680
1699
  Process data parameter to handle database connection strings.
1681
1700
 
@@ -1702,7 +1721,7 @@ def _process_connection_string(data: FrameT | Any) -> FrameT | Any:
1702
1721
  return connect_to_table(data)
1703
1722
 
1704
1723
 
1705
- def _process_csv_input(data: FrameT | Any) -> FrameT | Any:
1724
+ def _process_csv_input(data: Any) -> Any:
1706
1725
  """
1707
1726
  Process data parameter to handle CSV file inputs.
1708
1727
 
@@ -1760,7 +1779,7 @@ def _process_csv_input(data: FrameT | Any) -> FrameT | Any:
1760
1779
  )
1761
1780
 
1762
1781
 
1763
- def _process_parquet_input(data: FrameT | Any) -> FrameT | Any:
1782
+ def _process_parquet_input(data: Any) -> Any:
1764
1783
  """
1765
1784
  Process data parameter to handle Parquet file inputs.
1766
1785
 
@@ -1903,7 +1922,7 @@ def _process_parquet_input(data: FrameT | Any) -> FrameT | Any:
1903
1922
 
1904
1923
 
1905
1924
  def preview(
1906
- data: FrameT | Any,
1925
+ data: Any,
1907
1926
  columns_subset: str | list[str] | Column | None = None,
1908
1927
  n_head: int = 5,
1909
1928
  n_tail: int = 5,
@@ -1911,7 +1930,7 @@ def preview(
1911
1930
  show_row_numbers: bool = True,
1912
1931
  max_col_width: int = 250,
1913
1932
  min_tbl_width: int = 500,
1914
- incl_header: bool = None,
1933
+ incl_header: bool | None = None,
1915
1934
  ) -> GT:
1916
1935
  """
1917
1936
  Display a table preview that shows some rows from the top, some from the bottom.
@@ -2169,7 +2188,7 @@ def preview(
2169
2188
 
2170
2189
 
2171
2190
  def _generate_display_table(
2172
- data: FrameT | Any,
2191
+ data: Any,
2173
2192
  columns_subset: str | list[str] | Column | None = None,
2174
2193
  n_head: int = 5,
2175
2194
  n_tail: int = 5,
@@ -2177,7 +2196,7 @@ def _generate_display_table(
2177
2196
  show_row_numbers: bool = True,
2178
2197
  max_col_width: int = 250,
2179
2198
  min_tbl_width: int = 500,
2180
- incl_header: bool = None,
2199
+ incl_header: bool | None = None,
2181
2200
  mark_missing_values: bool = True,
2182
2201
  row_number_list: list[int] | None = None,
2183
2202
  ) -> GT:
@@ -2274,7 +2293,8 @@ def _generate_display_table(
2274
2293
  tbl_schema = Schema(tbl=data)
2275
2294
 
2276
2295
  # Get the row count for the table
2277
- ibis_rows = data.count()
2296
+ # Note: ibis tables have count(), to_polars(), to_pandas() methods
2297
+ ibis_rows = data.count() # type: ignore[union-attr]
2278
2298
  n_rows = ibis_rows.to_polars() if df_lib_name_gt == "polars" else int(ibis_rows.to_pandas())
2279
2299
 
2280
2300
  # If n_head + n_tail is greater than the row count, display the entire table
@@ -2283,11 +2303,11 @@ def _generate_display_table(
2283
2303
  data_subset = data
2284
2304
 
2285
2305
  if row_number_list is None:
2286
- row_number_list = range(1, n_rows + 1)
2306
+ row_number_list = list(range(1, n_rows + 1))
2287
2307
  else:
2288
2308
  # Get the first n and last n rows of the table
2289
- data_head = data.head(n_head)
2290
- data_tail = data.filter(
2309
+ data_head = data.head(n_head) # type: ignore[union-attr]
2310
+ data_tail = data.filter( # type: ignore[union-attr]
2291
2311
  [ibis.row_number() >= (n_rows - n_tail), ibis.row_number() <= n_rows]
2292
2312
  )
2293
2313
  data_subset = data_head.union(data_tail)
@@ -2299,9 +2319,9 @@ def _generate_display_table(
2299
2319
 
2300
2320
  # Convert either to Polars or Pandas depending on the available library
2301
2321
  if df_lib_name_gt == "polars":
2302
- data = data_subset.to_polars()
2322
+ data = data_subset.to_polars() # type: ignore[union-attr]
2303
2323
  else:
2304
- data = data_subset.to_pandas()
2324
+ data = data_subset.to_pandas() # type: ignore[union-attr]
2305
2325
 
2306
2326
  # From a DataFrame:
2307
2327
  # - get the row count
@@ -2312,17 +2332,18 @@ def _generate_display_table(
2312
2332
  tbl_schema = Schema(tbl=data)
2313
2333
 
2314
2334
  if tbl_type == "polars":
2315
- n_rows = int(data.height)
2335
+ # Note: polars DataFrames have height, head(), tail() attributes
2336
+ n_rows = int(data.height) # type: ignore[union-attr]
2316
2337
 
2317
2338
  # If n_head + n_tail is greater than the row count, display the entire table
2318
2339
  if n_head + n_tail >= n_rows:
2319
2340
  full_dataset = True
2320
2341
 
2321
2342
  if row_number_list is None:
2322
- row_number_list = range(1, n_rows + 1)
2343
+ row_number_list = list(range(1, n_rows + 1))
2323
2344
 
2324
2345
  else:
2325
- data = pl.concat([data.head(n=n_head), data.tail(n=n_tail)])
2346
+ data = pl.concat([data.head(n=n_head), data.tail(n=n_tail)]) # type: ignore[union-attr]
2326
2347
 
2327
2348
  if row_number_list is None:
2328
2349
  row_number_list = list(range(1, n_head + 1)) + list(
@@ -2330,40 +2351,42 @@ def _generate_display_table(
2330
2351
  )
2331
2352
 
2332
2353
  if tbl_type == "pandas":
2333
- n_rows = data.shape[0]
2354
+ # Note: pandas DataFrames have shape, head(), tail() attributes
2355
+ n_rows = data.shape[0] # type: ignore[union-attr]
2334
2356
 
2335
2357
  # If n_head + n_tail is greater than the row count, display the entire table
2336
2358
  if n_head + n_tail >= n_rows:
2337
2359
  full_dataset = True
2338
2360
  data_subset = data
2339
2361
 
2340
- row_number_list = range(1, n_rows + 1)
2362
+ row_number_list = list(range(1, n_rows + 1))
2341
2363
  else:
2342
- data = pd.concat([data.head(n=n_head), data.tail(n=n_tail)])
2364
+ data = pd.concat([data.head(n=n_head), data.tail(n=n_tail)]) # type: ignore[union-attr]
2343
2365
 
2344
2366
  row_number_list = list(range(1, n_head + 1)) + list(
2345
2367
  range(n_rows - n_tail + 1, n_rows + 1)
2346
2368
  )
2347
2369
 
2348
2370
  if tbl_type == "pyspark":
2349
- n_rows = data.count()
2371
+ # Note: pyspark DataFrames have count(), toPandas(), limit(), tail(), sparkSession
2372
+ n_rows = data.count() # type: ignore[union-attr]
2350
2373
 
2351
2374
  # If n_head + n_tail is greater than the row count, display the entire table
2352
2375
  if n_head + n_tail >= n_rows:
2353
2376
  full_dataset = True
2354
2377
  # Convert to pandas for Great Tables compatibility
2355
- data = data.toPandas()
2378
+ data = data.toPandas() # type: ignore[union-attr]
2356
2379
 
2357
- row_number_list = range(1, n_rows + 1)
2380
+ row_number_list = list(range(1, n_rows + 1))
2358
2381
  else:
2359
2382
  # Get head and tail samples, then convert to pandas
2360
- head_data = data.limit(n_head).toPandas()
2383
+ head_data = data.limit(n_head).toPandas() # type: ignore[union-attr]
2361
2384
 
2362
2385
  # PySpark tail() returns a list of Row objects, need to convert to DataFrame
2363
- tail_rows = data.tail(n_tail)
2386
+ tail_rows = data.tail(n_tail) # type: ignore[union-attr]
2364
2387
  if tail_rows:
2365
2388
  # Convert list of Row objects back to DataFrame, then to pandas
2366
- tail_df = data.sparkSession.createDataFrame(tail_rows, data.schema)
2389
+ tail_df = data.sparkSession.createDataFrame(tail_rows, data.schema) # type: ignore[union-attr]
2367
2390
  tail_data = tail_df.toPandas()
2368
2391
  else:
2369
2392
  # If no tail data, create empty DataFrame with same schema
@@ -2391,14 +2414,14 @@ def _generate_display_table(
2391
2414
  tbl_schema = Schema(tbl=data)
2392
2415
 
2393
2416
  # From the table schema, get a list of tuples containing column names and data types
2394
- col_dtype_dict = tbl_schema.columns
2417
+ col_dtype_list = tbl_schema.columns or []
2395
2418
 
2396
2419
  # Extract the column names from the list of tuples (first element of each tuple)
2397
- col_names = [col[0] for col in col_dtype_dict]
2420
+ col_names = [col[0] for col in col_dtype_list]
2398
2421
 
2399
2422
  # Iterate over the list of tuples and create a new dictionary with the
2400
2423
  # column names and data types
2401
- col_dtype_dict = {k: v for k, v in col_dtype_dict}
2424
+ col_dtype_dict = {k: v for k, v in col_dtype_list}
2402
2425
 
2403
2426
  # Create short versions of the data types by omitting any text in parentheses
2404
2427
  col_dtype_dict_short = {
@@ -2497,21 +2520,21 @@ def _generate_display_table(
2497
2520
  # Prepend a column that contains the row numbers if `show_row_numbers=True`
2498
2521
  if show_row_numbers or has_leading_row_num_col:
2499
2522
  if has_leading_row_num_col:
2500
- row_number_list = data["_row_num_"].to_list()
2523
+ row_number_list = data["_row_num_"].to_list() # type: ignore[union-attr]
2501
2524
 
2502
2525
  else:
2503
2526
  if df_lib_name_gt == "polars":
2504
2527
  import polars as pl
2505
2528
 
2506
2529
  row_number_series = pl.Series("_row_num_", row_number_list)
2507
- data = data.insert_column(0, row_number_series)
2530
+ data = data.insert_column(0, row_number_series) # type: ignore[union-attr]
2508
2531
 
2509
2532
  if df_lib_name_gt == "pandas":
2510
- data.insert(0, "_row_num_", row_number_list)
2533
+ data.insert(0, "_row_num_", row_number_list) # type: ignore[union-attr]
2511
2534
 
2512
2535
  if df_lib_name_gt == "pyspark":
2513
2536
  # For PySpark converted to pandas, use pandas method
2514
- data.insert(0, "_row_num_", row_number_list)
2537
+ data.insert(0, "_row_num_", row_number_list) # type: ignore[union-attr]
2515
2538
 
2516
2539
  # Get the highest number in the `row_number_list` and calculate a width that will
2517
2540
  # safely fit a number of that magnitude
@@ -2620,7 +2643,7 @@ def _generate_display_table(
2620
2643
  return gt_tbl
2621
2644
 
2622
2645
 
2623
- def missing_vals_tbl(data: FrameT | Any) -> GT:
2646
+ def missing_vals_tbl(data: Any) -> GT:
2624
2647
  """
2625
2648
  Display a table that shows the missing values in the input table.
2626
2649
 
@@ -3221,7 +3244,7 @@ def _get_column_names_safe(data: Any) -> list[str]:
3221
3244
  return list(data.columns) # pragma: no cover
3222
3245
 
3223
3246
 
3224
- def _get_column_names(data: FrameT | Any, ibis_tbl: bool, df_lib_name_gt: str) -> list[str]:
3247
+ def _get_column_names(data: Any, ibis_tbl: bool, df_lib_name_gt: str) -> list[str]:
3225
3248
  if ibis_tbl:
3226
3249
  return data.columns if df_lib_name_gt == "polars" else list(data.columns)
3227
3250
 
@@ -3245,12 +3268,10 @@ def _validate_columns_subset(
3245
3268
  )
3246
3269
  return columns_subset
3247
3270
 
3248
- return columns_subset.resolve(columns=col_names)
3271
+ return columns_subset.resolve(columns=col_names) # type: ignore[union-attr]
3249
3272
 
3250
3273
 
3251
- def _select_columns(
3252
- data: FrameT | Any, resolved_columns: list[str], ibis_tbl: bool, tbl_type: str
3253
- ) -> FrameT | Any:
3274
+ def _select_columns(data: Any, resolved_columns: list[str], ibis_tbl: bool, tbl_type: str) -> Any:
3254
3275
  if ibis_tbl:
3255
3276
  return data[resolved_columns]
3256
3277
  if tbl_type == "polars":
@@ -3258,7 +3279,7 @@ def _select_columns(
3258
3279
  return data[resolved_columns]
3259
3280
 
3260
3281
 
3261
- def get_column_count(data: FrameT | Any) -> int:
3282
+ def get_column_count(data: Any) -> int:
3262
3283
  """
3263
3284
  Get the number of columns in a table.
3264
3285
 
@@ -3470,7 +3491,7 @@ def _extract_enum_values(set_values: Any) -> list[Any]:
3470
3491
  return [set_values]
3471
3492
 
3472
3493
 
3473
- def get_row_count(data: FrameT | Any) -> int:
3494
+ def get_row_count(data: Any) -> int:
3474
3495
  """
3475
3496
  Get the number of rows in a table.
3476
3497
 
@@ -3723,18 +3744,46 @@ class _ValidationInfo:
3723
3744
  insertion order, ensuring notes appear in a consistent sequence in reports and logs.
3724
3745
  """
3725
3746
 
3747
+ @classmethod
3748
+ def from_agg_validator(
3749
+ cls,
3750
+ assertion_type: str,
3751
+ columns: _PBUnresolvedColumn,
3752
+ value: float | Column | ReferenceColumn,
3753
+ tol: Tolerance = 0,
3754
+ thresholds: float | bool | tuple | dict | Thresholds | None = None,
3755
+ brief: str | bool = False,
3756
+ actions: Actions | None = None,
3757
+ active: bool = True,
3758
+ ) -> _ValidationInfo:
3759
+ # This factory method creates a `_ValidationInfo` instance for aggregate
3760
+ # methods. The reason this is created, is because all agg methods share the same
3761
+ # signature so instead of instantiating the class directly each time, this method
3762
+ # can be used to reduce redundancy, boilerplate and mistakes :)
3763
+ _check_thresholds(thresholds=thresholds)
3764
+
3765
+ return cls(
3766
+ assertion_type=assertion_type,
3767
+ column=_resolve_columns(columns),
3768
+ values={"value": value, "tol": tol},
3769
+ thresholds=_normalize_thresholds_creation(thresholds),
3770
+ brief=_transform_auto_brief(brief=brief),
3771
+ actions=actions,
3772
+ active=active,
3773
+ )
3774
+
3726
3775
  # Validation plan
3727
3776
  i: int | None = None
3728
3777
  i_o: int | None = None
3729
3778
  step_id: str | None = None
3730
3779
  sha1: str | None = None
3731
3780
  assertion_type: str | None = None
3732
- column: any | None = None
3733
- values: any | list[any] | tuple | None = None
3781
+ column: Any | None = None
3782
+ values: Any | list[Any] | tuple | None = None
3734
3783
  inclusive: tuple[bool, bool] | None = None
3735
3784
  na_pass: bool | None = None
3736
3785
  pre: Callable | None = None
3737
- segments: any | None = None
3786
+ segments: Any | None = None
3738
3787
  thresholds: Thresholds | None = None
3739
3788
  actions: Actions | None = None
3740
3789
  label: str | None = None
@@ -3753,14 +3802,14 @@ class _ValidationInfo:
3753
3802
  error: bool | None = None
3754
3803
  critical: bool | None = None
3755
3804
  failure_text: str | None = None
3756
- tbl_checked: FrameT | None = None
3757
- extract: FrameT | None = None
3758
- val_info: dict[str, any] | None = None
3805
+ tbl_checked: Any = None
3806
+ extract: Any = None
3807
+ val_info: dict[str, Any] | None = None
3759
3808
  time_processed: str | None = None
3760
3809
  proc_duration_s: float | None = None
3761
3810
  notes: dict[str, dict[str, str]] | None = None
3762
3811
 
3763
- def get_val_info(self) -> dict[str, any]:
3812
+ def get_val_info(self) -> dict[str, Any] | None:
3764
3813
  return self.val_info
3765
3814
 
3766
3815
  def _add_note(self, key: str, markdown: str, text: str | None = None) -> None:
@@ -3936,7 +3985,7 @@ class _ValidationInfo:
3936
3985
  return self.notes is not None and len(self.notes) > 0
3937
3986
 
3938
3987
 
3939
- def _handle_connection_errors(e: Exception, connection_string: str) -> None:
3988
+ def _handle_connection_errors(e: Exception, connection_string: str) -> NoReturn:
3940
3989
  """
3941
3990
  Shared error handling for database connection failures.
3942
3991
 
@@ -4777,7 +4826,8 @@ class Validate:
4777
4826
  when table specifications are missing or backend dependencies are not installed.
4778
4827
  """
4779
4828
 
4780
- data: FrameT | Any
4829
+ data: IntoDataFrame
4830
+ reference: IntoFrame | None = None
4781
4831
  tbl_name: str | None = None
4782
4832
  label: str | None = None
4783
4833
  thresholds: int | float | bool | tuple | dict | Thresholds | None = None
@@ -4791,6 +4841,10 @@ class Validate:
4791
4841
  # Process data through the centralized data processing pipeline
4792
4842
  self.data = _process_data(self.data)
4793
4843
 
4844
+ # Process reference data if provided
4845
+ if self.reference is not None:
4846
+ self.reference = _process_data(self.reference)
4847
+
4794
4848
  # Check input of the `thresholds=` argument
4795
4849
  _check_thresholds(thresholds=self.thresholds)
4796
4850
 
@@ -4835,9 +4889,107 @@ class Validate:
4835
4889
 
4836
4890
  self.validation_info = []
4837
4891
 
4892
+ def _add_agg_validation(
4893
+ self,
4894
+ *,
4895
+ assertion_type: str,
4896
+ columns: str | Collection[str],
4897
+ value,
4898
+ tol=0,
4899
+ thresholds=None,
4900
+ brief=False,
4901
+ actions=None,
4902
+ active=True,
4903
+ ):
4904
+ """
4905
+ Add an aggregation-based validation step to the validation plan.
4906
+
4907
+ This internal method is used by all aggregation-based column validation methods
4908
+ (e.g., `col_sum_eq`, `col_avg_gt`, `col_sd_le`) to create and register validation
4909
+ steps. It relies heavily on the `_ValidationInfo.from_agg_validator()` class method.
4910
+
4911
+ Automatic Reference Inference
4912
+ -----------------------------
4913
+ When `value` is None and reference data has been set on the Validate object,
4914
+ this method automatically creates a `ReferenceColumn` pointing to the same
4915
+ column name in the reference data. This enables a convenient shorthand:
4916
+
4917
+ .. code-block:: python
4918
+
4919
+ # Instead of writing:
4920
+ Validate(data=df, reference=ref_df).col_sum_eq("a", ref("a"))
4921
+
4922
+ # You can simply write:
4923
+ Validate(data=df, reference=ref_df).col_sum_eq("a")
4924
+
4925
+ If `value` is None and no reference data is set, a `ValueError` is raised
4926
+ immediately to provide clear feedback to the user.
4927
+
4928
+ Parameters
4929
+ ----------
4930
+ assertion_type
4931
+ The type of assertion (e.g., "col_sum_eq", "col_avg_gt").
4932
+ columns
4933
+ Column name or collection of column names to validate.
4934
+ value
4935
+ The target value to compare against. Can be:
4936
+ - A numeric literal (int or float)
4937
+ - A `Column` object for cross-column comparison
4938
+ - A `ReferenceColumn` object for reference data comparison
4939
+ - None to automatically use `ref(column)` when reference data is set
4940
+ tol
4941
+ Tolerance for the comparison. Defaults to 0.
4942
+ thresholds
4943
+ Custom thresholds for the validation step.
4944
+ brief
4945
+ Brief description or auto-generate flag.
4946
+ actions
4947
+ Actions to take based on validation results.
4948
+ active
4949
+ Whether this validation step is active.
4950
+
4951
+ Returns
4952
+ -------
4953
+ Validate
4954
+ The Validate instance for method chaining.
4955
+
4956
+ Raises
4957
+ ------
4958
+ ValueError
4959
+ If `value` is None and no reference data is set on the Validate object.
4960
+ """
4961
+ if isinstance(columns, str):
4962
+ columns = [columns]
4963
+ for column in columns:
4964
+ # If value is None, default to referencing the same column from reference data
4965
+ resolved_value = value
4966
+ if value is None:
4967
+ if self.reference is None:
4968
+ raise ValueError(
4969
+ f"The 'value' parameter is required for {assertion_type}() "
4970
+ "when no reference data is set. Either provide a value, or "
4971
+ "set reference data on the Validate object using "
4972
+ "Validate(data=..., reference=...)."
4973
+ )
4974
+ resolved_value = ReferenceColumn(column_name=column)
4975
+
4976
+ val_info = _ValidationInfo.from_agg_validator(
4977
+ assertion_type=assertion_type,
4978
+ columns=column,
4979
+ value=resolved_value,
4980
+ tol=tol,
4981
+ thresholds=self.thresholds if thresholds is None else thresholds,
4982
+ actions=self.actions if actions is None else actions,
4983
+ brief=self.brief if brief is None else brief,
4984
+ active=active,
4985
+ )
4986
+ self._add_validation(validation_info=val_info)
4987
+
4988
+ return self
4989
+
4838
4990
  def set_tbl(
4839
4991
  self,
4840
- tbl: FrameT | Any,
4992
+ tbl: Any,
4841
4993
  tbl_name: str | None = None,
4842
4994
  label: str | None = None,
4843
4995
  ) -> Validate:
@@ -4980,7 +5132,7 @@ class Validate:
4980
5132
  na_pass: bool = False,
4981
5133
  pre: Callable | None = None,
4982
5134
  segments: SegmentSpec | None = None,
4983
- thresholds: int | float | bool | tuple | dict | Thresholds = None,
5135
+ thresholds: int | float | bool | tuple | dict | Thresholds | None = None,
4984
5136
  actions: Actions | None = None,
4985
5137
  brief: str | bool | None = None,
4986
5138
  active: bool = True,
@@ -5214,7 +5366,6 @@ class Validate:
5214
5366
  - Row 1: `c` is `1` and `b` is `2`.
5215
5367
  - Row 3: `c` is `2` and `b` is `2`.
5216
5368
  """
5217
-
5218
5369
  assertion_type = _get_fn_name()
5219
5370
 
5220
5371
  _check_column(column=columns)
@@ -5234,14 +5385,7 @@ class Validate:
5234
5385
  self.thresholds if thresholds is None else _normalize_thresholds_creation(thresholds)
5235
5386
  )
5236
5387
 
5237
- # If `columns` is a ColumnSelector or Narwhals selector, call `col()` on it to later
5238
- # resolve the columns
5239
- if isinstance(columns, (ColumnSelector, nw.selectors.Selector)):
5240
- columns = col(columns)
5241
-
5242
- # If `columns` is Column value or a string, place it in a list for iteration
5243
- if isinstance(columns, (Column, str)):
5244
- columns = [columns]
5388
+ columns = _resolve_columns(columns)
5245
5389
 
5246
5390
  # Determine brief to use (global or local) and transform any shorthands of `brief=`
5247
5391
  brief = self.brief if brief is None else _transform_auto_brief(brief=brief)
@@ -5272,7 +5416,7 @@ class Validate:
5272
5416
  na_pass: bool = False,
5273
5417
  pre: Callable | None = None,
5274
5418
  segments: SegmentSpec | None = None,
5275
- thresholds: int | float | bool | tuple | dict | Thresholds = None,
5419
+ thresholds: int | float | bool | tuple | dict | Thresholds | None = None,
5276
5420
  actions: Actions | None = None,
5277
5421
  brief: str | bool | None = None,
5278
5422
  active: bool = True,
@@ -5563,7 +5707,7 @@ class Validate:
5563
5707
  na_pass: bool = False,
5564
5708
  pre: Callable | None = None,
5565
5709
  segments: SegmentSpec | None = None,
5566
- thresholds: int | float | bool | tuple | dict | Thresholds = None,
5710
+ thresholds: int | float | bool | tuple | dict | Thresholds | None = None,
5567
5711
  actions: Actions | None = None,
5568
5712
  brief: str | bool | None = None,
5569
5713
  active: bool = True,
@@ -5854,7 +5998,7 @@ class Validate:
5854
5998
  na_pass: bool = False,
5855
5999
  pre: Callable | None = None,
5856
6000
  segments: SegmentSpec | None = None,
5857
- thresholds: int | float | bool | tuple | dict | Thresholds = None,
6001
+ thresholds: int | float | bool | tuple | dict | Thresholds | None = None,
5858
6002
  actions: Actions | None = None,
5859
6003
  brief: str | bool | None = None,
5860
6004
  active: bool = True,
@@ -6143,7 +6287,7 @@ class Validate:
6143
6287
  na_pass: bool = False,
6144
6288
  pre: Callable | None = None,
6145
6289
  segments: SegmentSpec | None = None,
6146
- thresholds: int | float | bool | tuple | dict | Thresholds = None,
6290
+ thresholds: int | float | bool | tuple | dict | Thresholds | None = None,
6147
6291
  actions: Actions | None = None,
6148
6292
  brief: str | bool | None = None,
6149
6293
  active: bool = True,
@@ -6435,7 +6579,7 @@ class Validate:
6435
6579
  na_pass: bool = False,
6436
6580
  pre: Callable | None = None,
6437
6581
  segments: SegmentSpec | None = None,
6438
- thresholds: int | float | bool | tuple | dict | Thresholds = None,
6582
+ thresholds: int | float | bool | tuple | dict | Thresholds | None = None,
6439
6583
  actions: Actions | None = None,
6440
6584
  brief: str | bool | None = None,
6441
6585
  active: bool = True,
@@ -6729,7 +6873,7 @@ class Validate:
6729
6873
  na_pass: bool = False,
6730
6874
  pre: Callable | None = None,
6731
6875
  segments: SegmentSpec | None = None,
6732
- thresholds: int | float | bool | tuple | dict | Thresholds = None,
6876
+ thresholds: int | float | bool | tuple | dict | Thresholds | None = None,
6733
6877
  actions: Actions | None = None,
6734
6878
  brief: str | bool | None = None,
6735
6879
  active: bool = True,
@@ -7049,7 +7193,7 @@ class Validate:
7049
7193
  na_pass: bool = False,
7050
7194
  pre: Callable | None = None,
7051
7195
  segments: SegmentSpec | None = None,
7052
- thresholds: int | float | bool | tuple | dict | Thresholds = None,
7196
+ thresholds: int | float | bool | tuple | dict | Thresholds | None = None,
7053
7197
  actions: Actions | None = None,
7054
7198
  brief: str | bool | None = None,
7055
7199
  active: bool = True,
@@ -7366,7 +7510,7 @@ class Validate:
7366
7510
  set: Collection[Any],
7367
7511
  pre: Callable | None = None,
7368
7512
  segments: SegmentSpec | None = None,
7369
- thresholds: int | float | bool | tuple | dict | Thresholds = None,
7513
+ thresholds: int | float | bool | tuple | dict | Thresholds | None = None,
7370
7514
  actions: Actions | None = None,
7371
7515
  brief: str | bool | None = None,
7372
7516
  active: bool = True,
@@ -7683,7 +7827,7 @@ class Validate:
7683
7827
  set: Collection[Any],
7684
7828
  pre: Callable | None = None,
7685
7829
  segments: SegmentSpec | None = None,
7686
- thresholds: int | float | bool | tuple | dict | Thresholds = None,
7830
+ thresholds: int | float | bool | tuple | dict | Thresholds | None = None,
7687
7831
  actions: Actions | None = None,
7688
7832
  brief: str | bool | None = None,
7689
7833
  active: bool = True,
@@ -7974,7 +8118,7 @@ class Validate:
7974
8118
  na_pass: bool = False,
7975
8119
  pre: Callable | None = None,
7976
8120
  segments: SegmentSpec | None = None,
7977
- thresholds: int | float | bool | tuple | dict | Thresholds = None,
8121
+ thresholds: int | float | bool | tuple | dict | Thresholds | None = None,
7978
8122
  actions: Actions | None = None,
7979
8123
  brief: str | bool | None = None,
7980
8124
  active: bool = True,
@@ -8162,7 +8306,7 @@ class Validate:
8162
8306
  na_pass: bool = False,
8163
8307
  pre: Callable | None = None,
8164
8308
  segments: SegmentSpec | None = None,
8165
- thresholds: int | float | bool | tuple | dict | Thresholds = None,
8309
+ thresholds: int | float | bool | tuple | dict | Thresholds | None = None,
8166
8310
  actions: Actions | None = None,
8167
8311
  brief: str | bool | None = None,
8168
8312
  active: bool = True,
@@ -8347,7 +8491,7 @@ class Validate:
8347
8491
  columns: str | list[str] | Column | ColumnSelector | ColumnSelectorNarwhals,
8348
8492
  pre: Callable | None = None,
8349
8493
  segments: SegmentSpec | None = None,
8350
- thresholds: int | float | bool | tuple | dict | Thresholds = None,
8494
+ thresholds: int | float | bool | tuple | dict | Thresholds | None = None,
8351
8495
  actions: Actions | None = None,
8352
8496
  brief: str | bool | None = None,
8353
8497
  active: bool = True,
@@ -8590,7 +8734,7 @@ class Validate:
8590
8734
  columns: str | list[str] | Column | ColumnSelector | ColumnSelectorNarwhals,
8591
8735
  pre: Callable | None = None,
8592
8736
  segments: SegmentSpec | None = None,
8593
- thresholds: int | float | bool | tuple | dict | Thresholds = None,
8737
+ thresholds: int | float | bool | tuple | dict | Thresholds | None = None,
8594
8738
  actions: Actions | None = None,
8595
8739
  brief: str | bool | None = None,
8596
8740
  active: bool = True,
@@ -8836,7 +8980,7 @@ class Validate:
8836
8980
  inverse: bool = False,
8837
8981
  pre: Callable | None = None,
8838
8982
  segments: SegmentSpec | None = None,
8839
- thresholds: int | float | bool | tuple | dict | Thresholds = None,
8983
+ thresholds: int | float | bool | tuple | dict | Thresholds | None = None,
8840
8984
  actions: Actions | None = None,
8841
8985
  brief: str | bool | None = None,
8842
8986
  active: bool = True,
@@ -9099,7 +9243,7 @@ class Validate:
9099
9243
  na_pass: bool = False,
9100
9244
  pre: Callable | None = None,
9101
9245
  segments: SegmentSpec | None = None,
9102
- thresholds: int | float | bool | tuple | dict | Thresholds = None,
9246
+ thresholds: int | float | bool | tuple | dict | Thresholds | None = None,
9103
9247
  actions: Actions | None = None,
9104
9248
  brief: str | bool | None = None,
9105
9249
  active: bool = True,
@@ -9379,10 +9523,10 @@ class Validate:
9379
9523
 
9380
9524
  def col_vals_expr(
9381
9525
  self,
9382
- expr: any,
9526
+ expr: Any,
9383
9527
  pre: Callable | None = None,
9384
9528
  segments: SegmentSpec | None = None,
9385
- thresholds: int | float | bool | tuple | dict | Thresholds = None,
9529
+ thresholds: int | float | bool | tuple | dict | Thresholds | None = None,
9386
9530
  actions: Actions | None = None,
9387
9531
  brief: str | bool | None = None,
9388
9532
  active: bool = True,
@@ -9600,7 +9744,7 @@ class Validate:
9600
9744
  def col_exists(
9601
9745
  self,
9602
9746
  columns: str | list[str] | Column | ColumnSelector | ColumnSelectorNarwhals,
9603
- thresholds: int | float | bool | tuple | dict | Thresholds = None,
9747
+ thresholds: int | float | bool | tuple | dict | Thresholds | None = None,
9604
9748
  actions: Actions | None = None,
9605
9749
  brief: str | bool | None = None,
9606
9750
  active: bool = True,
@@ -10072,7 +10216,7 @@ class Validate:
10072
10216
  columns_subset: str | list[str] | None = None,
10073
10217
  pre: Callable | None = None,
10074
10218
  segments: SegmentSpec | None = None,
10075
- thresholds: int | float | bool | tuple | dict | Thresholds = None,
10219
+ thresholds: int | float | bool | tuple | dict | Thresholds | None = None,
10076
10220
  actions: Actions | None = None,
10077
10221
  brief: str | bool | None = None,
10078
10222
  active: bool = True,
@@ -10313,7 +10457,7 @@ class Validate:
10313
10457
  columns_subset: str | list[str] | None = None,
10314
10458
  pre: Callable | None = None,
10315
10459
  segments: SegmentSpec | None = None,
10316
- thresholds: int | float | bool | tuple | dict | Thresholds = None,
10460
+ thresholds: int | float | bool | tuple | dict | Thresholds | None = None,
10317
10461
  actions: Actions | None = None,
10318
10462
  brief: str | bool | None = None,
10319
10463
  active: bool = True,
@@ -10558,7 +10702,7 @@ class Validate:
10558
10702
  max_concurrent: int = 3,
10559
10703
  pre: Callable | None = None,
10560
10704
  segments: SegmentSpec | None = None,
10561
- thresholds: int | float | bool | tuple | dict | Thresholds = None,
10705
+ thresholds: int | float | bool | tuple | dict | Thresholds | None = None,
10562
10706
  actions: Actions | None = None,
10563
10707
  brief: str | bool | None = None,
10564
10708
  active: bool = True,
@@ -10953,7 +11097,7 @@ class Validate:
10953
11097
  case_sensitive_dtypes: bool = True,
10954
11098
  full_match_dtypes: bool = True,
10955
11099
  pre: Callable | None = None,
10956
- thresholds: int | float | bool | tuple | dict | Thresholds = None,
11100
+ thresholds: int | float | bool | tuple | dict | Thresholds | None = None,
10957
11101
  actions: Actions | None = None,
10958
11102
  brief: str | bool | None = None,
10959
11103
  active: bool = True,
@@ -11169,11 +11313,11 @@ class Validate:
11169
11313
 
11170
11314
  def row_count_match(
11171
11315
  self,
11172
- count: int | FrameT | Any,
11316
+ count: int | Any,
11173
11317
  tol: Tolerance = 0,
11174
11318
  inverse: bool = False,
11175
11319
  pre: Callable | None = None,
11176
- thresholds: int | float | bool | tuple | dict | Thresholds = None,
11320
+ thresholds: int | float | bool | tuple | dict | Thresholds | None = None,
11177
11321
  actions: Actions | None = None,
11178
11322
  brief: str | bool | None = None,
11179
11323
  active: bool = True,
@@ -11388,10 +11532,10 @@ class Validate:
11388
11532
 
11389
11533
  def col_count_match(
11390
11534
  self,
11391
- count: int | FrameT | Any,
11535
+ count: int | Any,
11392
11536
  inverse: bool = False,
11393
11537
  pre: Callable | None = None,
11394
- thresholds: int | float | bool | tuple | dict | Thresholds = None,
11538
+ thresholds: int | float | bool | tuple | dict | Thresholds | None = None,
11395
11539
  actions: Actions | None = None,
11396
11540
  brief: str | bool | None = None,
11397
11541
  active: bool = True,
@@ -11564,9 +11708,9 @@ class Validate:
11564
11708
 
11565
11709
  def tbl_match(
11566
11710
  self,
11567
- tbl_compare: FrameT | Any,
11711
+ tbl_compare: Any,
11568
11712
  pre: Callable | None = None,
11569
- thresholds: int | float | bool | tuple | dict | Thresholds = None,
11713
+ thresholds: int | float | bool | tuple | dict | Thresholds | None = None,
11570
11714
  actions: Actions | None = None,
11571
11715
  brief: str | bool | None = None,
11572
11716
  active: bool = True,
@@ -11835,7 +11979,7 @@ class Validate:
11835
11979
  self,
11836
11980
  *exprs: Callable,
11837
11981
  pre: Callable | None = None,
11838
- thresholds: int | float | bool | tuple | dict | Thresholds = None,
11982
+ thresholds: int | float | bool | tuple | dict | Thresholds | None = None,
11839
11983
  actions: Actions | None = None,
11840
11984
  brief: str | bool | None = None,
11841
11985
  active: bool = True,
@@ -12083,7 +12227,7 @@ class Validate:
12083
12227
  self,
12084
12228
  expr: Callable,
12085
12229
  pre: Callable | None = None,
12086
- thresholds: int | float | bool | tuple | dict | Thresholds = None,
12230
+ thresholds: int | float | bool | tuple | dict | Thresholds | None = None,
12087
12231
  actions: Actions | None = None,
12088
12232
  brief: str | bool | None = None,
12089
12233
  active: bool = True,
@@ -12577,7 +12721,7 @@ class Validate:
12577
12721
  segment = validation.segments
12578
12722
 
12579
12723
  # Get compatible data types for this assertion type
12580
- assertion_method = ASSERTION_TYPE_METHOD_MAP[assertion_type]
12724
+ assertion_method = ASSERTION_TYPE_METHOD_MAP.get(assertion_type, assertion_type)
12581
12725
  compatible_dtypes = COMPATIBLE_DTYPES.get(assertion_method, [])
12582
12726
 
12583
12727
  # Process the `brief` text for the validation step by including template variables to
@@ -12632,7 +12776,11 @@ class Validate:
12632
12776
 
12633
12777
  # Make a deep copy of the table for this step to ensure proper isolation
12634
12778
  # This prevents modifications from one validation step affecting others
12635
- data_tbl_step = _copy_dataframe(data_tbl)
12779
+ try:
12780
+ # TODO: This copying should be scrutinized further
12781
+ data_tbl_step: IntoDataFrame = _copy_dataframe(data_tbl)
12782
+ except Exception as e: # pragma: no cover
12783
+ data_tbl_step: IntoDataFrame = data_tbl # pragma: no cover
12636
12784
 
12637
12785
  # Capture original table dimensions and columns before preprocessing
12638
12786
  # (only if preprocessing is present - we'll set these inside the preprocessing block)
@@ -13080,6 +13228,44 @@ class Validate:
13080
13228
  tbl_type=tbl_type,
13081
13229
  )
13082
13230
 
13231
+ elif is_valid_agg(assertion_type):
13232
+ agg, comp = resolve_agg_registries(assertion_type)
13233
+
13234
+ # Produce a 1-column Narwhals DataFrame
13235
+ # TODO: Should be able to take lazy too
13236
+ vec: nw.DataFrame = nw.from_native(data_tbl_step).select(column)
13237
+ real = agg(vec)
13238
+
13239
+ raw_value = value["value"]
13240
+ tol = value["tol"]
13241
+
13242
+ # Handle ReferenceColumn: compute target from reference data
13243
+ if isinstance(raw_value, ReferenceColumn):
13244
+ if self.reference is None:
13245
+ raise ValueError(
13246
+ f"Cannot use ref('{raw_value.column_name}') without "
13247
+ "setting reference data on the Validate object. "
13248
+ "Use Validate(data=..., reference=...) to set reference data."
13249
+ )
13250
+ ref_vec: nw.DataFrame = nw.from_native(self.reference).select(
13251
+ raw_value.column_name
13252
+ )
13253
+ target: float | int = agg(ref_vec)
13254
+ else:
13255
+ target = raw_value
13256
+
13257
+ lower_diff, upper_diff = _derive_bounds(target, tol)
13258
+
13259
+ lower_bound = target - lower_diff
13260
+ upper_bound = target + upper_diff
13261
+ result_bool: bool = comp(real, lower_bound, upper_bound)
13262
+
13263
+ validation.all_passed = result_bool
13264
+ validation.n = 1
13265
+ validation.n_passed = int(result_bool)
13266
+ validation.n_failed = 1 - result_bool
13267
+
13268
+ results_tbl = None
13083
13269
  else:
13084
13270
  raise ValueError(
13085
13271
  f"Unknown assertion type: {assertion_type}"
@@ -13822,12 +14008,14 @@ class Validate:
13822
14008
  )
13823
14009
 
13824
14010
  # Get the threshold status using the appropriate method
14011
+ # Note: scalar=False (default) always returns a dict
14012
+ status: dict[int, bool]
13825
14013
  if level == "warning":
13826
- status = self.warning(i=i)
14014
+ status = self.warning(i=i) # type: ignore[assignment]
13827
14015
  elif level == "error":
13828
- status = self.error(i=i)
13829
- elif level == "critical":
13830
- status = self.critical(i=i)
14016
+ status = self.error(i=i) # type: ignore[assignment]
14017
+ else: # level == "critical"
14018
+ status = self.critical(i=i) # type: ignore[assignment]
13831
14019
 
13832
14020
  # Find any steps that exceeded the threshold
13833
14021
  failures = []
@@ -13981,12 +14169,14 @@ class Validate:
13981
14169
  )
13982
14170
 
13983
14171
  # Get the threshold status using the appropriate method
14172
+ # Note: scalar=False (default) always returns a dict
14173
+ status: dict[int, bool]
13984
14174
  if level == "warning":
13985
- status = self.warning(i=i)
14175
+ status = self.warning(i=i) # type: ignore[assignment]
13986
14176
  elif level == "error":
13987
- status = self.error(i=i)
13988
- elif level == "critical":
13989
- status = self.critical(i=i)
14177
+ status = self.error(i=i) # type: ignore[assignment]
14178
+ else: # level == "critical"
14179
+ status = self.critical(i=i) # type: ignore[assignment]
13990
14180
 
13991
14181
  # Return True if any steps exceeded the threshold
13992
14182
  return any(status.values())
@@ -14759,7 +14949,7 @@ class Validate:
14759
14949
 
14760
14950
  def get_data_extracts(
14761
14951
  self, i: int | list[int] | None = None, frame: bool = False
14762
- ) -> dict[int, FrameT | None] | FrameT | None:
14952
+ ) -> dict[int, Any] | Any:
14763
14953
  """
14764
14954
  Get the rows that failed for each validation step.
14765
14955
 
@@ -14782,7 +14972,7 @@ class Validate:
14782
14972
 
14783
14973
  Returns
14784
14974
  -------
14785
- dict[int, FrameT | None] | FrameT | None
14975
+ dict[int, Any] | Any
14786
14976
  A dictionary of tables containing the rows that failed in every compatible validation
14787
14977
  step. Alternatively, it can be a DataFrame if `frame=True` and `i=` is a scalar.
14788
14978
 
@@ -15072,7 +15262,7 @@ class Validate:
15072
15262
 
15073
15263
  return json.dumps(report, indent=4, default=str)
15074
15264
 
15075
- def get_sundered_data(self, type="pass") -> FrameT:
15265
+ def get_sundered_data(self, type="pass") -> Any:
15076
15266
  """
15077
15267
  Get the data that passed or failed the validation steps.
15078
15268
 
@@ -15108,7 +15298,7 @@ class Validate:
15108
15298
 
15109
15299
  Returns
15110
15300
  -------
15111
- FrameT
15301
+ Any
15112
15302
  A table containing the data that passed or failed the validation steps.
15113
15303
 
15114
15304
  Examples
@@ -15200,6 +15390,7 @@ class Validate:
15200
15390
  # Get all validation step result tables and join together the `pb_is_good_` columns
15201
15391
  # ensuring that the columns are named uniquely (e.g., `pb_is_good_1`, `pb_is_good_2`, ...)
15202
15392
  # and that the index is reset
15393
+ labeled_tbl_nw: nw.DataFrame | nw.LazyFrame | None = None
15203
15394
  for i, validation in enumerate(validation_info):
15204
15395
  results_tbl = nw.from_native(validation.tbl_checked)
15205
15396
 
@@ -15220,7 +15411,7 @@ class Validate:
15220
15411
  )
15221
15412
 
15222
15413
  # Add the results table to the list of tables
15223
- if i == 0:
15414
+ if labeled_tbl_nw is None:
15224
15415
  labeled_tbl_nw = results_tbl
15225
15416
  else:
15226
15417
  labeled_tbl_nw = labeled_tbl_nw.join(results_tbl, on=index_name, how="left")
@@ -15396,10 +15587,10 @@ class Validate:
15396
15587
  def get_tabular_report(
15397
15588
  self,
15398
15589
  title: str | None = ":default:",
15399
- incl_header: bool = None,
15400
- incl_footer: bool = None,
15401
- incl_footer_timings: bool = None,
15402
- incl_footer_notes: bool = None,
15590
+ incl_header: bool | None = None,
15591
+ incl_footer: bool | None = None,
15592
+ incl_footer_timings: bool | None = None,
15593
+ incl_footer_notes: bool | None = None,
15403
15594
  ) -> GT:
15404
15595
  """
15405
15596
  Validation report as a GT table.
@@ -15767,10 +15958,16 @@ class Validate:
15767
15958
  elif assertion_type[i] in ["conjointly", "specially"]:
15768
15959
  column_text = ""
15769
15960
  else:
15770
- column_text = str(column)
15961
+ # Handle both string columns and list columns
15962
+ # For single-element lists like ['a'], display as 'a'
15963
+ # For multi-element lists, display as comma-separated values
15964
+ if isinstance(column, list):
15965
+ column_text = ", ".join(str(c) for c in column)
15966
+ else:
15967
+ column_text = str(column)
15771
15968
 
15772
- # Apply underline styling for synthetic columns (using the purple color from the icon)
15773
- # Only apply styling if column_text is not empty and not a special marker
15969
+ # Apply underline styling for synthetic columns; only apply styling if column_text is
15970
+ # not empty and not a special marker
15774
15971
  if (
15775
15972
  has_synthetic_column
15776
15973
  and column_text
@@ -15889,6 +16086,32 @@ class Validate:
15889
16086
  else: # pragma: no cover
15890
16087
  values_upd.append(str(value)) # pragma: no cover
15891
16088
 
16089
+ # Handle aggregation methods (col_sum_gt, col_avg_eq, etc.)
16090
+ elif is_valid_agg(assertion_type[i]):
16091
+ # Extract the value and tolerance from the values dict
16092
+ agg_value = value.get("value")
16093
+ tol_value = value.get("tol", 0)
16094
+
16095
+ # Format the value (could be a number, Column, or ReferenceColumn)
16096
+ if hasattr(agg_value, "__repr__"):
16097
+ # For Column or ReferenceColumn objects, use their repr
16098
+ value_str = repr(agg_value)
16099
+ else:
16100
+ value_str = str(agg_value)
16101
+
16102
+ # Format tolerance - only show on second line if non-zero
16103
+ if tol_value != 0:
16104
+ # Format tolerance based on its type
16105
+ if isinstance(tol_value, tuple):
16106
+ # Asymmetric bounds: (lower, upper)
16107
+ tol_str = f"tol=({tol_value[0]}, {tol_value[1]})"
16108
+ else:
16109
+ # Symmetric tolerance
16110
+ tol_str = f"tol={tol_value}"
16111
+ values_upd.append(f"{value_str}<br/>{tol_str}")
16112
+ else:
16113
+ values_upd.append(value_str)
16114
+
15892
16115
  # If the assertion type is not recognized, add the value as a string
15893
16116
  else: # pragma: no cover
15894
16117
  values_upd.append(str(value)) # pragma: no cover
@@ -16738,7 +16961,7 @@ class Validate:
16738
16961
  table = validation.pre(self.data)
16739
16962
 
16740
16963
  # Get the columns from the table as a list
16741
- columns = list(table.columns)
16964
+ columns = list(table.columns) # type: ignore[union-attr]
16742
16965
 
16743
16966
  # Evaluate the column expression
16744
16967
  if isinstance(column_expr, ColumnSelectorNarwhals):
@@ -17116,7 +17339,7 @@ def _convert_string_to_datetime(value: str) -> datetime.datetime:
17116
17339
  return datetime.datetime.strptime(value, "%Y-%m-%d %H:%M:%S")
17117
17340
 
17118
17341
 
17119
- def _string_date_dttm_conversion(value: any) -> any:
17342
+ def _string_date_dttm_conversion(value: Any) -> Any:
17120
17343
  """
17121
17344
  Convert a string to a date or datetime object if it is in the correct format.
17122
17345
  If the value is not a string, it is returned as is.
@@ -17151,8 +17374,8 @@ def _string_date_dttm_conversion(value: any) -> any:
17151
17374
 
17152
17375
 
17153
17376
  def _conditional_string_date_dttm_conversion(
17154
- value: any, allow_regular_strings: bool = False
17155
- ) -> any:
17377
+ value: Any, allow_regular_strings: bool = False
17378
+ ) -> Any:
17156
17379
  """
17157
17380
  Conditionally convert a string to a date or datetime object if it is in the correct format. If
17158
17381
  `allow_regular_strings=` is `True`, regular strings are allowed to pass through unchanged. If
@@ -17196,9 +17419,9 @@ def _process_brief(
17196
17419
  brief: str | None,
17197
17420
  step: int,
17198
17421
  col: str | list[str] | None,
17199
- values: any | None,
17200
- thresholds: any | None,
17201
- segment: any | None,
17422
+ values: Any | None,
17423
+ thresholds: Any | None,
17424
+ segment: Any | None,
17202
17425
  ) -> str:
17203
17426
  # If there is no brief, return `None`
17204
17427
  if brief is None:
@@ -17285,7 +17508,7 @@ def _process_action_str(
17285
17508
  action_str: str,
17286
17509
  step: int,
17287
17510
  col: str | None,
17288
- value: any,
17511
+ value: Any,
17289
17512
  type: str,
17290
17513
  level: str,
17291
17514
  time: str,
@@ -17337,8 +17560,8 @@ def _process_action_str(
17337
17560
  def _create_autobrief_or_failure_text(
17338
17561
  assertion_type: str,
17339
17562
  lang: str,
17340
- column: str | None,
17341
- values: str | None,
17563
+ column: str,
17564
+ values: Any,
17342
17565
  for_failure: bool,
17343
17566
  locale: str | None = None,
17344
17567
  n_rows: int | None = None,
@@ -17490,7 +17713,7 @@ def _create_autobrief_or_failure_text(
17490
17713
  for_failure=for_failure,
17491
17714
  )
17492
17715
 
17493
- return None # pragma: no cover
17716
+ return None
17494
17717
 
17495
17718
 
17496
17719
  def _expect_failure_type(for_failure: bool) -> str:
@@ -17500,7 +17723,7 @@ def _expect_failure_type(for_failure: bool) -> str:
17500
17723
  def _create_text_comparison(
17501
17724
  assertion_type: str,
17502
17725
  lang: str,
17503
- column: str | list[str] | None,
17726
+ column: str | list[str],
17504
17727
  values: str | None,
17505
17728
  for_failure: bool = False,
17506
17729
  ) -> str:
@@ -17526,7 +17749,7 @@ def _create_text_comparison(
17526
17749
 
17527
17750
  def _create_text_between(
17528
17751
  lang: str,
17529
- column: str | None,
17752
+ column: str,
17530
17753
  value_1: str,
17531
17754
  value_2: str,
17532
17755
  not_: bool = False,
@@ -17556,7 +17779,7 @@ def _create_text_between(
17556
17779
 
17557
17780
 
17558
17781
  def _create_text_set(
17559
- lang: str, column: str | None, values: list[any], not_: bool = False, for_failure: bool = False
17782
+ lang: str, column: str, values: list[Any], not_: bool = False, for_failure: bool = False
17560
17783
  ) -> str:
17561
17784
  type_ = _expect_failure_type(for_failure=for_failure)
17562
17785
 
@@ -17578,9 +17801,7 @@ def _create_text_set(
17578
17801
  return text
17579
17802
 
17580
17803
 
17581
- def _create_text_null(
17582
- lang: str, column: str | None, not_: bool = False, for_failure: bool = False
17583
- ) -> str:
17804
+ def _create_text_null(lang: str, column: str, not_: bool = False, for_failure: bool = False) -> str:
17584
17805
  type_ = _expect_failure_type(for_failure=for_failure)
17585
17806
 
17586
17807
  column_text = _prep_column_text(column=column)
@@ -17597,9 +17818,7 @@ def _create_text_null(
17597
17818
  return text
17598
17819
 
17599
17820
 
17600
- def _create_text_regex(
17601
- lang: str, column: str | None, pattern: str | dict, for_failure: bool = False
17602
- ) -> str:
17821
+ def _create_text_regex(lang: str, column: str, pattern: str, for_failure: bool = False) -> str:
17603
17822
  type_ = _expect_failure_type(for_failure=for_failure)
17604
17823
 
17605
17824
  column_text = _prep_column_text(column=column)
@@ -17631,7 +17850,7 @@ def _create_text_expr(lang: str, for_failure: bool) -> str:
17631
17850
  return EXPECT_FAIL_TEXT[f"col_vals_expr_{type_}_text"][lang]
17632
17851
 
17633
17852
 
17634
- def _create_text_col_exists(lang: str, column: str | None, for_failure: bool = False) -> str:
17853
+ def _create_text_col_exists(lang: str, column: str, for_failure: bool = False) -> str:
17635
17854
  type_ = _expect_failure_type(for_failure=for_failure)
17636
17855
 
17637
17856
  column_text = _prep_column_text(column=column)
@@ -17681,7 +17900,7 @@ def _create_text_rows_complete(
17681
17900
  return text
17682
17901
 
17683
17902
 
17684
- def _create_text_row_count_match(lang: str, value: int, for_failure: bool = False) -> str:
17903
+ def _create_text_row_count_match(lang: str, value: dict, for_failure: bool = False) -> str:
17685
17904
  type_ = _expect_failure_type(for_failure=for_failure)
17686
17905
 
17687
17906
  values_text = _prep_values_text(value["count"], lang=lang)
@@ -17689,7 +17908,7 @@ def _create_text_row_count_match(lang: str, value: int, for_failure: bool = Fals
17689
17908
  return EXPECT_FAIL_TEXT[f"row_count_match_n_{type_}_text"][lang].format(values_text=values_text)
17690
17909
 
17691
17910
 
17692
- def _create_text_col_count_match(lang: str, value: int, for_failure: bool = False) -> str:
17911
+ def _create_text_col_count_match(lang: str, value: dict, for_failure: bool = False) -> str:
17693
17912
  type_ = _expect_failure_type(for_failure=for_failure)
17694
17913
 
17695
17914
  values_text = _prep_values_text(value["count"], lang=lang)
@@ -17826,19 +18045,13 @@ def _create_text_prompt(lang: str, prompt: str, for_failure: bool = False) -> st
17826
18045
  def _prep_column_text(column: str | list[str]) -> str:
17827
18046
  if isinstance(column, list):
17828
18047
  return "`" + str(column[0]) + "`"
17829
- elif isinstance(column, str):
18048
+ if isinstance(column, str):
17830
18049
  return "`" + column + "`"
17831
- else:
17832
- return ""
18050
+ raise AssertionError
17833
18051
 
17834
18052
 
17835
18053
  def _prep_values_text(
17836
- values: str
17837
- | int
17838
- | float
17839
- | datetime.datetime
17840
- | datetime.date
17841
- | list[str | int | float | datetime.datetime | datetime.date],
18054
+ values: _CompliantValue | _CompliantValues,
17842
18055
  lang: str,
17843
18056
  limit: int = 3,
17844
18057
  ) -> str:
@@ -17886,7 +18099,7 @@ def _prep_values_text(
17886
18099
  return values_str
17887
18100
 
17888
18101
 
17889
- def _seg_expr_from_string(data_tbl: any, segments_expr: str) -> list[tuple[str, str]]:
18102
+ def _seg_expr_from_string(data_tbl: Any, segments_expr: str) -> tuple[str, str]:
17890
18103
  """
17891
18104
  Obtain the segmentation categories from a table column.
17892
18105
 
@@ -17989,7 +18202,7 @@ def _seg_expr_from_tuple(segments_expr: tuple) -> list[tuple[str, Any]]:
17989
18202
  return seg_tuples
17990
18203
 
17991
18204
 
17992
- def _apply_segments(data_tbl: any, segments_expr: tuple[str, Any]) -> any:
18205
+ def _apply_segments(data_tbl: Any, segments_expr: tuple[str, str]) -> Any:
17993
18206
  """
17994
18207
  Apply the segments expression to the data table.
17995
18208
 
@@ -18053,8 +18266,26 @@ def _apply_segments(data_tbl: any, segments_expr: tuple[str, Any]) -> any:
18053
18266
  except ValueError: # pragma: no cover
18054
18267
  pass # pragma: no cover
18055
18268
 
18056
- # Format 2: Datetime strings with UTC timezone like
18057
- # "2016-01-04 00:00:01 UTC.strict_cast(...)"
18269
+ # Format 2: Direct datetime strings like "2016-01-04 00:00:01" (Polars 1.36+)
18270
+ # These don't have UTC suffix anymore
18271
+ elif (
18272
+ " " in segment_str
18273
+ and "UTC" not in segment_str
18274
+ and "[" not in segment_str
18275
+ and ".alias" not in segment_str
18276
+ ):
18277
+ try:
18278
+ parsed_dt = datetime.fromisoformat(segment_str)
18279
+ # Convert midnight datetimes to dates for consistency
18280
+ if parsed_dt.time() == datetime.min.time():
18281
+ parsed_value = parsed_dt.date() # pragma: no cover
18282
+ else:
18283
+ parsed_value = parsed_dt
18284
+ except ValueError: # pragma: no cover
18285
+ pass # pragma: no cover
18286
+
18287
+ # Format 3: Datetime strings with UTC timezone like
18288
+ # "2016-01-04 00:00:01 UTC.strict_cast(...)" (Polars < 1.36)
18058
18289
  elif " UTC" in segment_str:
18059
18290
  try:
18060
18291
  # Extract just the datetime part before "UTC"
@@ -18069,7 +18300,7 @@ def _apply_segments(data_tbl: any, segments_expr: tuple[str, Any]) -> any:
18069
18300
  except (ValueError, IndexError): # pragma: no cover
18070
18301
  pass # pragma: no cover
18071
18302
 
18072
- # Format 3: Bracketed expressions like ['2016-01-04']
18303
+ # Format 4: Bracketed expressions like ['2016-01-04']
18073
18304
  elif segment_str.startswith("[") and segment_str.endswith("]"):
18074
18305
  try: # pragma: no cover
18075
18306
  # Remove [' and ']
@@ -18204,8 +18435,7 @@ def _validation_info_as_dict(validation_info: _ValidationInfo) -> dict:
18204
18435
 
18205
18436
  def _get_assertion_icon(icon: list[str], length_val: int = 30) -> list[str]:
18206
18437
  # For each icon, get the assertion icon SVG test from SVG_ICONS_FOR_ASSERTION_TYPES dictionary
18207
- # TODO: No point in using `get` if we can't handle missing keys anyways
18208
- icon_svg = [SVG_ICONS_FOR_ASSERTION_TYPES.get(icon) for icon in icon]
18438
+ icon_svg: list[str] = [SVG_ICONS_FOR_ASSERTION_TYPES[icon] for icon in icon]
18209
18439
 
18210
18440
  # Replace the width and height in the SVG string
18211
18441
  for i in range(len(icon_svg)):
@@ -18214,11 +18444,9 @@ def _get_assertion_icon(icon: list[str], length_val: int = 30) -> list[str]:
18214
18444
  return icon_svg
18215
18445
 
18216
18446
 
18217
- def _replace_svg_dimensions(svg: list[str], height_width: int | float) -> list[str]:
18447
+ def _replace_svg_dimensions(svg: str, height_width: int | float) -> str:
18218
18448
  svg = re.sub(r'width="[0-9]*?px', f'width="{height_width}px', svg)
18219
- svg = re.sub(r'height="[0-9]*?px', f'height="{height_width}px', svg)
18220
-
18221
- return svg
18449
+ return re.sub(r'height="[0-9]*?px', f'height="{height_width}px', svg)
18222
18450
 
18223
18451
 
18224
18452
  def _get_title_text(
@@ -18282,7 +18510,7 @@ def _process_title_text(title: str | None, tbl_name: str | None, lang: str) -> s
18282
18510
  return title_text
18283
18511
 
18284
18512
 
18285
- def _transform_tbl_preprocessed(pre: any, seg: any, interrogation_performed: bool) -> list[str]:
18513
+ def _transform_tbl_preprocessed(pre: Any, seg: Any, interrogation_performed: bool) -> list[str]:
18286
18514
  # If no interrogation was performed, return a list of empty strings
18287
18515
  if not interrogation_performed:
18288
18516
  return ["" for _ in range(len(pre))]
@@ -18304,9 +18532,7 @@ def _transform_tbl_preprocessed(pre: any, seg: any, interrogation_performed: boo
18304
18532
 
18305
18533
  def _get_preprocessed_table_icon(icon: list[str]) -> list[str]:
18306
18534
  # For each icon, get the SVG icon from the SVG_ICONS_FOR_TBL_STATUS dictionary
18307
- icon_svg = [SVG_ICONS_FOR_TBL_STATUS.get(icon) for icon in icon]
18308
-
18309
- return icon_svg
18535
+ return [SVG_ICONS_FOR_TBL_STATUS[icon] for icon in icon]
18310
18536
 
18311
18537
 
18312
18538
  def _transform_eval(
@@ -18384,9 +18610,9 @@ def _transform_test_units(
18384
18610
  return _format_single_number_with_gt(
18385
18611
  value, n_sigfig=3, compact=True, locale=locale, df_lib=df_lib
18386
18612
  )
18387
- else:
18388
- # Fallback to the original behavior
18389
- return str(vals.fmt_number(value, n_sigfig=3, compact=True, locale=locale)[0])
18613
+ formatted = vals.fmt_number(value, n_sigfig=3, compact=True, locale=locale)
18614
+ assert isinstance(formatted, list)
18615
+ return formatted[0]
18390
18616
 
18391
18617
  return [
18392
18618
  (
@@ -18590,22 +18816,21 @@ def _transform_assertion_str(
18590
18816
  return type_upd
18591
18817
 
18592
18818
 
18593
- def _pre_processing_funcs_to_str(pre: Callable) -> str | list[str]:
18819
+ def _pre_processing_funcs_to_str(pre: Callable) -> str | list[str] | None:
18594
18820
  if isinstance(pre, Callable):
18595
18821
  return _get_callable_source(fn=pre)
18822
+ return None
18596
18823
 
18597
18824
 
18598
18825
  def _get_callable_source(fn: Callable) -> str:
18599
- if isinstance(fn, Callable):
18600
- try:
18601
- source_lines, _ = inspect.getsourcelines(fn)
18602
- source = "".join(source_lines).strip()
18603
- # Extract the `pre` argument from the source code
18604
- pre_arg = _extract_pre_argument(source)
18605
- return pre_arg
18606
- except (OSError, TypeError): # pragma: no cover
18607
- return fn.__name__
18608
- return fn # pragma: no cover
18826
+ try:
18827
+ source_lines, _ = inspect.getsourcelines(fn)
18828
+ source = "".join(source_lines).strip()
18829
+ # Extract the `pre` argument from the source code
18830
+ pre_arg = _extract_pre_argument(source)
18831
+ return pre_arg
18832
+ except (OSError, TypeError): # pragma: no cover
18833
+ return fn.__name__ # ty: ignore
18609
18834
 
18610
18835
 
18611
18836
  def _extract_pre_argument(source: str) -> str:
@@ -18631,6 +18856,7 @@ def _create_table_time_html(
18631
18856
  if time_start is None:
18632
18857
  return ""
18633
18858
 
18859
+ assert time_end is not None # typing
18634
18860
  # Get the time duration (difference between `time_end` and `time_start`) in seconds
18635
18861
  time_duration = (time_end - time_start).total_seconds()
18636
18862
 
@@ -18845,11 +19071,11 @@ def _format_number_safe(
18845
19071
  locale=locale,
18846
19072
  df_lib=df_lib,
18847
19073
  )
18848
- else:
18849
- # Fallback to the original behavior
18850
- return fmt_number(
18851
- value, decimals=decimals, drop_trailing_zeros=drop_trailing_zeros, locale=locale
18852
- )[0] # pragma: no cover
19074
+ ints = fmt_number(
19075
+ value, decimals=decimals, drop_trailing_zeros=drop_trailing_zeros, locale=locale
19076
+ )
19077
+ assert isinstance(ints, list)
19078
+ return ints[0]
18853
19079
 
18854
19080
 
18855
19081
  def _format_integer_safe(value: int, locale: str = "en", df_lib=None) -> str:
@@ -18862,9 +19088,10 @@ def _format_integer_safe(value: int, locale: str = "en", df_lib=None) -> str:
18862
19088
  if df_lib is not None and value is not None:
18863
19089
  # Use GT-based formatting to avoid Pandas dependency completely
18864
19090
  return _format_single_integer_with_gt(value, locale=locale, df_lib=df_lib)
18865
- else:
18866
- # Fallback to the original behavior
18867
- return fmt_integer(value, locale=locale)[0]
19091
+
19092
+ ints = fmt_integer(value, locale=locale)
19093
+ assert isinstance(ints, list)
19094
+ return ints[0]
18868
19095
 
18869
19096
 
18870
19097
  def _create_thresholds_html(thresholds: Thresholds, locale: str, df_lib=None) -> str:
@@ -18980,7 +19207,7 @@ def _create_local_threshold_note_html(thresholds: Thresholds, locale: str = "en"
18980
19207
  HTML string containing the formatted threshold information.
18981
19208
  """
18982
19209
  if thresholds == Thresholds():
18983
- return ""
19210
+ return "" # pragma: no cover
18984
19211
 
18985
19212
  # Get df_lib for formatting
18986
19213
  df_lib = None
@@ -18988,10 +19215,10 @@ def _create_local_threshold_note_html(thresholds: Thresholds, locale: str = "en"
18988
19215
  import polars as pl
18989
19216
 
18990
19217
  df_lib = pl
18991
- elif _is_lib_present("pandas"):
18992
- import pandas as pd
19218
+ elif _is_lib_present("pandas"): # pragma: no cover
19219
+ import pandas as pd # pragma: no cover
18993
19220
 
18994
- df_lib = pd
19221
+ df_lib = pd # pragma: no cover
18995
19222
 
18996
19223
  # Helper function to format threshold values using the shared formatting functions
18997
19224
  def _format_threshold_value(fraction: float | None, count: int | None) -> str:
@@ -18999,10 +19226,12 @@ def _create_local_threshold_note_html(thresholds: Thresholds, locale: str = "en"
18999
19226
  # Format as fraction/percentage with locale formatting
19000
19227
  if fraction == 0:
19001
19228
  return "0"
19002
- elif fraction < 0.01:
19229
+ elif fraction < 0.01: # pragma: no cover
19003
19230
  # For very small fractions, show "<0.01" with locale formatting
19004
- formatted = _format_number_safe(0.01, decimals=2, locale=locale, df_lib=df_lib)
19005
- return f"&lt;{formatted}"
19231
+ formatted = _format_number_safe(
19232
+ 0.01, decimals=2, locale=locale, df_lib=df_lib
19233
+ ) # pragma: no cover
19234
+ return f"&lt;{formatted}" # pragma: no cover
19006
19235
  else:
19007
19236
  # Use shared formatting function with drop_trailing_zeros
19008
19237
  formatted = _format_number_safe(
@@ -19079,14 +19308,14 @@ def _create_local_threshold_note_text(thresholds: Thresholds) -> str:
19079
19308
  if fraction is not None:
19080
19309
  if fraction == 0:
19081
19310
  return "0"
19082
- elif fraction < 0.01:
19083
- return "<0.01"
19311
+ elif fraction < 0.01: # pragma: no cover
19312
+ return "<0.01" # pragma: no cover
19084
19313
  else:
19085
19314
  return f"{fraction:.2f}".rstrip("0").rstrip(".")
19086
19315
  elif count is not None:
19087
19316
  return str(count)
19088
19317
  else:
19089
- return "—"
19318
+ return "—" # pragma: no cover
19090
19319
 
19091
19320
  parts = []
19092
19321
 
@@ -19105,7 +19334,7 @@ def _create_local_threshold_note_text(thresholds: Thresholds) -> str:
19105
19334
  if parts:
19106
19335
  return "Step-specific thresholds set: " + ", ".join(parts)
19107
19336
  else:
19108
- return ""
19337
+ return "" # pragma: no cover
19109
19338
 
19110
19339
 
19111
19340
  def _create_threshold_reset_note_html(locale: str = "en") -> str:
@@ -19654,13 +19883,13 @@ def _create_col_schema_match_note_html(schema_info: dict, locale: str = "en") ->
19654
19883
  f'<span style="color:#FF3300;">✗</span> {failed_text}: ' + ", ".join(failures) + "."
19655
19884
  )
19656
19885
  else:
19657
- summary = f'<span style="color:#FF3300;">✗</span> {failed_text}.'
19886
+ summary = f'<span style="color:#FF3300;">✗</span> {failed_text}.' # pragma: no cover
19658
19887
 
19659
19888
  # Generate the step report table using the existing function
19660
19889
  # We'll call either _step_report_schema_in_order or _step_report_schema_any_order
19661
19890
  # depending on the in_order parameter
19662
- if in_order:
19663
- step_report_gt = _step_report_schema_in_order(
19891
+ if in_order: # pragma: no cover
19892
+ step_report_gt = _step_report_schema_in_order( # pragma: no cover
19664
19893
  step=1, schema_info=schema_info, header=None, lang=locale, debug_return_df=False
19665
19894
  )
19666
19895
  else:
@@ -19691,7 +19920,7 @@ def _create_col_schema_match_note_html(schema_info: dict, locale: str = "en") ->
19691
19920
  """
19692
19921
 
19693
19922
  # Add the settings as an additional source note to the step report
19694
- step_report_gt = step_report_gt.tab_source_note(source_note=html(source_note_html))
19923
+ step_report_gt = step_report_gt.tab_source_note(source_note=html(source_note_html)) # type: ignore[union-attr]
19695
19924
 
19696
19925
  # Extract the HTML from the GT object
19697
19926
  step_report_html = step_report_gt._repr_html_()
@@ -19743,12 +19972,12 @@ def _step_report_row_based(
19743
19972
  column: str,
19744
19973
  column_position: int,
19745
19974
  columns_subset: list[str] | None,
19746
- values: any,
19975
+ values: Any,
19747
19976
  inclusive: tuple[bool, bool] | None,
19748
19977
  n: int,
19749
19978
  n_failed: int,
19750
19979
  all_passed: bool,
19751
- extract: any,
19980
+ extract: Any,
19752
19981
  tbl_preview: GT,
19753
19982
  header: str,
19754
19983
  limit: int | None,
@@ -19775,10 +20004,12 @@ def _step_report_row_based(
19775
20004
  elif assertion_type == "col_vals_le":
19776
20005
  text = f"{column} &le; {values}"
19777
20006
  elif assertion_type == "col_vals_between":
20007
+ assert inclusive is not None
19778
20008
  symbol_left = "&le;" if inclusive[0] else "&lt;"
19779
20009
  symbol_right = "&le;" if inclusive[1] else "&lt;"
19780
20010
  text = f"{values[0]} {symbol_left} {column} {symbol_right} {values[1]}"
19781
20011
  elif assertion_type == "col_vals_outside":
20012
+ assert inclusive is not None
19782
20013
  symbol_left = "&lt;" if inclusive[0] else "&le;"
19783
20014
  symbol_right = "&gt;" if inclusive[1] else "&ge;"
19784
20015
  text = f"{column} {symbol_left} {values[0]}, {column} {symbol_right} {values[1]}"
@@ -19999,7 +20230,7 @@ def _step_report_rows_distinct(
19999
20230
  n: int,
20000
20231
  n_failed: int,
20001
20232
  all_passed: bool,
20002
- extract: any,
20233
+ extract: Any,
20003
20234
  tbl_preview: GT,
20004
20235
  header: str,
20005
20236
  limit: int | None,
@@ -20126,8 +20357,8 @@ def _step_report_rows_distinct(
20126
20357
 
20127
20358
 
20128
20359
  def _step_report_schema_in_order(
20129
- step: int, schema_info: dict, header: str, lang: str, debug_return_df: bool = False
20130
- ) -> GT | any:
20360
+ step: int, schema_info: dict, header: str | None, lang: str, debug_return_df: bool = False
20361
+ ) -> GT | Any:
20131
20362
  """
20132
20363
  This is the case for schema validation where the schema is supposed to have the same column
20133
20364
  order as the target table.
@@ -20195,22 +20426,22 @@ def _step_report_schema_in_order(
20195
20426
 
20196
20427
  # Check if this column exists in exp_columns_dict (it might not if it's a duplicate)
20197
20428
  # For duplicates, we need to handle them specially
20198
- if column_name_exp_i not in exp_columns_dict:
20429
+ if column_name_exp_i not in exp_columns_dict: # pragma: no cover
20199
20430
  # This is a duplicate or invalid column, mark it as incorrect
20200
- col_exp_correct.append(CROSS_MARK_SPAN)
20431
+ col_exp_correct.append(CROSS_MARK_SPAN) # pragma: no cover
20201
20432
 
20202
20433
  # For dtype, check if there's a dtype specified in the schema
20203
- if len(expect_schema[i]) > 1:
20204
- dtype_value = expect_schema[i][1]
20205
- if isinstance(dtype_value, list):
20206
- dtype_exp.append(" | ".join(dtype_value))
20207
- else:
20208
- dtype_exp.append(str(dtype_value))
20209
- else:
20210
- dtype_exp.append("&mdash;")
20434
+ if len(expect_schema[i]) > 1: # pragma: no cover
20435
+ dtype_value = expect_schema[i][1] # pragma: no cover
20436
+ if isinstance(dtype_value, list): # pragma: no cover
20437
+ dtype_exp.append(" | ".join(dtype_value)) # pragma: no cover
20438
+ else: # pragma: no cover
20439
+ dtype_exp.append(str(dtype_value)) # pragma: no cover
20440
+ else: # pragma: no cover
20441
+ dtype_exp.append("&mdash;") # pragma: no cover
20211
20442
 
20212
- dtype_exp_correct.append("&mdash;")
20213
- continue
20443
+ dtype_exp_correct.append("&mdash;") # pragma: no cover
20444
+ continue # pragma: no cover
20214
20445
 
20215
20446
  #
20216
20447
  # `col_exp_correct` values
@@ -20433,7 +20664,9 @@ def _step_report_schema_in_order(
20433
20664
  # Add a border below the row that terminates the target table schema
20434
20665
  step_report = step_report.tab_style(
20435
20666
  style=style.borders(sides="bottom", color="#6699CC80", style="solid", weight="1px"),
20436
- locations=loc.body(rows=len(colnames_tgt) - 1),
20667
+ locations=loc.body(
20668
+ rows=len(colnames_tgt) - 1 # ty: ignore (bug in GT, should allow an int)
20669
+ ),
20437
20670
  )
20438
20671
 
20439
20672
  # If the version of `great_tables` is `>=0.17.0` then disable Quarto table processing
@@ -20482,8 +20715,8 @@ def _step_report_schema_in_order(
20482
20715
 
20483
20716
 
20484
20717
  def _step_report_schema_any_order(
20485
- step: int, schema_info: dict, header: str, lang: str, debug_return_df: bool = False
20486
- ) -> GT | any:
20718
+ step: int, schema_info: dict, header: str | None, lang: str, debug_return_df: bool = False
20719
+ ) -> GT | pl.DataFrame:
20487
20720
  """
20488
20721
  This is the case for schema validation where the schema is permitted to not have to be in the
20489
20722
  same column order as the target table.
@@ -20902,9 +21135,7 @@ def _step_report_schema_any_order(
20902
21135
  header = header.format(title=title, details=details)
20903
21136
 
20904
21137
  # Create the header with `header` string
20905
- step_report = step_report.tab_header(title=md(header))
20906
-
20907
- return step_report
21138
+ return step_report.tab_header(title=md(header))
20908
21139
 
20909
21140
 
20910
21141
  def _create_label_text_html(
@@ -20993,3 +21224,321 @@ def _create_col_schema_match_params_html(
20993
21224
  f"{full_match_dtypes_text}"
20994
21225
  "</div>"
20995
21226
  )
21227
+
21228
+
21229
+ def _generate_agg_docstring(name: str) -> str:
21230
+ """Generate a comprehensive docstring for an aggregation validation method.
21231
+
21232
+ This function creates detailed documentation for dynamically generated methods like
21233
+ `col_sum_eq()`, `col_avg_gt()`, `col_sd_le()`, etc. The docstrings follow the same
21234
+ structure and quality as manually written validation methods like `col_vals_gt()`.
21235
+
21236
+ Parameters
21237
+ ----------
21238
+ name
21239
+ The method name (e.g., "col_sum_eq", "col_avg_gt", "col_sd_le").
21240
+
21241
+ Returns
21242
+ -------
21243
+ str
21244
+ A complete docstring for the method.
21245
+ """
21246
+ # Parse the method name to extract aggregation type and comparison operator
21247
+ # Format: col_{agg}_{comp} (e.g., col_sum_eq, col_avg_gt, col_sd_le)
21248
+ parts = name.split("_")
21249
+ agg_type = parts[1] # sum, avg, sd
21250
+ comp_type = parts[2] # eq, gt, ge, lt, le
21251
+
21252
+ # Human-readable names for aggregation types
21253
+ agg_names = {
21254
+ "sum": ("sum", "summed"),
21255
+ "avg": ("average", "averaged"),
21256
+ "sd": ("standard deviation", "computed for standard deviation"),
21257
+ }
21258
+
21259
+ # Human-readable descriptions for comparison operators (with article for title)
21260
+ comp_descriptions = {
21261
+ "eq": ("equal to", "equals", "an"),
21262
+ "gt": ("greater than", "is greater than", "a"),
21263
+ "ge": ("greater than or equal to", "is at least", "a"),
21264
+ "lt": ("less than", "is less than", "a"),
21265
+ "le": ("less than or equal to", "is at most", "a"),
21266
+ }
21267
+
21268
+ # Mathematical symbols for comparison operators
21269
+ comp_symbols = {
21270
+ "eq": "==",
21271
+ "gt": ">",
21272
+ "ge": ">=",
21273
+ "lt": "<",
21274
+ "le": "<=",
21275
+ }
21276
+
21277
+ agg_name, agg_verb = agg_names[agg_type]
21278
+ comp_desc, comp_phrase, comp_article = comp_descriptions[comp_type]
21279
+ comp_symbol = comp_symbols[comp_type]
21280
+
21281
+ # Determine the appropriate example values based on the aggregation and comparison
21282
+ if agg_type == "sum":
21283
+ example_value = "15"
21284
+ example_data = '{"a": [1, 2, 3, 4, 5], "b": [2, 2, 2, 2, 2]}'
21285
+ example_sum = "15" # sum of a
21286
+ example_ref_sum = "10" # sum of b
21287
+ elif agg_type == "avg":
21288
+ example_value = "3"
21289
+ example_data = '{"a": [1, 2, 3, 4, 5], "b": [2, 2, 2, 2, 2]}'
21290
+ example_sum = "3.0" # avg of a
21291
+ example_ref_sum = "2.0" # avg of b
21292
+ else: # sd
21293
+ example_value = "2"
21294
+ example_data = '{"a": [1, 2, 3, 4, 5], "b": [2, 2, 2, 2, 2]}'
21295
+ example_sum = "~1.58" # sd of a
21296
+ example_ref_sum = "0.0" # sd of b
21297
+
21298
+ # Build appropriate tolerance explanation based on comparison type
21299
+ if comp_type == "eq":
21300
+ tol_explanation = f"""The `tol=` parameter is particularly useful with `{name}()` since exact equality
21301
+ comparisons on floating-point aggregations can be problematic due to numerical precision.
21302
+ Setting a small tolerance (e.g., `tol=0.001`) allows for minor differences that arise from
21303
+ floating-point arithmetic."""
21304
+ else:
21305
+ tol_explanation = f"""The `tol=` parameter expands the acceptable range for the comparison. For
21306
+ `{name}()`, a tolerance of `tol=0.5` would mean the {agg_name} can be within `0.5` of the
21307
+ target value and still pass validation."""
21308
+
21309
+ docstring = f"""
21310
+ Does the column {agg_name} satisfy {comp_article} {comp_desc} comparison?
21311
+
21312
+ The `{name}()` validation method checks whether the {agg_name} of values in a column
21313
+ {comp_phrase} a specified `value=`. This is an aggregation-based validation where the entire
21314
+ column is reduced to a single {agg_name} value that is then compared against the target. The
21315
+ comparison used in this function is `{agg_name}(column) {comp_symbol} value`.
21316
+
21317
+ Unlike row-level validations (e.g., `col_vals_gt()`), this method treats the entire column as
21318
+ a single test unit. The validation either passes completely (if the aggregated value satisfies
21319
+ the comparison) or fails completely.
21320
+
21321
+ Parameters
21322
+ ----------
21323
+ columns
21324
+ A single column or a list of columns to validate. If multiple columns are supplied,
21325
+ there will be a separate validation step generated for each column. The columns must
21326
+ contain numeric data for the {agg_name} to be computed.
21327
+ value
21328
+ The value to compare the column {agg_name} against. This can be: (1) a numeric literal
21329
+ (`int` or `float`), (2) a [`col()`](`pointblank.col`) object referencing another column
21330
+ whose {agg_name} will be used for comparison, (3) a [`ref()`](`pointblank.ref`) object
21331
+ referencing a column in reference data (when `Validate(reference=)` has been set), or (4)
21332
+ `None` to automatically compare against the same column in reference data (shorthand for
21333
+ `ref(column_name)` when reference data is set).
21334
+ tol
21335
+ A tolerance value for the comparison. The default is `0`, meaning exact comparison. When
21336
+ set to a positive value, the comparison becomes more lenient. For example, with `tol=0.5`,
21337
+ a {agg_name} that differs from the target by up to `0.5` will still pass. {tol_explanation}
21338
+ thresholds
21339
+ Failure threshold levels so that the validation step can react accordingly when
21340
+ failing test units are level. Since this is an aggregation-based validation with only
21341
+ one test unit, threshold values typically should be set as absolute counts (e.g., `1`) to
21342
+ indicate pass/fail, or as proportions where any value less than `1.0` means failure is
21343
+ acceptable.
21344
+ brief
21345
+ An optional brief description of the validation step that will be displayed in the
21346
+ reporting table. You can use the templating elements like `"{{step}}"` to insert
21347
+ the step number, or `"{{auto}}"` to include an automatically generated brief. If `True`
21348
+ the entire brief will be automatically generated. If `None` (the default) then there
21349
+ won't be a brief.
21350
+ actions
21351
+ Optional actions to take when the validation step meets or exceeds any set threshold
21352
+ levels. If provided, the [`Actions`](`pointblank.Actions`) class should be used to
21353
+ define the actions.
21354
+ active
21355
+ A boolean value indicating whether the validation step should be active. Using `False`
21356
+ will make the validation step inactive (still reporting its presence and keeping indexes
21357
+ for the steps unchanged).
21358
+
21359
+ Returns
21360
+ -------
21361
+ Validate
21362
+ The `Validate` object with the added validation step.
21363
+
21364
+ Using Reference Data
21365
+ --------------------
21366
+ The `{name}()` method supports comparing column aggregations against reference data. This
21367
+ is useful for validating that statistical properties remain consistent across different
21368
+ versions of a dataset, or for comparing current data against historical baselines.
21369
+
21370
+ To use reference data, set the `reference=` parameter when creating the `Validate` object:
21371
+
21372
+ ```python
21373
+ validation = (
21374
+ pb.Validate(data=current_data, reference=baseline_data)
21375
+ .{name}(columns="revenue") # Compares sum(current.revenue) vs sum(baseline.revenue)
21376
+ .interrogate()
21377
+ )
21378
+ ```
21379
+
21380
+ When `value=None` and reference data is set, the method automatically compares against the
21381
+ same column in the reference data. You can also explicitly specify reference columns using
21382
+ the `ref()` helper:
21383
+
21384
+ ```python
21385
+ .{name}(columns="revenue", value=pb.ref("baseline_revenue"))
21386
+ ```
21387
+
21388
+ Understanding Tolerance
21389
+ -----------------------
21390
+ The `tol=` parameter allows for fuzzy comparisons, which is especially important for
21391
+ floating-point aggregations where exact equality is often unreliable.
21392
+
21393
+ {tol_explanation}
21394
+
21395
+ For equality comparisons (`col_*_eq`), the tolerance creates a range `[value - tol, value + tol]`
21396
+ within which the aggregation is considered valid. For inequality comparisons, the tolerance
21397
+ shifts the comparison boundary.
21398
+
21399
+ Thresholds
21400
+ ----------
21401
+ The `thresholds=` parameter is used to set the failure-condition levels for the validation
21402
+ step. If they are set here at the step level, these thresholds will override any thresholds
21403
+ set at the global level in `Validate(thresholds=...)`.
21404
+
21405
+ There are three threshold levels: 'warning', 'error', and 'critical'. Since aggregation
21406
+ validations operate on a single test unit (the aggregated value), threshold values are
21407
+ typically set as absolute counts:
21408
+
21409
+ - `thresholds=1` means any failure triggers a 'warning'
21410
+ - `thresholds=(1, 1, 1)` means any failure triggers all three levels
21411
+
21412
+ Thresholds can be defined using one of these input schemes:
21413
+
21414
+ 1. use the [`Thresholds`](`pointblank.Thresholds`) class (the most direct way to create
21415
+ thresholds)
21416
+ 2. provide a tuple of 1-3 values, where position `0` is the 'warning' level, position `1` is
21417
+ the 'error' level, and position `2` is the 'critical' level
21418
+ 3. create a dictionary of 1-3 value entries; the valid keys: are 'warning', 'error', and
21419
+ 'critical'
21420
+ 4. a single integer/float value denoting absolute number or fraction of failing test units
21421
+ for the 'warning' level only
21422
+
21423
+ Examples
21424
+ --------
21425
+ ```{{python}}
21426
+ #| echo: false
21427
+ #| output: false
21428
+ import pointblank as pb
21429
+ pb.config(report_incl_header=False, report_incl_footer=False, preview_incl_header=False)
21430
+ ```
21431
+ For the examples, we'll use a simple Polars DataFrame with numeric columns. The table is
21432
+ shown below:
21433
+
21434
+ ```{{python}}
21435
+ import pointblank as pb
21436
+ import polars as pl
21437
+
21438
+ tbl = pl.DataFrame(
21439
+ {{
21440
+ "a": [1, 2, 3, 4, 5],
21441
+ "b": [2, 2, 2, 2, 2],
21442
+ }}
21443
+ )
21444
+
21445
+ pb.preview(tbl)
21446
+ ```
21447
+
21448
+ Let's validate that the {agg_name} of column `a` {comp_phrase} `{example_value}`:
21449
+
21450
+ ```{{python}}
21451
+ validation = (
21452
+ pb.Validate(data=tbl)
21453
+ .{name}(columns="a", value={example_value})
21454
+ .interrogate()
21455
+ )
21456
+
21457
+ validation
21458
+ ```
21459
+
21460
+ The validation result shows whether the {agg_name} comparison passed or failed. Since this
21461
+ is an aggregation-based validation, there is exactly one test unit per column.
21462
+
21463
+ When validating multiple columns, each column gets its own validation step:
21464
+
21465
+ ```{{python}}
21466
+ validation = (
21467
+ pb.Validate(data=tbl)
21468
+ .{name}(columns=["a", "b"], value={example_value})
21469
+ .interrogate()
21470
+ )
21471
+
21472
+ validation
21473
+ ```
21474
+
21475
+ Using tolerance for flexible comparisons:
21476
+
21477
+ ```{{python}}
21478
+ validation = (
21479
+ pb.Validate(data=tbl)
21480
+ .{name}(columns="a", value={example_value}, tol=1.0)
21481
+ .interrogate()
21482
+ )
21483
+
21484
+ validation
21485
+ ```
21486
+ """
21487
+
21488
+ return docstring.strip()
21489
+
21490
+
21491
+ def make_agg_validator(name: str):
21492
+ """Factory for dynamically generated aggregate validation methods.
21493
+
21494
+ Why this exists:
21495
+ Aggregate validators all share identical behavior. The only thing that differs
21496
+ between them is the semantic assertion type (their name). The implementation
21497
+ of each aggregate validator is fetched from `from_agg_validator`.
21498
+
21499
+ Instead of copy/pasting dozens of identical methods, we generate
21500
+ them dynamically and attach them to the Validate class. The types are generated
21501
+ at build time with `make pyi` to allow the methods to be visible to the type checker,
21502
+ documentation builders and the IDEs/LSPs.
21503
+
21504
+ The returned function is a thin adapter that forwards all arguments to
21505
+ `_add_agg_validation`, supplying the assertion type explicitly.
21506
+ """
21507
+
21508
+ def agg_validator(
21509
+ self: Validate,
21510
+ columns: str | Collection[str],
21511
+ value: float | int | Column | ReferenceColumn | None = None,
21512
+ tol: float = 0,
21513
+ thresholds: int | float | bool | tuple | dict | Thresholds | None = None,
21514
+ brief: str | bool | None = None,
21515
+ actions: Actions | None = None,
21516
+ active: bool = True,
21517
+ ) -> Validate:
21518
+ # Dynamically generated aggregate validator.
21519
+ # This method is generated per assertion type and forwards all arguments
21520
+ # to the shared aggregate validation implementation.
21521
+ return self._add_agg_validation(
21522
+ assertion_type=name,
21523
+ columns=columns,
21524
+ value=value,
21525
+ tol=tol,
21526
+ thresholds=thresholds,
21527
+ brief=brief,
21528
+ actions=actions,
21529
+ active=active,
21530
+ )
21531
+
21532
+ # Manually set function identity so this behaves like a real method.
21533
+ # These must be set before attaching the function to the class.
21534
+ agg_validator.__name__ = name
21535
+ agg_validator.__qualname__ = f"Validate.{name}"
21536
+ agg_validator.__doc__ = _generate_agg_docstring(name)
21537
+
21538
+ return agg_validator
21539
+
21540
+
21541
+ # Finally, we grab all the valid aggregation method names and attach them to
21542
+ # the Validate class, registering each one appropriately.
21543
+ for method in load_validation_method_grid(): # -> `col_sum_*`, `col_mean_*`, etc.
21544
+ setattr(Validate, method, make_agg_validator(method))