dataeval 0.64.0__py3-none-any.whl → 0.65.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (36) hide show
  1. dataeval/__init__.py +2 -2
  2. dataeval/_internal/detectors/clusterer.py +46 -34
  3. dataeval/_internal/detectors/drift/base.py +52 -35
  4. dataeval/_internal/detectors/drift/cvm.py +4 -4
  5. dataeval/_internal/detectors/drift/ks.py +6 -6
  6. dataeval/_internal/detectors/drift/mmd.py +35 -16
  7. dataeval/_internal/detectors/drift/torch.py +6 -5
  8. dataeval/_internal/detectors/drift/uncertainty.py +7 -7
  9. dataeval/_internal/detectors/duplicates.py +55 -29
  10. dataeval/_internal/detectors/linter.py +40 -24
  11. dataeval/_internal/detectors/ood/base.py +36 -15
  12. dataeval/_internal/detectors/ood/llr.py +7 -7
  13. dataeval/_internal/flags.py +42 -21
  14. dataeval/_internal/interop.py +2 -2
  15. dataeval/_internal/metrics/balance.py +10 -2
  16. dataeval/_internal/metrics/ber.py +6 -5
  17. dataeval/_internal/metrics/coverage.py +15 -8
  18. dataeval/_internal/metrics/divergence.py +41 -7
  19. dataeval/_internal/metrics/diversity.py +17 -12
  20. dataeval/_internal/metrics/parity.py +30 -43
  21. dataeval/_internal/metrics/stats.py +196 -317
  22. dataeval/_internal/metrics/uap.py +5 -2
  23. dataeval/_internal/metrics/utils.py +70 -33
  24. dataeval/_internal/models/tensorflow/losses.py +3 -3
  25. dataeval/_internal/models/tensorflow/trainer.py +3 -2
  26. dataeval/_internal/models/tensorflow/utils.py +4 -3
  27. dataeval/_internal/output.py +82 -0
  28. dataeval/_internal/workflows/sufficiency.py +96 -107
  29. dataeval/flags/__init__.py +2 -2
  30. dataeval/metrics/__init__.py +3 -3
  31. {dataeval-0.64.0.dist-info → dataeval-0.65.0.dist-info}/METADATA +1 -1
  32. dataeval-0.65.0.dist-info/RECORD +60 -0
  33. dataeval/_internal/metrics/base.py +0 -10
  34. dataeval-0.64.0.dist-info/RECORD +0 -60
  35. {dataeval-0.64.0.dist-info → dataeval-0.65.0.dist-info}/LICENSE.txt +0 -0
  36. {dataeval-0.64.0.dist-info → dataeval-0.65.0.dist-info}/WHEEL +0 -0
dataeval/__init__.py CHANGED
@@ -1,9 +1,9 @@
1
+ __version__ = "0.65.0"
2
+
1
3
  from importlib.util import find_spec
2
4
 
3
5
  from . import detectors, flags, metrics
4
6
 
5
- __version__ = "0.64.0"
6
-
7
7
  __all__ = ["detectors", "flags", "metrics"]
8
8
 
9
9
  if find_spec("torch") is not None: # pragma: no cover
@@ -1,26 +1,50 @@
1
+ from dataclasses import dataclass
1
2
  from typing import Dict, Iterable, List, NamedTuple, Tuple, Union, cast
2
3
 
3
4
  import numpy as np
4
- from numpy.typing import ArrayLike
5
+ from numpy.typing import ArrayLike, NDArray
5
6
  from scipy.cluster.hierarchy import linkage
6
7
  from scipy.spatial.distance import pdist, squareform
7
8
 
8
9
  from dataeval._internal.interop import to_numpy
10
+ from dataeval._internal.metrics.utils import flatten
11
+ from dataeval._internal.output import OutputMetadata, set_metadata
9
12
 
10
13
 
11
- def extend_linkage(link_arr: np.ndarray) -> np.ndarray:
14
+ @dataclass(frozen=True)
15
+ class ClustererOutput(OutputMetadata):
16
+ """
17
+ Attributes
18
+ ----------
19
+ outliers : List[int]
20
+ Indices that do not fall within a cluster
21
+ potential_outliers : List[int]
22
+ Indices which are near the border between belonging in the cluster and being an outlier
23
+ duplicates : List[List[int]]
24
+ Groups of indices that are exact duplicates
25
+ potential_duplicates : List[List[int]]
26
+ Groups of indices which are not exact but closely related data points
27
+ """
28
+
29
+ outliers: List[int]
30
+ potential_outliers: List[int]
31
+ duplicates: List[List[int]]
32
+ potential_duplicates: List[List[int]]
33
+
34
+
35
+ def extend_linkage(link_arr: NDArray) -> NDArray:
12
36
  """
13
37
  Adds a column to the linkage matrix link_arr that tracks the new id assigned
14
38
  to each row
15
39
 
16
40
  Parameters
17
41
  ----------
18
- link_arr : np.ndarray
42
+ link_arr : NDArray
19
43
  linkage matrix
20
44
 
21
45
  Returns
22
46
  -------
23
- np.ndarray
47
+ NDArray
24
48
  linkage matrix with adjusted shape, new shape (link_arr.shape[0], link_arr.shape[1]+1)
25
49
  """
26
50
  # Adjusting linkage matrix to accommodate renumbering
@@ -35,7 +59,7 @@ def extend_linkage(link_arr: np.ndarray) -> np.ndarray:
35
59
  class Cluster:
36
60
  __slots__ = "merged", "samples", "sample_dist", "is_copy", "count", "dist_avg", "dist_std", "out1", "out2"
37
61
 
38
- def __init__(self, merged: int, samples: np.ndarray, sample_dist: Union[float, np.ndarray], is_copy: bool = False):
62
+ def __init__(self, merged: int, samples: NDArray, sample_dist: Union[float, NDArray], is_copy: bool = False):
39
63
  self.merged = merged
40
64
  self.samples = np.array(samples, dtype=np.int32)
41
65
  self.sample_dist = np.array([sample_dist] if np.isscalar(sample_dist) else sample_dist)
@@ -130,13 +154,13 @@ class Clusterer:
130
154
  self._on_init(dataset)
131
155
 
132
156
  def _on_init(self, dataset: ArrayLike):
133
- self._data: np.ndarray = to_numpy(dataset)
157
+ self._data: NDArray = flatten(to_numpy(dataset))
134
158
  self._validate_data(self._data)
135
159
  self._num_samples = len(self._data)
136
160
 
137
- self._darr: np.ndarray = pdist(self._data, metric="euclidean")
138
- self._sqdmat: np.ndarray = squareform(self._darr)
139
- self._larr: np.ndarray = extend_linkage(linkage(self._darr))
161
+ self._darr: NDArray = pdist(self._data, metric="euclidean")
162
+ self._sqdmat: NDArray = squareform(self._darr)
163
+ self._larr: NDArray = extend_linkage(linkage(self._darr))
140
164
  self._max_clusters: int = np.count_nonzero(self._larr[:, 3] == 2)
141
165
 
142
166
  min_num = int(self._num_samples * 0.05)
@@ -146,7 +170,7 @@ class Clusterer:
146
170
  self._last_good_merge_levels = None
147
171
 
148
172
  @property
149
- def data(self) -> np.ndarray:
173
+ def data(self) -> NDArray:
150
174
  return self._data
151
175
 
152
176
  @data.setter
@@ -166,10 +190,10 @@ class Clusterer:
166
190
  return self._last_good_merge_levels
167
191
 
168
192
  @classmethod
169
- def _validate_data(cls, x: np.ndarray):
193
+ def _validate_data(cls, x: NDArray):
170
194
  """Checks that the data has the correct size, shape, and format"""
171
195
  if not isinstance(x, np.ndarray):
172
- raise TypeError(f"Data should be of type np.ndarray; got {type(x)}")
196
+ raise TypeError(f"Data should be of type NDArray; got {type(x)}")
173
197
 
174
198
  if x.ndim != 2:
175
199
  raise ValueError(
@@ -240,7 +264,7 @@ class Clusterer:
240
264
  clusters[level_id].setdefault(cid, cluster)
241
265
  return clusters
242
266
 
243
- def _get_cluster_distances(self) -> np.ndarray:
267
+ def _get_cluster_distances(self) -> NDArray:
244
268
  """Calculates the minimum distances between clusters are each level"""
245
269
  # Cluster distance matrix
246
270
  max_level = self.clusters.max_level
@@ -261,7 +285,7 @@ class Clusterer:
261
285
 
262
286
  return cluster_matrix
263
287
 
264
- def _calc_merge_indices(self, merge_mean: List[np.ndarray], intra_max: List[float]) -> np.ndarray:
288
+ def _calc_merge_indices(self, merge_mean: List[NDArray], intra_max: List[float]) -> NDArray:
265
289
  """
266
290
  Determine what clusters should be merged and return their indices
267
291
  """
@@ -284,7 +308,7 @@ class Clusterer:
284
308
  mask2 = mask2_vals < one_std_check
285
309
  return np.logical_or(desired_merge, mask2)
286
310
 
287
- def _generate_merge_list(self, cluster_matrix: np.ndarray) -> List[ClusterMergeEntry]:
311
+ def _generate_merge_list(self, cluster_matrix: NDArray) -> List[ClusterMergeEntry]:
288
312
  """
289
313
  Runs through the clusters dictionary determining when clusters merge,
290
314
  and how close are those clusters when they merge.
@@ -464,35 +488,23 @@ class Clusterer:
464
488
 
465
489
  return exact_dupes, near_dupes
466
490
 
467
- def evaluate(self):
491
+ # TODO: Move data input to evaluate from class
492
+ @set_metadata("dataeval.detectors", ["data"])
493
+ def evaluate(self) -> ClustererOutput:
468
494
  """Finds and flags indices of the data for outliers and duplicates
469
495
 
470
496
  Returns
471
497
  -------
472
- Dict[str, List[int]]
473
- outliers :
474
- List of indices that do not fall within a cluster
475
- potential_outliers :
476
- List of indices which are near the border between belonging in the cluster and being an outlier
477
- duplicates :
478
- List of groups of indices that are exact duplicates
479
- potential_duplicates :
480
- List of groups of indices which are not exact but closely related data points
498
+ ClustererOutput
499
+ The outliers and duplicate indices found in the data
481
500
 
482
501
  Example
483
502
  -------
484
503
  >>> cluster.evaluate()
485
- {'outliers': [18, 21, 34, 35, 45], 'potential_outliers': [13, 15, 42], 'duplicates': [[9, 24], [23, 48]], 'potential_duplicates': [[1, 11]]}
504
+ ClustererOutput(outliers=[18, 21, 34, 35, 45], potential_outliers=[13, 15, 42], duplicates=[[9, 24], [23, 48]], potential_duplicates=[[1, 11]])
486
505
  """ # noqa: E501
487
506
 
488
507
  outliers, potential_outliers = self.find_outliers(self.last_good_merge_levels)
489
508
  duplicates, potential_duplicates = self.find_duplicates(self.last_good_merge_levels)
490
509
 
491
- ret = {
492
- "outliers": outliers,
493
- "potential_outliers": potential_outliers,
494
- "duplicates": duplicates,
495
- "potential_duplicates": potential_duplicates,
496
- }
497
-
498
- return ret
510
+ return ClustererOutput(outliers, potential_outliers, duplicates, potential_duplicates)
@@ -7,13 +7,48 @@ Licensed under Apache Software License (Apache 2.0)
7
7
  """
8
8
 
9
9
  from abc import ABC, abstractmethod
10
+ from dataclasses import dataclass
10
11
  from functools import wraps
11
- from typing import Callable, Dict, Literal, Optional, Tuple, Union
12
+ from typing import Callable, Literal, Optional, Tuple
12
13
 
13
14
  import numpy as np
14
- from numpy.typing import ArrayLike
15
+ from numpy.typing import ArrayLike, NDArray
15
16
 
16
17
  from dataeval._internal.interop import to_numpy
18
+ from dataeval._internal.output import OutputMetadata, set_metadata
19
+
20
+
21
+ @dataclass(frozen=True)
22
+ class DriftOutput(OutputMetadata):
23
+ is_drift: bool
24
+ threshold: float
25
+
26
+
27
+ @dataclass(frozen=True)
28
+ class DriftUnivariateOutput(DriftOutput):
29
+ """
30
+ Attributes
31
+ ----------
32
+ is_drift : bool
33
+ Drift prediction for the images
34
+ threshold : float
35
+ Threshold after multivariate correction if needed
36
+ feature_drift : NDArray[np.bool_]
37
+ Feature-level array of images detected to have drifted
38
+ feature_threshold : float
39
+ Feature-level threshold to determine drift
40
+ p_vals : NDArray[np.float32]
41
+ Feature-level p-values
42
+ distances : NDArray[np.float32]
43
+ Feature-level distances
44
+ """
45
+
46
+ # is_drift: bool
47
+ # threshold: float
48
+ feature_drift: NDArray[np.bool_]
49
+ feature_threshold: float
50
+ p_vals: NDArray[np.float32]
51
+ distances: NDArray[np.float32]
17
52
 
18
53
 
19
54
  def update_x_ref(fn):
@@ -52,7 +87,7 @@ class UpdateStrategy(ABC):
52
87
  self.n = n
53
88
 
54
89
  @abstractmethod
55
- def __call__(self, x_ref: np.ndarray, x: np.ndarray, count: int) -> np.ndarray:
90
+ def __call__(self, x_ref: NDArray, x: NDArray, count: int) -> NDArray:
56
91
  """Abstract implementation of update strategy"""
57
92
 
58
93
 
@@ -66,7 +101,7 @@ class LastSeenUpdate(UpdateStrategy):
66
101
  Update with last n instances seen by the detector.
67
102
  """
68
103
 
69
- def __call__(self, x_ref: np.ndarray, x: np.ndarray, count: int) -> np.ndarray:
104
+ def __call__(self, x_ref: NDArray, x: NDArray, count: int) -> NDArray:
70
105
  x_updated = np.concatenate([x_ref, x], axis=0)
71
106
  return x_updated[-self.n :]
72
107
 
@@ -81,7 +116,7 @@ class ReservoirSamplingUpdate(UpdateStrategy):
81
116
  Update with reservoir sampling of size n.
82
117
  """
83
118
 
84
- def __call__(self, x_ref: np.ndarray, x: np.ndarray, count: int) -> np.ndarray:
119
+ def __call__(self, x_ref: NDArray, x: NDArray, count: int) -> NDArray:
85
120
  if x.shape[0] + count <= self.n:
86
121
  return np.concatenate([x_ref, x], axis=0)
87
122
 
@@ -136,7 +171,7 @@ class BaseDrift:
136
171
  self._x_refcount = 0
137
172
 
138
173
  @property
139
- def x_ref(self) -> np.ndarray:
174
+ def x_ref(self) -> NDArray:
140
175
  if not self.x_ref_preprocessed:
141
176
  self.x_ref_preprocessed = True
142
177
  if self.preprocess_fn is not None:
@@ -152,7 +187,7 @@ class BaseDrift:
152
187
  return x
153
188
 
154
189
 
155
- class BaseUnivariateDrift(BaseDrift):
190
+ class BaseDriftUnivariate(BaseDrift):
156
191
  """
157
192
  Generic drift detector component which serves as a base class for methods using
158
193
  univariate tests. If n_features > 1, a multivariate correction is applied such
@@ -198,13 +233,13 @@ class BaseUnivariateDrift(BaseDrift):
198
233
 
199
234
  @preprocess_x
200
235
  @abstractmethod
201
- def score(self, x: ArrayLike) -> Tuple[np.ndarray, np.ndarray]:
236
+ def score(self, x: ArrayLike) -> Tuple[NDArray[np.float32], NDArray[np.float32]]:
202
237
  """Abstract method to calculate feature score after preprocessing"""
203
238
 
204
- def _apply_correction(self, p_vals: np.ndarray) -> Tuple[int, float]:
239
+ def _apply_correction(self, p_vals: NDArray) -> Tuple[bool, float]:
205
240
  if self.correction == "bonferroni":
206
241
  threshold = self.p_val / self.n_features
207
- drift_pred = int((p_vals < threshold).any())
242
+ drift_pred = bool((p_vals < threshold).any())
208
243
  return drift_pred, threshold
209
244
  elif self.correction == "fdr":
210
245
  n = p_vals.shape[0]
@@ -215,18 +250,18 @@ class BaseUnivariateDrift(BaseDrift):
215
250
  try:
216
251
  idx_threshold = int(np.where(below_threshold)[0].max())
217
252
  except ValueError: # sorted p-values not below thresholds
218
- return int(below_threshold.any()), q_threshold.min()
219
- return int(below_threshold.any()), q_threshold[idx_threshold]
253
+ return bool(below_threshold.any()), q_threshold.min()
254
+ return bool(below_threshold.any()), q_threshold[idx_threshold]
220
255
  else:
221
256
  raise ValueError("`correction` needs to be either `bonferroni` or `fdr`.")
222
257
 
258
+ @set_metadata("dataeval.detectors")
223
259
  @preprocess_x
224
260
  @update_x_ref
225
261
  def predict(
226
262
  self,
227
263
  x: ArrayLike,
228
- drift_type: Literal["batch", "feature"] = "batch",
229
- ) -> Dict[str, Union[int, float, np.ndarray]]:
264
+ ) -> DriftUnivariateOutput:
230
265
  """
231
266
  Predict whether a batch of data has drifted from the reference data and update
232
267
  reference data using specified update strategy.
@@ -235,10 +270,6 @@ class BaseUnivariateDrift(BaseDrift):
235
270
  ----------
236
271
  x : ArrayLike
237
272
  Batch of instances.
238
- drift_type : Literal["batch", "feature"], default "batch"
239
- Predict drift at the 'feature' or 'batch' level. For 'batch', the test
240
- statistics for each feature are aggregated using the Bonferroni or False
241
- Discovery Rate correction (if n_features>1).
242
273
 
243
274
  Returns
244
275
  -------
@@ -249,20 +280,6 @@ class BaseUnivariateDrift(BaseDrift):
249
280
  # compute drift scores
250
281
  p_vals, dist = self.score(x)
251
282
 
252
- # TODO: return both feature-level and batch-level drift predictions by default
253
- # values below p-value threshold are drift
254
- if drift_type == "feature":
255
- drift_pred = (p_vals < self.p_val).astype(int)
256
- threshold = self.p_val
257
- elif drift_type == "batch":
258
- drift_pred, threshold = self._apply_correction(p_vals)
259
- else:
260
- raise ValueError("`drift_type` needs to be either `feature` or `batch`.")
261
-
262
- # populate drift dict
263
- return {
264
- "is_drift": drift_pred,
265
- "p_val": p_vals,
266
- "threshold": threshold,
267
- "distance": dist,
268
- }
283
+ feature_drift = (p_vals < self.p_val).astype(np.bool_)
284
+ drift_pred, threshold = self._apply_correction(p_vals)
285
+ return DriftUnivariateOutput(drift_pred, threshold, feature_drift, self.p_val, p_vals, dist)
@@ -9,15 +9,15 @@ Licensed under Apache Software License (Apache 2.0)
9
9
  from typing import Callable, Literal, Optional, Tuple
10
10
 
11
11
  import numpy as np
12
- from numpy.typing import ArrayLike
12
+ from numpy.typing import ArrayLike, NDArray
13
13
  from scipy.stats import cramervonmises_2samp
14
14
 
15
15
  from dataeval._internal.interop import to_numpy
16
16
 
17
- from .base import BaseUnivariateDrift, UpdateStrategy, preprocess_x
17
+ from .base import BaseDriftUnivariate, UpdateStrategy, preprocess_x
18
18
 
19
19
 
20
- class DriftCVM(BaseUnivariateDrift):
20
+ class DriftCVM(BaseDriftUnivariate):
21
21
  """
22
22
  Cramér-von Mises (CVM) data drift detector, which tests for any change in the
23
23
  distribution of continuous univariate data. For multivariate data, a separate
@@ -76,7 +76,7 @@ class DriftCVM(BaseUnivariateDrift):
76
76
  )
77
77
 
78
78
  @preprocess_x
79
- def score(self, x: ArrayLike) -> Tuple[np.ndarray, np.ndarray]:
79
+ def score(self, x: ArrayLike) -> Tuple[NDArray[np.float32], NDArray[np.float32]]:
80
80
  """
81
81
  Performs the two-sample Cramér-von Mises test(s), computing the p-value and
82
82
  test statistic per feature.
@@ -9,22 +9,22 @@ Licensed under Apache Software License (Apache 2.0)
9
9
  from typing import Callable, Literal, Optional, Tuple
10
10
 
11
11
  import numpy as np
12
- from numpy.typing import ArrayLike
12
+ from numpy.typing import ArrayLike, NDArray
13
13
  from scipy.stats import ks_2samp
14
14
 
15
15
  from dataeval._internal.interop import to_numpy
16
16
 
17
- from .base import BaseUnivariateDrift, UpdateStrategy, preprocess_x
17
+ from .base import BaseDriftUnivariate, UpdateStrategy, preprocess_x
18
18
 
19
19
 
20
- class DriftKS(BaseUnivariateDrift):
20
+ class DriftKS(BaseDriftUnivariate):
21
21
  """
22
22
  Kolmogorov-Smirnov (K-S) data drift detector with Bonferroni or False Discovery
23
23
  Rate (FDR) correction for multivariate data.
24
24
 
25
25
  Parameters
26
26
  ----------
27
- x_ref : np.ndarray
27
+ x_ref : NDArray
28
28
  Data used as reference distribution.
29
29
  p_val : float, default 0.05
30
30
  p-value used for significance of the statistical test for each feature.
@@ -41,7 +41,7 @@ class DriftKS(BaseUnivariateDrift):
41
41
  :py:class:`dataeval.detectors.LastSeenUpdateStrategy`
42
42
  or via reservoir sampling with
43
43
  :py:class:`dataeval.detectors.ReservoirSamplingUpdateStrategy`.
44
- preprocess_fn : Optional[Callable[[np.ndarray], np.ndarray]], default None
44
+ preprocess_fn : Optional[Callable[[NDArray], NDArray]], default None
45
45
  Function to preprocess the data before computing the data drift metrics.
46
46
  Typically a dimensionality reduction technique.
47
47
  correction : Literal["bonferroni", "fdr"], default "bonferroni"
@@ -81,7 +81,7 @@ class DriftKS(BaseUnivariateDrift):
81
81
  self.alternative = alternative
82
82
 
83
83
  @preprocess_x
84
- def score(self, x: ArrayLike) -> Tuple[np.ndarray, np.ndarray]:
84
+ def score(self, x: ArrayLike) -> Tuple[NDArray[np.float32], NDArray[np.float32]]:
85
85
  """
86
86
  Compute K-S scores and statistics per feature.
87
87
 
@@ -6,17 +6,43 @@ Original code Copyright (c) 2023 Seldon Technologies Ltd
6
6
  Licensed under Apache Software License (Apache 2.0)
7
7
  """
8
8
 
9
- from typing import Callable, Dict, Optional, Tuple, Union
9
+ from dataclasses import dataclass
10
+ from typing import Callable, Optional, Tuple
10
11
 
11
12
  import torch
12
13
  from numpy.typing import ArrayLike
13
14
 
14
15
  from dataeval._internal.interop import to_numpy
16
+ from dataeval._internal.output import set_metadata
15
17
 
16
- from .base import BaseDrift, UpdateStrategy, preprocess_x, update_x_ref
18
+ from .base import BaseDrift, DriftOutput, UpdateStrategy, preprocess_x, update_x_ref
17
19
  from .torch import GaussianRBF, get_device, mmd2_from_kernel_matrix
18
20
 
19
21
 
22
+ @dataclass(frozen=True)
23
+ class DriftMMDOutput(DriftOutput):
24
+ """
25
+ Attributes
26
+ ----------
27
+ is_drift : bool
28
+ Drift prediction for the images
29
+ threshold : float
30
+ P-value used for significance of the permutation test
31
+ p_val : float
32
+ P-value obtained from the permutation test
33
+ distance : float
34
+ MMD^2 between the reference and test set
35
+ distance_threshold : float
36
+ MMD^2 threshold above which drift is flagged
37
+ """
38
+
39
+ # is_drift: bool
40
+ # threshold: float
41
+ p_val: float
42
+ distance: float
43
+ distance_threshold: float
44
+
45
+
20
46
  class DriftMMD(BaseDrift):
21
47
  """
22
48
  Maximum Mean Discrepancy (MMD) data drift detector using a permutation test.
@@ -129,19 +155,17 @@ class DriftMMD(BaseDrift):
129
155
  mmd2_permuted = torch.Tensor(
130
156
  [mmd2_from_kernel_matrix(kernel_mat, n, permute=True, zero_diag=False) for _ in range(self.n_permutations)]
131
157
  )
132
- mmd2, mmd2_permuted = mmd2.cpu(), mmd2_permuted.cpu()
158
+ mmd2, mmd2_permuted = mmd2.detach().cpu(), mmd2_permuted.detach().cpu()
133
159
  p_val = (mmd2 <= mmd2_permuted).float().mean()
134
160
  # compute distance threshold
135
161
  idx_threshold = int(self.p_val * len(mmd2_permuted))
136
162
  distance_threshold = torch.sort(mmd2_permuted, descending=True).values[idx_threshold]
137
163
  return p_val.numpy().item(), mmd2.numpy().item(), distance_threshold.numpy()
138
164
 
165
+ @set_metadata("dataeval.detectors")
139
166
  @preprocess_x
140
167
  @update_x_ref
141
- def predict(
142
- self,
143
- x: ArrayLike,
144
- ) -> Dict[str, Union[int, float]]:
168
+ def predict(self, x: ArrayLike) -> DriftMMDOutput:
145
169
  """
146
170
  Predict whether a batch of data has drifted from the reference data and then
147
171
  updates reference data using specified strategy.
@@ -153,17 +177,12 @@ class DriftMMD(BaseDrift):
153
177
 
154
178
  Returns
155
179
  -------
156
- Dictionary containing the drift prediction, p-value, threshold and MMD metric.
180
+ DriftMMDOutput
181
+ Output class containing the drift prediction, p-value, threshold and MMD metric.
157
182
  """
158
183
  # compute drift scores
159
184
  p_val, dist, distance_threshold = self.score(x)
160
- drift_pred = int(p_val < self.p_val)
185
+ drift_pred = bool(p_val < self.p_val)
161
186
 
162
187
  # populate drift dict
163
- return {
164
- "is_drift": drift_pred,
165
- "p_val": p_val,
166
- "threshold": self.p_val,
167
- "distance": dist,
168
- "distance_threshold": distance_threshold,
169
- }
188
+ return DriftMMDOutput(drift_pred, self.p_val, p_val, dist, distance_threshold)
@@ -12,6 +12,7 @@ from typing import Callable, Optional, Type, Union
12
12
  import numpy as np
13
13
  import torch
14
14
  import torch.nn as nn
15
+ from numpy.typing import NDArray
15
16
 
16
17
 
17
18
  def get_device(device: Optional[Union[str, torch.device]] = None) -> torch.device:
@@ -74,13 +75,13 @@ def mmd2_from_kernel_matrix(
74
75
 
75
76
 
76
77
  def predict_batch(
77
- x: Union[np.ndarray, torch.Tensor],
78
+ x: Union[NDArray, torch.Tensor],
78
79
  model: Union[Callable, nn.Module, nn.Sequential],
79
80
  device: Optional[torch.device] = None,
80
81
  batch_size: int = int(1e10),
81
82
  preprocess_fn: Optional[Callable] = None,
82
83
  dtype: Union[Type[np.generic], torch.dtype] = np.float32,
83
- ) -> Union[np.ndarray, torch.Tensor, tuple]:
84
+ ) -> Union[NDArray, torch.Tensor, tuple]:
84
85
  """
85
86
  Make batch predictions on a model.
86
87
 
@@ -138,7 +139,7 @@ def predict_batch(
138
139
  else:
139
140
  raise TypeError(
140
141
  f"Model output type {type(preds_tmp)} not supported. The model \
141
- output type needs to be one of list, tuple, np.ndarray or \
142
+ output type needs to be one of list, tuple, NDArray or \
142
143
  torch.Tensor."
143
144
  )
144
145
  concat = partial(np.concatenate, axis=0) if return_np else partial(torch.cat, dim=0)
@@ -149,13 +150,13 @@ def predict_batch(
149
150
 
150
151
 
151
152
  def preprocess_drift(
152
- x: np.ndarray,
153
+ x: NDArray,
153
154
  model: nn.Module,
154
155
  device: Optional[torch.device] = None,
155
156
  preprocess_batch_fn: Optional[Callable] = None,
156
157
  batch_size: int = int(1e10),
157
158
  dtype: Union[Type[np.generic], torch.dtype] = np.float32,
158
- ) -> Union[np.ndarray, torch.Tensor, tuple]:
159
+ ) -> Union[NDArray, torch.Tensor, tuple]:
159
160
  """
160
161
  Prediction function used for preprocessing step of drift detector.
161
162
 
@@ -7,23 +7,23 @@ Licensed under Apache Software License (Apache 2.0)
7
7
  """
8
8
 
9
9
  from functools import partial
10
- from typing import Callable, Dict, Literal, Optional, Union
10
+ from typing import Callable, Literal, Optional
11
11
 
12
12
  import numpy as np
13
- from numpy.typing import ArrayLike
13
+ from numpy.typing import ArrayLike, NDArray
14
14
  from scipy.special import softmax
15
15
  from scipy.stats import entropy
16
16
 
17
- from .base import UpdateStrategy
17
+ from .base import DriftUnivariateOutput, UpdateStrategy
18
18
  from .ks import DriftKS
19
19
  from .torch import get_device, preprocess_drift
20
20
 
21
21
 
22
22
  def classifier_uncertainty(
23
- x: np.ndarray,
23
+ x: NDArray,
24
24
  model_fn: Callable,
25
25
  preds_type: Literal["probs", "logits"] = "probs",
26
- ) -> np.ndarray:
26
+ ) -> NDArray:
27
27
  """
28
28
  Evaluate model_fn on x and transform predictions to prediction uncertainties.
29
29
 
@@ -111,7 +111,7 @@ class DriftUncertainty:
111
111
  preprocess_batch_fn: Optional[Callable] = None,
112
112
  device: Optional[str] = None,
113
113
  ) -> None:
114
- def model_fn(x: np.ndarray) -> np.ndarray:
114
+ def model_fn(x: NDArray) -> NDArray:
115
115
  return preprocess_drift(
116
116
  x,
117
117
  model, # type: ignore
@@ -134,7 +134,7 @@ class DriftUncertainty:
134
134
  preprocess_fn=preprocess_fn, # type: ignore
135
135
  )
136
136
 
137
- def predict(self, x: ArrayLike) -> Dict[str, Union[int, float, np.ndarray]]:
137
+ def predict(self, x: ArrayLike) -> DriftUnivariateOutput:
138
138
  """
139
139
  Predict whether a batch of data has drifted from the reference data.
140
140