dataeval 0.84.0__py3-none-any.whl → 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (67) hide show
  1. dataeval/__init__.py +1 -1
  2. dataeval/data/__init__.py +19 -0
  3. dataeval/data/_embeddings.py +345 -0
  4. dataeval/{utils/data → data}/_images.py +2 -2
  5. dataeval/{utils/data → data}/_metadata.py +8 -7
  6. dataeval/{utils/data → data}/_selection.py +22 -9
  7. dataeval/{utils/data → data}/_split.py +1 -1
  8. dataeval/data/selections/__init__.py +19 -0
  9. dataeval/data/selections/_classbalance.py +37 -0
  10. dataeval/data/selections/_classfilter.py +109 -0
  11. dataeval/{utils/data → data}/selections/_indices.py +1 -1
  12. dataeval/{utils/data → data}/selections/_limit.py +1 -1
  13. dataeval/{utils/data → data}/selections/_prioritize.py +3 -3
  14. dataeval/{utils/data → data}/selections/_reverse.py +1 -1
  15. dataeval/{utils/data → data}/selections/_shuffle.py +3 -3
  16. dataeval/detectors/drift/__init__.py +2 -2
  17. dataeval/detectors/drift/_base.py +55 -203
  18. dataeval/detectors/drift/_cvm.py +19 -30
  19. dataeval/detectors/drift/_ks.py +18 -30
  20. dataeval/detectors/drift/_mmd.py +189 -53
  21. dataeval/detectors/drift/_uncertainty.py +52 -56
  22. dataeval/detectors/drift/updates.py +13 -12
  23. dataeval/detectors/linters/duplicates.py +6 -4
  24. dataeval/detectors/linters/outliers.py +3 -3
  25. dataeval/detectors/ood/ae.py +1 -1
  26. dataeval/metadata/_distance.py +1 -1
  27. dataeval/metadata/_ood.py +4 -4
  28. dataeval/metrics/bias/_balance.py +1 -1
  29. dataeval/metrics/bias/_diversity.py +1 -1
  30. dataeval/metrics/bias/_parity.py +1 -1
  31. dataeval/metrics/stats/_base.py +7 -7
  32. dataeval/metrics/stats/_dimensionstats.py +2 -2
  33. dataeval/metrics/stats/_hashstats.py +2 -2
  34. dataeval/metrics/stats/_imagestats.py +4 -4
  35. dataeval/metrics/stats/_labelstats.py +2 -2
  36. dataeval/metrics/stats/_pixelstats.py +2 -2
  37. dataeval/metrics/stats/_visualstats.py +2 -2
  38. dataeval/outputs/_bias.py +1 -1
  39. dataeval/typing.py +53 -19
  40. dataeval/utils/__init__.py +2 -2
  41. dataeval/utils/_array.py +18 -7
  42. dataeval/utils/data/__init__.py +5 -20
  43. dataeval/utils/data/_dataset.py +6 -4
  44. dataeval/utils/data/collate.py +2 -0
  45. dataeval/utils/datasets/__init__.py +17 -0
  46. dataeval/utils/{data/datasets → datasets}/_base.py +10 -7
  47. dataeval/utils/{data/datasets → datasets}/_cifar10.py +11 -11
  48. dataeval/utils/{data/datasets → datasets}/_milco.py +44 -16
  49. dataeval/utils/{data/datasets → datasets}/_mnist.py +11 -7
  50. dataeval/utils/{data/datasets → datasets}/_ships.py +10 -6
  51. dataeval/utils/{data/datasets → datasets}/_voc.py +43 -22
  52. dataeval/utils/torch/_internal.py +12 -35
  53. {dataeval-0.84.0.dist-info → dataeval-1.0.0.dist-info}/METADATA +2 -3
  54. dataeval-1.0.0.dist-info/RECORD +107 -0
  55. dataeval/detectors/drift/_torch.py +0 -222
  56. dataeval/utils/data/_embeddings.py +0 -186
  57. dataeval/utils/data/datasets/__init__.py +0 -17
  58. dataeval/utils/data/selections/__init__.py +0 -17
  59. dataeval/utils/data/selections/_classfilter.py +0 -59
  60. dataeval-0.84.0.dist-info/RECORD +0 -106
  61. /dataeval/{utils/data → data}/_targets.py +0 -0
  62. /dataeval/utils/{metadata.py → data/metadata.py} +0 -0
  63. /dataeval/utils/{data/datasets → datasets}/_fileio.py +0 -0
  64. /dataeval/utils/{data/datasets → datasets}/_mixin.py +0 -0
  65. /dataeval/utils/{data/datasets → datasets}/_types.py +0 -0
  66. {dataeval-0.84.0.dist-info → dataeval-1.0.0.dist-info}/LICENSE.txt +0 -0
  67. {dataeval-0.84.0.dist-info → dataeval-1.0.0.dist-info}/WHEEL +0 -0
@@ -10,14 +10,15 @@ from __future__ import annotations
10
10
 
11
11
  __all__ = []
12
12
 
13
- from typing import Callable, Literal
13
+ from typing import Literal
14
14
 
15
15
  import numpy as np
16
16
  from numpy.typing import NDArray
17
17
  from scipy.stats import ks_2samp
18
18
 
19
+ from dataeval.data._embeddings import Embeddings
19
20
  from dataeval.detectors.drift._base import BaseDriftUnivariate, UpdateStrategy
20
- from dataeval.typing import ArrayLike
21
+ from dataeval.typing import Array
21
22
 
22
23
 
23
24
  class DriftKS(BaseDriftUnivariate):
@@ -31,43 +32,34 @@ class DriftKS(BaseDriftUnivariate):
31
32
 
32
33
  Parameters
33
34
  ----------
34
- x_ref : ArrayLike
35
+ data : Embeddings or Array
35
36
  Data used as reference distribution.
36
- p_val : float | None, default 0.05
37
+ p_val : float or None, default 0.05
37
38
  :term:`p-value<P-Value>` used for significance of the statistical test for each feature.
38
39
  If the FDR correction method is used, this corresponds to the acceptable
39
40
  q-value.
40
- x_ref_preprocessed : bool, default False
41
- Whether the given reference data ``x_ref`` has been preprocessed yet.
42
- If ``True``, only the test data ``x`` will be preprocessed at prediction time.
43
- If ``False``, the reference data will also be preprocessed.
44
- update_x_ref : UpdateStrategy | None, default None
41
+ update_strategy : UpdateStrategy or None, default None
45
42
  Reference data can optionally be updated using an UpdateStrategy class. Update
46
43
  using the last n instances seen by the detector with LastSeenUpdateStrategy
47
44
  or via reservoir sampling with ReservoirSamplingUpdateStrategy.
48
- preprocess_fn : Callable | None, default None
49
- Function to preprocess the data before computing the data :term:`drift<Drift>` metrics.
50
- Typically a :term:`dimensionality reduction<Dimensionality Reduction>` technique.
51
- correction : "bonferroni" | "fdr", default "bonferroni"
45
+ correction : "bonferroni" or "fdr", default "bonferroni"
52
46
  Correction type for multivariate data. Either 'bonferroni' or 'fdr' (False
53
47
  Discovery Rate).
54
- alternative : "two-sided" | "less" | "greater", default "two-sided"
48
+ alternative : "two-sided", "less" or "greater", default "two-sided"
55
49
  Defines the alternative hypothesis. Options are 'two-sided', 'less' or
56
50
  'greater'.
57
51
  n_features : int | None, default None
58
- Number of features used in the statistical test. No need to pass it if no
59
- preprocessing takes place. In case of a preprocessing step, this can also
60
- be inferred automatically but could be more expensive to compute.
52
+ Number of features used in the univariate drift tests. If not provided, it will
53
+ be inferred from the data.
61
54
 
62
55
  Example
63
56
  -------
64
- >>> from functools import partial
65
- >>> from dataeval.detectors.drift import preprocess_drift
57
+ >>> from dataeval.data import Embeddings
66
58
 
67
- Use a preprocess function to encode images before testing for drift
59
+ Use Embeddings to encode images before testing for drift
68
60
 
69
- >>> preprocess_fn = partial(preprocess_drift, model=encoder, batch_size=64)
70
- >>> drift = DriftKS(train_images, preprocess_fn=preprocess_fn)
61
+ >>> train_emb = Embeddings(train_images, model=encoder, batch_size=64)
62
+ >>> drift = DriftKS(train_emb)
71
63
 
72
64
  Test incoming images for drift
73
65
 
@@ -77,21 +69,17 @@ class DriftKS(BaseDriftUnivariate):
77
69
 
78
70
  def __init__(
79
71
  self,
80
- x_ref: ArrayLike,
72
+ data: Embeddings | Array,
81
73
  p_val: float = 0.05,
82
- x_ref_preprocessed: bool = False,
83
- update_x_ref: UpdateStrategy | None = None,
84
- preprocess_fn: Callable[[ArrayLike], ArrayLike] | None = None,
74
+ update_strategy: UpdateStrategy | None = None,
85
75
  correction: Literal["bonferroni", "fdr"] = "bonferroni",
86
76
  alternative: Literal["two-sided", "less", "greater"] = "two-sided",
87
77
  n_features: int | None = None,
88
78
  ) -> None:
89
79
  super().__init__(
90
- x_ref=x_ref,
80
+ data=data,
91
81
  p_val=p_val,
92
- x_ref_preprocessed=x_ref_preprocessed,
93
- update_x_ref=update_x_ref,
94
- preprocess_fn=preprocess_fn,
82
+ update_strategy=update_strategy,
95
83
  correction=correction,
96
84
  n_features=n_features,
97
85
  )
@@ -10,16 +10,16 @@ from __future__ import annotations
10
10
 
11
11
  __all__ = []
12
12
 
13
- from typing import Callable
13
+ from typing import Any, Callable
14
14
 
15
15
  import torch
16
16
 
17
17
  from dataeval.config import DeviceLike, get_device
18
- from dataeval.detectors.drift._base import BaseDrift, UpdateStrategy, preprocess_x, update_x_ref
19
- from dataeval.detectors.drift._torch import GaussianRBF, mmd2_from_kernel_matrix
18
+ from dataeval.data._embeddings import Embeddings
19
+ from dataeval.detectors.drift._base import BaseDrift, UpdateStrategy, update_strategy
20
20
  from dataeval.outputs import DriftMMDOutput
21
21
  from dataeval.outputs._base import set_metadata
22
- from dataeval.typing import ArrayLike
22
+ from dataeval.typing import Array
23
23
 
24
24
 
25
25
  class DriftMMD(BaseDrift):
@@ -29,29 +29,20 @@ class DriftMMD(BaseDrift):
29
29
 
30
30
  Parameters
31
31
  ----------
32
- x_ref : ArrayLike
32
+ data : Embeddings or Array
33
33
  Data used as reference distribution.
34
34
  p_val : float or None, default 0.05
35
35
  :term:`P-value` used for significance of the statistical test for each feature.
36
36
  If the FDR correction method is used, this corresponds to the acceptable
37
37
  q-value.
38
- x_ref_preprocessed : bool, default False
39
- Whether the given reference data ``x_ref`` has been preprocessed yet.
40
- If ``True``, only the test data ``x`` will be preprocessed at prediction time.
41
- If ``False``, the reference data will also be preprocessed.
42
- update_x_ref : UpdateStrategy or None, default None
38
+ update_strategy : UpdateStrategy or None, default None
43
39
  Reference data can optionally be updated using an UpdateStrategy class. Update
44
40
  using the last n instances seen by the detector with LastSeenUpdateStrategy
45
41
  or via reservoir sampling with ReservoirSamplingUpdateStrategy.
46
- preprocess_fn : Callable or None, default None
47
- Function to preprocess the data before computing the data drift metrics.
48
- Typically a :term:`dimensionality reduction<Dimensionality Reduction>` technique.
49
- sigma : ArrayLike or None, default None
42
+ sigma : Array or None, default None
50
43
  Optionally set the internal GaussianRBF kernel bandwidth. Can also pass multiple
51
44
  bandwidth values as an array. The kernel evaluation is then averaged over
52
45
  those bandwidths.
53
- configure_kernel_from_x_ref : bool, default True
54
- Whether to already configure the kernel bandwidth from the reference data.
55
46
  n_permutations : int, default 100
56
47
  Number of permutations used in the permutation test.
57
48
  device : DeviceLike or None, default None
@@ -60,13 +51,12 @@ class DriftMMD(BaseDrift):
60
51
 
61
52
  Example
62
53
  -------
63
- >>> from functools import partial
64
- >>> from dataeval.detectors.drift import preprocess_drift
54
+ >>> from dataeval.data import Embeddings
65
55
 
66
- Use a preprocess function to encode images before testing for drift
56
+ Use Embeddings to encode images before testing for drift
67
57
 
68
- >>> preprocess_fn = partial(preprocess_drift, model=encoder, batch_size=64)
69
- >>> drift = DriftMMD(train_images, preprocess_fn=preprocess_fn)
58
+ >>> train_emb = Embeddings(train_images, model=encoder, batch_size=64)
59
+ >>> drift = DriftMMD(train_emb)
70
60
 
71
61
  Test incoming images for drift
72
62
 
@@ -76,21 +66,14 @@ class DriftMMD(BaseDrift):
76
66
 
77
67
  def __init__(
78
68
  self,
79
- x_ref: ArrayLike,
69
+ data: Embeddings | Array,
80
70
  p_val: float = 0.05,
81
- x_ref_preprocessed: bool = False,
82
- update_x_ref: UpdateStrategy | None = None,
83
- preprocess_fn: Callable[..., ArrayLike] | None = None,
84
- sigma: ArrayLike | None = None,
85
- configure_kernel_from_x_ref: bool = True,
71
+ update_strategy: UpdateStrategy | None = None,
72
+ sigma: Array | None = None,
86
73
  n_permutations: int = 100,
87
74
  device: DeviceLike | None = None,
88
75
  ) -> None:
89
- super().__init__(x_ref, p_val, x_ref_preprocessed, update_x_ref, preprocess_fn)
90
-
91
- self._infer_sigma = configure_kernel_from_x_ref
92
- if configure_kernel_from_x_ref and sigma is not None:
93
- self._infer_sigma = False
76
+ super().__init__(data, p_val, update_strategy)
94
77
 
95
78
  self.n_permutations = n_permutations # nb of iterations through permutation test
96
79
 
@@ -102,23 +85,20 @@ class DriftMMD(BaseDrift):
102
85
  self._kernel = GaussianRBF(sigma_tensor).to(self.device)
103
86
 
104
87
  # compute kernel matrix for the reference data
105
- if self._infer_sigma or isinstance(sigma_tensor, torch.Tensor):
106
- x = torch.as_tensor(self.x_ref, device=self.device)
107
- self._k_xx = self._kernel(x, x, infer_sigma=self._infer_sigma)
108
- self._infer_sigma = False
88
+ if isinstance(sigma_tensor, torch.Tensor):
89
+ self._k_xx = self._kernel(self.x_ref, self.x_ref)
109
90
  else:
110
- self._k_xx, self._infer_sigma = None, True
91
+ self._k_xx = None
111
92
 
112
- def _kernel_matrix(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
93
+ def _kernel_matrix(self, x: Array, y: Array) -> torch.Tensor:
113
94
  """Compute and return full kernel matrix between arrays x and y."""
114
- k_xy = self._kernel(x, y, self._infer_sigma)
115
- k_xx = self._k_xx if self._k_xx is not None and self.update_x_ref is None else self._kernel(x, x)
95
+ k_xy = self._kernel(x, y)
96
+ k_xx = self._k_xx if self._k_xx is not None and self.update_strategy is None else self._kernel(x, x)
116
97
  k_yy = self._kernel(y, y)
117
98
  kernel_mat = torch.cat([torch.cat([k_xx, k_xy], 1), torch.cat([k_xy.T, k_yy], 1)], 0)
118
99
  return kernel_mat
119
100
 
120
- @preprocess_x
121
- def score(self, x: ArrayLike) -> tuple[float, float, float]:
101
+ def score(self, data: Embeddings | Array) -> tuple[float, float, float]:
122
102
  """
123
103
  Compute the :term:`p-value<P-Value>` resulting from a permutation test using the maximum mean
124
104
  discrepancy as a distance measure between the reference data and the data to
@@ -126,8 +106,8 @@ class DriftMMD(BaseDrift):
126
106
 
127
107
  Parameters
128
108
  ----------
129
- x : ArrayLike
130
- Batch of instances.
109
+ data : Embeddings or Array
110
+ Batch of instances to score.
131
111
 
132
112
  Returns
133
113
  -------
@@ -135,10 +115,9 @@ class DriftMMD(BaseDrift):
135
115
  p-value obtained from the permutation test, MMD^2 between the reference and test set,
136
116
  and MMD^2 threshold above which :term:`drift<Drift>` is flagged
137
117
  """
138
- x_ref = torch.as_tensor(self.x_ref, device=self.device)
139
- x_test = torch.as_tensor(x, device=self.device)
118
+ x_test = self._encode(data)
140
119
  n = x_test.shape[0]
141
- kernel_mat = self._kernel_matrix(x_ref, x_test)
120
+ kernel_mat = self._kernel_matrix(self.x_ref, x_test)
142
121
  kernel_mat = kernel_mat - torch.diag(kernel_mat.diag()) # zero diagonal
143
122
  mmd2 = mmd2_from_kernel_matrix(kernel_mat, n, permute=False, zero_diag=False)
144
123
  mmd2_permuted = torch.tensor(
@@ -152,17 +131,16 @@ class DriftMMD(BaseDrift):
152
131
  return float(p_val.item()), float(mmd2.item()), float(distance_threshold.item())
153
132
 
154
133
  @set_metadata
155
- @preprocess_x
156
- @update_x_ref
157
- def predict(self, x: ArrayLike) -> DriftMMDOutput:
134
+ @update_strategy
135
+ def predict(self, data: Embeddings | Array) -> DriftMMDOutput:
158
136
  """
159
137
  Predict whether a batch of data has drifted from the reference data and then
160
138
  updates reference data using specified strategy.
161
139
 
162
140
  Parameters
163
141
  ----------
164
- x : ArrayLike
165
- Batch of instances.
142
+ data : Embeddings or Array
143
+ Batch of instances to predict drift on.
166
144
 
167
145
  Returns
168
146
  -------
@@ -171,8 +149,166 @@ class DriftMMD(BaseDrift):
171
149
  threshold and MMD metric.
172
150
  """
173
151
  # compute drift scores
174
- p_val, dist, distance_threshold = self.score(x)
152
+ p_val, dist, distance_threshold = self.score(data)
175
153
  drift_pred = bool(p_val < self.p_val)
176
154
 
177
155
  # populate drift dict
178
156
  return DriftMMDOutput(drift_pred, self.p_val, p_val, dist, distance_threshold)
157
+
158
+
159
+ @torch.jit.script
160
+ def _squared_pairwise_distance(
161
+ x: torch.Tensor, y: torch.Tensor, a_min: float = 1e-30
162
+ ) -> torch.Tensor: # pragma: no cover - torch.jit.script code is compiled and copied
163
+ """
164
+ PyTorch pairwise squared Euclidean distance between samples x and y.
165
+
166
+ Parameters
167
+ ----------
168
+ x : torch.Tensor
169
+ Batch of instances of shape [Nx, features].
170
+ y : torch.Tensor
171
+ Batch of instances of shape [Ny, features].
172
+ a_min : float
173
+ Lower bound to clip distance values.
174
+
175
+ Returns
176
+ -------
177
+ torch.Tensor
178
+ Pairwise squared Euclidean distance [Nx, Ny].
179
+ """
180
+ x2 = x.pow(2).sum(dim=-1, keepdim=True)
181
+ y2 = y.pow(2).sum(dim=-1, keepdim=True)
182
+ dist = torch.addmm(y2.transpose(-2, -1), x, y.transpose(-2, -1), alpha=-2).add_(x2)
183
+ return dist.clamp_min_(a_min)
184
+
185
+
186
+ def sigma_median(x: torch.Tensor, y: torch.Tensor, dist: torch.Tensor) -> torch.Tensor:
187
+ """
188
+ Bandwidth estimation using the median heuristic `Gretton2012`
189
+
190
+ Parameters
191
+ ----------
192
+ x : torch.Tensor
193
+ Tensor of instances with dimension [Nx, features].
194
+ y : torch.Tensor
195
+ Tensor of instances with dimension [Ny, features].
196
+ dist : torch.Tensor
197
+ Tensor with dimensions [Nx, Ny], containing the pairwise distances
198
+ between `x` and `y`.
199
+
200
+ Returns
201
+ -------
202
+ torch.Tensor
203
+ The computed bandwidth, `sigma`.
204
+ """
205
+ n = min(x.shape[0], y.shape[0])
206
+ n = n if (x[:n] == y[:n]).all() and x.shape == y.shape else 0
207
+ n_median = n + (torch.prod(torch.as_tensor(dist.shape)) - n) // 2 - 1
208
+ sigma = (0.5 * dist.flatten().sort().values[int(n_median)].unsqueeze(dim=-1)) ** 0.5
209
+ return sigma
210
+
211
+
212
+ class GaussianRBF(torch.nn.Module):
213
+ """
214
+ Gaussian RBF kernel: k(x,y) = exp(-(1/(2*sigma^2)||x-y||^2).
215
+
216
+ A forward pass takes a batch of instances x [Nx, features] and
217
+ y [Ny, features] and returns the kernel matrix [Nx, Ny].
218
+
219
+ Parameters
220
+ ----------
221
+ sigma : torch.Tensor | None, default None
222
+ Bandwidth used for the kernel. Needn't be specified if being inferred or
223
+ trained. Can pass multiple values to eval kernel with and then average.
224
+ init_sigma_fn : Callable | None, default None
225
+ Function used to compute the bandwidth ``sigma``. Used when ``sigma`` is to be
226
+ inferred. The function's signature should take in the tensors ``x``, ``y`` and
227
+ ``dist`` and return ``sigma``. If ``None``, it is set to ``sigma_median``.
228
+ trainable : bool, default False
229
+ Whether or not to track gradients w.r.t. `sigma` to allow it to be trained.
230
+ """
231
+
232
+ def __init__(
233
+ self,
234
+ sigma: torch.Tensor | None = None,
235
+ init_sigma_fn: Callable | None = None,
236
+ trainable: bool = False,
237
+ ) -> None:
238
+ super().__init__()
239
+ init_sigma_fn = sigma_median if init_sigma_fn is None else init_sigma_fn
240
+ self.config: dict[str, Any] = {
241
+ "sigma": sigma,
242
+ "trainable": trainable,
243
+ "init_sigma_fn": init_sigma_fn,
244
+ }
245
+ if sigma is None:
246
+ self.log_sigma: torch.nn.Parameter = torch.nn.Parameter(torch.empty(1), requires_grad=trainable)
247
+ self.init_required: bool = True
248
+ else:
249
+ sigma = sigma.reshape(-1) # [Ns,]
250
+ self.log_sigma: torch.nn.Parameter = torch.nn.Parameter(sigma.log(), requires_grad=trainable)
251
+ self.init_required: bool = False
252
+ self.init_sigma_fn = init_sigma_fn
253
+ self.trainable = trainable
254
+
255
+ @property
256
+ def sigma(self) -> torch.Tensor:
257
+ return self.log_sigma.exp()
258
+
259
+ def forward(
260
+ self,
261
+ x: Array,
262
+ y: Array,
263
+ infer_sigma: bool = False,
264
+ ) -> torch.Tensor:
265
+ x, y = torch.as_tensor(x), torch.as_tensor(y)
266
+ dist = _squared_pairwise_distance(x.flatten(1), y.flatten(1)) # [Nx, Ny]
267
+
268
+ if infer_sigma or self.init_required:
269
+ if self.trainable and infer_sigma:
270
+ raise ValueError("Gradients cannot be computed w.r.t. an inferred sigma value")
271
+ sigma = self.init_sigma_fn(x, y, dist)
272
+ with torch.no_grad():
273
+ self.log_sigma.copy_(sigma.log().clone())
274
+ self.init_required: bool = False
275
+
276
+ gamma = 1.0 / (2.0 * self.sigma**2) # [Ns,]
277
+ # TODO: do matrix multiplication after all?
278
+ kernel_mat = torch.exp(-torch.cat([(g * dist)[None, :, :] for g in gamma], dim=0)) # [Ns, Nx, Ny]
279
+ return kernel_mat.mean(dim=0) # [Nx, Ny]
280
+
281
+
282
+ def mmd2_from_kernel_matrix(
283
+ kernel_mat: torch.Tensor, m: int, permute: bool = False, zero_diag: bool = True
284
+ ) -> torch.Tensor:
285
+ """
286
+ Compute maximum mean discrepancy (MMD^2) between 2 samples x and y from the
287
+ full kernel matrix between the samples.
288
+
289
+ Parameters
290
+ ----------
291
+ kernel_mat : torch.Tensor
292
+ Kernel matrix between samples x and y.
293
+ m : int
294
+ Number of instances in y.
295
+ permute : bool, default False
296
+ Whether to permute the row indices. Used for permutation tests.
297
+ zero_diag : bool, default True
298
+ Whether to zero out the diagonal of the kernel matrix.
299
+
300
+ Returns
301
+ -------
302
+ torch.Tensor
303
+ MMD^2 between the samples from the kernel matrix.
304
+ """
305
+ n = kernel_mat.shape[0] - m
306
+ if zero_diag:
307
+ kernel_mat = kernel_mat - torch.diag(kernel_mat.diag())
308
+ if permute:
309
+ idx = torch.randperm(kernel_mat.shape[0])
310
+ kernel_mat = kernel_mat[idx][:, idx]
311
+ k_xx, k_yy, k_xy = kernel_mat[:-m, :-m], kernel_mat[-m:, -m:], kernel_mat[-m:, :-m]
312
+ c_xx, c_yy = 1 / (n * (n - 1)), 1 / (m * (m - 1))
313
+ mmd2 = c_xx * k_xx.sum() + c_yy * k_yy.sum() - 2.0 * k_xy.mean()
314
+ return mmd2
@@ -10,33 +10,32 @@ from __future__ import annotations
10
10
 
11
11
  __all__ = []
12
12
 
13
- from functools import partial
14
- from typing import Callable, Literal
13
+ from typing import Literal, Sequence, cast
15
14
 
16
15
  import numpy as np
17
- from numpy.typing import NDArray
16
+ import torch
18
17
  from scipy.special import softmax
19
18
  from scipy.stats import entropy
20
19
 
21
- from dataeval.config import get_device
22
- from dataeval.detectors.drift._base import UpdateStrategy
20
+ from dataeval.config import DeviceLike, get_device
21
+ from dataeval.detectors.drift._base import BaseDrift, UpdateStrategy
23
22
  from dataeval.detectors.drift._ks import DriftKS
24
- from dataeval.detectors.drift._torch import preprocess_drift
25
23
  from dataeval.outputs import DriftOutput
26
- from dataeval.typing import ArrayLike
24
+ from dataeval.typing import Array, Transform
25
+ from dataeval.utils._array import as_numpy
26
+ from dataeval.utils.torch._internal import predict_batch
27
27
 
28
28
 
29
29
  def classifier_uncertainty(
30
- x: NDArray[np.float64],
31
- model_fn: Callable,
30
+ preds: Array,
32
31
  preds_type: Literal["probs", "logits"] = "probs",
33
- ) -> NDArray[np.float64]:
32
+ ) -> torch.Tensor:
34
33
  """
35
34
  Evaluate model_fn on x and transform predictions to prediction uncertainties.
36
35
 
37
36
  Parameters
38
37
  ----------
39
- x : np.ndarray
38
+ x : Array
40
39
  Batch of instances.
41
40
  model_fn : Callable
42
41
  Function that evaluates a :term:`classification<Classification>` model on x in a single call (contains
@@ -50,23 +49,21 @@ def classifier_uncertainty(
50
49
  NDArray
51
50
  A scalar indication of uncertainty of the model on each instance in x.
52
51
  """
53
-
54
- preds = model_fn(x)
55
-
52
+ preds_np = as_numpy(preds)
56
53
  if preds_type == "probs":
57
- if np.abs(1 - np.sum(preds, axis=-1)).mean() > 1e-6:
54
+ if np.abs(1 - np.sum(preds_np, axis=-1)).mean() > 1e-6:
58
55
  raise ValueError("Probabilities across labels should sum to 1")
59
- probs = preds
56
+ probs = preds_np
60
57
  elif preds_type == "logits":
61
- probs = softmax(preds, axis=-1)
58
+ probs = softmax(preds_np, axis=-1)
62
59
  else:
63
60
  raise NotImplementedError("Only prediction types 'probs' and 'logits' supported.")
64
61
 
65
- uncertainties = entropy(probs, axis=-1)
66
- return uncertainties[:, None] # Detectors expect N x d # type: ignore
62
+ uncertainties = cast(np.ndarray, entropy(probs, axis=-1))
63
+ return torch.as_tensor(uncertainties[:, None])
67
64
 
68
65
 
69
- class DriftUncertainty:
66
+ class DriftUncertainty(BaseDrift):
70
67
  """
71
68
  Test for a change in the number of instances falling into regions on which \
72
69
  the model is uncertain.
@@ -75,29 +72,27 @@ class DriftUncertainty:
75
72
 
76
73
  Parameters
77
74
  ----------
78
- x_ref : ArrayLike
75
+ data : Array
79
76
  Data used as reference distribution.
80
77
  model : Callable
81
78
  :term:`Classification` model outputting class probabilities (or logits)
82
79
  p_val : float, default 0.05
83
80
  :term:`P-Value` used for the significance of the test.
84
- x_ref_preprocessed : bool, default False
85
- Whether the given reference data ``x_ref`` has been preprocessed yet.
86
- If ``True``, only the test data ``x`` will be preprocessed at prediction time.
87
- If ``False``, the reference data will also be preprocessed.
88
- update_x_ref : UpdateStrategy or None, default None
81
+ update_strategy : UpdateStrategy or None, default None
89
82
  Reference data can optionally be updated using an UpdateStrategy class. Update
90
83
  using the last n instances seen by the detector with LastSeenUpdateStrategy
91
84
  or via reservoir sampling with ReservoirSamplingUpdateStrategy.
85
+ correction : "bonferroni" or "fdr", default "bonferroni"
86
+ Correction type for multivariate data. Either 'bonferroni' or 'fdr' (False
87
+ Discovery Rate).
92
88
  preds_type : "probs" or "logits", default "probs"
93
89
  Type of prediction output by the model. Options are 'probs' (in [0,1]) or
94
90
  'logits' (in [-inf,inf]).
95
91
  batch_size : int, default 32
96
92
  Batch size used to evaluate model. Only relevant when backend has been
97
93
  specified for batch prediction.
98
- preprocess_batch_fn : Callable or None, default None
99
- Optional batch preprocessing function. For example to convert a list of
100
- objects to a batch which can be processed by the model.
94
+ transforms : Transform, Sequence[Transform] or None, default None
95
+ Transform(s) to apply to the data.
101
96
  device : DeviceLike or None, default None
102
97
  Device type used. The default None tries to use the GPU and falls back on
103
98
  CPU if needed. Can be specified by passing either 'cuda' or 'cpu'.
@@ -120,46 +115,47 @@ class DriftUncertainty:
120
115
 
121
116
  def __init__(
122
117
  self,
123
- x_ref: ArrayLike,
124
- model: Callable,
118
+ data: Array,
119
+ model: torch.nn.Module,
125
120
  p_val: float = 0.05,
126
- x_ref_preprocessed: bool = False,
127
- update_x_ref: UpdateStrategy | None = None,
121
+ update_strategy: UpdateStrategy | None = None,
122
+ correction: Literal["bonferroni", "fdr"] = "bonferroni",
128
123
  preds_type: Literal["probs", "logits"] = "probs",
129
124
  batch_size: int = 32,
130
- preprocess_batch_fn: Callable | None = None,
131
- device: str | None = None,
125
+ transforms: Transform[torch.Tensor] | Sequence[Transform[torch.Tensor]] | None = None,
126
+ device: DeviceLike | None = None,
132
127
  ) -> None:
133
- def model_fn(x: NDArray) -> NDArray:
134
- return preprocess_drift(
135
- x,
136
- model, # type: ignore
137
- batch_size=batch_size,
138
- preprocess_batch_fn=preprocess_batch_fn,
139
- device=get_device(device),
140
- )
141
-
142
- preprocess_fn = partial(
143
- classifier_uncertainty,
144
- model_fn=model_fn,
145
- preds_type=preds_type,
146
- )
128
+ self.model: torch.nn.Module = model
129
+ self.device: torch.device = get_device(device)
130
+ self.batch_size: int = batch_size
131
+ self.preds_type: Literal["probs", "logits"] = preds_type
147
132
 
133
+ self._transforms = (
134
+ [] if transforms is None else [transforms] if isinstance(transforms, Transform) else transforms
135
+ )
148
136
  self._detector = DriftKS(
149
- x_ref=x_ref,
137
+ data=self._preprocess(data).cpu().numpy(),
150
138
  p_val=p_val,
151
- x_ref_preprocessed=x_ref_preprocessed,
152
- update_x_ref=update_x_ref,
153
- preprocess_fn=preprocess_fn, # type: ignore
139
+ update_strategy=update_strategy,
140
+ correction=correction,
154
141
  )
155
142
 
156
- def predict(self, x: ArrayLike) -> DriftOutput:
143
+ def _transform(self, x: torch.Tensor) -> torch.Tensor:
144
+ for transform in self._transforms:
145
+ x = transform(x)
146
+ return x
147
+
148
+ def _preprocess(self, x: Array) -> torch.Tensor:
149
+ preds = predict_batch(x, self.model, self.device, self.batch_size, self._transform)
150
+ return classifier_uncertainty(preds, self.preds_type)
151
+
152
+ def predict(self, x: Array) -> DriftOutput:
157
153
  """
158
154
  Predict whether a batch of data has drifted from the reference data.
159
155
 
160
156
  Parameters
161
157
  ----------
162
- x : ArrayLike
158
+ x : Array
163
159
  Batch of instances.
164
160
 
165
161
  Returns
@@ -168,4 +164,4 @@ class DriftUncertainty:
168
164
  Dictionary containing the drift prediction, :term:`p-value<P-Value>`, and threshold
169
165
  statistics.
170
166
  """
171
- return self._detector.predict(x)
167
+ return self._detector.predict(self._preprocess(x).cpu().numpy())