dataeval 0.83.0__py3-none-any.whl → 0.84.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (50) hide show
  1. dataeval/__init__.py +1 -1
  2. dataeval/config.py +3 -3
  3. dataeval/detectors/drift/__init__.py +2 -2
  4. dataeval/detectors/drift/_base.py +55 -203
  5. dataeval/detectors/drift/_cvm.py +19 -30
  6. dataeval/detectors/drift/_ks.py +18 -30
  7. dataeval/detectors/drift/_mmd.py +189 -53
  8. dataeval/detectors/drift/_uncertainty.py +52 -56
  9. dataeval/detectors/drift/updates.py +13 -12
  10. dataeval/detectors/linters/duplicates.py +5 -3
  11. dataeval/detectors/linters/outliers.py +2 -2
  12. dataeval/detectors/ood/ae.py +1 -1
  13. dataeval/metrics/bias/__init__.py +11 -1
  14. dataeval/metrics/bias/_completeness.py +130 -0
  15. dataeval/metrics/stats/_base.py +28 -32
  16. dataeval/metrics/stats/_dimensionstats.py +2 -2
  17. dataeval/metrics/stats/_hashstats.py +2 -2
  18. dataeval/metrics/stats/_imagestats.py +4 -4
  19. dataeval/metrics/stats/_labelstats.py +4 -45
  20. dataeval/metrics/stats/_pixelstats.py +2 -2
  21. dataeval/metrics/stats/_visualstats.py +2 -2
  22. dataeval/outputs/__init__.py +2 -1
  23. dataeval/outputs/_bias.py +31 -22
  24. dataeval/outputs/_stats.py +2 -3
  25. dataeval/typing.py +25 -22
  26. dataeval/utils/_array.py +43 -7
  27. dataeval/utils/data/_dataset.py +8 -4
  28. dataeval/utils/data/_embeddings.py +141 -24
  29. dataeval/utils/data/_images.py +38 -15
  30. dataeval/utils/data/_metadata.py +5 -4
  31. dataeval/utils/data/_selection.py +3 -15
  32. dataeval/utils/data/_split.py +76 -129
  33. dataeval/utils/data/datasets/_base.py +7 -4
  34. dataeval/utils/data/datasets/_cifar10.py +9 -9
  35. dataeval/utils/data/datasets/_milco.py +42 -14
  36. dataeval/utils/data/datasets/_mnist.py +9 -5
  37. dataeval/utils/data/datasets/_ships.py +8 -4
  38. dataeval/utils/data/datasets/_voc.py +40 -19
  39. dataeval/utils/data/selections/__init__.py +2 -0
  40. dataeval/utils/data/selections/_classbalance.py +38 -0
  41. dataeval/utils/data/selections/_classfilter.py +14 -29
  42. dataeval/utils/data/selections/_prioritize.py +1 -1
  43. dataeval/utils/data/selections/_shuffle.py +2 -2
  44. dataeval/utils/metadata.py +1 -1
  45. dataeval/utils/torch/_internal.py +12 -35
  46. {dataeval-0.83.0.dist-info → dataeval-0.84.1.dist-info}/METADATA +2 -3
  47. {dataeval-0.83.0.dist-info → dataeval-0.84.1.dist-info}/RECORD +49 -48
  48. dataeval/detectors/drift/_torch.py +0 -222
  49. {dataeval-0.83.0.dist-info → dataeval-0.84.1.dist-info}/LICENSE.txt +0 -0
  50. {dataeval-0.83.0.dist-info → dataeval-0.84.1.dist-info}/WHEEL +0 -0
dataeval/__init__.py CHANGED
@@ -8,7 +8,7 @@ shifts that impact performance of deployed models.
8
8
  from __future__ import annotations
9
9
 
10
10
  __all__ = ["config", "detectors", "log", "metrics", "typing", "utils", "workflows"]
11
- __version__ = "0.83.0"
11
+ __version__ = "0.84.1"
12
12
 
13
13
  import logging
14
14
 
dataeval/config.py CHANGED
@@ -45,13 +45,13 @@ def _todevice(device: DeviceLike) -> torch.device:
45
45
  return torch.device(*device) if isinstance(device, tuple) else torch.device(device)
46
46
 
47
47
 
48
- def set_device(device: DeviceLike) -> None:
48
+ def set_device(device: DeviceLike | None) -> None:
49
49
  """
50
50
  Sets the default device to use when executing against a PyTorch backend.
51
51
 
52
52
  Parameters
53
53
  ----------
54
- device : DeviceLike
54
+ device : DeviceLike or None
55
55
  The default device to use. See documentation for more information.
56
56
 
57
57
  See Also
@@ -59,7 +59,7 @@ def set_device(device: DeviceLike) -> None:
59
59
  `torch.device <https://pytorch.org/docs/stable/tensor_attributes.html#torch.device>`_
60
60
  """
61
61
  global _device
62
- _device = _todevice(device)
62
+ _device = None if device is None else _todevice(device)
63
63
 
64
64
 
65
65
  def get_device(override: DeviceLike | None = None) -> torch.device:
@@ -9,14 +9,14 @@ __all__ = [
9
9
  "DriftMMDOutput",
10
10
  "DriftOutput",
11
11
  "DriftUncertainty",
12
- "preprocess_drift",
12
+ "UpdateStrategy",
13
13
  "updates",
14
14
  ]
15
15
 
16
16
  from dataeval.detectors.drift import updates
17
+ from dataeval.detectors.drift._base import UpdateStrategy
17
18
  from dataeval.detectors.drift._cvm import DriftCVM
18
19
  from dataeval.detectors.drift._ks import DriftKS
19
20
  from dataeval.detectors.drift._mmd import DriftMMD
20
- from dataeval.detectors.drift._torch import preprocess_drift
21
21
  from dataeval.detectors.drift._uncertainty import DriftUncertainty
22
22
  from dataeval.outputs._drift import DriftMMDOutput, DriftOutput
@@ -13,15 +13,16 @@ __all__ = []
13
13
  import math
14
14
  from abc import abstractmethod
15
15
  from functools import wraps
16
- from typing import Any, Callable, Literal, Protocol, TypeVar, runtime_checkable
16
+ from typing import Callable, Literal, Protocol, TypeVar, runtime_checkable
17
17
 
18
18
  import numpy as np
19
19
  from numpy.typing import NDArray
20
20
 
21
21
  from dataeval.outputs import DriftOutput
22
22
  from dataeval.outputs._base import set_metadata
23
- from dataeval.typing import Array, ArrayLike
24
- from dataeval.utils._array import as_numpy, to_numpy
23
+ from dataeval.typing import Array
24
+ from dataeval.utils._array import as_numpy, flatten
25
+ from dataeval.utils.data import Embeddings
25
26
 
26
27
  R = TypeVar("R")
27
28
 
@@ -32,220 +33,88 @@ class UpdateStrategy(Protocol):
32
33
  Protocol for reference dataset update strategy for drift detectors
33
34
  """
34
35
 
35
- def __call__(self, x_ref: NDArray[Any], x: NDArray[Any], count: int) -> NDArray[Any]: ...
36
+ def __call__(self, x_ref: NDArray[np.float32], x_new: NDArray[np.float32], count: int) -> NDArray[np.float32]: ...
36
37
 
37
38
 
38
- def update_x_ref(fn: Callable[..., R]) -> Callable[..., R]:
39
+ def update_strategy(fn: Callable[..., R]) -> Callable[..., R]:
39
40
  """Decorator to update x_ref with x using selected update methodology"""
40
41
 
41
42
  @wraps(fn)
42
- def _(self, x, *args, **kwargs) -> R:
43
- output = fn(self, x, *args, **kwargs)
43
+ def _(self: BaseDrift, data: Embeddings | Array, *args, **kwargs) -> R:
44
+ output = fn(self, data, *args, **kwargs)
44
45
 
45
46
  # update reference dataset
46
- if self.update_x_ref is not None:
47
- self._x_ref = self.update_x_ref(self.x_ref, x, self.n)
47
+ if self.update_strategy is not None:
48
+ self._x_ref = self.update_strategy(self.x_ref, self._encode(data), self.n)
49
+ self.n += len(data)
48
50
 
49
- # used for reservoir sampling
50
- self.n += len(x)
51
- return output
52
-
53
- return _
54
-
55
-
56
- def preprocess_x(fn: Callable[..., R]) -> Callable[..., R]:
57
- """Decorator to run preprocess_fn on x before calling wrapped function"""
58
-
59
- @wraps(fn)
60
- def _(self, x, *args, **kwargs) -> R:
61
- if self._x_refcount == 0:
62
- self._x = self._preprocess(x)
63
- self._x_refcount += 1
64
- output = fn(self, self._x, *args, **kwargs)
65
- self._x_refcount -= 1
66
- if self._x_refcount == 0:
67
- del self._x
68
51
  return output
69
52
 
70
53
  return _
71
54
 
72
55
 
73
56
  class BaseDrift:
74
- """
75
- A generic :term:`drift<Drift>` detection component for preprocessing data and applying statistical correction.
76
-
77
- This class handles common tasks related to drift detection, such as preprocessing
78
- the reference data (`x_ref`), performing statistical correction (e.g., Bonferroni, FDR),
79
- and updating the reference data if needed.
80
-
81
- Parameters
82
- ----------
83
- x_ref : ArrayLike
84
- The reference dataset used for drift detection. This is the baseline data against
85
- which new data points will be compared.
86
- p_val : float, optional
87
- The significance level for detecting drift, by default 0.05.
88
- x_ref_preprocessed : bool, optional
89
- Flag indicating whether the reference data has already been preprocessed, by default False.
90
- update_x_ref : UpdateStrategy, optional
91
- A strategy object specifying how the reference data should be updated when drift is detected,
92
- by default None.
93
- preprocess_fn : Callable[[ArrayLike], ArrayLike], optional
94
- A function to preprocess the data before drift detection, by default None.
95
- correction : {'bonferroni', 'fdr'}, optional
96
- Statistical correction method applied to p-values, by default "bonferroni".
97
-
98
- Attributes
99
- ----------
100
- _x_ref : ArrayLike
101
- The reference dataset that is either raw or preprocessed.
102
- p_val : float
103
- The significance level for drift detection.
104
- update_x_ref : UpdateStrategy or None
105
- The strategy for updating the reference data if applicable.
106
- preprocess_fn : Callable or None
107
- Function used for preprocessing input data before drift detection.
108
- correction : str
109
- Statistical correction method applied to p-values.
110
- n : int
111
- The number of samples in the reference dataset (`x_ref`).
112
- x_ref_preprocessed : bool
113
- A flag that indicates whether the reference dataset has been preprocessed.
114
- _x_refcount : int
115
- Counter for how many times the reference data has been accessed after preprocessing.
116
-
117
- Methods
118
- -------
119
- x_ref:
120
- Property that returns the reference dataset, and applies preprocessing if not already done.
121
- _preprocess(x):
122
- Preprocesses the given data using the specified `preprocess_fn` if provided.
123
- """
57
+ p_val: float
58
+ update_strategy: UpdateStrategy | None
59
+ correction: Literal["bonferroni", "fdr"]
60
+ n: int
124
61
 
125
62
  def __init__(
126
63
  self,
127
- x_ref: ArrayLike,
64
+ data: Embeddings | Array,
128
65
  p_val: float = 0.05,
129
- x_ref_preprocessed: bool = False,
130
- update_x_ref: UpdateStrategy | None = None,
131
- preprocess_fn: Callable[..., ArrayLike] | None = None,
66
+ update_strategy: UpdateStrategy | None = None,
132
67
  correction: Literal["bonferroni", "fdr"] = "bonferroni",
133
68
  ) -> None:
134
69
  # Type checking
135
- if preprocess_fn is not None and not isinstance(preprocess_fn, Callable):
136
- raise ValueError("`preprocess_fn` is not a valid Callable.")
137
- if update_x_ref is not None and not isinstance(update_x_ref, UpdateStrategy):
138
- raise ValueError("`update_x_ref` is not a valid ReferenceUpdate class.")
70
+ if update_strategy is not None and not isinstance(update_strategy, UpdateStrategy):
71
+ raise ValueError("`update_strategy` is not a valid UpdateStrategy class.")
139
72
  if correction not in ["bonferroni", "fdr"]:
140
73
  raise ValueError("`correction` must be `bonferroni` or `fdr`.")
141
74
 
142
- self._x_ref = x_ref
143
- self.x_ref_preprocessed: bool = x_ref_preprocessed
144
-
145
- # Other attributes
75
+ self._data = data
146
76
  self.p_val = p_val
147
- self.update_x_ref = update_x_ref
148
- self.preprocess_fn = preprocess_fn
77
+ self.update_strategy = update_strategy
149
78
  self.correction = correction
150
- self.n: int = len(x_ref)
79
+ self.n = len(data)
151
80
 
152
- # Ref counter for preprocessed x
153
- self._x_refcount = 0
81
+ self._x_ref: NDArray[np.float32] | None = None
154
82
 
155
83
  @property
156
- def x_ref(self) -> ArrayLike:
84
+ def x_ref(self) -> NDArray[np.float32]:
157
85
  """
158
- Retrieve the reference data, applying preprocessing if not already done.
86
+ Retrieve the reference data of the drift detector.
159
87
 
160
88
  Returns
161
89
  -------
162
- ArrayLike
163
- The reference dataset (`x_ref`), preprocessed if needed.
90
+ NDArray[np.float32]
91
+ The reference data as a 32-bit floating point numpy array.
164
92
  """
165
- if not self.x_ref_preprocessed:
166
- self.x_ref_preprocessed = True
167
- if self.preprocess_fn is not None:
168
- self._x_ref = self.preprocess_fn(self._x_ref)
169
-
93
+ if self._x_ref is None:
94
+ self._x_ref = self._encode(self._data)
170
95
  return self._x_ref
171
96
 
172
- def _preprocess(self, x: ArrayLike) -> ArrayLike:
173
- """
174
- Preprocess the given data before computing the :term:`drift<Drift>` scores.
175
-
176
- Parameters
177
- ----------
178
- x : ArrayLike
179
- The input data to preprocess.
180
-
181
- Returns
182
- -------
183
- ArrayLike
184
- The preprocessed input data.
185
- """
186
- if self.preprocess_fn is not None:
187
- x = self.preprocess_fn(x)
188
- return x
97
+ def _encode(self, data: Embeddings | Array) -> NDArray[np.float32]:
98
+ array = (
99
+ data.to_numpy().astype(np.float32)
100
+ if isinstance(data, Embeddings)
101
+ else self._data.new(data).to_numpy().astype(np.float32)
102
+ if isinstance(self._data, Embeddings)
103
+ else as_numpy(data).astype(np.float32)
104
+ )
105
+ return flatten(array)
189
106
 
190
107
 
191
108
  class BaseDriftUnivariate(BaseDrift):
192
- """
193
- Base class for :term:`drift<Drift>` detection methods using univariate statistical tests.
194
-
195
- This class inherits from `BaseDrift` and serves as a generic component for detecting
196
- distribution drift in univariate features. If the number of features `n_features` is greater
197
- than 1, a multivariate correction method (e.g., Bonferroni or FDR) is applied to control
198
- the :term:`false positive rate<False Positive Rate (FP)>`, ensuring it does not exceed the specified
199
- :term:`p-value<P-Value>`.
200
-
201
- Parameters
202
- ----------
203
- x_ref : ArrayLike
204
- Reference data used as the baseline to compare against when detecting drift.
205
- p_val : float, default 0.05
206
- Significance level used for detecting drift.
207
- x_ref_preprocessed : bool, default False
208
- Indicates whether the reference data has been preprocessed.
209
- update_x_ref : UpdateStrategy | None, default None
210
- Strategy for updating the reference data when drift is detected.
211
- preprocess_fn : Callable[ArrayLike] | None, default None
212
- Function used to preprocess input data before detecting drift.
213
- correction : 'bonferroni' | 'fdr', default 'bonferroni'
214
- Multivariate correction method applied to p-values.
215
- n_features : int | None, default None
216
- Number of features used in the univariate drift tests. If not provided, it will
217
- be inferred from the data.
218
-
219
- Attributes
220
- ----------
221
- p_val : float
222
- The significance level for drift detection.
223
- correction : str
224
- The method for controlling the :term:`False Discovery Rate (FDR)` or applying a Bonferroni correction.
225
- update_x_ref : UpdateStrategy | None
226
- Strategy for updating the reference data if applicable.
227
- preprocess_fn : Callable | None
228
- Function used for preprocessing input data before drift detection.
229
- """
230
-
231
109
  def __init__(
232
110
  self,
233
- x_ref: ArrayLike,
111
+ data: Embeddings | Array,
234
112
  p_val: float = 0.05,
235
- x_ref_preprocessed: bool = False,
236
- update_x_ref: UpdateStrategy | None = None,
237
- preprocess_fn: Callable[[ArrayLike], ArrayLike] | None = None,
113
+ update_strategy: UpdateStrategy | None = None,
238
114
  correction: Literal["bonferroni", "fdr"] = "bonferroni",
239
115
  n_features: int | None = None,
240
116
  ) -> None:
241
- super().__init__(
242
- x_ref,
243
- p_val,
244
- x_ref_preprocessed,
245
- update_x_ref,
246
- preprocess_fn,
247
- correction,
248
- )
117
+ super().__init__(data, p_val, update_strategy, correction)
249
118
 
250
119
  self._n_features = n_features
251
120
 
@@ -255,8 +124,7 @@ class BaseDriftUnivariate(BaseDrift):
255
124
  Get the number of features in the reference data.
256
125
 
257
126
  If the number of features is not provided during initialization, it will be inferred
258
- from the reference data (``x_ref``). If a preprocessing function is provided, the number
259
- of features will be inferred after applying the preprocessing function.
127
+ from the reference data (``x_ref``).
260
128
 
261
129
  Returns
262
130
  -------
@@ -264,48 +132,36 @@ class BaseDriftUnivariate(BaseDrift):
264
132
  Number of features in the reference data.
265
133
  """
266
134
  # lazy process n_features as needed
267
- if not isinstance(self._n_features, int):
268
- # compute number of features for the univariate tests
269
- x_ref = (
270
- self.x_ref
271
- if self.preprocess_fn is None or self.x_ref_preprocessed
272
- else self.preprocess_fn(self._x_ref[0:1])
273
- )
274
- # infer features from preprocessed reference data
275
- shape = x_ref.shape if isinstance(x_ref, Array) else as_numpy(x_ref).shape
276
- self._n_features = int(math.prod(shape[1:])) # Multiplies all channel sizes after first
135
+ if self._n_features is None:
136
+ self._n_features = int(math.prod(self._data[0].shape))
277
137
 
278
138
  return self._n_features
279
139
 
280
- @preprocess_x
281
- def score(self, x: ArrayLike) -> tuple[NDArray[np.float32], NDArray[np.float32]]:
140
+ def score(self, data: Embeddings | Array) -> tuple[NDArray[np.float32], NDArray[np.float32]]:
282
141
  """
283
142
  Calculates p-values and test statistics per feature.
284
143
 
285
144
  Parameters
286
145
  ----------
287
- x : ArrayLike
288
- Batch of instances
146
+ data : Embeddings or Array
147
+ Batch of instances to score.
289
148
 
290
149
  Returns
291
150
  -------
292
151
  tuple[NDArray, NDArray]
293
152
  Feature level p-values and test statistics
294
153
  """
295
- x_np = to_numpy(x)
296
- x_np = x_np.reshape(x_np.shape[0], -1)
297
- x_ref_np = as_numpy(self.x_ref)
298
- x_ref_np = x_ref_np.reshape(x_ref_np.shape[0], -1)
154
+ x_np = self._encode(data)
299
155
  p_val = np.zeros(self.n_features, dtype=np.float32)
300
156
  dist = np.zeros_like(p_val)
301
157
  for f in range(self.n_features):
302
- dist[f], p_val[f] = self._score_fn(x_ref_np[:, f], x_np[:, f])
158
+ dist[f], p_val[f] = self._score_fn(self.x_ref[:, f], x_np[:, f])
303
159
  return p_val, dist
304
160
 
305
161
  @abstractmethod
306
162
  def _score_fn(self, x: NDArray[np.float32], y: NDArray[np.float32]) -> tuple[np.float32, np.float32]: ...
307
163
 
308
- def _apply_correction(self, p_vals: NDArray) -> tuple[bool, float]:
164
+ def _apply_correction(self, p_vals: NDArray[np.float32]) -> tuple[bool, float]:
309
165
  """
310
166
  Apply the specified correction method (Bonferroni or FDR) to the p-values.
311
167
 
@@ -343,20 +199,16 @@ class BaseDriftUnivariate(BaseDrift):
343
199
  raise ValueError("`correction` needs to be either `bonferroni` or `fdr`.")
344
200
 
345
201
  @set_metadata
346
- @preprocess_x
347
- @update_x_ref
348
- def predict(
349
- self,
350
- x: ArrayLike,
351
- ) -> DriftOutput:
202
+ @update_strategy
203
+ def predict(self, data: Embeddings | Array) -> DriftOutput:
352
204
  """
353
205
  Predict whether a batch of data has drifted from the reference data and update
354
206
  reference data using specified update strategy.
355
207
 
356
208
  Parameters
357
209
  ----------
358
- x : ArrayLike
359
- Batch of instances.
210
+ data : Embeddings or Array
211
+ Batch of instances to predict drift on.
360
212
 
361
213
  Returns
362
214
  -------
@@ -365,7 +217,7 @@ class BaseDriftUnivariate(BaseDrift):
365
217
  p-values, threshold after multivariate correction if needed and test :term:`statistics<Statistics>`.
366
218
  """
367
219
  # compute drift scores
368
- p_vals, dist = self.score(x)
220
+ p_vals, dist = self.score(data)
369
221
 
370
222
  feature_drift = (p_vals < self.p_val).astype(np.bool_)
371
223
  drift_pred, threshold = self._apply_correction(p_vals)
@@ -10,14 +10,15 @@ from __future__ import annotations
10
10
 
11
11
  __all__ = []
12
12
 
13
- from typing import Callable, Literal
13
+ from typing import Literal
14
14
 
15
15
  import numpy as np
16
16
  from numpy.typing import NDArray
17
17
  from scipy.stats import cramervonmises_2samp
18
18
 
19
19
  from dataeval.detectors.drift._base import BaseDriftUnivariate, UpdateStrategy
20
- from dataeval.typing import ArrayLike
20
+ from dataeval.typing import Array
21
+ from dataeval.utils.data._embeddings import Embeddings
21
22
 
22
23
 
23
24
  class DriftCVM(BaseDriftUnivariate):
@@ -31,40 +32,32 @@ class DriftCVM(BaseDriftUnivariate):
31
32
 
32
33
  Parameters
33
34
  ----------
34
- x_ref : ArrayLike
35
+ data : Embeddings or Array
35
36
  Data used as reference distribution.
36
- p_val : float | None, default 0.05
37
+ p_val : float or None, default 0.05
37
38
  :term:`p-value<P-Value>` used for significance of the statistical test for each feature.
38
39
  If the FDR correction method is used, this corresponds to the acceptable
39
40
  q-value.
40
- x_ref_preprocessed : bool, default False
41
- Whether the given reference data ``x_ref`` has been preprocessed yet.
42
- If ``True``, only the test data ``x`` will be preprocessed at prediction time.
43
- If ``False``, the reference data will also be preprocessed.
44
- update_x_ref : UpdateStrategy | None, default None
41
+ update_strategy : UpdateStrategy or None, default None
45
42
  Reference data can optionally be updated using an UpdateStrategy class. Update
46
43
  using the last n instances seen by the detector with LastSeenUpdateStrategy
47
44
  or via reservoir sampling with ReservoirSamplingUpdateStrategy.
48
- preprocess_fn : Callable | None, default None
49
- Function to preprocess the data before computing the data drift metrics.
50
- Typically a :term:`dimensionality reduction<Dimensionality Reduction>` technique.
51
- correction : "bonferroni" | "fdr", default "bonferroni"
45
+ correction : "bonferroni" or "fdr", default "bonferroni"
52
46
  Correction type for multivariate data. Either 'bonferroni' or 'fdr' (False
53
47
  Discovery Rate).
54
- n_features : int | None, default None
55
- Number of features used in the statistical test. No need to pass it if no
56
- preprocessing takes place. In case of a preprocessing step, this can also
57
- be inferred automatically but could be more expensive to compute.
48
+ n_features : int or None, default None
49
+ Number of features used in the univariate drift tests. If not provided, it will
50
+ be inferred from the data.
51
+
58
52
 
59
53
  Example
60
54
  -------
61
- >>> from functools import partial
62
- >>> from dataeval.detectors.drift import preprocess_drift
55
+ >>> from dataeval.utils.data import Embeddings
63
56
 
64
- Use a preprocess function to encode images before testing for drift
57
+ Use Embeddings to encode images before testing for drift
65
58
 
66
- >>> preprocess_fn = partial(preprocess_drift, model=encoder, batch_size=64)
67
- >>> drift = DriftCVM(train_images, preprocess_fn=preprocess_fn)
59
+ >>> train_emb = Embeddings(train_images, model=encoder, batch_size=64)
60
+ >>> drift = DriftCVM(train_emb)
68
61
 
69
62
  Test incoming images for drift
70
63
 
@@ -74,20 +67,16 @@ class DriftCVM(BaseDriftUnivariate):
74
67
 
75
68
  def __init__(
76
69
  self,
77
- x_ref: ArrayLike,
70
+ data: Embeddings | Array,
78
71
  p_val: float = 0.05,
79
- x_ref_preprocessed: bool = False,
80
- update_x_ref: UpdateStrategy | None = None,
81
- preprocess_fn: Callable[[ArrayLike], ArrayLike] | None = None,
72
+ update_strategy: UpdateStrategy | None = None,
82
73
  correction: Literal["bonferroni", "fdr"] = "bonferroni",
83
74
  n_features: int | None = None,
84
75
  ) -> None:
85
76
  super().__init__(
86
- x_ref=x_ref,
77
+ data=data,
87
78
  p_val=p_val,
88
- x_ref_preprocessed=x_ref_preprocessed,
89
- update_x_ref=update_x_ref,
90
- preprocess_fn=preprocess_fn,
79
+ update_strategy=update_strategy,
91
80
  correction=correction,
92
81
  n_features=n_features,
93
82
  )
@@ -10,14 +10,15 @@ from __future__ import annotations
10
10
 
11
11
  __all__ = []
12
12
 
13
- from typing import Callable, Literal
13
+ from typing import Literal
14
14
 
15
15
  import numpy as np
16
16
  from numpy.typing import NDArray
17
17
  from scipy.stats import ks_2samp
18
18
 
19
19
  from dataeval.detectors.drift._base import BaseDriftUnivariate, UpdateStrategy
20
- from dataeval.typing import ArrayLike
20
+ from dataeval.typing import Array
21
+ from dataeval.utils.data._embeddings import Embeddings
21
22
 
22
23
 
23
24
  class DriftKS(BaseDriftUnivariate):
@@ -31,43 +32,34 @@ class DriftKS(BaseDriftUnivariate):
31
32
 
32
33
  Parameters
33
34
  ----------
34
- x_ref : ArrayLike
35
+ data : Embeddings or Array
35
36
  Data used as reference distribution.
36
- p_val : float | None, default 0.05
37
+ p_val : float or None, default 0.05
37
38
  :term:`p-value<P-Value>` used for significance of the statistical test for each feature.
38
39
  If the FDR correction method is used, this corresponds to the acceptable
39
40
  q-value.
40
- x_ref_preprocessed : bool, default False
41
- Whether the given reference data ``x_ref`` has been preprocessed yet.
42
- If ``True``, only the test data ``x`` will be preprocessed at prediction time.
43
- If ``False``, the reference data will also be preprocessed.
44
- update_x_ref : UpdateStrategy | None, default None
41
+ update_strategy : UpdateStrategy or None, default None
45
42
  Reference data can optionally be updated using an UpdateStrategy class. Update
46
43
  using the last n instances seen by the detector with LastSeenUpdateStrategy
47
44
  or via reservoir sampling with ReservoirSamplingUpdateStrategy.
48
- preprocess_fn : Callable | None, default None
49
- Function to preprocess the data before computing the data :term:`drift<Drift>` metrics.
50
- Typically a :term:`dimensionality reduction<Dimensionality Reduction>` technique.
51
- correction : "bonferroni" | "fdr", default "bonferroni"
45
+ correction : "bonferroni" or "fdr", default "bonferroni"
52
46
  Correction type for multivariate data. Either 'bonferroni' or 'fdr' (False
53
47
  Discovery Rate).
54
- alternative : "two-sided" | "less" | "greater", default "two-sided"
48
+ alternative : "two-sided", "less" or "greater", default "two-sided"
55
49
  Defines the alternative hypothesis. Options are 'two-sided', 'less' or
56
50
  'greater'.
57
51
  n_features : int | None, default None
58
- Number of features used in the statistical test. No need to pass it if no
59
- preprocessing takes place. In case of a preprocessing step, this can also
60
- be inferred automatically but could be more expensive to compute.
52
+ Number of features used in the univariate drift tests. If not provided, it will
53
+ be inferred from the data.
61
54
 
62
55
  Example
63
56
  -------
64
- >>> from functools import partial
65
- >>> from dataeval.detectors.drift import preprocess_drift
57
+ >>> from dataeval.utils.data import Embeddings
66
58
 
67
- Use a preprocess function to encode images before testing for drift
59
+ Use Embeddings to encode images before testing for drift
68
60
 
69
- >>> preprocess_fn = partial(preprocess_drift, model=encoder, batch_size=64)
70
- >>> drift = DriftKS(train_images, preprocess_fn=preprocess_fn)
61
+ >>> train_emb = Embeddings(train_images, model=encoder, batch_size=64)
62
+ >>> drift = DriftKS(train_emb)
71
63
 
72
64
  Test incoming images for drift
73
65
 
@@ -77,21 +69,17 @@ class DriftKS(BaseDriftUnivariate):
77
69
 
78
70
  def __init__(
79
71
  self,
80
- x_ref: ArrayLike,
72
+ data: Embeddings | Array,
81
73
  p_val: float = 0.05,
82
- x_ref_preprocessed: bool = False,
83
- update_x_ref: UpdateStrategy | None = None,
84
- preprocess_fn: Callable[[ArrayLike], ArrayLike] | None = None,
74
+ update_strategy: UpdateStrategy | None = None,
85
75
  correction: Literal["bonferroni", "fdr"] = "bonferroni",
86
76
  alternative: Literal["two-sided", "less", "greater"] = "two-sided",
87
77
  n_features: int | None = None,
88
78
  ) -> None:
89
79
  super().__init__(
90
- x_ref=x_ref,
80
+ data=data,
91
81
  p_val=p_val,
92
- x_ref_preprocessed=x_ref_preprocessed,
93
- update_x_ref=update_x_ref,
94
- preprocess_fn=preprocess_fn,
82
+ update_strategy=update_strategy,
95
83
  correction=correction,
96
84
  n_features=n_features,
97
85
  )