dataeval 0.63.0__py3-none-any.whl → 0.65.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (55) hide show
  1. dataeval/__init__.py +4 -4
  2. dataeval/_internal/detectors/clusterer.py +47 -34
  3. dataeval/_internal/detectors/drift/base.py +53 -35
  4. dataeval/_internal/detectors/drift/cvm.py +5 -4
  5. dataeval/_internal/detectors/drift/ks.py +7 -6
  6. dataeval/_internal/detectors/drift/mmd.py +39 -19
  7. dataeval/_internal/detectors/drift/torch.py +6 -5
  8. dataeval/_internal/detectors/drift/uncertainty.py +7 -8
  9. dataeval/_internal/detectors/duplicates.py +57 -30
  10. dataeval/_internal/detectors/linter.py +40 -24
  11. dataeval/_internal/detectors/ood/ae.py +2 -1
  12. dataeval/_internal/detectors/ood/aegmm.py +2 -1
  13. dataeval/_internal/detectors/ood/base.py +37 -15
  14. dataeval/_internal/detectors/ood/llr.py +9 -8
  15. dataeval/_internal/detectors/ood/vae.py +2 -1
  16. dataeval/_internal/detectors/ood/vaegmm.py +2 -1
  17. dataeval/_internal/flags.py +42 -21
  18. dataeval/_internal/interop.py +3 -12
  19. dataeval/_internal/metrics/balance.py +188 -0
  20. dataeval/_internal/metrics/ber.py +123 -48
  21. dataeval/_internal/metrics/coverage.py +90 -74
  22. dataeval/_internal/metrics/divergence.py +101 -67
  23. dataeval/_internal/metrics/diversity.py +211 -0
  24. dataeval/_internal/metrics/parity.py +287 -155
  25. dataeval/_internal/metrics/stats.py +198 -317
  26. dataeval/_internal/metrics/uap.py +40 -29
  27. dataeval/_internal/metrics/utils.py +430 -0
  28. dataeval/_internal/models/tensorflow/losses.py +3 -3
  29. dataeval/_internal/models/tensorflow/trainer.py +3 -2
  30. dataeval/_internal/models/tensorflow/utils.py +4 -3
  31. dataeval/_internal/output.py +82 -0
  32. dataeval/_internal/utils.py +64 -0
  33. dataeval/_internal/workflows/sufficiency.py +96 -107
  34. dataeval/flags/__init__.py +2 -2
  35. dataeval/metrics/__init__.py +26 -7
  36. dataeval/utils/__init__.py +9 -0
  37. {dataeval-0.63.0.dist-info → dataeval-0.65.0.dist-info}/METADATA +1 -1
  38. dataeval-0.65.0.dist-info/RECORD +60 -0
  39. dataeval/_internal/functional/__init__.py +0 -0
  40. dataeval/_internal/functional/ber.py +0 -63
  41. dataeval/_internal/functional/coverage.py +0 -75
  42. dataeval/_internal/functional/divergence.py +0 -16
  43. dataeval/_internal/functional/hash.py +0 -79
  44. dataeval/_internal/functional/metadata.py +0 -136
  45. dataeval/_internal/functional/metadataparity.py +0 -190
  46. dataeval/_internal/functional/uap.py +0 -6
  47. dataeval/_internal/functional/utils.py +0 -158
  48. dataeval/_internal/maite/__init__.py +0 -0
  49. dataeval/_internal/maite/utils.py +0 -30
  50. dataeval/_internal/metrics/base.py +0 -92
  51. dataeval/_internal/metrics/metadata.py +0 -610
  52. dataeval/_internal/metrics/metadataparity.py +0 -67
  53. dataeval-0.63.0.dist-info/RECORD +0 -68
  54. {dataeval-0.63.0.dist-info → dataeval-0.65.0.dist-info}/LICENSE.txt +0 -0
  55. {dataeval-0.63.0.dist-info → dataeval-0.65.0.dist-info}/WHEEL +0 -0
dataeval/__init__.py CHANGED
@@ -1,15 +1,15 @@
1
+ __version__ = "0.65.0"
2
+
1
3
  from importlib.util import find_spec
2
4
 
3
5
  from . import detectors, flags, metrics
4
6
 
5
- __version__ = "0.63.0"
6
-
7
7
  __all__ = ["detectors", "flags", "metrics"]
8
8
 
9
9
  if find_spec("torch") is not None: # pragma: no cover
10
- from . import models, workflows
10
+ from . import models, utils, workflows
11
11
 
12
- __all__ += ["models", "workflows"]
12
+ __all__ += ["models", "utils", "workflows"]
13
13
  elif find_spec("tensorflow") is not None: # pragma: no cover
14
14
  from . import models
15
15
 
@@ -1,25 +1,50 @@
1
+ from dataclasses import dataclass
1
2
  from typing import Dict, Iterable, List, NamedTuple, Tuple, Union, cast
2
3
 
3
4
  import numpy as np
5
+ from numpy.typing import ArrayLike, NDArray
4
6
  from scipy.cluster.hierarchy import linkage
5
7
  from scipy.spatial.distance import pdist, squareform
6
8
 
7
- from dataeval._internal.interop import ArrayLike, to_numpy
9
+ from dataeval._internal.interop import to_numpy
10
+ from dataeval._internal.metrics.utils import flatten
11
+ from dataeval._internal.output import OutputMetadata, set_metadata
8
12
 
9
13
 
10
- def extend_linkage(link_arr: np.ndarray) -> np.ndarray:
14
+ @dataclass(frozen=True)
15
+ class ClustererOutput(OutputMetadata):
16
+ """
17
+ Attributes
18
+ ----------
19
+ outliers : List[int]
20
+ Indices that do not fall within a cluster
21
+ potential_outliers : List[int]
22
+ Indices which are near the border between belonging in the cluster and being an outlier
23
+ duplicates : List[List[int]]
24
+ Groups of indices that are exact duplicates
25
+ potential_duplicates : List[List[int]]
26
+ Groups of indices which are not exact but closely related data points
27
+ """
28
+
29
+ outliers: List[int]
30
+ potential_outliers: List[int]
31
+ duplicates: List[List[int]]
32
+ potential_duplicates: List[List[int]]
33
+
34
+
35
+ def extend_linkage(link_arr: NDArray) -> NDArray:
11
36
  """
12
37
  Adds a column to the linkage matrix link_arr that tracks the new id assigned
13
38
  to each row
14
39
 
15
40
  Parameters
16
41
  ----------
17
- link_arr : np.ndarray
42
+ link_arr : NDArray
18
43
  linkage matrix
19
44
 
20
45
  Returns
21
46
  -------
22
- np.ndarray
47
+ NDArray
23
48
  linkage matrix with adjusted shape, new shape (link_arr.shape[0], link_arr.shape[1]+1)
24
49
  """
25
50
  # Adjusting linkage matrix to accommodate renumbering
@@ -34,7 +59,7 @@ def extend_linkage(link_arr: np.ndarray) -> np.ndarray:
34
59
  class Cluster:
35
60
  __slots__ = "merged", "samples", "sample_dist", "is_copy", "count", "dist_avg", "dist_std", "out1", "out2"
36
61
 
37
- def __init__(self, merged: int, samples: np.ndarray, sample_dist: Union[float, np.ndarray], is_copy: bool = False):
62
+ def __init__(self, merged: int, samples: NDArray, sample_dist: Union[float, NDArray], is_copy: bool = False):
38
63
  self.merged = merged
39
64
  self.samples = np.array(samples, dtype=np.int32)
40
65
  self.sample_dist = np.array([sample_dist] if np.isscalar(sample_dist) else sample_dist)
@@ -129,13 +154,13 @@ class Clusterer:
129
154
  self._on_init(dataset)
130
155
 
131
156
  def _on_init(self, dataset: ArrayLike):
132
- self._data: np.ndarray = to_numpy(dataset)
157
+ self._data: NDArray = flatten(to_numpy(dataset))
133
158
  self._validate_data(self._data)
134
159
  self._num_samples = len(self._data)
135
160
 
136
- self._darr: np.ndarray = pdist(self._data, metric="euclidean")
137
- self._sqdmat: np.ndarray = squareform(self._darr)
138
- self._larr: np.ndarray = extend_linkage(linkage(self._darr))
161
+ self._darr: NDArray = pdist(self._data, metric="euclidean")
162
+ self._sqdmat: NDArray = squareform(self._darr)
163
+ self._larr: NDArray = extend_linkage(linkage(self._darr))
139
164
  self._max_clusters: int = np.count_nonzero(self._larr[:, 3] == 2)
140
165
 
141
166
  min_num = int(self._num_samples * 0.05)
@@ -145,7 +170,7 @@ class Clusterer:
145
170
  self._last_good_merge_levels = None
146
171
 
147
172
  @property
148
- def data(self) -> np.ndarray:
173
+ def data(self) -> NDArray:
149
174
  return self._data
150
175
 
151
176
  @data.setter
@@ -165,10 +190,10 @@ class Clusterer:
165
190
  return self._last_good_merge_levels
166
191
 
167
192
  @classmethod
168
- def _validate_data(cls, x: np.ndarray):
193
+ def _validate_data(cls, x: NDArray):
169
194
  """Checks that the data has the correct size, shape, and format"""
170
195
  if not isinstance(x, np.ndarray):
171
- raise TypeError(f"Data should be of type np.ndarray; got {type(x)}")
196
+ raise TypeError(f"Data should be of type NDArray; got {type(x)}")
172
197
 
173
198
  if x.ndim != 2:
174
199
  raise ValueError(
@@ -239,7 +264,7 @@ class Clusterer:
239
264
  clusters[level_id].setdefault(cid, cluster)
240
265
  return clusters
241
266
 
242
- def _get_cluster_distances(self) -> np.ndarray:
267
+ def _get_cluster_distances(self) -> NDArray:
243
268
  """Calculates the minimum distances between clusters are each level"""
244
269
  # Cluster distance matrix
245
270
  max_level = self.clusters.max_level
@@ -260,7 +285,7 @@ class Clusterer:
260
285
 
261
286
  return cluster_matrix
262
287
 
263
- def _calc_merge_indices(self, merge_mean: List[np.ndarray], intra_max: List[float]) -> np.ndarray:
288
+ def _calc_merge_indices(self, merge_mean: List[NDArray], intra_max: List[float]) -> NDArray:
264
289
  """
265
290
  Determine what clusters should be merged and return their indices
266
291
  """
@@ -283,7 +308,7 @@ class Clusterer:
283
308
  mask2 = mask2_vals < one_std_check
284
309
  return np.logical_or(desired_merge, mask2)
285
310
 
286
- def _generate_merge_list(self, cluster_matrix: np.ndarray) -> List[ClusterMergeEntry]:
311
+ def _generate_merge_list(self, cluster_matrix: NDArray) -> List[ClusterMergeEntry]:
287
312
  """
288
313
  Runs through the clusters dictionary determining when clusters merge,
289
314
  and how close are those clusters when they merge.
@@ -463,35 +488,23 @@ class Clusterer:
463
488
 
464
489
  return exact_dupes, near_dupes
465
490
 
466
- def evaluate(self):
491
+ # TODO: Move data input to evaluate from class
492
+ @set_metadata("dataeval.detectors", ["data"])
493
+ def evaluate(self) -> ClustererOutput:
467
494
  """Finds and flags indices of the data for outliers and duplicates
468
495
 
469
496
  Returns
470
497
  -------
471
- Dict[str, List[int]]
472
- outliers :
473
- List of indices that do not fall within a cluster
474
- potential_outliers :
475
- List of indices which are near the border between belonging in the cluster and being an outlier
476
- duplicates :
477
- List of groups of indices that are exact duplicates
478
- potential_duplicates :
479
- List of groups of indices which are not exact but closely related data points
498
+ ClustererOutput
499
+ The outliers and duplicate indices found in the data
480
500
 
481
501
  Example
482
502
  -------
483
503
  >>> cluster.evaluate()
484
- {'outliers': [18, 21, 34, 35, 45], 'potential_outliers': [13, 15, 42], 'duplicates': [[9, 24], [23, 48]], 'potential_duplicates': [[1, 11]]}
504
+ ClustererOutput(outliers=[18, 21, 34, 35, 45], potential_outliers=[13, 15, 42], duplicates=[[9, 24], [23, 48]], potential_duplicates=[[1, 11]])
485
505
  """ # noqa: E501
486
506
 
487
507
  outliers, potential_outliers = self.find_outliers(self.last_good_merge_levels)
488
508
  duplicates, potential_duplicates = self.find_duplicates(self.last_good_merge_levels)
489
509
 
490
- ret = {
491
- "outliers": outliers,
492
- "potential_outliers": potential_outliers,
493
- "duplicates": duplicates,
494
- "potential_duplicates": potential_duplicates,
495
- }
496
-
497
- return ret
510
+ return ClustererOutput(outliers, potential_outliers, duplicates, potential_duplicates)
@@ -7,12 +7,48 @@ Licensed under Apache Software License (Apache 2.0)
7
7
  """
8
8
 
9
9
  from abc import ABC, abstractmethod
10
+ from dataclasses import dataclass
10
11
  from functools import wraps
11
- from typing import Callable, Dict, Literal, Optional, Tuple, Union
12
+ from typing import Callable, Literal, Optional, Tuple
12
13
 
13
14
  import numpy as np
15
+ from numpy.typing import ArrayLike, NDArray
14
16
 
15
- from dataeval._internal.interop import ArrayLike, to_numpy
17
+ from dataeval._internal.interop import to_numpy
18
+ from dataeval._internal.output import OutputMetadata, set_metadata
19
+
20
+
21
+ @dataclass(frozen=True)
22
+ class DriftOutput(OutputMetadata):
23
+ is_drift: bool
24
+ threshold: float
25
+
26
+
27
+ @dataclass(frozen=True)
28
+ class DriftUnivariateOutput(DriftOutput):
29
+ """
30
+ Attributes
31
+ ----------
32
+ is_drift : bool
33
+ Drift prediction for the images
34
+ threshold : float
35
+ Threshold after multivariate correction if needed
36
+ feature_drift : NDArray[np.bool_]
37
+ Feature-level array of images detected to have drifted
38
+ feature_threshold : float
39
+ Feature-level threshold to determine drift
40
+ p_vals : NDArray[np.float32]
41
+ Feature-level p-values
42
+ distances : NDArray[np.float32]
43
+ Feature-level distances
44
+ """
45
+
46
+ # is_drift: bool
47
+ # threshold: float
48
+ feature_drift: NDArray[np.bool_]
49
+ feature_threshold: float
50
+ p_vals: NDArray[np.float32]
51
+ distances: NDArray[np.float32]
16
52
 
17
53
 
18
54
  def update_x_ref(fn):
@@ -51,7 +87,7 @@ class UpdateStrategy(ABC):
51
87
  self.n = n
52
88
 
53
89
  @abstractmethod
54
- def __call__(self, x_ref: np.ndarray, x: np.ndarray, count: int) -> np.ndarray:
90
+ def __call__(self, x_ref: NDArray, x: NDArray, count: int) -> NDArray:
55
91
  """Abstract implementation of update strategy"""
56
92
 
57
93
 
@@ -65,7 +101,7 @@ class LastSeenUpdate(UpdateStrategy):
65
101
  Update with last n instances seen by the detector.
66
102
  """
67
103
 
68
- def __call__(self, x_ref: np.ndarray, x: np.ndarray, count: int) -> np.ndarray:
104
+ def __call__(self, x_ref: NDArray, x: NDArray, count: int) -> NDArray:
69
105
  x_updated = np.concatenate([x_ref, x], axis=0)
70
106
  return x_updated[-self.n :]
71
107
 
@@ -80,7 +116,7 @@ class ReservoirSamplingUpdate(UpdateStrategy):
80
116
  Update with reservoir sampling of size n.
81
117
  """
82
118
 
83
- def __call__(self, x_ref: np.ndarray, x: np.ndarray, count: int) -> np.ndarray:
119
+ def __call__(self, x_ref: NDArray, x: NDArray, count: int) -> NDArray:
84
120
  if x.shape[0] + count <= self.n:
85
121
  return np.concatenate([x_ref, x], axis=0)
86
122
 
@@ -135,7 +171,7 @@ class BaseDrift:
135
171
  self._x_refcount = 0
136
172
 
137
173
  @property
138
- def x_ref(self) -> np.ndarray:
174
+ def x_ref(self) -> NDArray:
139
175
  if not self.x_ref_preprocessed:
140
176
  self.x_ref_preprocessed = True
141
177
  if self.preprocess_fn is not None:
@@ -151,7 +187,7 @@ class BaseDrift:
151
187
  return x
152
188
 
153
189
 
154
- class BaseUnivariateDrift(BaseDrift):
190
+ class BaseDriftUnivariate(BaseDrift):
155
191
  """
156
192
  Generic drift detector component which serves as a base class for methods using
157
193
  univariate tests. If n_features > 1, a multivariate correction is applied such
@@ -197,13 +233,13 @@ class BaseUnivariateDrift(BaseDrift):
197
233
 
198
234
  @preprocess_x
199
235
  @abstractmethod
200
- def score(self, x: ArrayLike) -> Tuple[np.ndarray, np.ndarray]:
236
+ def score(self, x: ArrayLike) -> Tuple[NDArray[np.float32], NDArray[np.float32]]:
201
237
  """Abstract method to calculate feature score after preprocessing"""
202
238
 
203
- def _apply_correction(self, p_vals: np.ndarray) -> Tuple[int, float]:
239
+ def _apply_correction(self, p_vals: NDArray) -> Tuple[bool, float]:
204
240
  if self.correction == "bonferroni":
205
241
  threshold = self.p_val / self.n_features
206
- drift_pred = int((p_vals < threshold).any())
242
+ drift_pred = bool((p_vals < threshold).any())
207
243
  return drift_pred, threshold
208
244
  elif self.correction == "fdr":
209
245
  n = p_vals.shape[0]
@@ -214,18 +250,18 @@ class BaseUnivariateDrift(BaseDrift):
214
250
  try:
215
251
  idx_threshold = int(np.where(below_threshold)[0].max())
216
252
  except ValueError: # sorted p-values not below thresholds
217
- return int(below_threshold.any()), q_threshold.min()
218
- return int(below_threshold.any()), q_threshold[idx_threshold]
253
+ return bool(below_threshold.any()), q_threshold.min()
254
+ return bool(below_threshold.any()), q_threshold[idx_threshold]
219
255
  else:
220
256
  raise ValueError("`correction` needs to be either `bonferroni` or `fdr`.")
221
257
 
258
+ @set_metadata("dataeval.detectors")
222
259
  @preprocess_x
223
260
  @update_x_ref
224
261
  def predict(
225
262
  self,
226
263
  x: ArrayLike,
227
- drift_type: Literal["batch", "feature"] = "batch",
228
- ) -> Dict[str, Union[int, float, np.ndarray]]:
264
+ ) -> DriftUnivariateOutput:
229
265
  """
230
266
  Predict whether a batch of data has drifted from the reference data and update
231
267
  reference data using specified update strategy.
@@ -234,10 +270,6 @@ class BaseUnivariateDrift(BaseDrift):
234
270
  ----------
235
271
  x : ArrayLike
236
272
  Batch of instances.
237
- drift_type : Literal["batch", "feature"], default "batch"
238
- Predict drift at the 'feature' or 'batch' level. For 'batch', the test
239
- statistics for each feature are aggregated using the Bonferroni or False
240
- Discovery Rate correction (if n_features>1).
241
273
 
242
274
  Returns
243
275
  -------
@@ -248,20 +280,6 @@ class BaseUnivariateDrift(BaseDrift):
248
280
  # compute drift scores
249
281
  p_vals, dist = self.score(x)
250
282
 
251
- # TODO: return both feature-level and batch-level drift predictions by default
252
- # values below p-value threshold are drift
253
- if drift_type == "feature":
254
- drift_pred = (p_vals < self.p_val).astype(int)
255
- threshold = self.p_val
256
- elif drift_type == "batch":
257
- drift_pred, threshold = self._apply_correction(p_vals)
258
- else:
259
- raise ValueError("`drift_type` needs to be either `feature` or `batch`.")
260
-
261
- # populate drift dict
262
- return {
263
- "is_drift": drift_pred,
264
- "p_val": p_vals,
265
- "threshold": threshold,
266
- "distance": dist,
267
- }
283
+ feature_drift = (p_vals < self.p_val).astype(np.bool_)
284
+ drift_pred, threshold = self._apply_correction(p_vals)
285
+ return DriftUnivariateOutput(drift_pred, threshold, feature_drift, self.p_val, p_vals, dist)
@@ -9,14 +9,15 @@ Licensed under Apache Software License (Apache 2.0)
9
9
  from typing import Callable, Literal, Optional, Tuple
10
10
 
11
11
  import numpy as np
12
+ from numpy.typing import ArrayLike, NDArray
12
13
  from scipy.stats import cramervonmises_2samp
13
14
 
14
- from dataeval._internal.interop import ArrayLike, to_numpy
15
+ from dataeval._internal.interop import to_numpy
15
16
 
16
- from .base import BaseUnivariateDrift, UpdateStrategy, preprocess_x
17
+ from .base import BaseDriftUnivariate, UpdateStrategy, preprocess_x
17
18
 
18
19
 
19
- class DriftCVM(BaseUnivariateDrift):
20
+ class DriftCVM(BaseDriftUnivariate):
20
21
  """
21
22
  Cramér-von Mises (CVM) data drift detector, which tests for any change in the
22
23
  distribution of continuous univariate data. For multivariate data, a separate
@@ -75,7 +76,7 @@ class DriftCVM(BaseUnivariateDrift):
75
76
  )
76
77
 
77
78
  @preprocess_x
78
- def score(self, x: ArrayLike) -> Tuple[np.ndarray, np.ndarray]:
79
+ def score(self, x: ArrayLike) -> Tuple[NDArray[np.float32], NDArray[np.float32]]:
79
80
  """
80
81
  Performs the two-sample Cramér-von Mises test(s), computing the p-value and
81
82
  test statistic per feature.
@@ -9,21 +9,22 @@ Licensed under Apache Software License (Apache 2.0)
9
9
  from typing import Callable, Literal, Optional, Tuple
10
10
 
11
11
  import numpy as np
12
+ from numpy.typing import ArrayLike, NDArray
12
13
  from scipy.stats import ks_2samp
13
14
 
14
- from dataeval._internal.interop import ArrayLike, to_numpy
15
+ from dataeval._internal.interop import to_numpy
15
16
 
16
- from .base import BaseUnivariateDrift, UpdateStrategy, preprocess_x
17
+ from .base import BaseDriftUnivariate, UpdateStrategy, preprocess_x
17
18
 
18
19
 
19
- class DriftKS(BaseUnivariateDrift):
20
+ class DriftKS(BaseDriftUnivariate):
20
21
  """
21
22
  Kolmogorov-Smirnov (K-S) data drift detector with Bonferroni or False Discovery
22
23
  Rate (FDR) correction for multivariate data.
23
24
 
24
25
  Parameters
25
26
  ----------
26
- x_ref : np.ndarray
27
+ x_ref : NDArray
27
28
  Data used as reference distribution.
28
29
  p_val : float, default 0.05
29
30
  p-value used for significance of the statistical test for each feature.
@@ -40,7 +41,7 @@ class DriftKS(BaseUnivariateDrift):
40
41
  :py:class:`dataeval.detectors.LastSeenUpdateStrategy`
41
42
  or via reservoir sampling with
42
43
  :py:class:`dataeval.detectors.ReservoirSamplingUpdateStrategy`.
43
- preprocess_fn : Optional[Callable[[np.ndarray], np.ndarray]], default None
44
+ preprocess_fn : Optional[Callable[[NDArray], NDArray]], default None
44
45
  Function to preprocess the data before computing the data drift metrics.
45
46
  Typically a dimensionality reduction technique.
46
47
  correction : Literal["bonferroni", "fdr"], default "bonferroni"
@@ -80,7 +81,7 @@ class DriftKS(BaseUnivariateDrift):
80
81
  self.alternative = alternative
81
82
 
82
83
  @preprocess_x
83
- def score(self, x: ArrayLike) -> Tuple[np.ndarray, np.ndarray]:
84
+ def score(self, x: ArrayLike) -> Tuple[NDArray[np.float32], NDArray[np.float32]]:
84
85
  """
85
86
  Compute K-S scores and statistics per feature.
86
87
 
@@ -6,16 +6,43 @@ Original code Copyright (c) 2023 Seldon Technologies Ltd
6
6
  Licensed under Apache Software License (Apache 2.0)
7
7
  """
8
8
 
9
- from typing import Callable, Dict, Optional, Tuple, Union
9
+ from dataclasses import dataclass
10
+ from typing import Callable, Optional, Tuple
10
11
 
11
12
  import torch
13
+ from numpy.typing import ArrayLike
12
14
 
13
- from dataeval._internal.interop import ArrayLike, to_numpy
15
+ from dataeval._internal.interop import to_numpy
16
+ from dataeval._internal.output import set_metadata
14
17
 
15
- from .base import BaseDrift, UpdateStrategy, preprocess_x, update_x_ref
18
+ from .base import BaseDrift, DriftOutput, UpdateStrategy, preprocess_x, update_x_ref
16
19
  from .torch import GaussianRBF, get_device, mmd2_from_kernel_matrix
17
20
 
18
21
 
22
+ @dataclass(frozen=True)
23
+ class DriftMMDOutput(DriftOutput):
24
+ """
25
+ Attributes
26
+ ----------
27
+ is_drift : bool
28
+ Drift prediction for the images
29
+ threshold : float
30
+ P-value used for significance of the permutation test
31
+ p_val : float
32
+ P-value obtained from the permutation test
33
+ distance : float
34
+ MMD^2 between the reference and test set
35
+ distance_threshold : float
36
+ MMD^2 threshold above which drift is flagged
37
+ """
38
+
39
+ # is_drift: bool
40
+ # threshold: float
41
+ p_val: float
42
+ distance: float
43
+ distance_threshold: float
44
+
45
+
19
46
  class DriftMMD(BaseDrift):
20
47
  """
21
48
  Maximum Mean Discrepancy (MMD) data drift detector using a permutation test.
@@ -74,7 +101,7 @@ class DriftMMD(BaseDrift):
74
101
  super().__init__(x_ref, p_val, x_ref_preprocessed, update_x_ref, preprocess_fn)
75
102
 
76
103
  self.infer_sigma = configure_kernel_from_x_ref
77
- if configure_kernel_from_x_ref and isinstance(sigma, ArrayLike):
104
+ if configure_kernel_from_x_ref and sigma is not None:
78
105
  self.infer_sigma = False
79
106
 
80
107
  self.n_permutations = n_permutations # nb of iterations through permutation test
@@ -83,7 +110,7 @@ class DriftMMD(BaseDrift):
83
110
  self.device = get_device(device)
84
111
 
85
112
  # initialize kernel
86
- sigma_tensor = torch.from_numpy(to_numpy(sigma)).to(self.device) if isinstance(sigma, ArrayLike) else None
113
+ sigma_tensor = torch.from_numpy(to_numpy(sigma)).to(self.device) if sigma is not None else None
87
114
  self.kernel = kernel(sigma_tensor).to(self.device) if kernel == GaussianRBF else kernel
88
115
 
89
116
  # compute kernel matrix for the reference data
@@ -128,19 +155,17 @@ class DriftMMD(BaseDrift):
128
155
  mmd2_permuted = torch.Tensor(
129
156
  [mmd2_from_kernel_matrix(kernel_mat, n, permute=True, zero_diag=False) for _ in range(self.n_permutations)]
130
157
  )
131
- mmd2, mmd2_permuted = mmd2.cpu(), mmd2_permuted.cpu()
158
+ mmd2, mmd2_permuted = mmd2.detach().cpu(), mmd2_permuted.detach().cpu()
132
159
  p_val = (mmd2 <= mmd2_permuted).float().mean()
133
160
  # compute distance threshold
134
161
  idx_threshold = int(self.p_val * len(mmd2_permuted))
135
162
  distance_threshold = torch.sort(mmd2_permuted, descending=True).values[idx_threshold]
136
163
  return p_val.numpy().item(), mmd2.numpy().item(), distance_threshold.numpy()
137
164
 
165
+ @set_metadata("dataeval.detectors")
138
166
  @preprocess_x
139
167
  @update_x_ref
140
- def predict(
141
- self,
142
- x: ArrayLike,
143
- ) -> Dict[str, Union[int, float]]:
168
+ def predict(self, x: ArrayLike) -> DriftMMDOutput:
144
169
  """
145
170
  Predict whether a batch of data has drifted from the reference data and then
146
171
  updates reference data using specified strategy.
@@ -152,17 +177,12 @@ class DriftMMD(BaseDrift):
152
177
 
153
178
  Returns
154
179
  -------
155
- Dictionary containing the drift prediction, p-value, threshold and MMD metric.
180
+ DriftMMDOutput
181
+ Output class containing the drift prediction, p-value, threshold and MMD metric.
156
182
  """
157
183
  # compute drift scores
158
184
  p_val, dist, distance_threshold = self.score(x)
159
- drift_pred = int(p_val < self.p_val)
185
+ drift_pred = bool(p_val < self.p_val)
160
186
 
161
187
  # populate drift dict
162
- return {
163
- "is_drift": drift_pred,
164
- "p_val": p_val,
165
- "threshold": self.p_val,
166
- "distance": dist,
167
- "distance_threshold": distance_threshold,
168
- }
188
+ return DriftMMDOutput(drift_pred, self.p_val, p_val, dist, distance_threshold)
@@ -12,6 +12,7 @@ from typing import Callable, Optional, Type, Union
12
12
  import numpy as np
13
13
  import torch
14
14
  import torch.nn as nn
15
+ from numpy.typing import NDArray
15
16
 
16
17
 
17
18
  def get_device(device: Optional[Union[str, torch.device]] = None) -> torch.device:
@@ -74,13 +75,13 @@ def mmd2_from_kernel_matrix(
74
75
 
75
76
 
76
77
  def predict_batch(
77
- x: Union[np.ndarray, torch.Tensor],
78
+ x: Union[NDArray, torch.Tensor],
78
79
  model: Union[Callable, nn.Module, nn.Sequential],
79
80
  device: Optional[torch.device] = None,
80
81
  batch_size: int = int(1e10),
81
82
  preprocess_fn: Optional[Callable] = None,
82
83
  dtype: Union[Type[np.generic], torch.dtype] = np.float32,
83
- ) -> Union[np.ndarray, torch.Tensor, tuple]:
84
+ ) -> Union[NDArray, torch.Tensor, tuple]:
84
85
  """
85
86
  Make batch predictions on a model.
86
87
 
@@ -138,7 +139,7 @@ def predict_batch(
138
139
  else:
139
140
  raise TypeError(
140
141
  f"Model output type {type(preds_tmp)} not supported. The model \
141
- output type needs to be one of list, tuple, np.ndarray or \
142
+ output type needs to be one of list, tuple, NDArray or \
142
143
  torch.Tensor."
143
144
  )
144
145
  concat = partial(np.concatenate, axis=0) if return_np else partial(torch.cat, dim=0)
@@ -149,13 +150,13 @@ def predict_batch(
149
150
 
150
151
 
151
152
  def preprocess_drift(
152
- x: np.ndarray,
153
+ x: NDArray,
153
154
  model: nn.Module,
154
155
  device: Optional[torch.device] = None,
155
156
  preprocess_batch_fn: Optional[Callable] = None,
156
157
  batch_size: int = int(1e10),
157
158
  dtype: Union[Type[np.generic], torch.dtype] = np.float32,
158
- ) -> Union[np.ndarray, torch.Tensor, tuple]:
159
+ ) -> Union[NDArray, torch.Tensor, tuple]:
159
160
  """
160
161
  Prediction function used for preprocessing step of drift detector.
161
162
 
@@ -7,24 +7,23 @@ Licensed under Apache Software License (Apache 2.0)
7
7
  """
8
8
 
9
9
  from functools import partial
10
- from typing import Callable, Dict, Literal, Optional, Union
10
+ from typing import Callable, Literal, Optional
11
11
 
12
12
  import numpy as np
13
+ from numpy.typing import ArrayLike, NDArray
13
14
  from scipy.special import softmax
14
15
  from scipy.stats import entropy
15
16
 
16
- from dataeval._internal.interop import ArrayLike
17
-
18
- from .base import UpdateStrategy
17
+ from .base import DriftUnivariateOutput, UpdateStrategy
19
18
  from .ks import DriftKS
20
19
  from .torch import get_device, preprocess_drift
21
20
 
22
21
 
23
22
  def classifier_uncertainty(
24
- x: np.ndarray,
23
+ x: NDArray,
25
24
  model_fn: Callable,
26
25
  preds_type: Literal["probs", "logits"] = "probs",
27
- ) -> np.ndarray:
26
+ ) -> NDArray:
28
27
  """
29
28
  Evaluate model_fn on x and transform predictions to prediction uncertainties.
30
29
 
@@ -112,7 +111,7 @@ class DriftUncertainty:
112
111
  preprocess_batch_fn: Optional[Callable] = None,
113
112
  device: Optional[str] = None,
114
113
  ) -> None:
115
- def model_fn(x: np.ndarray) -> np.ndarray:
114
+ def model_fn(x: NDArray) -> NDArray:
116
115
  return preprocess_drift(
117
116
  x,
118
117
  model, # type: ignore
@@ -135,7 +134,7 @@ class DriftUncertainty:
135
134
  preprocess_fn=preprocess_fn, # type: ignore
136
135
  )
137
136
 
138
- def predict(self, x: ArrayLike) -> Dict[str, Union[int, float, np.ndarray]]:
137
+ def predict(self, x: ArrayLike) -> DriftUnivariateOutput:
139
138
  """
140
139
  Predict whether a batch of data has drifted from the reference data.
141
140