dataeval 0.64.0__py3-none-any.whl → 0.66.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (65) hide show
  1. dataeval/__init__.py +13 -9
  2. dataeval/_internal/detectors/clusterer.py +63 -49
  3. dataeval/_internal/detectors/drift/base.py +248 -51
  4. dataeval/_internal/detectors/drift/cvm.py +28 -26
  5. dataeval/_internal/detectors/drift/ks.py +31 -28
  6. dataeval/_internal/detectors/drift/mmd.py +62 -42
  7. dataeval/_internal/detectors/drift/torch.py +69 -60
  8. dataeval/_internal/detectors/drift/uncertainty.py +32 -32
  9. dataeval/_internal/detectors/duplicates.py +67 -31
  10. dataeval/_internal/detectors/ood/ae.py +15 -29
  11. dataeval/_internal/detectors/ood/aegmm.py +33 -27
  12. dataeval/_internal/detectors/ood/base.py +86 -47
  13. dataeval/_internal/detectors/ood/llr.py +34 -31
  14. dataeval/_internal/detectors/ood/vae.py +32 -31
  15. dataeval/_internal/detectors/ood/vaegmm.py +34 -28
  16. dataeval/_internal/detectors/{linter.py → outliers.py} +60 -38
  17. dataeval/_internal/flags.py +44 -21
  18. dataeval/_internal/interop.py +5 -3
  19. dataeval/_internal/metrics/balance.py +42 -5
  20. dataeval/_internal/metrics/ber.py +11 -8
  21. dataeval/_internal/metrics/coverage.py +15 -8
  22. dataeval/_internal/metrics/divergence.py +41 -7
  23. dataeval/_internal/metrics/diversity.py +57 -19
  24. dataeval/_internal/metrics/parity.py +141 -66
  25. dataeval/_internal/metrics/stats.py +330 -313
  26. dataeval/_internal/metrics/uap.py +33 -4
  27. dataeval/_internal/metrics/utils.py +79 -40
  28. dataeval/_internal/models/pytorch/autoencoder.py +127 -22
  29. dataeval/_internal/models/tensorflow/autoencoder.py +33 -30
  30. dataeval/_internal/models/tensorflow/gmm.py +4 -2
  31. dataeval/_internal/models/tensorflow/losses.py +17 -13
  32. dataeval/_internal/models/tensorflow/pixelcnn.py +19 -18
  33. dataeval/_internal/models/tensorflow/trainer.py +10 -7
  34. dataeval/_internal/models/tensorflow/utils.py +23 -20
  35. dataeval/_internal/output.py +85 -0
  36. dataeval/_internal/utils.py +5 -3
  37. dataeval/_internal/workflows/sufficiency.py +122 -121
  38. dataeval/detectors/__init__.py +6 -25
  39. dataeval/detectors/drift/__init__.py +16 -0
  40. dataeval/detectors/drift/kernels/__init__.py +6 -0
  41. dataeval/detectors/drift/updates/__init__.py +3 -0
  42. dataeval/detectors/linters/__init__.py +5 -0
  43. dataeval/detectors/ood/__init__.py +11 -0
  44. dataeval/flags/__init__.py +2 -2
  45. dataeval/metrics/__init__.py +2 -26
  46. dataeval/metrics/bias/__init__.py +14 -0
  47. dataeval/metrics/estimators/__init__.py +9 -0
  48. dataeval/metrics/stats/__init__.py +6 -0
  49. dataeval/tensorflow/__init__.py +3 -0
  50. dataeval/tensorflow/loss/__init__.py +3 -0
  51. dataeval/tensorflow/models/__init__.py +5 -0
  52. dataeval/tensorflow/recon/__init__.py +3 -0
  53. dataeval/torch/__init__.py +3 -0
  54. dataeval/{models/torch → torch/models}/__init__.py +1 -2
  55. dataeval/torch/trainer/__init__.py +3 -0
  56. dataeval/utils/__init__.py +3 -6
  57. dataeval/workflows/__init__.py +2 -4
  58. {dataeval-0.64.0.dist-info → dataeval-0.66.0.dist-info}/METADATA +1 -1
  59. dataeval-0.66.0.dist-info/RECORD +72 -0
  60. dataeval/_internal/metrics/base.py +0 -10
  61. dataeval/models/__init__.py +0 -15
  62. dataeval/models/tensorflow/__init__.py +0 -6
  63. dataeval-0.64.0.dist-info/RECORD +0 -60
  64. {dataeval-0.64.0.dist-info → dataeval-0.66.0.dist-info}/LICENSE.txt +0 -0
  65. {dataeval-0.64.0.dist-info → dataeval-0.66.0.dist-info}/WHEEL +0 -0
@@ -6,14 +6,64 @@ Original code Copyright (c) 2023 Seldon Technologies Ltd
6
6
  Licensed under Apache Software License (Apache 2.0)
7
7
  """
8
8
 
9
+ from __future__ import annotations
10
+
9
11
  from abc import ABC, abstractmethod
12
+ from dataclasses import dataclass
10
13
  from functools import wraps
11
- from typing import Callable, Dict, Literal, Optional, Tuple, Union
14
+ from typing import Callable, Literal
12
15
 
13
16
  import numpy as np
14
- from numpy.typing import ArrayLike
17
+ from numpy.typing import ArrayLike, NDArray
15
18
 
16
19
  from dataeval._internal.interop import to_numpy
20
+ from dataeval._internal.output import OutputMetadata, set_metadata
21
+
22
+
23
+ @dataclass(frozen=True)
24
+ class DriftBaseOutput(OutputMetadata):
25
+ """
26
+ Output class for Drift
27
+
28
+ Attributes
29
+ ----------
30
+ is_drift : bool
31
+ Drift prediction for the images
32
+ threshold : float
33
+ Threshold after multivariate correction if needed
34
+ """
35
+
36
+ is_drift: bool
37
+ threshold: float
38
+
39
+
40
+ @dataclass(frozen=True)
41
+ class DriftOutput(DriftBaseOutput):
42
+ """
43
+ Output class for DriftCVM and DriftKS
44
+
45
+ Attributes
46
+ ----------
47
+ is_drift : bool
48
+ Drift prediction for the images
49
+ threshold : float
50
+ Threshold after multivariate correction if needed
51
+ feature_drift : NDArray
52
+ Feature-level array of images detected to have drifted
53
+ feature_threshold : float
54
+ Feature-level threshold to determine drift
55
+ p_vals : NDArray
56
+ Feature-level p-values
57
+ distances : NDArray
58
+ Feature-level distances
59
+ """
60
+
61
+ # is_drift: bool
62
+ # threshold: float
63
+ feature_drift: NDArray[np.bool_]
64
+ feature_threshold: float
65
+ p_vals: NDArray[np.float32]
66
+ distances: NDArray[np.float32]
17
67
 
18
68
 
19
69
  def update_x_ref(fn):
@@ -48,11 +98,20 @@ def preprocess_x(fn):
48
98
 
49
99
 
50
100
  class UpdateStrategy(ABC):
101
+ """
102
+ Updates reference dataset for drift detector
103
+
104
+ Parameters
105
+ ----------
106
+ n : int
107
+ Update with last n instances seen by the detector.
108
+ """
109
+
51
110
  def __init__(self, n: int):
52
111
  self.n = n
53
112
 
54
113
  @abstractmethod
55
- def __call__(self, x_ref: np.ndarray, x: np.ndarray, count: int) -> np.ndarray:
114
+ def __call__(self, x_ref: NDArray, x: NDArray, count: int) -> NDArray:
56
115
  """Abstract implementation of update strategy"""
57
116
 
58
117
 
@@ -66,7 +125,7 @@ class LastSeenUpdate(UpdateStrategy):
66
125
  Update with last n instances seen by the detector.
67
126
  """
68
127
 
69
- def __call__(self, x_ref: np.ndarray, x: np.ndarray, count: int) -> np.ndarray:
128
+ def __call__(self, x_ref: NDArray, x: NDArray, count: int) -> NDArray:
70
129
  x_updated = np.concatenate([x_ref, x], axis=0)
71
130
  return x_updated[-self.n :]
72
131
 
@@ -78,10 +137,10 @@ class ReservoirSamplingUpdate(UpdateStrategy):
78
137
  Parameters
79
138
  ----------
80
139
  n : int
81
- Update with reservoir sampling of size n.
140
+ Update with last n instances seen by the detector.
82
141
  """
83
142
 
84
- def __call__(self, x_ref: np.ndarray, x: np.ndarray, count: int) -> np.ndarray:
143
+ def __call__(self, x_ref: NDArray, x: NDArray, count: int) -> NDArray:
85
144
  if x.shape[0] + count <= self.n:
86
145
  return np.concatenate([x_ref, x], axis=0)
87
146
 
@@ -103,15 +162,64 @@ class ReservoirSamplingUpdate(UpdateStrategy):
103
162
 
104
163
 
105
164
  class BaseDrift:
106
- """Generic drift detector component handling preprocessing of data and correction"""
165
+ """
166
+ A generic drift detection component for preprocessing data and applying statistical correction.
167
+
168
+ This class handles common tasks related to drift detection, such as preprocessing
169
+ the reference data (`x_ref`), performing statistical correction (e.g., Bonferroni, FDR),
170
+ and updating the reference data if needed.
171
+
172
+ Parameters
173
+ ----------
174
+ x_ref : ArrayLike
175
+ The reference dataset used for drift detection. This is the baseline data against
176
+ which new data points will be compared.
177
+ p_val : float, optional
178
+ The significance level for detecting drift, by default 0.05.
179
+ x_ref_preprocessed : bool, optional
180
+ Flag indicating whether the reference data has already been preprocessed, by default False.
181
+ update_x_ref : UpdateStrategy, optional
182
+ A strategy object specifying how the reference data should be updated when drift is detected,
183
+ by default None.
184
+ preprocess_fn : Callable[[ArrayLike], ArrayLike], optional
185
+ A function to preprocess the data before drift detection, by default None.
186
+ correction : {'bonferroni', 'fdr'}, optional
187
+ Statistical correction method applied to p-values, by default "bonferroni".
188
+
189
+ Attributes
190
+ ----------
191
+ _x_ref : ArrayLike
192
+ The reference dataset that is either raw or preprocessed.
193
+ p_val : float
194
+ The significance level for drift detection.
195
+ update_x_ref : UpdateStrategy or None
196
+ The strategy for updating the reference data if applicable.
197
+ preprocess_fn : Callable or None
198
+ Function used for preprocessing input data before drift detection.
199
+ correction : str
200
+ Statistical correction method applied to p-values.
201
+ n : int
202
+ The number of samples in the reference dataset (`x_ref`).
203
+ x_ref_preprocessed : bool
204
+ A flag that indicates whether the reference dataset has been preprocessed.
205
+ _x_refcount : int
206
+ Counter for how many times the reference data has been accessed after preprocessing.
207
+
208
+ Methods
209
+ -------
210
+ x_ref:
211
+ Property that returns the reference dataset, and applies preprocessing if not already done.
212
+ _preprocess(x):
213
+ Preprocesses the given data using the specified `preprocess_fn` if provided.
214
+ """
107
215
 
108
216
  def __init__(
109
217
  self,
110
218
  x_ref: ArrayLike,
111
219
  p_val: float = 0.05,
112
220
  x_ref_preprocessed: bool = False,
113
- update_x_ref: Optional[UpdateStrategy] = None,
114
- preprocess_fn: Optional[Callable[[ArrayLike], ArrayLike]] = None,
221
+ update_x_ref: UpdateStrategy | None = None,
222
+ preprocess_fn: Callable[[ArrayLike], ArrayLike] | None = None,
115
223
  correction: Literal["bonferroni", "fdr"] = "bonferroni",
116
224
  ) -> None:
117
225
  # Type checking
@@ -136,7 +244,15 @@ class BaseDrift:
136
244
  self._x_refcount = 0
137
245
 
138
246
  @property
139
- def x_ref(self) -> np.ndarray:
247
+ def x_ref(self) -> NDArray:
248
+ """
249
+ Retrieve the reference data, applying preprocessing if not already done.
250
+
251
+ Returns
252
+ -------
253
+ NDArray
254
+ The reference dataset (`x_ref`), preprocessed if needed.
255
+ """
140
256
  if not self.x_ref_preprocessed:
141
257
  self.x_ref_preprocessed = True
142
258
  if self.preprocess_fn is not None:
@@ -146,18 +262,75 @@ class BaseDrift:
146
262
  return self._x_ref
147
263
 
148
264
  def _preprocess(self, x: ArrayLike) -> ArrayLike:
149
- """Data preprocessing before computing the drift scores."""
265
+ """
266
+ Preprocess the given data before computing the drift scores.
267
+
268
+ Parameters
269
+ ----------
270
+ x : ArrayLike
271
+ The input data to preprocess.
272
+
273
+ Returns
274
+ -------
275
+ ArrayLike
276
+ The preprocessed input data.
277
+ """
150
278
  if self.preprocess_fn is not None:
151
279
  x = self.preprocess_fn(x)
152
280
  return x
153
281
 
154
282
 
155
- class BaseUnivariateDrift(BaseDrift):
283
+ class BaseDriftUnivariate(BaseDrift):
156
284
  """
157
- Generic drift detector component which serves as a base class for methods using
158
- univariate tests. If n_features > 1, a multivariate correction is applied such
159
- that the false positive rate is upper bounded by the specified p-value, with
160
- equality in the case of independent features.
285
+ Base class for drift detection methods using univariate statistical tests.
286
+
287
+ This class inherits from `BaseDrift` and serves as a generic component for detecting
288
+ distribution drift in univariate features. If the number of features `n_features` is greater
289
+ than 1, a multivariate correction method (e.g., Bonferroni or FDR) is applied to control
290
+ the false positive rate, ensuring it does not exceed the specified p-value.
291
+
292
+ Parameters
293
+ ----------
294
+ x_ref : ArrayLike
295
+ Reference data used as the baseline to compare against when detecting drift.
296
+ p_val : float, default 0.05
297
+ Significance level used for detecting drift.
298
+ x_ref_preprocessed : bool, default False
299
+ Indicates whether the reference data has been preprocessed.
300
+ update_x_ref : UpdateStrategy | None, default None
301
+ Strategy for updating the reference data when drift is detected.
302
+ preprocess_fn : Callable[ArrayLike] | None, default None
303
+ Function used to preprocess input data before detecting drift.
304
+ correction : 'bonferroni' | 'fdr', default 'bonferroni'
305
+ Multivariate correction method applied to p-values.
306
+ n_features : int | None, default None
307
+ Number of features used in the univariate drift tests. If not provided, it will
308
+ be inferred from the data.
309
+
310
+ Attributes
311
+ ----------
312
+ _n_features : int | None
313
+ Number of features in the data. If not provided, it is lazily inferred from the
314
+ input data and any preprocessing function.
315
+ p_val : float
316
+ The significance level for drift detection.
317
+ correction : str
318
+ The method for controlling the false discovery rate or applying a Bonferroni correction.
319
+ update_x_ref : UpdateStrategy | None
320
+ Strategy for updating the reference data if applicable.
321
+ preprocess_fn : Callable | None
322
+ Function used for preprocessing input data before drift detection.
323
+
324
+ Methods
325
+ -------
326
+ n_features:
327
+ Property that returns the number of features, inferring it if necessary.
328
+ score(x):
329
+ Abstract method to compute univariate feature scores after preprocessing.
330
+ _apply_correction(p_vals):
331
+ Apply a statistical correction to p-values to account for multiple testing.
332
+ predict(x):
333
+ Predict whether drift has occurred on a batch of data, applying multivariate correction if needed.
161
334
  """
162
335
 
163
336
  def __init__(
@@ -165,10 +338,10 @@ class BaseUnivariateDrift(BaseDrift):
165
338
  x_ref: ArrayLike,
166
339
  p_val: float = 0.05,
167
340
  x_ref_preprocessed: bool = False,
168
- update_x_ref: Optional[UpdateStrategy] = None,
169
- preprocess_fn: Optional[Callable[[ArrayLike], ArrayLike]] = None,
341
+ update_x_ref: UpdateStrategy | None = None,
342
+ preprocess_fn: Callable[[ArrayLike], ArrayLike] | None = None,
170
343
  correction: Literal["bonferroni", "fdr"] = "bonferroni",
171
- n_features: Optional[int] = None,
344
+ n_features: int | None = None,
172
345
  ) -> None:
173
346
  super().__init__(
174
347
  x_ref,
@@ -183,6 +356,18 @@ class BaseUnivariateDrift(BaseDrift):
183
356
 
184
357
  @property
185
358
  def n_features(self) -> int:
359
+ """
360
+ Get the number of features in the reference data.
361
+
362
+ If the number of features is not provided during initialization, it will be inferred
363
+ from the reference data (``x_ref``). If a preprocessing function is provided, the number
364
+ of features will be inferred after applying the preprocessing function.
365
+
366
+ Returns
367
+ -------
368
+ int
369
+ Number of features in the reference data.
370
+ """
186
371
  # lazy process n_features as needed
187
372
  if not isinstance(self._n_features, int):
188
373
  # compute number of features for the univariate tests
@@ -198,13 +383,43 @@ class BaseUnivariateDrift(BaseDrift):
198
383
 
199
384
  @preprocess_x
200
385
  @abstractmethod
201
- def score(self, x: ArrayLike) -> Tuple[np.ndarray, np.ndarray]:
202
- """Abstract method to calculate feature score after preprocessing"""
386
+ def score(self, x: ArrayLike) -> tuple[NDArray[np.float32], NDArray[np.float32]]:
387
+ """
388
+ Abstract method to calculate feature scores after preprocessing.
389
+
390
+ Parameters
391
+ ----------
392
+ x : ArrayLike
393
+ The batch of data to calculate univariate drift scores for each feature.
203
394
 
204
- def _apply_correction(self, p_vals: np.ndarray) -> Tuple[int, float]:
395
+ Returns
396
+ -------
397
+ tuple[NDArray, NDArray]
398
+ A tuple containing p-values and distance statistics for each feature.
399
+ """
400
+
401
+ def _apply_correction(self, p_vals: NDArray) -> tuple[bool, float]:
402
+ """
403
+ Apply the specified correction method (Bonferroni or FDR) to the p-values.
404
+
405
+ If the correction method is Bonferroni, the threshold for detecting drift
406
+ is divided by the number of features. For FDR, the correction is applied
407
+ using the Benjamini-Hochberg procedure.
408
+
409
+ Parameters
410
+ ----------
411
+ p_vals : NDArray
412
+ Array of p-values from the univariate tests for each feature.
413
+
414
+ Returns
415
+ -------
416
+ tuple[bool, float]
417
+ A tuple containing a boolean indicating if drift was detected and the
418
+ threshold after correction.
419
+ """
205
420
  if self.correction == "bonferroni":
206
421
  threshold = self.p_val / self.n_features
207
- drift_pred = int((p_vals < threshold).any())
422
+ drift_pred = bool((p_vals < threshold).any())
208
423
  return drift_pred, threshold
209
424
  elif self.correction == "fdr":
210
425
  n = p_vals.shape[0]
@@ -215,18 +430,18 @@ class BaseUnivariateDrift(BaseDrift):
215
430
  try:
216
431
  idx_threshold = int(np.where(below_threshold)[0].max())
217
432
  except ValueError: # sorted p-values not below thresholds
218
- return int(below_threshold.any()), q_threshold.min()
219
- return int(below_threshold.any()), q_threshold[idx_threshold]
433
+ return bool(below_threshold.any()), q_threshold.min()
434
+ return bool(below_threshold.any()), q_threshold[idx_threshold]
220
435
  else:
221
436
  raise ValueError("`correction` needs to be either `bonferroni` or `fdr`.")
222
437
 
438
+ @set_metadata("dataeval.detectors")
223
439
  @preprocess_x
224
440
  @update_x_ref
225
441
  def predict(
226
442
  self,
227
443
  x: ArrayLike,
228
- drift_type: Literal["batch", "feature"] = "batch",
229
- ) -> Dict[str, Union[int, float, np.ndarray]]:
444
+ ) -> DriftOutput:
230
445
  """
231
446
  Predict whether a batch of data has drifted from the reference data and update
232
447
  reference data using specified update strategy.
@@ -235,34 +450,16 @@ class BaseUnivariateDrift(BaseDrift):
235
450
  ----------
236
451
  x : ArrayLike
237
452
  Batch of instances.
238
- drift_type : Literal["batch", "feature"], default "batch"
239
- Predict drift at the 'feature' or 'batch' level. For 'batch', the test
240
- statistics for each feature are aggregated using the Bonferroni or False
241
- Discovery Rate correction (if n_features>1).
242
453
 
243
454
  Returns
244
455
  -------
245
- Dictionary containing the drift prediction and optionally the feature level
246
- p-values, threshold after multivariate correction if needed and test
247
- statistics.
456
+ DriftOutput
457
+ Dictionary containing the drift prediction and optionally the feature level
458
+ p-values, threshold after multivariate correction if needed and test statistics.
248
459
  """
249
460
  # compute drift scores
250
461
  p_vals, dist = self.score(x)
251
462
 
252
- # TODO: return both feature-level and batch-level drift predictions by default
253
- # values below p-value threshold are drift
254
- if drift_type == "feature":
255
- drift_pred = (p_vals < self.p_val).astype(int)
256
- threshold = self.p_val
257
- elif drift_type == "batch":
258
- drift_pred, threshold = self._apply_correction(p_vals)
259
- else:
260
- raise ValueError("`drift_type` needs to be either `feature` or `batch`.")
261
-
262
- # populate drift dict
263
- return {
264
- "is_drift": drift_pred,
265
- "p_val": p_vals,
266
- "threshold": threshold,
267
- "distance": dist,
268
- }
463
+ feature_drift = (p_vals < self.p_val).astype(np.bool_)
464
+ drift_pred, threshold = self._apply_correction(p_vals)
465
+ return DriftOutput(drift_pred, threshold, feature_drift, self.p_val, p_vals, dist)
@@ -6,50 +6,51 @@ Original code Copyright (c) 2023 Seldon Technologies Ltd
6
6
  Licensed under Apache Software License (Apache 2.0)
7
7
  """
8
8
 
9
- from typing import Callable, Literal, Optional, Tuple
9
+ from __future__ import annotations
10
+
11
+ from typing import Callable, Literal
10
12
 
11
13
  import numpy as np
12
- from numpy.typing import ArrayLike
14
+ from numpy.typing import ArrayLike, NDArray
13
15
  from scipy.stats import cramervonmises_2samp
14
16
 
15
17
  from dataeval._internal.interop import to_numpy
16
18
 
17
- from .base import BaseUnivariateDrift, UpdateStrategy, preprocess_x
19
+ from .base import BaseDriftUnivariate, UpdateStrategy, preprocess_x
18
20
 
19
21
 
20
- class DriftCVM(BaseUnivariateDrift):
22
+ class DriftCVM(BaseDriftUnivariate):
21
23
  """
22
- Cramér-von Mises (CVM) data drift detector, which tests for any change in the
23
- distribution of continuous univariate data. For multivariate data, a separate
24
- CVM test is applied to each feature, and the obtained p-values are aggregated
25
- via the Bonferroni or False Discovery Rate (FDR) corrections.
24
+ Drift detector employing the Cramér-von Mises (CVM) distribution test.
25
+
26
+ The CVM test detects changes in the distribution of continuous
27
+ univariate data. For multivariate data, a separate CVM test is applied to each
28
+ feature, and the obtained p-values are aggregated via the Bonferroni or
29
+ False Discovery Rate (FDR) corrections.
26
30
 
27
31
  Parameters
28
32
  ----------
29
33
  x_ref : ArrayLike
30
34
  Data used as reference distribution.
31
- p_val : float, default 0.05
35
+ p_val : float | None, default 0.05
32
36
  p-value used for significance of the statistical test for each feature.
33
37
  If the FDR correction method is used, this corresponds to the acceptable
34
38
  q-value.
35
39
  x_ref_preprocessed : bool, default False
36
- Whether the given reference data `x_ref` has been preprocessed yet. If
37
- `x_ref_preprocessed=True`, only the test data `x` will be preprocessed at
38
- prediction time. If `x_ref_preprocessed=False`, the reference data will also
39
- be preprocessed.
40
- update_x_ref : Optional[UpdateStrategy], default None
40
+ Whether the given reference data ``x_ref`` has been preprocessed yet.
41
+ If ``True``, only the test data ``x`` will be preprocessed at prediction time.
42
+ If ``False``, the reference data will also be preprocessed.
43
+ update_x_ref : UpdateStrategy | None, default None
41
44
  Reference data can optionally be updated using an UpdateStrategy class. Update
42
- using the last n instances seen by the detector with
43
- :py:class:`dataeval.detectors.LastSeenUpdateStrategy`
44
- or via reservoir sampling with
45
- :py:class:`dataeval.detectors.ReservoirSamplingUpdateStrategy`.
46
- preprocess_fn : Optional[Callable[[ArrayLike], ArrayLike]], default None
45
+ using the last n instances seen by the detector with LastSeenUpdateStrategy
46
+ or via reservoir sampling with ReservoirSamplingUpdateStrategy.
47
+ preprocess_fn : Callable | None, default None
47
48
  Function to preprocess the data before computing the data drift metrics.
48
49
  Typically a dimensionality reduction technique.
49
- correction : Literal["bonferroni", "fdr"], default "bonferroni"
50
+ correction : "bonferroni" | "fdr", default "bonferroni"
50
51
  Correction type for multivariate data. Either 'bonferroni' or 'fdr' (False
51
52
  Discovery Rate).
52
- n_features
53
+ n_features : int | None, default None
53
54
  Number of features used in the statistical test. No need to pass it if no
54
55
  preprocessing takes place. In case of a preprocessing step, this can also
55
56
  be inferred automatically but could be more expensive to compute.
@@ -60,10 +61,10 @@ class DriftCVM(BaseUnivariateDrift):
60
61
  x_ref: ArrayLike,
61
62
  p_val: float = 0.05,
62
63
  x_ref_preprocessed: bool = False,
63
- update_x_ref: Optional[UpdateStrategy] = None,
64
- preprocess_fn: Optional[Callable[[ArrayLike], ArrayLike]] = None,
64
+ update_x_ref: UpdateStrategy | None = None,
65
+ preprocess_fn: Callable[[ArrayLike], ArrayLike] | None = None,
65
66
  correction: Literal["bonferroni", "fdr"] = "bonferroni",
66
- n_features: Optional[int] = None,
67
+ n_features: int | None = None,
67
68
  ) -> None:
68
69
  super().__init__(
69
70
  x_ref=x_ref,
@@ -76,7 +77,7 @@ class DriftCVM(BaseUnivariateDrift):
76
77
  )
77
78
 
78
79
  @preprocess_x
79
- def score(self, x: ArrayLike) -> Tuple[np.ndarray, np.ndarray]:
80
+ def score(self, x: ArrayLike) -> tuple[NDArray[np.float32], NDArray[np.float32]]:
80
81
  """
81
82
  Performs the two-sample Cramér-von Mises test(s), computing the p-value and
82
83
  test statistic per feature.
@@ -88,7 +89,8 @@ class DriftCVM(BaseUnivariateDrift):
88
89
 
89
90
  Returns
90
91
  -------
91
- Feature level p-values and CVM statistics.
92
+ tuple[NDArray, NDArray]
93
+ Feature level p-values and CVM statistic
92
94
  """
93
95
  x_np = to_numpy(x)
94
96
  x_np = x_np.reshape(x_np.shape[0], -1)
@@ -6,51 +6,53 @@ Original code Copyright (c) 2023 Seldon Technologies Ltd
6
6
  Licensed under Apache Software License (Apache 2.0)
7
7
  """
8
8
 
9
- from typing import Callable, Literal, Optional, Tuple
9
+ from __future__ import annotations
10
+
11
+ from typing import Callable, Literal
10
12
 
11
13
  import numpy as np
12
- from numpy.typing import ArrayLike
14
+ from numpy.typing import ArrayLike, NDArray
13
15
  from scipy.stats import ks_2samp
14
16
 
15
17
  from dataeval._internal.interop import to_numpy
16
18
 
17
- from .base import BaseUnivariateDrift, UpdateStrategy, preprocess_x
19
+ from .base import BaseDriftUnivariate, UpdateStrategy, preprocess_x
18
20
 
19
21
 
20
- class DriftKS(BaseUnivariateDrift):
22
+ class DriftKS(BaseDriftUnivariate):
21
23
  """
22
- Kolmogorov-Smirnov (K-S) data drift detector with Bonferroni or False Discovery
23
- Rate (FDR) correction for multivariate data.
24
+ Drift detector employing the Kolmogorov-Smirnov (KS) distribution test.
25
+
26
+ The KS test detects changes in the maximum distance between two data
27
+ distributions with Bonferroni or False Discovery Rate (FDR) correction
28
+ for multivariate data.
24
29
 
25
30
  Parameters
26
31
  ----------
27
- x_ref : np.ndarray
32
+ x_ref : ArrayLike
28
33
  Data used as reference distribution.
29
- p_val : float, default 0.05
34
+ p_val : float | None, default 0.05
30
35
  p-value used for significance of the statistical test for each feature.
31
36
  If the FDR correction method is used, this corresponds to the acceptable
32
37
  q-value.
33
38
  x_ref_preprocessed : bool, default False
34
- Whether the given reference data `x_ref` has been preprocessed yet. If
35
- `x_ref_preprocessed=True`, only the test data `x` will be preprocessed at
36
- prediction time. If `x_ref_preprocessed=False`, the reference data will also
37
- be preprocessed.
38
- update_x_ref : Optional[UpdateStrategy], default None
39
+ Whether the given reference data ``x_ref`` has been preprocessed yet.
40
+ If ``True``, only the test data ``x`` will be preprocessed at prediction time.
41
+ If ``False``, the reference data will also be preprocessed.
42
+ update_x_ref : UpdateStrategy | None, default None
39
43
  Reference data can optionally be updated using an UpdateStrategy class. Update
40
- using the last n instances seen by the detector with
41
- :py:class:`dataeval.detectors.LastSeenUpdateStrategy`
42
- or via reservoir sampling with
43
- :py:class:`dataeval.detectors.ReservoirSamplingUpdateStrategy`.
44
- preprocess_fn : Optional[Callable[[np.ndarray], np.ndarray]], default None
44
+ using the last n instances seen by the detector with LastSeenUpdateStrategy
45
+ or via reservoir sampling with ReservoirSamplingUpdateStrategy.
46
+ preprocess_fn : Callable | None, default None
45
47
  Function to preprocess the data before computing the data drift metrics.
46
48
  Typically a dimensionality reduction technique.
47
- correction : Literal["bonferroni", "fdr"], default "bonferroni"
49
+ correction : "bonferroni" | "fdr", default "bonferroni"
48
50
  Correction type for multivariate data. Either 'bonferroni' or 'fdr' (False
49
51
  Discovery Rate).
50
- alternative : Literal["two-sided", "less", "greater"], default "two-sided"
52
+ alternative : "two-sided" | "less" | "greater", default "two-sided"
51
53
  Defines the alternative hypothesis. Options are 'two-sided', 'less' or
52
54
  'greater'.
53
- n_features
55
+ n_features : int | None, default None
54
56
  Number of features used in the statistical test. No need to pass it if no
55
57
  preprocessing takes place. In case of a preprocessing step, this can also
56
58
  be inferred automatically but could be more expensive to compute.
@@ -61,11 +63,11 @@ class DriftKS(BaseUnivariateDrift):
61
63
  x_ref: ArrayLike,
62
64
  p_val: float = 0.05,
63
65
  x_ref_preprocessed: bool = False,
64
- update_x_ref: Optional[UpdateStrategy] = None,
65
- preprocess_fn: Optional[Callable[[ArrayLike], ArrayLike]] = None,
66
+ update_x_ref: UpdateStrategy | None = None,
67
+ preprocess_fn: Callable[[ArrayLike], ArrayLike] | None = None,
66
68
  correction: Literal["bonferroni", "fdr"] = "bonferroni",
67
69
  alternative: Literal["two-sided", "less", "greater"] = "two-sided",
68
- n_features: Optional[int] = None,
70
+ n_features: int | None = None,
69
71
  ) -> None:
70
72
  super().__init__(
71
73
  x_ref=x_ref,
@@ -81,18 +83,19 @@ class DriftKS(BaseUnivariateDrift):
81
83
  self.alternative = alternative
82
84
 
83
85
  @preprocess_x
84
- def score(self, x: ArrayLike) -> Tuple[np.ndarray, np.ndarray]:
86
+ def score(self, x: ArrayLike) -> tuple[NDArray[np.float32], NDArray[np.float32]]:
85
87
  """
86
- Compute K-S scores and statistics per feature.
88
+ Compute KS scores and statistics per feature.
87
89
 
88
90
  Parameters
89
91
  ----------
90
- x
92
+ x : ArrayLike
91
93
  Batch of instances.
92
94
 
93
95
  Returns
94
96
  -------
95
- Feature level p-values and K-S statistics.
97
+ tuple[NDArray, NDArray]
98
+ Feature level p-values and KS statistic
96
99
  """
97
100
  x = to_numpy(x)
98
101
  x = x.reshape(x.shape[0], -1)