dataeval 0.72.0__py3-none-any.whl → 0.72.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (90) hide show
  1. dataeval/__init__.py +4 -4
  2. dataeval/detectors/__init__.py +4 -3
  3. dataeval/detectors/drift/__init__.py +10 -11
  4. dataeval/{_internal/detectors → detectors}/drift/base.py +51 -102
  5. dataeval/{_internal/detectors → detectors}/drift/cvm.py +9 -8
  6. dataeval/{_internal/detectors → detectors}/drift/ks.py +11 -10
  7. dataeval/{_internal/detectors → detectors}/drift/mmd.py +33 -34
  8. dataeval/{_internal/detectors → detectors}/drift/torch.py +15 -13
  9. dataeval/{_internal/detectors → detectors}/drift/uncertainty.py +12 -9
  10. dataeval/detectors/drift/updates.py +61 -0
  11. dataeval/detectors/linters/__init__.py +3 -3
  12. dataeval/{_internal/detectors → detectors/linters}/clusterer.py +47 -45
  13. dataeval/{_internal/detectors → detectors/linters}/duplicates.py +20 -10
  14. dataeval/{_internal/detectors → detectors/linters}/merged_stats.py +3 -1
  15. dataeval/{_internal/detectors → detectors/linters}/outliers.py +19 -26
  16. dataeval/detectors/ood/__init__.py +8 -16
  17. dataeval/{_internal/detectors → detectors}/ood/ae.py +9 -9
  18. dataeval/{_internal/detectors → detectors}/ood/aegmm.py +10 -30
  19. dataeval/{_internal/detectors → detectors}/ood/base.py +27 -21
  20. dataeval/{_internal/detectors → detectors}/ood/llr.py +27 -23
  21. dataeval/detectors/ood/metadata_ks_compare.py +99 -0
  22. dataeval/detectors/ood/metadata_least_likely.py +119 -0
  23. dataeval/detectors/ood/metadata_ood_mi.py +92 -0
  24. dataeval/{_internal/detectors → detectors}/ood/vae.py +11 -13
  25. dataeval/{_internal/detectors → detectors}/ood/vaegmm.py +10 -32
  26. dataeval/{_internal/interop.py → interop.py} +12 -7
  27. dataeval/metrics/__init__.py +1 -1
  28. dataeval/metrics/bias/__init__.py +4 -4
  29. dataeval/{_internal/metrics → metrics/bias}/balance.py +70 -4
  30. dataeval/{_internal/metrics → metrics/bias}/coverage.py +10 -8
  31. dataeval/{_internal/metrics → metrics/bias}/diversity.py +54 -20
  32. dataeval/metrics/bias/metadata.py +275 -0
  33. dataeval/{_internal/metrics → metrics/bias}/parity.py +21 -17
  34. dataeval/metrics/estimators/__init__.py +3 -3
  35. dataeval/{_internal/metrics → metrics/estimators}/ber.py +31 -28
  36. dataeval/{_internal/metrics → metrics/estimators}/divergence.py +15 -16
  37. dataeval/{_internal/metrics → metrics/estimators}/uap.py +8 -6
  38. dataeval/metrics/stats/__init__.py +7 -7
  39. dataeval/{_internal/metrics → metrics}/stats/base.py +66 -40
  40. dataeval/{_internal/metrics → metrics}/stats/boxratiostats.py +19 -15
  41. dataeval/{_internal/metrics → metrics}/stats/datasetstats.py +19 -17
  42. dataeval/{_internal/metrics → metrics}/stats/dimensionstats.py +12 -10
  43. dataeval/metrics/stats/hashstats.py +156 -0
  44. dataeval/{_internal/metrics → metrics}/stats/labelstats.py +8 -6
  45. dataeval/{_internal/metrics → metrics}/stats/pixelstats.py +12 -11
  46. dataeval/{_internal/metrics → metrics}/stats/visualstats.py +14 -13
  47. dataeval/{_internal/output.py → output.py} +26 -6
  48. dataeval/utils/__init__.py +8 -4
  49. dataeval/utils/image.py +71 -0
  50. dataeval/utils/shared.py +151 -0
  51. dataeval/utils/split_dataset.py +486 -0
  52. dataeval/utils/tensorflow/__init__.py +9 -7
  53. dataeval/{_internal/models/tensorflow → utils/tensorflow/_internal}/autoencoder.py +64 -68
  54. dataeval/{_internal/models/tensorflow/losses.py → utils/tensorflow/_internal/loss.py} +10 -9
  55. dataeval/{_internal/models/tensorflow → utils/tensorflow/_internal}/pixelcnn.py +18 -22
  56. dataeval/{_internal/models/tensorflow → utils/tensorflow/_internal}/trainer.py +3 -1
  57. dataeval/{_internal/models/tensorflow → utils/tensorflow/_internal}/utils.py +18 -18
  58. dataeval/utils/tensorflow/loss/__init__.py +6 -2
  59. dataeval/utils/torch/__init__.py +7 -3
  60. dataeval/{_internal/models/pytorch → utils/torch}/blocks.py +19 -14
  61. dataeval/{_internal → utils/torch}/datasets.py +49 -43
  62. dataeval/utils/torch/models.py +138 -0
  63. dataeval/{_internal/models/pytorch/autoencoder.py → utils/torch/trainer.py} +12 -141
  64. dataeval/{_internal → utils/torch}/utils.py +3 -1
  65. dataeval/workflows/__init__.py +1 -1
  66. dataeval/{_internal/workflows → workflows}/sufficiency.py +42 -37
  67. {dataeval-0.72.0.dist-info → dataeval-0.72.2.dist-info}/METADATA +7 -5
  68. dataeval-0.72.2.dist-info/RECORD +72 -0
  69. dataeval/_internal/detectors/__init__.py +0 -0
  70. dataeval/_internal/detectors/drift/__init__.py +0 -0
  71. dataeval/_internal/detectors/ood/__init__.py +0 -0
  72. dataeval/_internal/metrics/__init__.py +0 -0
  73. dataeval/_internal/metrics/stats/hashstats.py +0 -75
  74. dataeval/_internal/metrics/utils.py +0 -447
  75. dataeval/_internal/models/__init__.py +0 -0
  76. dataeval/_internal/models/pytorch/__init__.py +0 -0
  77. dataeval/_internal/models/pytorch/utils.py +0 -67
  78. dataeval/_internal/models/tensorflow/__init__.py +0 -0
  79. dataeval/_internal/workflows/__init__.py +0 -0
  80. dataeval/detectors/drift/kernels/__init__.py +0 -10
  81. dataeval/detectors/drift/updates/__init__.py +0 -7
  82. dataeval/utils/tensorflow/models/__init__.py +0 -9
  83. dataeval/utils/tensorflow/recon/__init__.py +0 -3
  84. dataeval/utils/torch/datasets/__init__.py +0 -12
  85. dataeval/utils/torch/models/__init__.py +0 -11
  86. dataeval/utils/torch/trainer/__init__.py +0 -7
  87. dataeval-0.72.0.dist-info/RECORD +0 -80
  88. /dataeval/{_internal/models/tensorflow → utils/tensorflow/_internal}/gmm.py +0 -0
  89. {dataeval-0.72.0.dist-info → dataeval-0.72.2.dist-info}/LICENSE.txt +0 -0
  90. {dataeval-0.72.0.dist-info → dataeval-0.72.2.dist-info}/WHEEL +0 -0
dataeval/__init__.py CHANGED
@@ -1,4 +1,4 @@
1
- __version__ = "0.72.0"
1
+ __version__ = "0.72.2"
2
2
 
3
3
  from importlib.util import find_spec
4
4
 
@@ -8,16 +8,16 @@ _IS_TENSORFLOW_AVAILABLE = find_spec("tensorflow") is not None and find_spec("te
8
8
 
9
9
  del find_spec
10
10
 
11
- from . import detectors, metrics # noqa: E402
11
+ from dataeval import detectors, metrics # noqa: E402
12
12
 
13
13
  __all__ = ["detectors", "metrics"]
14
14
 
15
15
  if _IS_TORCH_AVAILABLE: # pragma: no cover
16
- from . import workflows
16
+ from dataeval import workflows
17
17
 
18
18
  __all__ += ["workflows"]
19
19
 
20
20
  if _IS_TENSORFLOW_AVAILABLE or _IS_TORCH_AVAILABLE: # pragma: no cover
21
- from . import utils
21
+ from dataeval import utils
22
22
 
23
23
  __all__ += ["utils"]
@@ -3,12 +3,13 @@ Detectors can determine if a dataset or individual images in a dataset are indic
3
3
  """
4
4
 
5
5
  from dataeval import _IS_TENSORFLOW_AVAILABLE
6
-
7
- from . import drift, linters
6
+ from dataeval.detectors import drift, linters
8
7
 
9
8
  __all__ = ["drift", "linters"]
10
9
 
11
10
  if _IS_TENSORFLOW_AVAILABLE: # pragma: no cover
12
- from . import ood
11
+ from dataeval.detectors import ood
13
12
 
14
13
  __all__ += ["ood"]
14
+
15
+ del _IS_TENSORFLOW_AVAILABLE
@@ -1,21 +1,20 @@
1
1
  """
2
- Drift detectors identify if the statistical properties of the data has changed.
2
+ :term:`Drift` detectors identify if the statistical properties of the data has changed.
3
3
  """
4
4
 
5
5
  from dataeval import _IS_TORCH_AVAILABLE
6
- from dataeval._internal.detectors.drift.base import DriftOutput
7
- from dataeval._internal.detectors.drift.cvm import DriftCVM
8
- from dataeval._internal.detectors.drift.ks import DriftKS
9
-
10
- from . import updates
6
+ from dataeval.detectors.drift import updates
7
+ from dataeval.detectors.drift.base import DriftOutput
8
+ from dataeval.detectors.drift.cvm import DriftCVM
9
+ from dataeval.detectors.drift.ks import DriftKS
11
10
 
12
11
  __all__ = ["DriftCVM", "DriftKS", "DriftOutput", "updates"]
13
12
 
14
13
  if _IS_TORCH_AVAILABLE: # pragma: no cover
15
- from dataeval._internal.detectors.drift.mmd import DriftMMD, DriftMMDOutput
16
- from dataeval._internal.detectors.drift.torch import preprocess_drift
17
- from dataeval._internal.detectors.drift.uncertainty import DriftUncertainty
14
+ from dataeval.detectors.drift.mmd import DriftMMD, DriftMMDOutput
15
+ from dataeval.detectors.drift.torch import preprocess_drift
16
+ from dataeval.detectors.drift.uncertainty import DriftUncertainty
18
17
 
19
- from . import kernels
18
+ __all__ += ["DriftMMD", "DriftMMDOutput", "DriftUncertainty", "preprocess_drift"]
20
19
 
21
- __all__ += ["DriftMMD", "DriftMMDOutput", "DriftUncertainty", "kernels", "preprocess_drift"]
20
+ del _IS_TORCH_AVAILABLE
@@ -8,16 +8,38 @@ Licensed under Apache Software License (Apache 2.0)
8
8
 
9
9
  from __future__ import annotations
10
10
 
11
+ __all__ = ["DriftOutput"]
12
+
11
13
  from abc import ABC, abstractmethod
12
14
  from dataclasses import dataclass
13
15
  from functools import wraps
14
- from typing import Callable, Literal
16
+ from typing import Any, Callable, Literal, TypeVar
15
17
 
16
18
  import numpy as np
17
19
  from numpy.typing import ArrayLike, NDArray
18
20
 
19
- from dataeval._internal.interop import as_numpy, to_numpy
20
- from dataeval._internal.output import OutputMetadata, set_metadata
21
+ from dataeval.interop import as_numpy, to_numpy
22
+ from dataeval.output import OutputMetadata, set_metadata
23
+
24
+ R = TypeVar("R")
25
+
26
+
27
+ class UpdateStrategy(ABC):
28
+ """
29
+ Updates reference dataset for drift detector
30
+
31
+ Parameters
32
+ ----------
33
+ n : int
34
+ Update with last n instances seen by the detector.
35
+ """
36
+
37
+ def __init__(self, n: int) -> None:
38
+ self.n = n
39
+
40
+ @abstractmethod
41
+ def __call__(self, x_ref: NDArray[Any], x: NDArray[Any], count: int) -> NDArray[Any]:
42
+ """Abstract implementation of update strategy"""
21
43
 
22
44
 
23
45
  @dataclass(frozen=True)
@@ -47,7 +69,7 @@ class DriftOutput(DriftBaseOutput):
47
69
  Attributes
48
70
  ----------
49
71
  is_drift : bool
50
- Drift prediction for the images
72
+ :term:`Drift` prediction for the images
51
73
  threshold : float
52
74
  Threshold after multivariate correction if needed
53
75
  feature_drift : NDArray
@@ -70,9 +92,11 @@ class DriftOutput(DriftBaseOutput):
70
92
  distances: NDArray[np.float32]
71
93
 
72
94
 
73
- def update_x_ref(fn):
95
+ def update_x_ref(fn: Callable[..., R]) -> Callable[..., R]:
96
+ """Decorator to update x_ref with x using selected update methodology"""
97
+
74
98
  @wraps(fn)
75
- def _(self, x, *args, **kwargs):
99
+ def _(self, x, *args, **kwargs) -> R:
76
100
  output = fn(self, x, *args, **kwargs)
77
101
 
78
102
  # update reference dataset
@@ -86,9 +110,11 @@ def update_x_ref(fn):
86
110
  return _
87
111
 
88
112
 
89
- def preprocess_x(fn):
113
+ def preprocess_x(fn: Callable[..., R]) -> Callable[..., R]:
114
+ """Decorator to run preprocess_fn on x before calling wrapped function"""
115
+
90
116
  @wraps(fn)
91
- def _(self, x, *args, **kwargs):
117
+ def _(self, x, *args, **kwargs) -> R:
92
118
  if self._x_refcount == 0:
93
119
  self._x = self._preprocess(x)
94
120
  self._x_refcount += 1
@@ -101,73 +127,9 @@ def preprocess_x(fn):
101
127
  return _
102
128
 
103
129
 
104
- class UpdateStrategy(ABC):
105
- """
106
- Updates reference dataset for drift detector
107
-
108
- Parameters
109
- ----------
110
- n : int
111
- Update with last n instances seen by the detector.
112
- """
113
-
114
- def __init__(self, n: int):
115
- self.n = n
116
-
117
- @abstractmethod
118
- def __call__(self, x_ref: NDArray, x: NDArray, count: int) -> NDArray:
119
- """Abstract implementation of update strategy"""
120
-
121
-
122
- class LastSeenUpdate(UpdateStrategy):
123
- """
124
- Updates reference dataset for drift detector using last seen method.
125
-
126
- Parameters
127
- ----------
128
- n : int
129
- Update with last n instances seen by the detector.
130
- """
131
-
132
- def __call__(self, x_ref: NDArray, x: NDArray, count: int) -> NDArray:
133
- x_updated = np.concatenate([x_ref, x], axis=0)
134
- return x_updated[-self.n :]
135
-
136
-
137
- class ReservoirSamplingUpdate(UpdateStrategy):
138
- """
139
- Updates reference dataset for drift detector using reservoir sampling method.
140
-
141
- Parameters
142
- ----------
143
- n : int
144
- Update with last n instances seen by the detector.
145
- """
146
-
147
- def __call__(self, x_ref: NDArray, x: NDArray, count: int) -> NDArray:
148
- if x.shape[0] + count <= self.n:
149
- return np.concatenate([x_ref, x], axis=0)
150
-
151
- n_ref = x_ref.shape[0]
152
- output_size = min(self.n, n_ref + x.shape[0])
153
- shape = (output_size,) + x.shape[1:]
154
- x_reservoir = np.zeros(shape, dtype=x_ref.dtype)
155
- x_reservoir[:n_ref] = x_ref
156
- for item in x:
157
- count += 1
158
- if n_ref < self.n:
159
- x_reservoir[n_ref, :] = item
160
- n_ref += 1
161
- else:
162
- r = np.random.randint(0, count)
163
- if r < self.n:
164
- x_reservoir[r, :] = item
165
- return x_reservoir
166
-
167
-
168
130
  class BaseDrift:
169
131
  """
170
- A generic drift detection component for preprocessing data and applying statistical correction.
132
+ A generic :term:`drift<Drift>` detection component for preprocessing data and applying statistical correction.
171
133
 
172
134
  This class handles common tasks related to drift detection, such as preprocessing
173
135
  the reference data (`x_ref`), performing statistical correction (e.g., Bonferroni, FDR),
@@ -223,7 +185,7 @@ class BaseDrift:
223
185
  p_val: float = 0.05,
224
186
  x_ref_preprocessed: bool = False,
225
187
  update_x_ref: UpdateStrategy | None = None,
226
- preprocess_fn: Callable[[ArrayLike], ArrayLike] | None = None,
188
+ preprocess_fn: Callable[..., ArrayLike] | None = None,
227
189
  correction: Literal["bonferroni", "fdr"] = "bonferroni",
228
190
  ) -> None:
229
191
  # Type checking
@@ -235,20 +197,20 @@ class BaseDrift:
235
197
  raise ValueError("`correction` must be `bonferroni` or `fdr`.")
236
198
 
237
199
  self._x_ref = to_numpy(x_ref)
238
- self.x_ref_preprocessed = x_ref_preprocessed
200
+ self.x_ref_preprocessed: bool = x_ref_preprocessed
239
201
 
240
202
  # Other attributes
241
203
  self.p_val = p_val
242
204
  self.update_x_ref = update_x_ref
243
205
  self.preprocess_fn = preprocess_fn
244
206
  self.correction = correction
245
- self.n = len(self._x_ref)
207
+ self.n: int = len(self._x_ref)
246
208
 
247
209
  # Ref counter for preprocessed x
248
210
  self._x_refcount = 0
249
211
 
250
212
  @property
251
- def x_ref(self) -> NDArray:
213
+ def x_ref(self) -> NDArray[Any]:
252
214
  """
253
215
  Retrieve the reference data, applying preprocessing if not already done.
254
216
 
@@ -266,7 +228,7 @@ class BaseDrift:
266
228
 
267
229
  def _preprocess(self, x: ArrayLike) -> ArrayLike:
268
230
  """
269
- Preprocess the given data before computing the drift scores.
231
+ Preprocess the given data before computing the :term:`drift<Drift>` scores.
270
232
 
271
233
  Parameters
272
234
  ----------
@@ -285,12 +247,13 @@ class BaseDrift:
285
247
 
286
248
  class BaseDriftUnivariate(BaseDrift):
287
249
  """
288
- Base class for drift detection methods using univariate statistical tests.
250
+ Base class for :term:`drift<Drift>` detection methods using univariate statistical tests.
289
251
 
290
252
  This class inherits from `BaseDrift` and serves as a generic component for detecting
291
253
  distribution drift in univariate features. If the number of features `n_features` is greater
292
254
  than 1, a multivariate correction method (e.g., Bonferroni or FDR) is applied to control
293
- the false positive rate, ensuring it does not exceed the specified p-value.
255
+ the :term:`false positive rate<False Positive Rate (FP)>`, ensuring it does not exceed the specified
256
+ :term:`p-value<P-Value>`.
294
257
 
295
258
  Parameters
296
259
  ----------
@@ -312,28 +275,14 @@ class BaseDriftUnivariate(BaseDrift):
312
275
 
313
276
  Attributes
314
277
  ----------
315
- _n_features : int | None
316
- Number of features in the data. If not provided, it is lazily inferred from the
317
- input data and any preprocessing function.
318
278
  p_val : float
319
279
  The significance level for drift detection.
320
280
  correction : str
321
- The method for controlling the false discovery rate or applying a Bonferroni correction.
281
+ The method for controlling the :term:`False Discovery Rate (FDR)` or applying a Bonferroni correction.
322
282
  update_x_ref : UpdateStrategy | None
323
283
  Strategy for updating the reference data if applicable.
324
284
  preprocess_fn : Callable | None
325
285
  Function used for preprocessing input data before drift detection.
326
-
327
- Methods
328
- -------
329
- n_features:
330
- Property that returns the number of features, inferring it if necessary.
331
- score(x):
332
- Abstract method to compute univariate feature scores after preprocessing.
333
- _apply_correction(p_vals):
334
- Apply a statistical correction to p-values to account for multiple testing.
335
- predict(x):
336
- Predict whether drift has occurred on a batch of data, applying multivariate correction if needed.
337
286
  """
338
287
 
339
288
  def __init__(
@@ -393,19 +342,19 @@ class BaseDriftUnivariate(BaseDrift):
393
342
  Parameters
394
343
  ----------
395
344
  x : ArrayLike
396
- The batch of data to calculate univariate drift scores for each feature.
345
+ The batch of data to calculate univariate :term:`drift<Drift>` scores for each feature.
397
346
 
398
347
  Returns
399
348
  -------
400
349
  tuple[NDArray, NDArray]
401
- A tuple containing p-values and distance statistics for each feature.
350
+ A tuple containing p-values and distance :term:`statistics<Statistics>` for each feature.
402
351
  """
403
352
 
404
353
  def _apply_correction(self, p_vals: NDArray) -> tuple[bool, float]:
405
354
  """
406
355
  Apply the specified correction method (Bonferroni or FDR) to the p-values.
407
356
 
408
- If the correction method is Bonferroni, the threshold for detecting drift
357
+ If the correction method is Bonferroni, the threshold for detecting :term:`drift<Drift>`
409
358
  is divided by the number of features. For FDR, the correction is applied
410
359
  using the Benjamini-Hochberg procedure.
411
360
 
@@ -426,7 +375,7 @@ class BaseDriftUnivariate(BaseDrift):
426
375
  return drift_pred, threshold
427
376
  elif self.correction == "fdr":
428
377
  n = p_vals.shape[0]
429
- i = np.arange(n) + 1
378
+ i = np.arange(n) + np.int_(1)
430
379
  p_sorted = np.sort(p_vals)
431
380
  q_threshold = self.p_val * i / n
432
381
  below_threshold = p_sorted < q_threshold
@@ -438,7 +387,7 @@ class BaseDriftUnivariate(BaseDrift):
438
387
  else:
439
388
  raise ValueError("`correction` needs to be either `bonferroni` or `fdr`.")
440
389
 
441
- @set_metadata("dataeval.detectors")
390
+ @set_metadata()
442
391
  @preprocess_x
443
392
  @update_x_ref
444
393
  def predict(
@@ -457,8 +406,8 @@ class BaseDriftUnivariate(BaseDrift):
457
406
  Returns
458
407
  -------
459
408
  DriftOutput
460
- Dictionary containing the drift prediction and optionally the feature level
461
- p-values, threshold after multivariate correction if needed and test statistics.
409
+ Dictionary containing the :term:`drift<Drift>` prediction and optionally the feature level
410
+ p-values, threshold after multivariate correction if needed and test :term:`statistics<Statistics>`.
462
411
  """
463
412
  # compute drift scores
464
413
  p_vals, dist = self.score(x)
@@ -8,32 +8,33 @@ Licensed under Apache Software License (Apache 2.0)
8
8
 
9
9
  from __future__ import annotations
10
10
 
11
+ __all__ = ["DriftCVM"]
12
+
11
13
  from typing import Callable, Literal
12
14
 
13
15
  import numpy as np
14
16
  from numpy.typing import ArrayLike, NDArray
15
17
  from scipy.stats import cramervonmises_2samp
16
18
 
17
- from dataeval._internal.interop import to_numpy
18
-
19
- from .base import BaseDriftUnivariate, UpdateStrategy, preprocess_x
19
+ from dataeval.detectors.drift.base import BaseDriftUnivariate, UpdateStrategy, preprocess_x
20
+ from dataeval.interop import to_numpy
20
21
 
21
22
 
22
23
  class DriftCVM(BaseDriftUnivariate):
23
24
  """
24
- Drift detector employing the Cramér-von Mises (CVM) distribution test.
25
+ :term:`Drift` detector employing the :term:`Cramér-von Mises (CVM) Drift Detection` test.
25
26
 
26
27
  The CVM test detects changes in the distribution of continuous
27
28
  univariate data. For multivariate data, a separate CVM test is applied to each
28
29
  feature, and the obtained p-values are aggregated via the Bonferroni or
29
- False Discovery Rate (FDR) corrections.
30
+ :term:`False Discovery Rate (FDR)` corrections.
30
31
 
31
32
  Parameters
32
33
  ----------
33
34
  x_ref : ArrayLike
34
35
  Data used as reference distribution.
35
36
  p_val : float | None, default 0.05
36
- p-value used for significance of the statistical test for each feature.
37
+ :term:`p-value<P-Value>` used for significance of the statistical test for each feature.
37
38
  If the FDR correction method is used, this corresponds to the acceptable
38
39
  q-value.
39
40
  x_ref_preprocessed : bool, default False
@@ -46,7 +47,7 @@ class DriftCVM(BaseDriftUnivariate):
46
47
  or via reservoir sampling with ReservoirSamplingUpdateStrategy.
47
48
  preprocess_fn : Callable | None, default None
48
49
  Function to preprocess the data before computing the data drift metrics.
49
- Typically a dimensionality reduction technique.
50
+ Typically a :term:`dimensionality reduction<Dimensionality Reduction>` technique.
50
51
  correction : "bonferroni" | "fdr", default "bonferroni"
51
52
  Correction type for multivariate data. Either 'bonferroni' or 'fdr' (False
52
53
  Discovery Rate).
@@ -79,7 +80,7 @@ class DriftCVM(BaseDriftUnivariate):
79
80
  @preprocess_x
80
81
  def score(self, x: ArrayLike) -> tuple[NDArray[np.float32], NDArray[np.float32]]:
81
82
  """
82
- Performs the two-sample Cramér-von Mises test(s), computing the p-value and
83
+ Performs the two-sample Cramér-von Mises test(s), computing the :term:`p-value<P-value>` and
83
84
  test statistic per feature.
84
85
 
85
86
  Parameters
@@ -8,23 +8,24 @@ Licensed under Apache Software License (Apache 2.0)
8
8
 
9
9
  from __future__ import annotations
10
10
 
11
+ __all__ = ["DriftKS"]
12
+
11
13
  from typing import Callable, Literal
12
14
 
13
15
  import numpy as np
14
16
  from numpy.typing import ArrayLike, NDArray
15
17
  from scipy.stats import ks_2samp
16
18
 
17
- from dataeval._internal.interop import to_numpy
18
-
19
- from .base import BaseDriftUnivariate, UpdateStrategy, preprocess_x
19
+ from dataeval.detectors.drift.base import BaseDriftUnivariate, UpdateStrategy, preprocess_x
20
+ from dataeval.interop import to_numpy
20
21
 
21
22
 
22
23
  class DriftKS(BaseDriftUnivariate):
23
24
  """
24
- Drift detector employing the Kolmogorov-Smirnov (KS) distribution test.
25
+ :term:`Drift` detector employing the Kolmogorov-Smirnov (KS) distribution test.
25
26
 
26
27
  The KS test detects changes in the maximum distance between two data
27
- distributions with Bonferroni or False Discovery Rate (FDR) correction
28
+ distributions with Bonferroni or :term:`False Discovery Rate (FDR)` correction
28
29
  for multivariate data.
29
30
 
30
31
  Parameters
@@ -32,7 +33,7 @@ class DriftKS(BaseDriftUnivariate):
32
33
  x_ref : ArrayLike
33
34
  Data used as reference distribution.
34
35
  p_val : float | None, default 0.05
35
- p-value used for significance of the statistical test for each feature.
36
+ :term:`p-value<P-Value>` used for significance of the statistical test for each feature.
36
37
  If the FDR correction method is used, this corresponds to the acceptable
37
38
  q-value.
38
39
  x_ref_preprocessed : bool, default False
@@ -44,8 +45,8 @@ class DriftKS(BaseDriftUnivariate):
44
45
  using the last n instances seen by the detector with LastSeenUpdateStrategy
45
46
  or via reservoir sampling with ReservoirSamplingUpdateStrategy.
46
47
  preprocess_fn : Callable | None, default None
47
- Function to preprocess the data before computing the data drift metrics.
48
- Typically a dimensionality reduction technique.
48
+ Function to preprocess the data before computing the data :term:`drift<Drift>` metrics.
49
+ Typically a :term:`dimensionality reduction<Dimensionality Reduction>` technique.
49
50
  correction : "bonferroni" | "fdr", default "bonferroni"
50
51
  Correction type for multivariate data. Either 'bonferroni' or 'fdr' (False
51
52
  Discovery Rate).
@@ -85,7 +86,7 @@ class DriftKS(BaseDriftUnivariate):
85
86
  @preprocess_x
86
87
  def score(self, x: ArrayLike) -> tuple[NDArray[np.float32], NDArray[np.float32]]:
87
88
  """
88
- Compute KS scores and statistics per feature.
89
+ Compute KS scores and :term:Statistics` per feature.
89
90
 
90
91
  Parameters
91
92
  ----------
@@ -95,7 +96,7 @@ class DriftKS(BaseDriftUnivariate):
95
96
  Returns
96
97
  -------
97
98
  tuple[NDArray, NDArray]
98
- Feature level p-values and KS statistic
99
+ Feature level :term:p-values and KS statistic
99
100
  """
100
101
  x = to_numpy(x)
101
102
  x = x.reshape(x.shape[0], -1)
@@ -8,30 +8,31 @@ Licensed under Apache Software License (Apache 2.0)
8
8
 
9
9
  from __future__ import annotations
10
10
 
11
+ __all__ = ["DriftMMD", "DriftMMDOutput"]
12
+
11
13
  from dataclasses import dataclass
12
14
  from typing import Callable
13
15
 
14
16
  import torch
15
17
  from numpy.typing import ArrayLike
16
18
 
17
- from dataeval._internal.interop import as_numpy
18
- from dataeval._internal.output import set_metadata
19
-
20
- from .base import BaseDrift, DriftBaseOutput, UpdateStrategy, preprocess_x, update_x_ref
21
- from .torch import GaussianRBF, get_device, mmd2_from_kernel_matrix
19
+ from dataeval.detectors.drift.base import BaseDrift, DriftBaseOutput, UpdateStrategy, preprocess_x, update_x_ref
20
+ from dataeval.detectors.drift.torch import _GaussianRBF, _mmd2_from_kernel_matrix, get_device
21
+ from dataeval.interop import as_numpy
22
+ from dataeval.output import set_metadata
22
23
 
23
24
 
24
25
  @dataclass(frozen=True)
25
26
  class DriftMMDOutput(DriftBaseOutput):
26
27
  """
27
- Output class for :class:`DriftMMD` drift detector
28
+ Output class for :class:`DriftMMD` :term:`drift<Drift>` detector
28
29
 
29
30
  Attributes
30
31
  ----------
31
32
  is_drift : bool
32
33
  Drift prediction for the images
33
34
  threshold : float
34
- P-value used for significance of the permutation test
35
+ :term:`P-Value` used for significance of the permutation test
35
36
  p_val : float
36
37
  P-value obtained from the permutation test
37
38
  distance : float
@@ -49,14 +50,14 @@ class DriftMMDOutput(DriftBaseOutput):
49
50
 
50
51
  class DriftMMD(BaseDrift):
51
52
  """
52
- Maximum Mean Discrepancy (MMD) data drift detector using a permutation test.
53
+ :term:`Maximum Mean Discrepancy (MMD) Drift Detection` algorithm using a permutation test.
53
54
 
54
55
  Parameters
55
56
  ----------
56
57
  x_ref : ArrayLike
57
58
  Data used as reference distribution.
58
59
  p_val : float | None, default 0.05
59
- p-value used for significance of the statistical test for each feature.
60
+ :term:`P-value` used for significance of the statistical test for each feature.
60
61
  If the FDR correction method is used, this corresponds to the acceptable
61
62
  q-value.
62
63
  x_ref_preprocessed : bool, default False
@@ -69,11 +70,9 @@ class DriftMMD(BaseDrift):
69
70
  or via reservoir sampling with ReservoirSamplingUpdateStrategy.
70
71
  preprocess_fn : Callable | None, default None
71
72
  Function to preprocess the data before computing the data drift metrics.
72
- Typically a dimensionality reduction technique.
73
- kernel : Callable, default GaussianRBF
74
- Kernel used for the MMD computation, defaults to Gaussian RBF kernel.
73
+ Typically a :term:`dimensionality reduction<Dimensionality Reduction>` technique.
75
74
  sigma : ArrayLike | None, default None
76
- Optionally set the GaussianRBF kernel bandwidth. Can also pass multiple
75
+ Optionally set the internal GaussianRBF kernel bandwidth. Can also pass multiple
77
76
  bandwidth values as an array. The kernel evaluation is then averaged over
78
77
  those bandwidths.
79
78
  configure_kernel_from_x_ref : bool, default True
@@ -91,48 +90,47 @@ class DriftMMD(BaseDrift):
91
90
  p_val: float = 0.05,
92
91
  x_ref_preprocessed: bool = False,
93
92
  update_x_ref: UpdateStrategy | None = None,
94
- preprocess_fn: Callable[[ArrayLike], ArrayLike] | None = None,
95
- kernel: Callable = GaussianRBF,
93
+ preprocess_fn: Callable[..., ArrayLike] | None = None,
96
94
  sigma: ArrayLike | None = None,
97
95
  configure_kernel_from_x_ref: bool = True,
98
96
  n_permutations: int = 100,
99
- device: str | None = None,
97
+ device: str | torch.device | None = None,
100
98
  ) -> None:
101
99
  super().__init__(x_ref, p_val, x_ref_preprocessed, update_x_ref, preprocess_fn)
102
100
 
103
- self.infer_sigma = configure_kernel_from_x_ref
101
+ self._infer_sigma = configure_kernel_from_x_ref
104
102
  if configure_kernel_from_x_ref and sigma is not None:
105
- self.infer_sigma = False
103
+ self._infer_sigma = False
106
104
 
107
105
  self.n_permutations = n_permutations # nb of iterations through permutation test
108
106
 
109
107
  # set device
110
- self.device = get_device(device)
108
+ self.device: torch.device = get_device(device)
111
109
 
112
110
  # initialize kernel
113
111
  sigma_tensor = torch.from_numpy(as_numpy(sigma)).to(self.device) if sigma is not None else None
114
- self.kernel = kernel(sigma_tensor).to(self.device) if kernel == GaussianRBF else kernel
112
+ self._kernel = _GaussianRBF(sigma_tensor).to(self.device)
115
113
 
116
114
  # compute kernel matrix for the reference data
117
- if self.infer_sigma or isinstance(sigma_tensor, torch.Tensor):
115
+ if self._infer_sigma or isinstance(sigma_tensor, torch.Tensor):
118
116
  x = torch.from_numpy(self.x_ref).to(self.device)
119
- self.k_xx = self.kernel(x, x, infer_sigma=self.infer_sigma)
120
- self.infer_sigma = False
117
+ self._k_xx = self._kernel(x, x, infer_sigma=self._infer_sigma)
118
+ self._infer_sigma = False
121
119
  else:
122
- self.k_xx, self.infer_sigma = None, True
120
+ self._k_xx, self._infer_sigma = None, True
123
121
 
124
122
  def _kernel_matrix(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
125
123
  """Compute and return full kernel matrix between arrays x and y."""
126
- k_xy = self.kernel(x, y, self.infer_sigma)
127
- k_xx = self.k_xx if self.k_xx is not None and self.update_x_ref is None else self.kernel(x, x)
128
- k_yy = self.kernel(y, y)
124
+ k_xy = self._kernel(x, y, self._infer_sigma)
125
+ k_xx = self._k_xx if self._k_xx is not None and self.update_x_ref is None else self._kernel(x, x)
126
+ k_yy = self._kernel(y, y)
129
127
  kernel_mat = torch.cat([torch.cat([k_xx, k_xy], 1), torch.cat([k_xy.T, k_yy], 1)], 0)
130
128
  return kernel_mat
131
129
 
132
130
  @preprocess_x
133
131
  def score(self, x: ArrayLike) -> tuple[float, float, float]:
134
132
  """
135
- Compute the p-value resulting from a permutation test using the maximum mean
133
+ Compute the :term:`p-value<P-Value>` resulting from a permutation test using the maximum mean
136
134
  discrepancy as a distance measure between the reference data and the data to
137
135
  be tested.
138
136
 
@@ -145,25 +143,25 @@ class DriftMMD(BaseDrift):
145
143
  -------
146
144
  tuple(float, float, float)
147
145
  p-value obtained from the permutation test, MMD^2 between the reference and test set,
148
- and MMD^2 threshold above which drift is flagged
146
+ and MMD^2 threshold above which :term:`drift<Drift>` is flagged
149
147
  """
150
148
  x = as_numpy(x)
151
149
  x_ref = torch.from_numpy(self.x_ref).to(self.device)
152
150
  n = x.shape[0]
153
151
  kernel_mat = self._kernel_matrix(x_ref, torch.from_numpy(x).to(self.device))
154
152
  kernel_mat = kernel_mat - torch.diag(kernel_mat.diag()) # zero diagonal
155
- mmd2 = mmd2_from_kernel_matrix(kernel_mat, n, permute=False, zero_diag=False)
153
+ mmd2 = _mmd2_from_kernel_matrix(kernel_mat, n, permute=False, zero_diag=False)
156
154
  mmd2_permuted = torch.Tensor(
157
- [mmd2_from_kernel_matrix(kernel_mat, n, permute=True, zero_diag=False) for _ in range(self.n_permutations)]
155
+ [_mmd2_from_kernel_matrix(kernel_mat, n, permute=True, zero_diag=False) for _ in range(self.n_permutations)]
158
156
  )
159
157
  mmd2, mmd2_permuted = mmd2.detach().cpu(), mmd2_permuted.detach().cpu()
160
158
  p_val = (mmd2 <= mmd2_permuted).float().mean()
161
159
  # compute distance threshold
162
160
  idx_threshold = int(self.p_val * len(mmd2_permuted))
163
161
  distance_threshold = torch.sort(mmd2_permuted, descending=True).values[idx_threshold]
164
- return p_val.numpy().item(), mmd2.numpy().item(), distance_threshold.numpy()
162
+ return p_val.numpy().item(), mmd2.numpy().item(), distance_threshold.numpy().item()
165
163
 
166
- @set_metadata("dataeval.detectors")
164
+ @set_metadata()
167
165
  @preprocess_x
168
166
  @update_x_ref
169
167
  def predict(self, x: ArrayLike) -> DriftMMDOutput:
@@ -179,7 +177,8 @@ class DriftMMD(BaseDrift):
179
177
  Returns
180
178
  -------
181
179
  DriftMMDOutput
182
- Output class containing the drift prediction, p-value, threshold and MMD metric.
180
+ Output class containing the :term:`drift<Drift>` prediction, :term:`p-value<P-Value>`,
181
+ threshold and MMD metric.
183
182
  """
184
183
  # compute drift scores
185
184
  p_val, dist, distance_threshold = self.score(x)