dataeval 0.72.1__py3-none-any.whl → 0.72.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (90) hide show
  1. dataeval/__init__.py +4 -4
  2. dataeval/detectors/__init__.py +4 -3
  3. dataeval/detectors/drift/__init__.py +9 -10
  4. dataeval/{_internal/detectors → detectors}/drift/base.py +39 -91
  5. dataeval/{_internal/detectors → detectors}/drift/cvm.py +4 -3
  6. dataeval/{_internal/detectors → detectors}/drift/ks.py +4 -3
  7. dataeval/{_internal/detectors → detectors}/drift/mmd.py +23 -25
  8. dataeval/{_internal/detectors → detectors}/drift/torch.py +13 -11
  9. dataeval/{_internal/detectors → detectors}/drift/uncertainty.py +7 -5
  10. dataeval/detectors/drift/updates.py +61 -0
  11. dataeval/detectors/linters/__init__.py +3 -3
  12. dataeval/{_internal/detectors → detectors/linters}/clusterer.py +41 -39
  13. dataeval/{_internal/detectors → detectors/linters}/duplicates.py +19 -9
  14. dataeval/{_internal/detectors → detectors/linters}/merged_stats.py +3 -1
  15. dataeval/{_internal/detectors → detectors/linters}/outliers.py +14 -21
  16. dataeval/detectors/ood/__init__.py +6 -6
  17. dataeval/{_internal/detectors → detectors}/ood/ae.py +7 -7
  18. dataeval/{_internal/detectors → detectors}/ood/aegmm.py +9 -29
  19. dataeval/{_internal/detectors → detectors}/ood/base.py +24 -18
  20. dataeval/{_internal/detectors → detectors}/ood/llr.py +24 -20
  21. dataeval/detectors/ood/metadata_ks_compare.py +99 -0
  22. dataeval/detectors/ood/metadata_least_likely.py +119 -0
  23. dataeval/detectors/ood/metadata_ood_mi.py +92 -0
  24. dataeval/{_internal/detectors → detectors}/ood/vae.py +10 -12
  25. dataeval/{_internal/detectors → detectors}/ood/vaegmm.py +10 -32
  26. dataeval/{_internal/interop.py → interop.py} +12 -7
  27. dataeval/metrics/__init__.py +1 -1
  28. dataeval/metrics/bias/__init__.py +4 -4
  29. dataeval/{_internal/metrics → metrics/bias}/balance.py +75 -9
  30. dataeval/{_internal/metrics → metrics/bias}/coverage.py +6 -4
  31. dataeval/{_internal/metrics → metrics/bias}/diversity.py +48 -14
  32. dataeval/metrics/bias/metadata.py +275 -0
  33. dataeval/{_internal/metrics → metrics/bias}/parity.py +12 -10
  34. dataeval/metrics/estimators/__init__.py +3 -3
  35. dataeval/{_internal/metrics → metrics/estimators}/ber.py +25 -22
  36. dataeval/{_internal/metrics → metrics/estimators}/divergence.py +11 -12
  37. dataeval/{_internal/metrics → metrics/estimators}/uap.py +5 -3
  38. dataeval/metrics/stats/__init__.py +7 -7
  39. dataeval/{_internal/metrics → metrics}/stats/base.py +59 -35
  40. dataeval/{_internal/metrics → metrics}/stats/boxratiostats.py +18 -14
  41. dataeval/{_internal/metrics → metrics}/stats/datasetstats.py +18 -16
  42. dataeval/{_internal/metrics → metrics}/stats/dimensionstats.py +9 -7
  43. dataeval/metrics/stats/hashstats.py +156 -0
  44. dataeval/{_internal/metrics → metrics}/stats/labelstats.py +5 -3
  45. dataeval/{_internal/metrics → metrics}/stats/pixelstats.py +9 -8
  46. dataeval/{_internal/metrics → metrics}/stats/visualstats.py +10 -9
  47. dataeval/{_internal/output.py → output.py} +26 -6
  48. dataeval/utils/__init__.py +7 -3
  49. dataeval/utils/image.py +71 -0
  50. dataeval/utils/shared.py +151 -0
  51. dataeval/{_internal → utils}/split_dataset.py +98 -33
  52. dataeval/utils/tensorflow/__init__.py +7 -6
  53. dataeval/{_internal/models/tensorflow → utils/tensorflow/_internal}/autoencoder.py +60 -64
  54. dataeval/{_internal/models/tensorflow/losses.py → utils/tensorflow/_internal/loss.py} +9 -8
  55. dataeval/{_internal/models/tensorflow → utils/tensorflow/_internal}/pixelcnn.py +16 -20
  56. dataeval/{_internal/models/tensorflow → utils/tensorflow/_internal}/trainer.py +3 -1
  57. dataeval/{_internal/models/tensorflow → utils/tensorflow/_internal}/utils.py +17 -17
  58. dataeval/utils/tensorflow/loss/__init__.py +6 -2
  59. dataeval/utils/torch/__init__.py +7 -3
  60. dataeval/{_internal/models/pytorch → utils/torch}/blocks.py +19 -14
  61. dataeval/{_internal → utils/torch}/datasets.py +48 -42
  62. dataeval/utils/torch/models.py +138 -0
  63. dataeval/{_internal/models/pytorch/autoencoder.py → utils/torch/trainer.py} +7 -136
  64. dataeval/{_internal → utils/torch}/utils.py +3 -1
  65. dataeval/workflows/__init__.py +1 -1
  66. dataeval/{_internal/workflows → workflows}/sufficiency.py +39 -34
  67. {dataeval-0.72.1.dist-info → dataeval-0.72.2.dist-info}/METADATA +2 -1
  68. dataeval-0.72.2.dist-info/RECORD +72 -0
  69. dataeval/_internal/detectors/__init__.py +0 -0
  70. dataeval/_internal/detectors/drift/__init__.py +0 -0
  71. dataeval/_internal/detectors/ood/__init__.py +0 -0
  72. dataeval/_internal/metrics/__init__.py +0 -0
  73. dataeval/_internal/metrics/stats/hashstats.py +0 -75
  74. dataeval/_internal/metrics/utils.py +0 -447
  75. dataeval/_internal/models/__init__.py +0 -0
  76. dataeval/_internal/models/pytorch/__init__.py +0 -0
  77. dataeval/_internal/models/pytorch/utils.py +0 -67
  78. dataeval/_internal/models/tensorflow/__init__.py +0 -0
  79. dataeval/_internal/workflows/__init__.py +0 -0
  80. dataeval/detectors/drift/kernels/__init__.py +0 -10
  81. dataeval/detectors/drift/updates/__init__.py +0 -8
  82. dataeval/utils/tensorflow/models/__init__.py +0 -9
  83. dataeval/utils/tensorflow/recon/__init__.py +0 -3
  84. dataeval/utils/torch/datasets/__init__.py +0 -12
  85. dataeval/utils/torch/models/__init__.py +0 -11
  86. dataeval/utils/torch/trainer/__init__.py +0 -7
  87. dataeval-0.72.1.dist-info/RECORD +0 -81
  88. /dataeval/{_internal/models/tensorflow → utils/tensorflow/_internal}/gmm.py +0 -0
  89. {dataeval-0.72.1.dist-info → dataeval-0.72.2.dist-info}/LICENSE.txt +0 -0
  90. {dataeval-0.72.1.dist-info → dataeval-0.72.2.dist-info}/WHEEL +0 -0
dataeval/__init__.py CHANGED
@@ -1,4 +1,4 @@
1
- __version__ = "0.72.1"
1
+ __version__ = "0.72.2"
2
2
 
3
3
  from importlib.util import find_spec
4
4
 
@@ -8,16 +8,16 @@ _IS_TENSORFLOW_AVAILABLE = find_spec("tensorflow") is not None and find_spec("te
8
8
 
9
9
  del find_spec
10
10
 
11
- from . import detectors, metrics # noqa: E402
11
+ from dataeval import detectors, metrics # noqa: E402
12
12
 
13
13
  __all__ = ["detectors", "metrics"]
14
14
 
15
15
  if _IS_TORCH_AVAILABLE: # pragma: no cover
16
- from . import workflows
16
+ from dataeval import workflows
17
17
 
18
18
  __all__ += ["workflows"]
19
19
 
20
20
  if _IS_TENSORFLOW_AVAILABLE or _IS_TORCH_AVAILABLE: # pragma: no cover
21
- from . import utils
21
+ from dataeval import utils
22
22
 
23
23
  __all__ += ["utils"]
@@ -3,12 +3,13 @@ Detectors can determine if a dataset or individual images in a dataset are indic
3
3
  """
4
4
 
5
5
  from dataeval import _IS_TENSORFLOW_AVAILABLE
6
-
7
- from . import drift, linters
6
+ from dataeval.detectors import drift, linters
8
7
 
9
8
  __all__ = ["drift", "linters"]
10
9
 
11
10
  if _IS_TENSORFLOW_AVAILABLE: # pragma: no cover
12
- from . import ood
11
+ from dataeval.detectors import ood
13
12
 
14
13
  __all__ += ["ood"]
14
+
15
+ del _IS_TENSORFLOW_AVAILABLE
@@ -3,19 +3,18 @@
3
3
  """
4
4
 
5
5
  from dataeval import _IS_TORCH_AVAILABLE
6
- from dataeval._internal.detectors.drift.base import DriftOutput
7
- from dataeval._internal.detectors.drift.cvm import DriftCVM
8
- from dataeval._internal.detectors.drift.ks import DriftKS
9
-
10
- from . import updates
6
+ from dataeval.detectors.drift import updates
7
+ from dataeval.detectors.drift.base import DriftOutput
8
+ from dataeval.detectors.drift.cvm import DriftCVM
9
+ from dataeval.detectors.drift.ks import DriftKS
11
10
 
12
11
  __all__ = ["DriftCVM", "DriftKS", "DriftOutput", "updates"]
13
12
 
14
13
  if _IS_TORCH_AVAILABLE: # pragma: no cover
15
- from dataeval._internal.detectors.drift.mmd import DriftMMD, DriftMMDOutput
16
- from dataeval._internal.detectors.drift.torch import preprocess_drift
17
- from dataeval._internal.detectors.drift.uncertainty import DriftUncertainty
14
+ from dataeval.detectors.drift.mmd import DriftMMD, DriftMMDOutput
15
+ from dataeval.detectors.drift.torch import preprocess_drift
16
+ from dataeval.detectors.drift.uncertainty import DriftUncertainty
18
17
 
19
- from . import kernels
18
+ __all__ += ["DriftMMD", "DriftMMDOutput", "DriftUncertainty", "preprocess_drift"]
20
19
 
21
- __all__ += ["DriftMMD", "DriftMMDOutput", "DriftUncertainty", "kernels", "preprocess_drift"]
20
+ del _IS_TORCH_AVAILABLE
@@ -8,16 +8,38 @@ Licensed under Apache Software License (Apache 2.0)
8
8
 
9
9
  from __future__ import annotations
10
10
 
11
+ __all__ = ["DriftOutput"]
12
+
11
13
  from abc import ABC, abstractmethod
12
14
  from dataclasses import dataclass
13
15
  from functools import wraps
14
- from typing import Callable, Literal
16
+ from typing import Any, Callable, Literal, TypeVar
15
17
 
16
18
  import numpy as np
17
19
  from numpy.typing import ArrayLike, NDArray
18
20
 
19
- from dataeval._internal.interop import as_numpy, to_numpy
20
- from dataeval._internal.output import OutputMetadata, set_metadata
21
+ from dataeval.interop import as_numpy, to_numpy
22
+ from dataeval.output import OutputMetadata, set_metadata
23
+
24
+ R = TypeVar("R")
25
+
26
+
27
+ class UpdateStrategy(ABC):
28
+ """
29
+ Updates reference dataset for drift detector
30
+
31
+ Parameters
32
+ ----------
33
+ n : int
34
+ Update with last n instances seen by the detector.
35
+ """
36
+
37
+ def __init__(self, n: int) -> None:
38
+ self.n = n
39
+
40
+ @abstractmethod
41
+ def __call__(self, x_ref: NDArray[Any], x: NDArray[Any], count: int) -> NDArray[Any]:
42
+ """Abstract implementation of update strategy"""
21
43
 
22
44
 
23
45
  @dataclass(frozen=True)
@@ -70,9 +92,11 @@ class DriftOutput(DriftBaseOutput):
70
92
  distances: NDArray[np.float32]
71
93
 
72
94
 
73
- def update_x_ref(fn):
95
+ def update_x_ref(fn: Callable[..., R]) -> Callable[..., R]:
96
+ """Decorator to update x_ref with x using selected update methodology"""
97
+
74
98
  @wraps(fn)
75
- def _(self, x, *args, **kwargs):
99
+ def _(self, x, *args, **kwargs) -> R:
76
100
  output = fn(self, x, *args, **kwargs)
77
101
 
78
102
  # update reference dataset
@@ -86,9 +110,11 @@ def update_x_ref(fn):
86
110
  return _
87
111
 
88
112
 
89
- def preprocess_x(fn):
113
+ def preprocess_x(fn: Callable[..., R]) -> Callable[..., R]:
114
+ """Decorator to run preprocess_fn on x before calling wrapped function"""
115
+
90
116
  @wraps(fn)
91
- def _(self, x, *args, **kwargs):
117
+ def _(self, x, *args, **kwargs) -> R:
92
118
  if self._x_refcount == 0:
93
119
  self._x = self._preprocess(x)
94
120
  self._x_refcount += 1
@@ -101,70 +127,6 @@ def preprocess_x(fn):
101
127
  return _
102
128
 
103
129
 
104
- class UpdateStrategy(ABC):
105
- """
106
- Updates reference dataset for :term:`drift<Drift>` detector
107
-
108
- Parameters
109
- ----------
110
- n : int
111
- Update with last n instances seen by the detector.
112
- """
113
-
114
- def __init__(self, n: int):
115
- self.n = n
116
-
117
- @abstractmethod
118
- def __call__(self, x_ref: NDArray, x: NDArray, count: int) -> NDArray:
119
- """Abstract implementation of update strategy"""
120
-
121
-
122
- class LastSeenUpdate(UpdateStrategy):
123
- """
124
- Updates reference dataset for :term:`drift<Drift>` detector using last seen method.
125
-
126
- Parameters
127
- ----------
128
- n : int
129
- Update with last n instances seen by the detector.
130
- """
131
-
132
- def __call__(self, x_ref: NDArray, x: NDArray, count: int) -> NDArray:
133
- x_updated = np.concatenate([x_ref, x], axis=0)
134
- return x_updated[-self.n :]
135
-
136
-
137
- class ReservoirSamplingUpdate(UpdateStrategy):
138
- """
139
- Updates reference dataset for :term:`drift<Drift>` detector using reservoir sampling method.
140
-
141
- Parameters
142
- ----------
143
- n : int
144
- Update with last n instances seen by the detector.
145
- """
146
-
147
- def __call__(self, x_ref: NDArray, x: NDArray, count: int) -> NDArray:
148
- if x.shape[0] + count <= self.n:
149
- return np.concatenate([x_ref, x], axis=0)
150
-
151
- n_ref = x_ref.shape[0]
152
- output_size = min(self.n, n_ref + x.shape[0])
153
- shape = (output_size,) + x.shape[1:]
154
- x_reservoir = np.zeros(shape, dtype=x_ref.dtype)
155
- x_reservoir[:n_ref] = x_ref
156
- for item in x:
157
- count += 1
158
- if n_ref < self.n:
159
- x_reservoir[n_ref, :] = item
160
- n_ref += 1
161
- else:
162
- r = np.random.randint(0, count)
163
- if r < self.n:
164
- x_reservoir[r, :] = item
165
- return x_reservoir
166
-
167
-
168
130
  class BaseDrift:
169
131
  """
170
132
  A generic :term:`drift<Drift>` detection component for preprocessing data and applying statistical correction.
@@ -223,7 +185,7 @@ class BaseDrift:
223
185
  p_val: float = 0.05,
224
186
  x_ref_preprocessed: bool = False,
225
187
  update_x_ref: UpdateStrategy | None = None,
226
- preprocess_fn: Callable[[ArrayLike], ArrayLike] | None = None,
188
+ preprocess_fn: Callable[..., ArrayLike] | None = None,
227
189
  correction: Literal["bonferroni", "fdr"] = "bonferroni",
228
190
  ) -> None:
229
191
  # Type checking
@@ -235,20 +197,20 @@ class BaseDrift:
235
197
  raise ValueError("`correction` must be `bonferroni` or `fdr`.")
236
198
 
237
199
  self._x_ref = to_numpy(x_ref)
238
- self.x_ref_preprocessed = x_ref_preprocessed
200
+ self.x_ref_preprocessed: bool = x_ref_preprocessed
239
201
 
240
202
  # Other attributes
241
203
  self.p_val = p_val
242
204
  self.update_x_ref = update_x_ref
243
205
  self.preprocess_fn = preprocess_fn
244
206
  self.correction = correction
245
- self.n = len(self._x_ref)
207
+ self.n: int = len(self._x_ref)
246
208
 
247
209
  # Ref counter for preprocessed x
248
210
  self._x_refcount = 0
249
211
 
250
212
  @property
251
- def x_ref(self) -> NDArray:
213
+ def x_ref(self) -> NDArray[Any]:
252
214
  """
253
215
  Retrieve the reference data, applying preprocessing if not already done.
254
216
 
@@ -313,9 +275,6 @@ class BaseDriftUnivariate(BaseDrift):
313
275
 
314
276
  Attributes
315
277
  ----------
316
- _n_features : int | None
317
- Number of features in the data. If not provided, it is lazily inferred from the
318
- input data and any preprocessing function.
319
278
  p_val : float
320
279
  The significance level for drift detection.
321
280
  correction : str
@@ -324,17 +283,6 @@ class BaseDriftUnivariate(BaseDrift):
324
283
  Strategy for updating the reference data if applicable.
325
284
  preprocess_fn : Callable | None
326
285
  Function used for preprocessing input data before drift detection.
327
-
328
- Methods
329
- -------
330
- n_features:
331
- Property that returns the number of features, inferring it if necessary.
332
- score(x):
333
- Abstract method to compute univariate feature scores after preprocessing.
334
- _apply_correction(p_vals):
335
- Apply a statistical correction to p-values to account for multiple testing.
336
- predict(x):
337
- Predict whether drift has occurred on a batch of data, applying multivariate correction if needed.
338
286
  """
339
287
 
340
288
  def __init__(
@@ -427,7 +375,7 @@ class BaseDriftUnivariate(BaseDrift):
427
375
  return drift_pred, threshold
428
376
  elif self.correction == "fdr":
429
377
  n = p_vals.shape[0]
430
- i = np.arange(n) + 1
378
+ i = np.arange(n) + np.int_(1)
431
379
  p_sorted = np.sort(p_vals)
432
380
  q_threshold = self.p_val * i / n
433
381
  below_threshold = p_sorted < q_threshold
@@ -439,7 +387,7 @@ class BaseDriftUnivariate(BaseDrift):
439
387
  else:
440
388
  raise ValueError("`correction` needs to be either `bonferroni` or `fdr`.")
441
389
 
442
- @set_metadata("dataeval.detectors")
390
+ @set_metadata()
443
391
  @preprocess_x
444
392
  @update_x_ref
445
393
  def predict(
@@ -8,15 +8,16 @@ Licensed under Apache Software License (Apache 2.0)
8
8
 
9
9
  from __future__ import annotations
10
10
 
11
+ __all__ = ["DriftCVM"]
12
+
11
13
  from typing import Callable, Literal
12
14
 
13
15
  import numpy as np
14
16
  from numpy.typing import ArrayLike, NDArray
15
17
  from scipy.stats import cramervonmises_2samp
16
18
 
17
- from dataeval._internal.interop import to_numpy
18
-
19
- from .base import BaseDriftUnivariate, UpdateStrategy, preprocess_x
19
+ from dataeval.detectors.drift.base import BaseDriftUnivariate, UpdateStrategy, preprocess_x
20
+ from dataeval.interop import to_numpy
20
21
 
21
22
 
22
23
  class DriftCVM(BaseDriftUnivariate):
@@ -8,15 +8,16 @@ Licensed under Apache Software License (Apache 2.0)
8
8
 
9
9
  from __future__ import annotations
10
10
 
11
+ __all__ = ["DriftKS"]
12
+
11
13
  from typing import Callable, Literal
12
14
 
13
15
  import numpy as np
14
16
  from numpy.typing import ArrayLike, NDArray
15
17
  from scipy.stats import ks_2samp
16
18
 
17
- from dataeval._internal.interop import to_numpy
18
-
19
- from .base import BaseDriftUnivariate, UpdateStrategy, preprocess_x
19
+ from dataeval.detectors.drift.base import BaseDriftUnivariate, UpdateStrategy, preprocess_x
20
+ from dataeval.interop import to_numpy
20
21
 
21
22
 
22
23
  class DriftKS(BaseDriftUnivariate):
@@ -8,17 +8,18 @@ Licensed under Apache Software License (Apache 2.0)
8
8
 
9
9
  from __future__ import annotations
10
10
 
11
+ __all__ = ["DriftMMD", "DriftMMDOutput"]
12
+
11
13
  from dataclasses import dataclass
12
14
  from typing import Callable
13
15
 
14
16
  import torch
15
17
  from numpy.typing import ArrayLike
16
18
 
17
- from dataeval._internal.interop import as_numpy
18
- from dataeval._internal.output import set_metadata
19
-
20
- from .base import BaseDrift, DriftBaseOutput, UpdateStrategy, preprocess_x, update_x_ref
21
- from .torch import GaussianRBF, get_device, mmd2_from_kernel_matrix
19
+ from dataeval.detectors.drift.base import BaseDrift, DriftBaseOutput, UpdateStrategy, preprocess_x, update_x_ref
20
+ from dataeval.detectors.drift.torch import _GaussianRBF, _mmd2_from_kernel_matrix, get_device
21
+ from dataeval.interop import as_numpy
22
+ from dataeval.output import set_metadata
22
23
 
23
24
 
24
25
  @dataclass(frozen=True)
@@ -70,10 +71,8 @@ class DriftMMD(BaseDrift):
70
71
  preprocess_fn : Callable | None, default None
71
72
  Function to preprocess the data before computing the data drift metrics.
72
73
  Typically a :term:`dimensionality reduction<Dimensionality Reduction>` technique.
73
- kernel : Callable, default GaussianRBF
74
- Kernel used for the MMD computation, defaults to Gaussian RBF kernel.
75
74
  sigma : ArrayLike | None, default None
76
- Optionally set the GaussianRBF kernel bandwidth. Can also pass multiple
75
+ Optionally set the internal GaussianRBF kernel bandwidth. Can also pass multiple
77
76
  bandwidth values as an array. The kernel evaluation is then averaged over
78
77
  those bandwidths.
79
78
  configure_kernel_from_x_ref : bool, default True
@@ -91,41 +90,40 @@ class DriftMMD(BaseDrift):
91
90
  p_val: float = 0.05,
92
91
  x_ref_preprocessed: bool = False,
93
92
  update_x_ref: UpdateStrategy | None = None,
94
- preprocess_fn: Callable[[ArrayLike], ArrayLike] | None = None,
95
- kernel: Callable = GaussianRBF,
93
+ preprocess_fn: Callable[..., ArrayLike] | None = None,
96
94
  sigma: ArrayLike | None = None,
97
95
  configure_kernel_from_x_ref: bool = True,
98
96
  n_permutations: int = 100,
99
- device: str | None = None,
97
+ device: str | torch.device | None = None,
100
98
  ) -> None:
101
99
  super().__init__(x_ref, p_val, x_ref_preprocessed, update_x_ref, preprocess_fn)
102
100
 
103
- self.infer_sigma = configure_kernel_from_x_ref
101
+ self._infer_sigma = configure_kernel_from_x_ref
104
102
  if configure_kernel_from_x_ref and sigma is not None:
105
- self.infer_sigma = False
103
+ self._infer_sigma = False
106
104
 
107
105
  self.n_permutations = n_permutations # nb of iterations through permutation test
108
106
 
109
107
  # set device
110
- self.device = get_device(device)
108
+ self.device: torch.device = get_device(device)
111
109
 
112
110
  # initialize kernel
113
111
  sigma_tensor = torch.from_numpy(as_numpy(sigma)).to(self.device) if sigma is not None else None
114
- self.kernel = kernel(sigma_tensor).to(self.device) if kernel == GaussianRBF else kernel
112
+ self._kernel = _GaussianRBF(sigma_tensor).to(self.device)
115
113
 
116
114
  # compute kernel matrix for the reference data
117
- if self.infer_sigma or isinstance(sigma_tensor, torch.Tensor):
115
+ if self._infer_sigma or isinstance(sigma_tensor, torch.Tensor):
118
116
  x = torch.from_numpy(self.x_ref).to(self.device)
119
- self.k_xx = self.kernel(x, x, infer_sigma=self.infer_sigma)
120
- self.infer_sigma = False
117
+ self._k_xx = self._kernel(x, x, infer_sigma=self._infer_sigma)
118
+ self._infer_sigma = False
121
119
  else:
122
- self.k_xx, self.infer_sigma = None, True
120
+ self._k_xx, self._infer_sigma = None, True
123
121
 
124
122
  def _kernel_matrix(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
125
123
  """Compute and return full kernel matrix between arrays x and y."""
126
- k_xy = self.kernel(x, y, self.infer_sigma)
127
- k_xx = self.k_xx if self.k_xx is not None and self.update_x_ref is None else self.kernel(x, x)
128
- k_yy = self.kernel(y, y)
124
+ k_xy = self._kernel(x, y, self._infer_sigma)
125
+ k_xx = self._k_xx if self._k_xx is not None and self.update_x_ref is None else self._kernel(x, x)
126
+ k_yy = self._kernel(y, y)
129
127
  kernel_mat = torch.cat([torch.cat([k_xx, k_xy], 1), torch.cat([k_xy.T, k_yy], 1)], 0)
130
128
  return kernel_mat
131
129
 
@@ -152,9 +150,9 @@ class DriftMMD(BaseDrift):
152
150
  n = x.shape[0]
153
151
  kernel_mat = self._kernel_matrix(x_ref, torch.from_numpy(x).to(self.device))
154
152
  kernel_mat = kernel_mat - torch.diag(kernel_mat.diag()) # zero diagonal
155
- mmd2 = mmd2_from_kernel_matrix(kernel_mat, n, permute=False, zero_diag=False)
153
+ mmd2 = _mmd2_from_kernel_matrix(kernel_mat, n, permute=False, zero_diag=False)
156
154
  mmd2_permuted = torch.Tensor(
157
- [mmd2_from_kernel_matrix(kernel_mat, n, permute=True, zero_diag=False) for _ in range(self.n_permutations)]
155
+ [_mmd2_from_kernel_matrix(kernel_mat, n, permute=True, zero_diag=False) for _ in range(self.n_permutations)]
158
156
  )
159
157
  mmd2, mmd2_permuted = mmd2.detach().cpu(), mmd2_permuted.detach().cpu()
160
158
  p_val = (mmd2 <= mmd2_permuted).float().mean()
@@ -163,7 +161,7 @@ class DriftMMD(BaseDrift):
163
161
  distance_threshold = torch.sort(mmd2_permuted, descending=True).values[idx_threshold]
164
162
  return p_val.numpy().item(), mmd2.numpy().item(), distance_threshold.numpy().item()
165
163
 
166
- @set_metadata("dataeval.detectors")
164
+ @set_metadata()
167
165
  @preprocess_x
168
166
  @update_x_ref
169
167
  def predict(self, x: ArrayLike) -> DriftMMDOutput:
@@ -8,8 +8,10 @@ Licensed under Apache Software License (Apache 2.0)
8
8
 
9
9
  from __future__ import annotations
10
10
 
11
+ __all__ = []
12
+
11
13
  from functools import partial
12
- from typing import Callable
14
+ from typing import Any, Callable
13
15
 
14
16
  import numpy as np
15
17
  import torch
@@ -42,7 +44,7 @@ def get_device(device: str | torch.device | None = None) -> torch.device:
42
44
  return torch_device
43
45
 
44
46
 
45
- def mmd2_from_kernel_matrix(
47
+ def _mmd2_from_kernel_matrix(
46
48
  kernel_mat: torch.Tensor, m: int, permute: bool = False, zero_diag: bool = True
47
49
  ) -> torch.Tensor:
48
50
  """
@@ -78,13 +80,13 @@ def mmd2_from_kernel_matrix(
78
80
 
79
81
 
80
82
  def predict_batch(
81
- x: NDArray | torch.Tensor,
83
+ x: NDArray[Any] | torch.Tensor,
82
84
  model: Callable | nn.Module | nn.Sequential,
83
85
  device: torch.device | None = None,
84
86
  batch_size: int = int(1e10),
85
87
  preprocess_fn: Callable | None = None,
86
88
  dtype: type[np.generic] | torch.dtype = np.float32,
87
- ) -> NDArray | torch.Tensor | tuple:
89
+ ) -> NDArray[Any] | torch.Tensor | tuple[Any, ...]:
88
90
  """
89
91
  Make batch predictions on a model.
90
92
 
@@ -154,13 +156,13 @@ def predict_batch(
154
156
 
155
157
 
156
158
  def preprocess_drift(
157
- x: NDArray,
159
+ x: NDArray[Any],
158
160
  model: nn.Module,
159
- device: torch.device | None = None,
161
+ device: str | torch.device | None = None,
160
162
  preprocess_batch_fn: Callable | None = None,
161
163
  batch_size: int = int(1e10),
162
164
  dtype: type[np.generic] | torch.dtype = np.float32,
163
- ) -> NDArray | torch.Tensor | tuple:
165
+ ) -> NDArray[Any] | torch.Tensor | tuple[Any, ...]:
164
166
  """
165
167
  Prediction function used for preprocessing step of drift detector.
166
168
 
@@ -189,7 +191,7 @@ def preprocess_drift(
189
191
  return predict_batch(
190
192
  x,
191
193
  model,
192
- device=device,
194
+ device=get_device(device),
193
195
  batch_size=batch_size,
194
196
  preprocess_fn=preprocess_batch_fn,
195
197
  dtype=dtype,
@@ -197,7 +199,7 @@ def preprocess_drift(
197
199
 
198
200
 
199
201
  @torch.jit.script
200
- def squared_pairwise_distance(
202
+ def _squared_pairwise_distance(
201
203
  x: torch.Tensor, y: torch.Tensor, a_min: float = 1e-30
202
204
  ) -> torch.Tensor: # pragma: no cover - torch.jit.script code is compiled and copied
203
205
  """
@@ -249,7 +251,7 @@ def sigma_median(x: torch.Tensor, y: torch.Tensor, dist: torch.Tensor) -> torch.
249
251
  return sigma
250
252
 
251
253
 
252
- class GaussianRBF(nn.Module):
254
+ class _GaussianRBF(nn.Module):
253
255
  """
254
256
  Gaussian RBF kernel: k(x,y) = exp(-(1/(2*sigma^2)||x-y||^2).
255
257
 
@@ -303,7 +305,7 @@ class GaussianRBF(nn.Module):
303
305
  infer_sigma: bool = False,
304
306
  ) -> torch.Tensor:
305
307
  x, y = torch.as_tensor(x), torch.as_tensor(y)
306
- dist = squared_pairwise_distance(x.flatten(1), y.flatten(1)) # [Nx, Ny]
308
+ dist = _squared_pairwise_distance(x.flatten(1), y.flatten(1)) # [Nx, Ny]
307
309
 
308
310
  if infer_sigma or self.init_required:
309
311
  if self.trainable and infer_sigma:
@@ -8,6 +8,8 @@ Licensed under Apache Software License (Apache 2.0)
8
8
 
9
9
  from __future__ import annotations
10
10
 
11
+ __all__ = ["DriftUncertainty"]
12
+
11
13
  from functools import partial
12
14
  from typing import Callable, Literal
13
15
 
@@ -16,16 +18,16 @@ from numpy.typing import ArrayLike, NDArray
16
18
  from scipy.special import softmax
17
19
  from scipy.stats import entropy
18
20
 
19
- from .base import DriftOutput, UpdateStrategy
20
- from .ks import DriftKS
21
- from .torch import get_device, preprocess_drift
21
+ from dataeval.detectors.drift.base import DriftOutput, UpdateStrategy
22
+ from dataeval.detectors.drift.ks import DriftKS
23
+ from dataeval.detectors.drift.torch import get_device, preprocess_drift
22
24
 
23
25
 
24
26
  def classifier_uncertainty(
25
- x: NDArray,
27
+ x: NDArray[np.float64],
26
28
  model_fn: Callable,
27
29
  preds_type: Literal["probs", "logits"] = "probs",
28
- ) -> NDArray:
30
+ ) -> NDArray[np.float64]:
29
31
  """
30
32
  Evaluate model_fn on x and transform predictions to prediction uncertainties.
31
33
 
@@ -0,0 +1,61 @@
1
+ """
2
+ Update strategies inform how the :term:`drift<Drift>` detector classes update the reference data when monitoring
3
+ for drift.
4
+ """
5
+
6
+ from __future__ import annotations
7
+
8
+ __all__ = ["LastSeenUpdate", "ReservoirSamplingUpdate"]
9
+
10
+ from typing import Any
11
+
12
+ import numpy as np
13
+ from numpy.typing import NDArray
14
+
15
+ from dataeval.detectors.drift.base import UpdateStrategy
16
+
17
+
18
+ class LastSeenUpdate(UpdateStrategy):
19
+ """
20
+ Updates reference dataset for :term:`drift<Drift>` detector using last seen method.
21
+
22
+ Parameters
23
+ ----------
24
+ n : int
25
+ Update with last n instances seen by the detector.
26
+ """
27
+
28
+ def __call__(self, x_ref: NDArray[Any], x: NDArray[Any], count: int) -> NDArray[Any]:
29
+ x_updated = np.concatenate([x_ref, x], axis=0)
30
+ return x_updated[-self.n :]
31
+
32
+
33
+ class ReservoirSamplingUpdate(UpdateStrategy):
34
+ """
35
+ Updates reference dataset for :term:`drift<Drift>` detector using reservoir sampling method.
36
+
37
+ Parameters
38
+ ----------
39
+ n : int
40
+ Update with last n instances seen by the detector.
41
+ """
42
+
43
+ def __call__(self, x_ref: NDArray[Any], x: NDArray[Any], count: int) -> NDArray[Any]:
44
+ if x.shape[0] + count <= self.n:
45
+ return np.concatenate([x_ref, x], axis=0)
46
+
47
+ n_ref = x_ref.shape[0]
48
+ output_size = min(self.n, n_ref + x.shape[0])
49
+ shape = (output_size,) + x.shape[1:]
50
+ x_reservoir = np.zeros(shape, dtype=x_ref.dtype)
51
+ x_reservoir[:n_ref] = x_ref
52
+ for item in x:
53
+ count += 1
54
+ if n_ref < self.n:
55
+ x_reservoir[n_ref, :] = item
56
+ n_ref += 1
57
+ else:
58
+ r = np.random.randint(0, count)
59
+ if r < self.n:
60
+ x_reservoir[r, :] = item
61
+ return x_reservoir
@@ -2,9 +2,9 @@
2
2
  Linters help identify potential issues in training and test data and are an important aspect of data cleaning.
3
3
  """
4
4
 
5
- from dataeval._internal.detectors.clusterer import Clusterer, ClustererOutput
6
- from dataeval._internal.detectors.duplicates import Duplicates, DuplicatesOutput
7
- from dataeval._internal.detectors.outliers import Outliers, OutliersOutput
5
+ from dataeval.detectors.linters.clusterer import Clusterer, ClustererOutput
6
+ from dataeval.detectors.linters.duplicates import Duplicates, DuplicatesOutput
7
+ from dataeval.detectors.linters.outliers import Outliers, OutliersOutput
8
8
 
9
9
  __all__ = [
10
10
  "Clusterer",