dataeval 0.64.0__py3-none-any.whl → 0.65.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dataeval/__init__.py +2 -2
- dataeval/_internal/detectors/clusterer.py +46 -34
- dataeval/_internal/detectors/drift/base.py +52 -35
- dataeval/_internal/detectors/drift/cvm.py +4 -4
- dataeval/_internal/detectors/drift/ks.py +6 -6
- dataeval/_internal/detectors/drift/mmd.py +35 -16
- dataeval/_internal/detectors/drift/torch.py +6 -5
- dataeval/_internal/detectors/drift/uncertainty.py +7 -7
- dataeval/_internal/detectors/duplicates.py +55 -29
- dataeval/_internal/detectors/linter.py +40 -24
- dataeval/_internal/detectors/ood/base.py +36 -15
- dataeval/_internal/detectors/ood/llr.py +7 -7
- dataeval/_internal/flags.py +42 -21
- dataeval/_internal/interop.py +2 -2
- dataeval/_internal/metrics/balance.py +10 -2
- dataeval/_internal/metrics/ber.py +6 -5
- dataeval/_internal/metrics/coverage.py +15 -8
- dataeval/_internal/metrics/divergence.py +41 -7
- dataeval/_internal/metrics/diversity.py +17 -12
- dataeval/_internal/metrics/parity.py +30 -43
- dataeval/_internal/metrics/stats.py +196 -317
- dataeval/_internal/metrics/uap.py +5 -2
- dataeval/_internal/metrics/utils.py +70 -33
- dataeval/_internal/models/tensorflow/losses.py +3 -3
- dataeval/_internal/models/tensorflow/trainer.py +3 -2
- dataeval/_internal/models/tensorflow/utils.py +4 -3
- dataeval/_internal/output.py +82 -0
- dataeval/_internal/workflows/sufficiency.py +96 -107
- dataeval/flags/__init__.py +2 -2
- dataeval/metrics/__init__.py +3 -3
- {dataeval-0.64.0.dist-info → dataeval-0.65.0.dist-info}/METADATA +1 -1
- dataeval-0.65.0.dist-info/RECORD +60 -0
- dataeval/_internal/metrics/base.py +0 -10
- dataeval-0.64.0.dist-info/RECORD +0 -60
- {dataeval-0.64.0.dist-info → dataeval-0.65.0.dist-info}/LICENSE.txt +0 -0
- {dataeval-0.64.0.dist-info → dataeval-0.65.0.dist-info}/WHEEL +0 -0
dataeval/__init__.py
CHANGED
@@ -1,26 +1,50 @@
|
|
1
|
+
from dataclasses import dataclass
|
1
2
|
from typing import Dict, Iterable, List, NamedTuple, Tuple, Union, cast
|
2
3
|
|
3
4
|
import numpy as np
|
4
|
-
from numpy.typing import ArrayLike
|
5
|
+
from numpy.typing import ArrayLike, NDArray
|
5
6
|
from scipy.cluster.hierarchy import linkage
|
6
7
|
from scipy.spatial.distance import pdist, squareform
|
7
8
|
|
8
9
|
from dataeval._internal.interop import to_numpy
|
10
|
+
from dataeval._internal.metrics.utils import flatten
|
11
|
+
from dataeval._internal.output import OutputMetadata, set_metadata
|
9
12
|
|
10
13
|
|
11
|
-
|
14
|
+
@dataclass(frozen=True)
|
15
|
+
class ClustererOutput(OutputMetadata):
|
16
|
+
"""
|
17
|
+
Attributes
|
18
|
+
----------
|
19
|
+
outliers : List[int]
|
20
|
+
Indices that do not fall within a cluster
|
21
|
+
potential_outliers : List[int]
|
22
|
+
Indices which are near the border between belonging in the cluster and being an outlier
|
23
|
+
duplicates : List[List[int]]
|
24
|
+
Groups of indices that are exact duplicates
|
25
|
+
potential_duplicates : List[List[int]]
|
26
|
+
Groups of indices which are not exact but closely related data points
|
27
|
+
"""
|
28
|
+
|
29
|
+
outliers: List[int]
|
30
|
+
potential_outliers: List[int]
|
31
|
+
duplicates: List[List[int]]
|
32
|
+
potential_duplicates: List[List[int]]
|
33
|
+
|
34
|
+
|
35
|
+
def extend_linkage(link_arr: NDArray) -> NDArray:
|
12
36
|
"""
|
13
37
|
Adds a column to the linkage matrix link_arr that tracks the new id assigned
|
14
38
|
to each row
|
15
39
|
|
16
40
|
Parameters
|
17
41
|
----------
|
18
|
-
link_arr :
|
42
|
+
link_arr : NDArray
|
19
43
|
linkage matrix
|
20
44
|
|
21
45
|
Returns
|
22
46
|
-------
|
23
|
-
|
47
|
+
NDArray
|
24
48
|
linkage matrix with adjusted shape, new shape (link_arr.shape[0], link_arr.shape[1]+1)
|
25
49
|
"""
|
26
50
|
# Adjusting linkage matrix to accommodate renumbering
|
@@ -35,7 +59,7 @@ def extend_linkage(link_arr: np.ndarray) -> np.ndarray:
|
|
35
59
|
class Cluster:
|
36
60
|
__slots__ = "merged", "samples", "sample_dist", "is_copy", "count", "dist_avg", "dist_std", "out1", "out2"
|
37
61
|
|
38
|
-
def __init__(self, merged: int, samples:
|
62
|
+
def __init__(self, merged: int, samples: NDArray, sample_dist: Union[float, NDArray], is_copy: bool = False):
|
39
63
|
self.merged = merged
|
40
64
|
self.samples = np.array(samples, dtype=np.int32)
|
41
65
|
self.sample_dist = np.array([sample_dist] if np.isscalar(sample_dist) else sample_dist)
|
@@ -130,13 +154,13 @@ class Clusterer:
|
|
130
154
|
self._on_init(dataset)
|
131
155
|
|
132
156
|
def _on_init(self, dataset: ArrayLike):
|
133
|
-
self._data:
|
157
|
+
self._data: NDArray = flatten(to_numpy(dataset))
|
134
158
|
self._validate_data(self._data)
|
135
159
|
self._num_samples = len(self._data)
|
136
160
|
|
137
|
-
self._darr:
|
138
|
-
self._sqdmat:
|
139
|
-
self._larr:
|
161
|
+
self._darr: NDArray = pdist(self._data, metric="euclidean")
|
162
|
+
self._sqdmat: NDArray = squareform(self._darr)
|
163
|
+
self._larr: NDArray = extend_linkage(linkage(self._darr))
|
140
164
|
self._max_clusters: int = np.count_nonzero(self._larr[:, 3] == 2)
|
141
165
|
|
142
166
|
min_num = int(self._num_samples * 0.05)
|
@@ -146,7 +170,7 @@ class Clusterer:
|
|
146
170
|
self._last_good_merge_levels = None
|
147
171
|
|
148
172
|
@property
|
149
|
-
def data(self) ->
|
173
|
+
def data(self) -> NDArray:
|
150
174
|
return self._data
|
151
175
|
|
152
176
|
@data.setter
|
@@ -166,10 +190,10 @@ class Clusterer:
|
|
166
190
|
return self._last_good_merge_levels
|
167
191
|
|
168
192
|
@classmethod
|
169
|
-
def _validate_data(cls, x:
|
193
|
+
def _validate_data(cls, x: NDArray):
|
170
194
|
"""Checks that the data has the correct size, shape, and format"""
|
171
195
|
if not isinstance(x, np.ndarray):
|
172
|
-
raise TypeError(f"Data should be of type
|
196
|
+
raise TypeError(f"Data should be of type NDArray; got {type(x)}")
|
173
197
|
|
174
198
|
if x.ndim != 2:
|
175
199
|
raise ValueError(
|
@@ -240,7 +264,7 @@ class Clusterer:
|
|
240
264
|
clusters[level_id].setdefault(cid, cluster)
|
241
265
|
return clusters
|
242
266
|
|
243
|
-
def _get_cluster_distances(self) ->
|
267
|
+
def _get_cluster_distances(self) -> NDArray:
|
244
268
|
"""Calculates the minimum distances between clusters are each level"""
|
245
269
|
# Cluster distance matrix
|
246
270
|
max_level = self.clusters.max_level
|
@@ -261,7 +285,7 @@ class Clusterer:
|
|
261
285
|
|
262
286
|
return cluster_matrix
|
263
287
|
|
264
|
-
def _calc_merge_indices(self, merge_mean: List[
|
288
|
+
def _calc_merge_indices(self, merge_mean: List[NDArray], intra_max: List[float]) -> NDArray:
|
265
289
|
"""
|
266
290
|
Determine what clusters should be merged and return their indices
|
267
291
|
"""
|
@@ -284,7 +308,7 @@ class Clusterer:
|
|
284
308
|
mask2 = mask2_vals < one_std_check
|
285
309
|
return np.logical_or(desired_merge, mask2)
|
286
310
|
|
287
|
-
def _generate_merge_list(self, cluster_matrix:
|
311
|
+
def _generate_merge_list(self, cluster_matrix: NDArray) -> List[ClusterMergeEntry]:
|
288
312
|
"""
|
289
313
|
Runs through the clusters dictionary determining when clusters merge,
|
290
314
|
and how close are those clusters when they merge.
|
@@ -464,35 +488,23 @@ class Clusterer:
|
|
464
488
|
|
465
489
|
return exact_dupes, near_dupes
|
466
490
|
|
467
|
-
|
491
|
+
# TODO: Move data input to evaluate from class
|
492
|
+
@set_metadata("dataeval.detectors", ["data"])
|
493
|
+
def evaluate(self) -> ClustererOutput:
|
468
494
|
"""Finds and flags indices of the data for outliers and duplicates
|
469
495
|
|
470
496
|
Returns
|
471
497
|
-------
|
472
|
-
|
473
|
-
outliers
|
474
|
-
List of indices that do not fall within a cluster
|
475
|
-
potential_outliers :
|
476
|
-
List of indices which are near the border between belonging in the cluster and being an outlier
|
477
|
-
duplicates :
|
478
|
-
List of groups of indices that are exact duplicates
|
479
|
-
potential_duplicates :
|
480
|
-
List of groups of indices which are not exact but closely related data points
|
498
|
+
ClustererOutput
|
499
|
+
The outliers and duplicate indices found in the data
|
481
500
|
|
482
501
|
Example
|
483
502
|
-------
|
484
503
|
>>> cluster.evaluate()
|
485
|
-
|
504
|
+
ClustererOutput(outliers=[18, 21, 34, 35, 45], potential_outliers=[13, 15, 42], duplicates=[[9, 24], [23, 48]], potential_duplicates=[[1, 11]])
|
486
505
|
""" # noqa: E501
|
487
506
|
|
488
507
|
outliers, potential_outliers = self.find_outliers(self.last_good_merge_levels)
|
489
508
|
duplicates, potential_duplicates = self.find_duplicates(self.last_good_merge_levels)
|
490
509
|
|
491
|
-
|
492
|
-
"outliers": outliers,
|
493
|
-
"potential_outliers": potential_outliers,
|
494
|
-
"duplicates": duplicates,
|
495
|
-
"potential_duplicates": potential_duplicates,
|
496
|
-
}
|
497
|
-
|
498
|
-
return ret
|
510
|
+
return ClustererOutput(outliers, potential_outliers, duplicates, potential_duplicates)
|
@@ -7,13 +7,48 @@ Licensed under Apache Software License (Apache 2.0)
|
|
7
7
|
"""
|
8
8
|
|
9
9
|
from abc import ABC, abstractmethod
|
10
|
+
from dataclasses import dataclass
|
10
11
|
from functools import wraps
|
11
|
-
from typing import Callable,
|
12
|
+
from typing import Callable, Literal, Optional, Tuple
|
12
13
|
|
13
14
|
import numpy as np
|
14
|
-
from numpy.typing import ArrayLike
|
15
|
+
from numpy.typing import ArrayLike, NDArray
|
15
16
|
|
16
17
|
from dataeval._internal.interop import to_numpy
|
18
|
+
from dataeval._internal.output import OutputMetadata, set_metadata
|
19
|
+
|
20
|
+
|
21
|
+
@dataclass(frozen=True)
|
22
|
+
class DriftOutput(OutputMetadata):
|
23
|
+
is_drift: bool
|
24
|
+
threshold: float
|
25
|
+
|
26
|
+
|
27
|
+
@dataclass(frozen=True)
|
28
|
+
class DriftUnivariateOutput(DriftOutput):
|
29
|
+
"""
|
30
|
+
Attributes
|
31
|
+
----------
|
32
|
+
is_drift : bool
|
33
|
+
Drift prediction for the images
|
34
|
+
threshold : float
|
35
|
+
Threshold after multivariate correction if needed
|
36
|
+
feature_drift : NDArray[np.bool_]
|
37
|
+
Feature-level array of images detected to have drifted
|
38
|
+
feature_threshold : float
|
39
|
+
Feature-level threshold to determine drift
|
40
|
+
p_vals : NDArray[np.float32]
|
41
|
+
Feature-level p-values
|
42
|
+
distances : NDArray[np.float32]
|
43
|
+
Feature-level distances
|
44
|
+
"""
|
45
|
+
|
46
|
+
# is_drift: bool
|
47
|
+
# threshold: float
|
48
|
+
feature_drift: NDArray[np.bool_]
|
49
|
+
feature_threshold: float
|
50
|
+
p_vals: NDArray[np.float32]
|
51
|
+
distances: NDArray[np.float32]
|
17
52
|
|
18
53
|
|
19
54
|
def update_x_ref(fn):
|
@@ -52,7 +87,7 @@ class UpdateStrategy(ABC):
|
|
52
87
|
self.n = n
|
53
88
|
|
54
89
|
@abstractmethod
|
55
|
-
def __call__(self, x_ref:
|
90
|
+
def __call__(self, x_ref: NDArray, x: NDArray, count: int) -> NDArray:
|
56
91
|
"""Abstract implementation of update strategy"""
|
57
92
|
|
58
93
|
|
@@ -66,7 +101,7 @@ class LastSeenUpdate(UpdateStrategy):
|
|
66
101
|
Update with last n instances seen by the detector.
|
67
102
|
"""
|
68
103
|
|
69
|
-
def __call__(self, x_ref:
|
104
|
+
def __call__(self, x_ref: NDArray, x: NDArray, count: int) -> NDArray:
|
70
105
|
x_updated = np.concatenate([x_ref, x], axis=0)
|
71
106
|
return x_updated[-self.n :]
|
72
107
|
|
@@ -81,7 +116,7 @@ class ReservoirSamplingUpdate(UpdateStrategy):
|
|
81
116
|
Update with reservoir sampling of size n.
|
82
117
|
"""
|
83
118
|
|
84
|
-
def __call__(self, x_ref:
|
119
|
+
def __call__(self, x_ref: NDArray, x: NDArray, count: int) -> NDArray:
|
85
120
|
if x.shape[0] + count <= self.n:
|
86
121
|
return np.concatenate([x_ref, x], axis=0)
|
87
122
|
|
@@ -136,7 +171,7 @@ class BaseDrift:
|
|
136
171
|
self._x_refcount = 0
|
137
172
|
|
138
173
|
@property
|
139
|
-
def x_ref(self) ->
|
174
|
+
def x_ref(self) -> NDArray:
|
140
175
|
if not self.x_ref_preprocessed:
|
141
176
|
self.x_ref_preprocessed = True
|
142
177
|
if self.preprocess_fn is not None:
|
@@ -152,7 +187,7 @@ class BaseDrift:
|
|
152
187
|
return x
|
153
188
|
|
154
189
|
|
155
|
-
class
|
190
|
+
class BaseDriftUnivariate(BaseDrift):
|
156
191
|
"""
|
157
192
|
Generic drift detector component which serves as a base class for methods using
|
158
193
|
univariate tests. If n_features > 1, a multivariate correction is applied such
|
@@ -198,13 +233,13 @@ class BaseUnivariateDrift(BaseDrift):
|
|
198
233
|
|
199
234
|
@preprocess_x
|
200
235
|
@abstractmethod
|
201
|
-
def score(self, x: ArrayLike) -> Tuple[np.
|
236
|
+
def score(self, x: ArrayLike) -> Tuple[NDArray[np.float32], NDArray[np.float32]]:
|
202
237
|
"""Abstract method to calculate feature score after preprocessing"""
|
203
238
|
|
204
|
-
def _apply_correction(self, p_vals:
|
239
|
+
def _apply_correction(self, p_vals: NDArray) -> Tuple[bool, float]:
|
205
240
|
if self.correction == "bonferroni":
|
206
241
|
threshold = self.p_val / self.n_features
|
207
|
-
drift_pred =
|
242
|
+
drift_pred = bool((p_vals < threshold).any())
|
208
243
|
return drift_pred, threshold
|
209
244
|
elif self.correction == "fdr":
|
210
245
|
n = p_vals.shape[0]
|
@@ -215,18 +250,18 @@ class BaseUnivariateDrift(BaseDrift):
|
|
215
250
|
try:
|
216
251
|
idx_threshold = int(np.where(below_threshold)[0].max())
|
217
252
|
except ValueError: # sorted p-values not below thresholds
|
218
|
-
return
|
219
|
-
return
|
253
|
+
return bool(below_threshold.any()), q_threshold.min()
|
254
|
+
return bool(below_threshold.any()), q_threshold[idx_threshold]
|
220
255
|
else:
|
221
256
|
raise ValueError("`correction` needs to be either `bonferroni` or `fdr`.")
|
222
257
|
|
258
|
+
@set_metadata("dataeval.detectors")
|
223
259
|
@preprocess_x
|
224
260
|
@update_x_ref
|
225
261
|
def predict(
|
226
262
|
self,
|
227
263
|
x: ArrayLike,
|
228
|
-
|
229
|
-
) -> Dict[str, Union[int, float, np.ndarray]]:
|
264
|
+
) -> DriftUnivariateOutput:
|
230
265
|
"""
|
231
266
|
Predict whether a batch of data has drifted from the reference data and update
|
232
267
|
reference data using specified update strategy.
|
@@ -235,10 +270,6 @@ class BaseUnivariateDrift(BaseDrift):
|
|
235
270
|
----------
|
236
271
|
x : ArrayLike
|
237
272
|
Batch of instances.
|
238
|
-
drift_type : Literal["batch", "feature"], default "batch"
|
239
|
-
Predict drift at the 'feature' or 'batch' level. For 'batch', the test
|
240
|
-
statistics for each feature are aggregated using the Bonferroni or False
|
241
|
-
Discovery Rate correction (if n_features>1).
|
242
273
|
|
243
274
|
Returns
|
244
275
|
-------
|
@@ -249,20 +280,6 @@ class BaseUnivariateDrift(BaseDrift):
|
|
249
280
|
# compute drift scores
|
250
281
|
p_vals, dist = self.score(x)
|
251
282
|
|
252
|
-
|
253
|
-
|
254
|
-
|
255
|
-
drift_pred = (p_vals < self.p_val).astype(int)
|
256
|
-
threshold = self.p_val
|
257
|
-
elif drift_type == "batch":
|
258
|
-
drift_pred, threshold = self._apply_correction(p_vals)
|
259
|
-
else:
|
260
|
-
raise ValueError("`drift_type` needs to be either `feature` or `batch`.")
|
261
|
-
|
262
|
-
# populate drift dict
|
263
|
-
return {
|
264
|
-
"is_drift": drift_pred,
|
265
|
-
"p_val": p_vals,
|
266
|
-
"threshold": threshold,
|
267
|
-
"distance": dist,
|
268
|
-
}
|
283
|
+
feature_drift = (p_vals < self.p_val).astype(np.bool_)
|
284
|
+
drift_pred, threshold = self._apply_correction(p_vals)
|
285
|
+
return DriftUnivariateOutput(drift_pred, threshold, feature_drift, self.p_val, p_vals, dist)
|
@@ -9,15 +9,15 @@ Licensed under Apache Software License (Apache 2.0)
|
|
9
9
|
from typing import Callable, Literal, Optional, Tuple
|
10
10
|
|
11
11
|
import numpy as np
|
12
|
-
from numpy.typing import ArrayLike
|
12
|
+
from numpy.typing import ArrayLike, NDArray
|
13
13
|
from scipy.stats import cramervonmises_2samp
|
14
14
|
|
15
15
|
from dataeval._internal.interop import to_numpy
|
16
16
|
|
17
|
-
from .base import
|
17
|
+
from .base import BaseDriftUnivariate, UpdateStrategy, preprocess_x
|
18
18
|
|
19
19
|
|
20
|
-
class DriftCVM(
|
20
|
+
class DriftCVM(BaseDriftUnivariate):
|
21
21
|
"""
|
22
22
|
Cramér-von Mises (CVM) data drift detector, which tests for any change in the
|
23
23
|
distribution of continuous univariate data. For multivariate data, a separate
|
@@ -76,7 +76,7 @@ class DriftCVM(BaseUnivariateDrift):
|
|
76
76
|
)
|
77
77
|
|
78
78
|
@preprocess_x
|
79
|
-
def score(self, x: ArrayLike) -> Tuple[np.
|
79
|
+
def score(self, x: ArrayLike) -> Tuple[NDArray[np.float32], NDArray[np.float32]]:
|
80
80
|
"""
|
81
81
|
Performs the two-sample Cramér-von Mises test(s), computing the p-value and
|
82
82
|
test statistic per feature.
|
@@ -9,22 +9,22 @@ Licensed under Apache Software License (Apache 2.0)
|
|
9
9
|
from typing import Callable, Literal, Optional, Tuple
|
10
10
|
|
11
11
|
import numpy as np
|
12
|
-
from numpy.typing import ArrayLike
|
12
|
+
from numpy.typing import ArrayLike, NDArray
|
13
13
|
from scipy.stats import ks_2samp
|
14
14
|
|
15
15
|
from dataeval._internal.interop import to_numpy
|
16
16
|
|
17
|
-
from .base import
|
17
|
+
from .base import BaseDriftUnivariate, UpdateStrategy, preprocess_x
|
18
18
|
|
19
19
|
|
20
|
-
class DriftKS(
|
20
|
+
class DriftKS(BaseDriftUnivariate):
|
21
21
|
"""
|
22
22
|
Kolmogorov-Smirnov (K-S) data drift detector with Bonferroni or False Discovery
|
23
23
|
Rate (FDR) correction for multivariate data.
|
24
24
|
|
25
25
|
Parameters
|
26
26
|
----------
|
27
|
-
x_ref :
|
27
|
+
x_ref : NDArray
|
28
28
|
Data used as reference distribution.
|
29
29
|
p_val : float, default 0.05
|
30
30
|
p-value used for significance of the statistical test for each feature.
|
@@ -41,7 +41,7 @@ class DriftKS(BaseUnivariateDrift):
|
|
41
41
|
:py:class:`dataeval.detectors.LastSeenUpdateStrategy`
|
42
42
|
or via reservoir sampling with
|
43
43
|
:py:class:`dataeval.detectors.ReservoirSamplingUpdateStrategy`.
|
44
|
-
preprocess_fn : Optional[Callable[[
|
44
|
+
preprocess_fn : Optional[Callable[[NDArray], NDArray]], default None
|
45
45
|
Function to preprocess the data before computing the data drift metrics.
|
46
46
|
Typically a dimensionality reduction technique.
|
47
47
|
correction : Literal["bonferroni", "fdr"], default "bonferroni"
|
@@ -81,7 +81,7 @@ class DriftKS(BaseUnivariateDrift):
|
|
81
81
|
self.alternative = alternative
|
82
82
|
|
83
83
|
@preprocess_x
|
84
|
-
def score(self, x: ArrayLike) -> Tuple[np.
|
84
|
+
def score(self, x: ArrayLike) -> Tuple[NDArray[np.float32], NDArray[np.float32]]:
|
85
85
|
"""
|
86
86
|
Compute K-S scores and statistics per feature.
|
87
87
|
|
@@ -6,17 +6,43 @@ Original code Copyright (c) 2023 Seldon Technologies Ltd
|
|
6
6
|
Licensed under Apache Software License (Apache 2.0)
|
7
7
|
"""
|
8
8
|
|
9
|
-
from
|
9
|
+
from dataclasses import dataclass
|
10
|
+
from typing import Callable, Optional, Tuple
|
10
11
|
|
11
12
|
import torch
|
12
13
|
from numpy.typing import ArrayLike
|
13
14
|
|
14
15
|
from dataeval._internal.interop import to_numpy
|
16
|
+
from dataeval._internal.output import set_metadata
|
15
17
|
|
16
|
-
from .base import BaseDrift, UpdateStrategy, preprocess_x, update_x_ref
|
18
|
+
from .base import BaseDrift, DriftOutput, UpdateStrategy, preprocess_x, update_x_ref
|
17
19
|
from .torch import GaussianRBF, get_device, mmd2_from_kernel_matrix
|
18
20
|
|
19
21
|
|
22
|
+
@dataclass(frozen=True)
|
23
|
+
class DriftMMDOutput(DriftOutput):
|
24
|
+
"""
|
25
|
+
Attributes
|
26
|
+
----------
|
27
|
+
is_drift : bool
|
28
|
+
Drift prediction for the images
|
29
|
+
threshold : float
|
30
|
+
P-value used for significance of the permutation test
|
31
|
+
p_val : float
|
32
|
+
P-value obtained from the permutation test
|
33
|
+
distance : float
|
34
|
+
MMD^2 between the reference and test set
|
35
|
+
distance_threshold : float
|
36
|
+
MMD^2 threshold above which drift is flagged
|
37
|
+
"""
|
38
|
+
|
39
|
+
# is_drift: bool
|
40
|
+
# threshold: float
|
41
|
+
p_val: float
|
42
|
+
distance: float
|
43
|
+
distance_threshold: float
|
44
|
+
|
45
|
+
|
20
46
|
class DriftMMD(BaseDrift):
|
21
47
|
"""
|
22
48
|
Maximum Mean Discrepancy (MMD) data drift detector using a permutation test.
|
@@ -129,19 +155,17 @@ class DriftMMD(BaseDrift):
|
|
129
155
|
mmd2_permuted = torch.Tensor(
|
130
156
|
[mmd2_from_kernel_matrix(kernel_mat, n, permute=True, zero_diag=False) for _ in range(self.n_permutations)]
|
131
157
|
)
|
132
|
-
mmd2, mmd2_permuted = mmd2.cpu(), mmd2_permuted.cpu()
|
158
|
+
mmd2, mmd2_permuted = mmd2.detach().cpu(), mmd2_permuted.detach().cpu()
|
133
159
|
p_val = (mmd2 <= mmd2_permuted).float().mean()
|
134
160
|
# compute distance threshold
|
135
161
|
idx_threshold = int(self.p_val * len(mmd2_permuted))
|
136
162
|
distance_threshold = torch.sort(mmd2_permuted, descending=True).values[idx_threshold]
|
137
163
|
return p_val.numpy().item(), mmd2.numpy().item(), distance_threshold.numpy()
|
138
164
|
|
165
|
+
@set_metadata("dataeval.detectors")
|
139
166
|
@preprocess_x
|
140
167
|
@update_x_ref
|
141
|
-
def predict(
|
142
|
-
self,
|
143
|
-
x: ArrayLike,
|
144
|
-
) -> Dict[str, Union[int, float]]:
|
168
|
+
def predict(self, x: ArrayLike) -> DriftMMDOutput:
|
145
169
|
"""
|
146
170
|
Predict whether a batch of data has drifted from the reference data and then
|
147
171
|
updates reference data using specified strategy.
|
@@ -153,17 +177,12 @@ class DriftMMD(BaseDrift):
|
|
153
177
|
|
154
178
|
Returns
|
155
179
|
-------
|
156
|
-
|
180
|
+
DriftMMDOutput
|
181
|
+
Output class containing the drift prediction, p-value, threshold and MMD metric.
|
157
182
|
"""
|
158
183
|
# compute drift scores
|
159
184
|
p_val, dist, distance_threshold = self.score(x)
|
160
|
-
drift_pred =
|
185
|
+
drift_pred = bool(p_val < self.p_val)
|
161
186
|
|
162
187
|
# populate drift dict
|
163
|
-
return
|
164
|
-
"is_drift": drift_pred,
|
165
|
-
"p_val": p_val,
|
166
|
-
"threshold": self.p_val,
|
167
|
-
"distance": dist,
|
168
|
-
"distance_threshold": distance_threshold,
|
169
|
-
}
|
188
|
+
return DriftMMDOutput(drift_pred, self.p_val, p_val, dist, distance_threshold)
|
@@ -12,6 +12,7 @@ from typing import Callable, Optional, Type, Union
|
|
12
12
|
import numpy as np
|
13
13
|
import torch
|
14
14
|
import torch.nn as nn
|
15
|
+
from numpy.typing import NDArray
|
15
16
|
|
16
17
|
|
17
18
|
def get_device(device: Optional[Union[str, torch.device]] = None) -> torch.device:
|
@@ -74,13 +75,13 @@ def mmd2_from_kernel_matrix(
|
|
74
75
|
|
75
76
|
|
76
77
|
def predict_batch(
|
77
|
-
x: Union[
|
78
|
+
x: Union[NDArray, torch.Tensor],
|
78
79
|
model: Union[Callable, nn.Module, nn.Sequential],
|
79
80
|
device: Optional[torch.device] = None,
|
80
81
|
batch_size: int = int(1e10),
|
81
82
|
preprocess_fn: Optional[Callable] = None,
|
82
83
|
dtype: Union[Type[np.generic], torch.dtype] = np.float32,
|
83
|
-
) -> Union[
|
84
|
+
) -> Union[NDArray, torch.Tensor, tuple]:
|
84
85
|
"""
|
85
86
|
Make batch predictions on a model.
|
86
87
|
|
@@ -138,7 +139,7 @@ def predict_batch(
|
|
138
139
|
else:
|
139
140
|
raise TypeError(
|
140
141
|
f"Model output type {type(preds_tmp)} not supported. The model \
|
141
|
-
output type needs to be one of list, tuple,
|
142
|
+
output type needs to be one of list, tuple, NDArray or \
|
142
143
|
torch.Tensor."
|
143
144
|
)
|
144
145
|
concat = partial(np.concatenate, axis=0) if return_np else partial(torch.cat, dim=0)
|
@@ -149,13 +150,13 @@ def predict_batch(
|
|
149
150
|
|
150
151
|
|
151
152
|
def preprocess_drift(
|
152
|
-
x:
|
153
|
+
x: NDArray,
|
153
154
|
model: nn.Module,
|
154
155
|
device: Optional[torch.device] = None,
|
155
156
|
preprocess_batch_fn: Optional[Callable] = None,
|
156
157
|
batch_size: int = int(1e10),
|
157
158
|
dtype: Union[Type[np.generic], torch.dtype] = np.float32,
|
158
|
-
) -> Union[
|
159
|
+
) -> Union[NDArray, torch.Tensor, tuple]:
|
159
160
|
"""
|
160
161
|
Prediction function used for preprocessing step of drift detector.
|
161
162
|
|
@@ -7,23 +7,23 @@ Licensed under Apache Software License (Apache 2.0)
|
|
7
7
|
"""
|
8
8
|
|
9
9
|
from functools import partial
|
10
|
-
from typing import Callable,
|
10
|
+
from typing import Callable, Literal, Optional
|
11
11
|
|
12
12
|
import numpy as np
|
13
|
-
from numpy.typing import ArrayLike
|
13
|
+
from numpy.typing import ArrayLike, NDArray
|
14
14
|
from scipy.special import softmax
|
15
15
|
from scipy.stats import entropy
|
16
16
|
|
17
|
-
from .base import UpdateStrategy
|
17
|
+
from .base import DriftUnivariateOutput, UpdateStrategy
|
18
18
|
from .ks import DriftKS
|
19
19
|
from .torch import get_device, preprocess_drift
|
20
20
|
|
21
21
|
|
22
22
|
def classifier_uncertainty(
|
23
|
-
x:
|
23
|
+
x: NDArray,
|
24
24
|
model_fn: Callable,
|
25
25
|
preds_type: Literal["probs", "logits"] = "probs",
|
26
|
-
) ->
|
26
|
+
) -> NDArray:
|
27
27
|
"""
|
28
28
|
Evaluate model_fn on x and transform predictions to prediction uncertainties.
|
29
29
|
|
@@ -111,7 +111,7 @@ class DriftUncertainty:
|
|
111
111
|
preprocess_batch_fn: Optional[Callable] = None,
|
112
112
|
device: Optional[str] = None,
|
113
113
|
) -> None:
|
114
|
-
def model_fn(x:
|
114
|
+
def model_fn(x: NDArray) -> NDArray:
|
115
115
|
return preprocess_drift(
|
116
116
|
x,
|
117
117
|
model, # type: ignore
|
@@ -134,7 +134,7 @@ class DriftUncertainty:
|
|
134
134
|
preprocess_fn=preprocess_fn, # type: ignore
|
135
135
|
)
|
136
136
|
|
137
|
-
def predict(self, x: ArrayLike) ->
|
137
|
+
def predict(self, x: ArrayLike) -> DriftUnivariateOutput:
|
138
138
|
"""
|
139
139
|
Predict whether a batch of data has drifted from the reference data.
|
140
140
|
|