dataeval 0.65.0__py3-none-any.whl → 0.66.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (61) hide show
  1. dataeval/__init__.py +13 -9
  2. dataeval/_internal/detectors/clusterer.py +24 -22
  3. dataeval/_internal/detectors/drift/base.py +206 -26
  4. dataeval/_internal/detectors/drift/cvm.py +25 -23
  5. dataeval/_internal/detectors/drift/ks.py +28 -25
  6. dataeval/_internal/detectors/drift/mmd.py +30 -29
  7. dataeval/_internal/detectors/drift/torch.py +66 -58
  8. dataeval/_internal/detectors/drift/uncertainty.py +28 -28
  9. dataeval/_internal/detectors/duplicates.py +28 -18
  10. dataeval/_internal/detectors/ood/ae.py +15 -29
  11. dataeval/_internal/detectors/ood/aegmm.py +33 -27
  12. dataeval/_internal/detectors/ood/base.py +61 -43
  13. dataeval/_internal/detectors/ood/llr.py +27 -24
  14. dataeval/_internal/detectors/ood/vae.py +32 -31
  15. dataeval/_internal/detectors/ood/vaegmm.py +34 -28
  16. dataeval/_internal/detectors/{linter.py → outliers.py} +33 -27
  17. dataeval/_internal/flags.py +5 -3
  18. dataeval/_internal/interop.py +4 -2
  19. dataeval/_internal/metrics/balance.py +33 -4
  20. dataeval/_internal/metrics/ber.py +6 -4
  21. dataeval/_internal/metrics/diversity.py +45 -12
  22. dataeval/_internal/metrics/parity.py +114 -26
  23. dataeval/_internal/metrics/stats.py +154 -16
  24. dataeval/_internal/metrics/uap.py +28 -2
  25. dataeval/_internal/metrics/utils.py +20 -18
  26. dataeval/_internal/models/pytorch/autoencoder.py +127 -22
  27. dataeval/_internal/models/tensorflow/autoencoder.py +33 -30
  28. dataeval/_internal/models/tensorflow/gmm.py +4 -2
  29. dataeval/_internal/models/tensorflow/losses.py +15 -11
  30. dataeval/_internal/models/tensorflow/pixelcnn.py +19 -18
  31. dataeval/_internal/models/tensorflow/trainer.py +8 -6
  32. dataeval/_internal/models/tensorflow/utils.py +21 -19
  33. dataeval/_internal/output.py +13 -10
  34. dataeval/_internal/utils.py +5 -3
  35. dataeval/_internal/workflows/sufficiency.py +42 -30
  36. dataeval/detectors/__init__.py +6 -25
  37. dataeval/detectors/drift/__init__.py +16 -0
  38. dataeval/detectors/drift/kernels/__init__.py +6 -0
  39. dataeval/detectors/drift/updates/__init__.py +3 -0
  40. dataeval/detectors/linters/__init__.py +5 -0
  41. dataeval/detectors/ood/__init__.py +11 -0
  42. dataeval/metrics/__init__.py +2 -26
  43. dataeval/metrics/bias/__init__.py +14 -0
  44. dataeval/metrics/estimators/__init__.py +9 -0
  45. dataeval/metrics/stats/__init__.py +6 -0
  46. dataeval/tensorflow/__init__.py +3 -0
  47. dataeval/tensorflow/loss/__init__.py +3 -0
  48. dataeval/tensorflow/models/__init__.py +5 -0
  49. dataeval/tensorflow/recon/__init__.py +3 -0
  50. dataeval/torch/__init__.py +3 -0
  51. dataeval/{models/torch → torch/models}/__init__.py +1 -2
  52. dataeval/torch/trainer/__init__.py +3 -0
  53. dataeval/utils/__init__.py +3 -6
  54. dataeval/workflows/__init__.py +2 -4
  55. {dataeval-0.65.0.dist-info → dataeval-0.66.0.dist-info}/METADATA +1 -1
  56. dataeval-0.66.0.dist-info/RECORD +72 -0
  57. dataeval/models/__init__.py +0 -15
  58. dataeval/models/tensorflow/__init__.py +0 -6
  59. dataeval-0.65.0.dist-info/RECORD +0 -60
  60. {dataeval-0.65.0.dist-info → dataeval-0.66.0.dist-info}/LICENSE.txt +0 -0
  61. {dataeval-0.65.0.dist-info → dataeval-0.66.0.dist-info}/WHEEL +0 -0
dataeval/__init__.py CHANGED
@@ -1,18 +1,22 @@
1
- __version__ = "0.65.0"
1
+ __version__ = "0.66.0"
2
2
 
3
3
  from importlib.util import find_spec
4
4
 
5
- from . import detectors, flags, metrics
5
+ _IS_TORCH_AVAILABLE = find_spec("torch") is not None
6
+ _IS_TENSORFLOW_AVAILABLE = find_spec("tensorflow") is not None and find_spec("tensorflow_probability") is not None
7
+
8
+ del find_spec
9
+
10
+ from . import detectors, flags, metrics # noqa: E402
6
11
 
7
12
  __all__ = ["detectors", "flags", "metrics"]
8
13
 
9
- if find_spec("torch") is not None: # pragma: no cover
10
- from . import models, utils, workflows
14
+ if _IS_TORCH_AVAILABLE: # pragma: no cover
15
+ from . import torch, utils, workflows
11
16
 
12
- __all__ += ["models", "utils", "workflows"]
13
- elif find_spec("tensorflow") is not None: # pragma: no cover
14
- from . import models
17
+ __all__ += ["torch", "utils", "workflows"]
15
18
 
16
- __all__ += ["models"]
19
+ if _IS_TENSORFLOW_AVAILABLE: # pragma: no cover
20
+ from . import tensorflow
17
21
 
18
- del find_spec
22
+ __all__ += ["tensorflow"]
@@ -1,5 +1,7 @@
1
+ from __future__ import annotations
2
+
1
3
  from dataclasses import dataclass
2
- from typing import Dict, Iterable, List, NamedTuple, Tuple, Union, cast
4
+ from typing import Iterable, NamedTuple, cast
3
5
 
4
6
  import numpy as np
5
7
  from numpy.typing import ArrayLike, NDArray
@@ -26,10 +28,10 @@ class ClustererOutput(OutputMetadata):
26
28
  Groups of indices which are not exact but closely related data points
27
29
  """
28
30
 
29
- outliers: List[int]
30
- potential_outliers: List[int]
31
- duplicates: List[List[int]]
32
- potential_duplicates: List[List[int]]
31
+ outliers: list[int]
32
+ potential_outliers: list[int]
33
+ duplicates: list[list[int]]
34
+ potential_duplicates: list[list[int]]
33
35
 
34
36
 
35
37
  def extend_linkage(link_arr: NDArray) -> NDArray:
@@ -59,7 +61,7 @@ def extend_linkage(link_arr: NDArray) -> NDArray:
59
61
  class Cluster:
60
62
  __slots__ = "merged", "samples", "sample_dist", "is_copy", "count", "dist_avg", "dist_std", "out1", "out2"
61
63
 
62
- def __init__(self, merged: int, samples: NDArray, sample_dist: Union[float, NDArray], is_copy: bool = False):
64
+ def __init__(self, merged: int, samples: NDArray, sample_dist: float | NDArray, is_copy: bool = False):
63
65
  self.merged = merged
64
66
  self.samples = np.array(samples, dtype=np.int32)
65
67
  self.sample_dist = np.array([sample_dist] if np.isscalar(sample_dist) else sample_dist)
@@ -81,7 +83,7 @@ class Cluster:
81
83
  self.out1 = dist > out1
82
84
  self.out2 = dist > out2
83
85
 
84
- def copy(self) -> "Cluster":
86
+ def copy(self) -> Cluster:
85
87
  return Cluster(False, self.samples, self.sample_dist, True)
86
88
 
87
89
  def __repr__(self) -> str:
@@ -94,7 +96,7 @@ class Cluster:
94
96
  return f"{self.__class__.__name__}(**{repr(_params)})"
95
97
 
96
98
 
97
- class Clusters(Dict[int, Dict[int, Cluster]]):
99
+ class Clusters(dict[int, dict[int, Cluster]]):
98
100
  def __init__(self, *args, **kwargs):
99
101
  super().__init__(*args, **kwargs)
100
102
  self.max_level: int = 1
@@ -116,10 +118,10 @@ class ClusterMergeEntry:
116
118
  self.inner_cluster = inner_cluster
117
119
  self.status = status
118
120
 
119
- def __lt__(self, value: "ClusterMergeEntry") -> bool:
121
+ def __lt__(self, value: ClusterMergeEntry) -> bool:
120
122
  return self.level.__lt__(value.level)
121
123
 
122
- def __gt__(self, value: "ClusterMergeEntry") -> bool:
124
+ def __gt__(self, value: ClusterMergeEntry) -> bool:
123
125
  return self.level.__gt__(value.level)
124
126
 
125
127
 
@@ -184,7 +186,7 @@ class Clusterer:
184
186
  return self._clusters
185
187
 
186
188
  @property
187
- def last_good_merge_levels(self) -> Dict[int, int]:
189
+ def last_good_merge_levels(self) -> dict[int, int]:
188
190
  if self._last_good_merge_levels is None:
189
191
  self._last_good_merge_levels = self._get_last_merge_levels()
190
192
  return self._last_good_merge_levels
@@ -208,7 +210,7 @@ class Clusterer:
208
210
  def _create_clusters(self) -> Clusters:
209
211
  """Generates clusters based on linkage matrix"""
210
212
  next_cluster_id = 0
211
- cluster_map: Dict[int, ClusterPosition] = {} # Dictionary to associate new cluster ids with actual clusters
213
+ cluster_map: dict[int, ClusterPosition] = {} # Dictionary to associate new cluster ids with actual clusters
212
214
  clusters: Clusters = Clusters()
213
215
 
214
216
  # Walking through the linkage array to generate clusters
@@ -236,7 +238,7 @@ class Clusterer:
236
238
  # Update clusters to include previously skipped levels
237
239
  clusters = self._fill_levels(clusters, left, right)
238
240
  elif left or right:
239
- child, other_id = cast(Tuple[ClusterPosition, int], (left, right_id) if left else (right, left_id))
241
+ child, other_id = cast(tuple[ClusterPosition, int], (left, right_id) if left else (right, left_id))
240
242
  cc = clusters[child.level][child.cid]
241
243
  samples = np.concatenate([cc.samples, [other_id]])
242
244
  sample_dist = np.concatenate([cc.sample_dist, sample_dist])
@@ -285,7 +287,7 @@ class Clusterer:
285
287
 
286
288
  return cluster_matrix
287
289
 
288
- def _calc_merge_indices(self, merge_mean: List[NDArray], intra_max: List[float]) -> NDArray:
290
+ def _calc_merge_indices(self, merge_mean: list[NDArray], intra_max: list[float]) -> NDArray:
289
291
  """
290
292
  Determine what clusters should be merged and return their indices
291
293
  """
@@ -308,7 +310,7 @@ class Clusterer:
308
310
  mask2 = mask2_vals < one_std_check
309
311
  return np.logical_or(desired_merge, mask2)
310
312
 
311
- def _generate_merge_list(self, cluster_matrix: NDArray) -> List[ClusterMergeEntry]:
313
+ def _generate_merge_list(self, cluster_matrix: NDArray) -> list[ClusterMergeEntry]:
312
314
  """
313
315
  Runs through the clusters dictionary determining when clusters merge,
314
316
  and how close are those clusters when they merge.
@@ -325,7 +327,7 @@ class Clusterer:
325
327
  """
326
328
  intra_max = []
327
329
  merge_mean = []
328
- merge_list: List[ClusterMergeEntry] = []
330
+ merge_list: list[ClusterMergeEntry] = []
329
331
 
330
332
  for level, cluster_set in self.clusters.items():
331
333
  for outer_cluster, cluster in cluster_set.items():
@@ -363,7 +365,7 @@ class Clusterer:
363
365
 
364
366
  return merge_list
365
367
 
366
- def _get_last_merge_levels(self) -> Dict[int, int]:
368
+ def _get_last_merge_levels(self) -> dict[int, int]:
367
369
  """
368
370
  Creates a dictionary for important cluster ids mapped to their last good merge level
369
371
 
@@ -372,7 +374,7 @@ class Clusterer:
372
374
  Dict[int, int]
373
375
  A mapping of a cluster id to its last good merge level
374
376
  """
375
- last_merge_levels: Dict[int, int] = {}
377
+ last_merge_levels: dict[int, int] = {}
376
378
 
377
379
  if self._max_clusters <= 1:
378
380
  last_merge_levels = {0: int(self._num_samples * 0.1)}
@@ -395,7 +397,7 @@ class Clusterer:
395
397
 
396
398
  return last_merge_levels
397
399
 
398
- def find_outliers(self, last_merge_levels: Dict[int, int]) -> Tuple[List[int], List[int]]:
400
+ def find_outliers(self, last_merge_levels: dict[int, int]) -> tuple[list[int], list[int]]:
399
401
  """
400
402
  Retrieves outliers based on when the sample was added to the cluster
401
403
  and how far it was from the cluster when it was added
@@ -439,9 +441,9 @@ class Clusterer:
439
441
 
440
442
  return sorted(outliers), sorted(possible_outliers)
441
443
 
442
- def _sorted_union_find(self, index_groups: Iterable[Iterable[int]]) -> List[List[int]]:
444
+ def _sorted_union_find(self, index_groups: Iterable[Iterable[int]]) -> list[list[int]]:
443
445
  """Merges and sorts groups of indices that share any common index"""
444
- groups: List[List[int]] = []
446
+ groups: list[list[int]] = []
445
447
  for indices in zip(*index_groups):
446
448
  indices = set(indices)
447
449
  temp = []
@@ -454,7 +456,7 @@ class Clusterer:
454
456
  groups = temp
455
457
  return sorted(groups)
456
458
 
457
- def find_duplicates(self, last_merge_levels: Dict[int, int]) -> Tuple[List[List[int]], List[List[int]]]:
459
+ def find_duplicates(self, last_merge_levels: dict[int, int]) -> tuple[list[list[int]], list[list[int]]]:
458
460
  """
459
461
  Finds duplicate and near duplicate data based on the last good merge levels when building the cluster
460
462
 
@@ -6,10 +6,12 @@ Original code Copyright (c) 2023 Seldon Technologies Ltd
6
6
  Licensed under Apache Software License (Apache 2.0)
7
7
  """
8
8
 
9
+ from __future__ import annotations
10
+
9
11
  from abc import ABC, abstractmethod
10
12
  from dataclasses import dataclass
11
13
  from functools import wraps
12
- from typing import Callable, Literal, Optional, Tuple
14
+ from typing import Callable, Literal
13
15
 
14
16
  import numpy as np
15
17
  from numpy.typing import ArrayLike, NDArray
@@ -19,27 +21,40 @@ from dataeval._internal.output import OutputMetadata, set_metadata
19
21
 
20
22
 
21
23
  @dataclass(frozen=True)
22
- class DriftOutput(OutputMetadata):
24
+ class DriftBaseOutput(OutputMetadata):
25
+ """
26
+ Output class for Drift
27
+
28
+ Attributes
29
+ ----------
30
+ is_drift : bool
31
+ Drift prediction for the images
32
+ threshold : float
33
+ Threshold after multivariate correction if needed
34
+ """
35
+
23
36
  is_drift: bool
24
37
  threshold: float
25
38
 
26
39
 
27
40
  @dataclass(frozen=True)
28
- class DriftUnivariateOutput(DriftOutput):
41
+ class DriftOutput(DriftBaseOutput):
29
42
  """
43
+ Output class for DriftCVM and DriftKS
44
+
30
45
  Attributes
31
46
  ----------
32
47
  is_drift : bool
33
48
  Drift prediction for the images
34
49
  threshold : float
35
50
  Threshold after multivariate correction if needed
36
- feature_drift : NDArray[np.bool_]
51
+ feature_drift : NDArray
37
52
  Feature-level array of images detected to have drifted
38
53
  feature_threshold : float
39
54
  Feature-level threshold to determine drift
40
- p_vals : NDArray[np.float32]
55
+ p_vals : NDArray
41
56
  Feature-level p-values
42
- distances : NDArray[np.float32]
57
+ distances : NDArray
43
58
  Feature-level distances
44
59
  """
45
60
 
@@ -83,6 +98,15 @@ def preprocess_x(fn):
83
98
 
84
99
 
85
100
  class UpdateStrategy(ABC):
101
+ """
102
+ Updates reference dataset for drift detector
103
+
104
+ Parameters
105
+ ----------
106
+ n : int
107
+ Update with last n instances seen by the detector.
108
+ """
109
+
86
110
  def __init__(self, n: int):
87
111
  self.n = n
88
112
 
@@ -113,7 +137,7 @@ class ReservoirSamplingUpdate(UpdateStrategy):
113
137
  Parameters
114
138
  ----------
115
139
  n : int
116
- Update with reservoir sampling of size n.
140
+ Update with last n instances seen by the detector.
117
141
  """
118
142
 
119
143
  def __call__(self, x_ref: NDArray, x: NDArray, count: int) -> NDArray:
@@ -138,15 +162,64 @@ class ReservoirSamplingUpdate(UpdateStrategy):
138
162
 
139
163
 
140
164
  class BaseDrift:
141
- """Generic drift detector component handling preprocessing of data and correction"""
165
+ """
166
+ A generic drift detection component for preprocessing data and applying statistical correction.
167
+
168
+ This class handles common tasks related to drift detection, such as preprocessing
169
+ the reference data (`x_ref`), performing statistical correction (e.g., Bonferroni, FDR),
170
+ and updating the reference data if needed.
171
+
172
+ Parameters
173
+ ----------
174
+ x_ref : ArrayLike
175
+ The reference dataset used for drift detection. This is the baseline data against
176
+ which new data points will be compared.
177
+ p_val : float, optional
178
+ The significance level for detecting drift, by default 0.05.
179
+ x_ref_preprocessed : bool, optional
180
+ Flag indicating whether the reference data has already been preprocessed, by default False.
181
+ update_x_ref : UpdateStrategy, optional
182
+ A strategy object specifying how the reference data should be updated when drift is detected,
183
+ by default None.
184
+ preprocess_fn : Callable[[ArrayLike], ArrayLike], optional
185
+ A function to preprocess the data before drift detection, by default None.
186
+ correction : {'bonferroni', 'fdr'}, optional
187
+ Statistical correction method applied to p-values, by default "bonferroni".
188
+
189
+ Attributes
190
+ ----------
191
+ _x_ref : ArrayLike
192
+ The reference dataset that is either raw or preprocessed.
193
+ p_val : float
194
+ The significance level for drift detection.
195
+ update_x_ref : UpdateStrategy or None
196
+ The strategy for updating the reference data if applicable.
197
+ preprocess_fn : Callable or None
198
+ Function used for preprocessing input data before drift detection.
199
+ correction : str
200
+ Statistical correction method applied to p-values.
201
+ n : int
202
+ The number of samples in the reference dataset (`x_ref`).
203
+ x_ref_preprocessed : bool
204
+ A flag that indicates whether the reference dataset has been preprocessed.
205
+ _x_refcount : int
206
+ Counter for how many times the reference data has been accessed after preprocessing.
207
+
208
+ Methods
209
+ -------
210
+ x_ref:
211
+ Property that returns the reference dataset, and applies preprocessing if not already done.
212
+ _preprocess(x):
213
+ Preprocesses the given data using the specified `preprocess_fn` if provided.
214
+ """
142
215
 
143
216
  def __init__(
144
217
  self,
145
218
  x_ref: ArrayLike,
146
219
  p_val: float = 0.05,
147
220
  x_ref_preprocessed: bool = False,
148
- update_x_ref: Optional[UpdateStrategy] = None,
149
- preprocess_fn: Optional[Callable[[ArrayLike], ArrayLike]] = None,
221
+ update_x_ref: UpdateStrategy | None = None,
222
+ preprocess_fn: Callable[[ArrayLike], ArrayLike] | None = None,
150
223
  correction: Literal["bonferroni", "fdr"] = "bonferroni",
151
224
  ) -> None:
152
225
  # Type checking
@@ -172,6 +245,14 @@ class BaseDrift:
172
245
 
173
246
  @property
174
247
  def x_ref(self) -> NDArray:
248
+ """
249
+ Retrieve the reference data, applying preprocessing if not already done.
250
+
251
+ Returns
252
+ -------
253
+ NDArray
254
+ The reference dataset (`x_ref`), preprocessed if needed.
255
+ """
175
256
  if not self.x_ref_preprocessed:
176
257
  self.x_ref_preprocessed = True
177
258
  if self.preprocess_fn is not None:
@@ -181,7 +262,19 @@ class BaseDrift:
181
262
  return self._x_ref
182
263
 
183
264
  def _preprocess(self, x: ArrayLike) -> ArrayLike:
184
- """Data preprocessing before computing the drift scores."""
265
+ """
266
+ Preprocess the given data before computing the drift scores.
267
+
268
+ Parameters
269
+ ----------
270
+ x : ArrayLike
271
+ The input data to preprocess.
272
+
273
+ Returns
274
+ -------
275
+ ArrayLike
276
+ The preprocessed input data.
277
+ """
185
278
  if self.preprocess_fn is not None:
186
279
  x = self.preprocess_fn(x)
187
280
  return x
@@ -189,10 +282,55 @@ class BaseDrift:
189
282
 
190
283
  class BaseDriftUnivariate(BaseDrift):
191
284
  """
192
- Generic drift detector component which serves as a base class for methods using
193
- univariate tests. If n_features > 1, a multivariate correction is applied such
194
- that the false positive rate is upper bounded by the specified p-value, with
195
- equality in the case of independent features.
285
+ Base class for drift detection methods using univariate statistical tests.
286
+
287
+ This class inherits from `BaseDrift` and serves as a generic component for detecting
288
+ distribution drift in univariate features. If the number of features `n_features` is greater
289
+ than 1, a multivariate correction method (e.g., Bonferroni or FDR) is applied to control
290
+ the false positive rate, ensuring it does not exceed the specified p-value.
291
+
292
+ Parameters
293
+ ----------
294
+ x_ref : ArrayLike
295
+ Reference data used as the baseline to compare against when detecting drift.
296
+ p_val : float, default 0.05
297
+ Significance level used for detecting drift.
298
+ x_ref_preprocessed : bool, default False
299
+ Indicates whether the reference data has been preprocessed.
300
+ update_x_ref : UpdateStrategy | None, default None
301
+ Strategy for updating the reference data when drift is detected.
302
+ preprocess_fn : Callable[ArrayLike] | None, default None
303
+ Function used to preprocess input data before detecting drift.
304
+ correction : 'bonferroni' | 'fdr', default 'bonferroni'
305
+ Multivariate correction method applied to p-values.
306
+ n_features : int | None, default None
307
+ Number of features used in the univariate drift tests. If not provided, it will
308
+ be inferred from the data.
309
+
310
+ Attributes
311
+ ----------
312
+ _n_features : int | None
313
+ Number of features in the data. If not provided, it is lazily inferred from the
314
+ input data and any preprocessing function.
315
+ p_val : float
316
+ The significance level for drift detection.
317
+ correction : str
318
+ The method for controlling the false discovery rate or applying a Bonferroni correction.
319
+ update_x_ref : UpdateStrategy | None
320
+ Strategy for updating the reference data if applicable.
321
+ preprocess_fn : Callable | None
322
+ Function used for preprocessing input data before drift detection.
323
+
324
+ Methods
325
+ -------
326
+ n_features:
327
+ Property that returns the number of features, inferring it if necessary.
328
+ score(x):
329
+ Abstract method to compute univariate feature scores after preprocessing.
330
+ _apply_correction(p_vals):
331
+ Apply a statistical correction to p-values to account for multiple testing.
332
+ predict(x):
333
+ Predict whether drift has occurred on a batch of data, applying multivariate correction if needed.
196
334
  """
197
335
 
198
336
  def __init__(
@@ -200,10 +338,10 @@ class BaseDriftUnivariate(BaseDrift):
200
338
  x_ref: ArrayLike,
201
339
  p_val: float = 0.05,
202
340
  x_ref_preprocessed: bool = False,
203
- update_x_ref: Optional[UpdateStrategy] = None,
204
- preprocess_fn: Optional[Callable[[ArrayLike], ArrayLike]] = None,
341
+ update_x_ref: UpdateStrategy | None = None,
342
+ preprocess_fn: Callable[[ArrayLike], ArrayLike] | None = None,
205
343
  correction: Literal["bonferroni", "fdr"] = "bonferroni",
206
- n_features: Optional[int] = None,
344
+ n_features: int | None = None,
207
345
  ) -> None:
208
346
  super().__init__(
209
347
  x_ref,
@@ -218,6 +356,18 @@ class BaseDriftUnivariate(BaseDrift):
218
356
 
219
357
  @property
220
358
  def n_features(self) -> int:
359
+ """
360
+ Get the number of features in the reference data.
361
+
362
+ If the number of features is not provided during initialization, it will be inferred
363
+ from the reference data (``x_ref``). If a preprocessing function is provided, the number
364
+ of features will be inferred after applying the preprocessing function.
365
+
366
+ Returns
367
+ -------
368
+ int
369
+ Number of features in the reference data.
370
+ """
221
371
  # lazy process n_features as needed
222
372
  if not isinstance(self._n_features, int):
223
373
  # compute number of features for the univariate tests
@@ -233,10 +383,40 @@ class BaseDriftUnivariate(BaseDrift):
233
383
 
234
384
  @preprocess_x
235
385
  @abstractmethod
236
- def score(self, x: ArrayLike) -> Tuple[NDArray[np.float32], NDArray[np.float32]]:
237
- """Abstract method to calculate feature score after preprocessing"""
386
+ def score(self, x: ArrayLike) -> tuple[NDArray[np.float32], NDArray[np.float32]]:
387
+ """
388
+ Abstract method to calculate feature scores after preprocessing.
389
+
390
+ Parameters
391
+ ----------
392
+ x : ArrayLike
393
+ The batch of data to calculate univariate drift scores for each feature.
394
+
395
+ Returns
396
+ -------
397
+ tuple[NDArray, NDArray]
398
+ A tuple containing p-values and distance statistics for each feature.
399
+ """
238
400
 
239
- def _apply_correction(self, p_vals: NDArray) -> Tuple[bool, float]:
401
+ def _apply_correction(self, p_vals: NDArray) -> tuple[bool, float]:
402
+ """
403
+ Apply the specified correction method (Bonferroni or FDR) to the p-values.
404
+
405
+ If the correction method is Bonferroni, the threshold for detecting drift
406
+ is divided by the number of features. For FDR, the correction is applied
407
+ using the Benjamini-Hochberg procedure.
408
+
409
+ Parameters
410
+ ----------
411
+ p_vals : NDArray
412
+ Array of p-values from the univariate tests for each feature.
413
+
414
+ Returns
415
+ -------
416
+ tuple[bool, float]
417
+ A tuple containing a boolean indicating if drift was detected and the
418
+ threshold after correction.
419
+ """
240
420
  if self.correction == "bonferroni":
241
421
  threshold = self.p_val / self.n_features
242
422
  drift_pred = bool((p_vals < threshold).any())
@@ -261,7 +441,7 @@ class BaseDriftUnivariate(BaseDrift):
261
441
  def predict(
262
442
  self,
263
443
  x: ArrayLike,
264
- ) -> DriftUnivariateOutput:
444
+ ) -> DriftOutput:
265
445
  """
266
446
  Predict whether a batch of data has drifted from the reference data and update
267
447
  reference data using specified update strategy.
@@ -273,13 +453,13 @@ class BaseDriftUnivariate(BaseDrift):
273
453
 
274
454
  Returns
275
455
  -------
276
- Dictionary containing the drift prediction and optionally the feature level
277
- p-values, threshold after multivariate correction if needed and test
278
- statistics.
456
+ DriftOutput
457
+ Dictionary containing the drift prediction and optionally the feature level
458
+ p-values, threshold after multivariate correction if needed and test statistics.
279
459
  """
280
460
  # compute drift scores
281
461
  p_vals, dist = self.score(x)
282
462
 
283
463
  feature_drift = (p_vals < self.p_val).astype(np.bool_)
284
464
  drift_pred, threshold = self._apply_correction(p_vals)
285
- return DriftUnivariateOutput(drift_pred, threshold, feature_drift, self.p_val, p_vals, dist)
465
+ return DriftOutput(drift_pred, threshold, feature_drift, self.p_val, p_vals, dist)
@@ -6,7 +6,9 @@ Original code Copyright (c) 2023 Seldon Technologies Ltd
6
6
  Licensed under Apache Software License (Apache 2.0)
7
7
  """
8
8
 
9
- from typing import Callable, Literal, Optional, Tuple
9
+ from __future__ import annotations
10
+
11
+ from typing import Callable, Literal
10
12
 
11
13
  import numpy as np
12
14
  from numpy.typing import ArrayLike, NDArray
@@ -19,37 +21,36 @@ from .base import BaseDriftUnivariate, UpdateStrategy, preprocess_x
19
21
 
20
22
  class DriftCVM(BaseDriftUnivariate):
21
23
  """
22
- Cramér-von Mises (CVM) data drift detector, which tests for any change in the
23
- distribution of continuous univariate data. For multivariate data, a separate
24
- CVM test is applied to each feature, and the obtained p-values are aggregated
25
- via the Bonferroni or False Discovery Rate (FDR) corrections.
24
+ Drift detector employing the Cramér-von Mises (CVM) distribution test.
25
+
26
+ The CVM test detects changes in the distribution of continuous
27
+ univariate data. For multivariate data, a separate CVM test is applied to each
28
+ feature, and the obtained p-values are aggregated via the Bonferroni or
29
+ False Discovery Rate (FDR) corrections.
26
30
 
27
31
  Parameters
28
32
  ----------
29
33
  x_ref : ArrayLike
30
34
  Data used as reference distribution.
31
- p_val : float, default 0.05
35
+ p_val : float | None, default 0.05
32
36
  p-value used for significance of the statistical test for each feature.
33
37
  If the FDR correction method is used, this corresponds to the acceptable
34
38
  q-value.
35
39
  x_ref_preprocessed : bool, default False
36
- Whether the given reference data `x_ref` has been preprocessed yet. If
37
- `x_ref_preprocessed=True`, only the test data `x` will be preprocessed at
38
- prediction time. If `x_ref_preprocessed=False`, the reference data will also
39
- be preprocessed.
40
- update_x_ref : Optional[UpdateStrategy], default None
40
+ Whether the given reference data ``x_ref`` has been preprocessed yet.
41
+ If ``True``, only the test data ``x`` will be preprocessed at prediction time.
42
+ If ``False``, the reference data will also be preprocessed.
43
+ update_x_ref : UpdateStrategy | None, default None
41
44
  Reference data can optionally be updated using an UpdateStrategy class. Update
42
- using the last n instances seen by the detector with
43
- :py:class:`dataeval.detectors.LastSeenUpdateStrategy`
44
- or via reservoir sampling with
45
- :py:class:`dataeval.detectors.ReservoirSamplingUpdateStrategy`.
46
- preprocess_fn : Optional[Callable[[ArrayLike], ArrayLike]], default None
45
+ using the last n instances seen by the detector with LastSeenUpdateStrategy
46
+ or via reservoir sampling with ReservoirSamplingUpdateStrategy.
47
+ preprocess_fn : Callable | None, default None
47
48
  Function to preprocess the data before computing the data drift metrics.
48
49
  Typically a dimensionality reduction technique.
49
- correction : Literal["bonferroni", "fdr"], default "bonferroni"
50
+ correction : "bonferroni" | "fdr", default "bonferroni"
50
51
  Correction type for multivariate data. Either 'bonferroni' or 'fdr' (False
51
52
  Discovery Rate).
52
- n_features
53
+ n_features : int | None, default None
53
54
  Number of features used in the statistical test. No need to pass it if no
54
55
  preprocessing takes place. In case of a preprocessing step, this can also
55
56
  be inferred automatically but could be more expensive to compute.
@@ -60,10 +61,10 @@ class DriftCVM(BaseDriftUnivariate):
60
61
  x_ref: ArrayLike,
61
62
  p_val: float = 0.05,
62
63
  x_ref_preprocessed: bool = False,
63
- update_x_ref: Optional[UpdateStrategy] = None,
64
- preprocess_fn: Optional[Callable[[ArrayLike], ArrayLike]] = None,
64
+ update_x_ref: UpdateStrategy | None = None,
65
+ preprocess_fn: Callable[[ArrayLike], ArrayLike] | None = None,
65
66
  correction: Literal["bonferroni", "fdr"] = "bonferroni",
66
- n_features: Optional[int] = None,
67
+ n_features: int | None = None,
67
68
  ) -> None:
68
69
  super().__init__(
69
70
  x_ref=x_ref,
@@ -76,7 +77,7 @@ class DriftCVM(BaseDriftUnivariate):
76
77
  )
77
78
 
78
79
  @preprocess_x
79
- def score(self, x: ArrayLike) -> Tuple[NDArray[np.float32], NDArray[np.float32]]:
80
+ def score(self, x: ArrayLike) -> tuple[NDArray[np.float32], NDArray[np.float32]]:
80
81
  """
81
82
  Performs the two-sample Cramér-von Mises test(s), computing the p-value and
82
83
  test statistic per feature.
@@ -88,7 +89,8 @@ class DriftCVM(BaseDriftUnivariate):
88
89
 
89
90
  Returns
90
91
  -------
91
- Feature level p-values and CVM statistics.
92
+ tuple[NDArray, NDArray]
93
+ Feature level p-values and CVM statistic
92
94
  """
93
95
  x_np = to_numpy(x)
94
96
  x_np = x_np.reshape(x_np.shape[0], -1)