dataeval 0.63.0__py3-none-any.whl → 0.65.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dataeval/__init__.py +4 -4
- dataeval/_internal/detectors/clusterer.py +47 -34
- dataeval/_internal/detectors/drift/base.py +53 -35
- dataeval/_internal/detectors/drift/cvm.py +5 -4
- dataeval/_internal/detectors/drift/ks.py +7 -6
- dataeval/_internal/detectors/drift/mmd.py +39 -19
- dataeval/_internal/detectors/drift/torch.py +6 -5
- dataeval/_internal/detectors/drift/uncertainty.py +7 -8
- dataeval/_internal/detectors/duplicates.py +57 -30
- dataeval/_internal/detectors/linter.py +40 -24
- dataeval/_internal/detectors/ood/ae.py +2 -1
- dataeval/_internal/detectors/ood/aegmm.py +2 -1
- dataeval/_internal/detectors/ood/base.py +37 -15
- dataeval/_internal/detectors/ood/llr.py +9 -8
- dataeval/_internal/detectors/ood/vae.py +2 -1
- dataeval/_internal/detectors/ood/vaegmm.py +2 -1
- dataeval/_internal/flags.py +42 -21
- dataeval/_internal/interop.py +3 -12
- dataeval/_internal/metrics/balance.py +188 -0
- dataeval/_internal/metrics/ber.py +123 -48
- dataeval/_internal/metrics/coverage.py +90 -74
- dataeval/_internal/metrics/divergence.py +101 -67
- dataeval/_internal/metrics/diversity.py +211 -0
- dataeval/_internal/metrics/parity.py +287 -155
- dataeval/_internal/metrics/stats.py +198 -317
- dataeval/_internal/metrics/uap.py +40 -29
- dataeval/_internal/metrics/utils.py +430 -0
- dataeval/_internal/models/tensorflow/losses.py +3 -3
- dataeval/_internal/models/tensorflow/trainer.py +3 -2
- dataeval/_internal/models/tensorflow/utils.py +4 -3
- dataeval/_internal/output.py +82 -0
- dataeval/_internal/utils.py +64 -0
- dataeval/_internal/workflows/sufficiency.py +96 -107
- dataeval/flags/__init__.py +2 -2
- dataeval/metrics/__init__.py +26 -7
- dataeval/utils/__init__.py +9 -0
- {dataeval-0.63.0.dist-info → dataeval-0.65.0.dist-info}/METADATA +1 -1
- dataeval-0.65.0.dist-info/RECORD +60 -0
- dataeval/_internal/functional/__init__.py +0 -0
- dataeval/_internal/functional/ber.py +0 -63
- dataeval/_internal/functional/coverage.py +0 -75
- dataeval/_internal/functional/divergence.py +0 -16
- dataeval/_internal/functional/hash.py +0 -79
- dataeval/_internal/functional/metadata.py +0 -136
- dataeval/_internal/functional/metadataparity.py +0 -190
- dataeval/_internal/functional/uap.py +0 -6
- dataeval/_internal/functional/utils.py +0 -158
- dataeval/_internal/maite/__init__.py +0 -0
- dataeval/_internal/maite/utils.py +0 -30
- dataeval/_internal/metrics/base.py +0 -92
- dataeval/_internal/metrics/metadata.py +0 -610
- dataeval/_internal/metrics/metadataparity.py +0 -67
- dataeval-0.63.0.dist-info/RECORD +0 -68
- {dataeval-0.63.0.dist-info → dataeval-0.65.0.dist-info}/LICENSE.txt +0 -0
- {dataeval-0.63.0.dist-info → dataeval-0.65.0.dist-info}/WHEEL +0 -0
dataeval/__init__.py
CHANGED
@@ -1,15 +1,15 @@
|
|
1
|
+
__version__ = "0.65.0"
|
2
|
+
|
1
3
|
from importlib.util import find_spec
|
2
4
|
|
3
5
|
from . import detectors, flags, metrics
|
4
6
|
|
5
|
-
__version__ = "0.63.0"
|
6
|
-
|
7
7
|
__all__ = ["detectors", "flags", "metrics"]
|
8
8
|
|
9
9
|
if find_spec("torch") is not None: # pragma: no cover
|
10
|
-
from . import models, workflows
|
10
|
+
from . import models, utils, workflows
|
11
11
|
|
12
|
-
__all__ += ["models", "workflows"]
|
12
|
+
__all__ += ["models", "utils", "workflows"]
|
13
13
|
elif find_spec("tensorflow") is not None: # pragma: no cover
|
14
14
|
from . import models
|
15
15
|
|
@@ -1,25 +1,50 @@
|
|
1
|
+
from dataclasses import dataclass
|
1
2
|
from typing import Dict, Iterable, List, NamedTuple, Tuple, Union, cast
|
2
3
|
|
3
4
|
import numpy as np
|
5
|
+
from numpy.typing import ArrayLike, NDArray
|
4
6
|
from scipy.cluster.hierarchy import linkage
|
5
7
|
from scipy.spatial.distance import pdist, squareform
|
6
8
|
|
7
|
-
from dataeval._internal.interop import
|
9
|
+
from dataeval._internal.interop import to_numpy
|
10
|
+
from dataeval._internal.metrics.utils import flatten
|
11
|
+
from dataeval._internal.output import OutputMetadata, set_metadata
|
8
12
|
|
9
13
|
|
10
|
-
|
14
|
+
@dataclass(frozen=True)
|
15
|
+
class ClustererOutput(OutputMetadata):
|
16
|
+
"""
|
17
|
+
Attributes
|
18
|
+
----------
|
19
|
+
outliers : List[int]
|
20
|
+
Indices that do not fall within a cluster
|
21
|
+
potential_outliers : List[int]
|
22
|
+
Indices which are near the border between belonging in the cluster and being an outlier
|
23
|
+
duplicates : List[List[int]]
|
24
|
+
Groups of indices that are exact duplicates
|
25
|
+
potential_duplicates : List[List[int]]
|
26
|
+
Groups of indices which are not exact but closely related data points
|
27
|
+
"""
|
28
|
+
|
29
|
+
outliers: List[int]
|
30
|
+
potential_outliers: List[int]
|
31
|
+
duplicates: List[List[int]]
|
32
|
+
potential_duplicates: List[List[int]]
|
33
|
+
|
34
|
+
|
35
|
+
def extend_linkage(link_arr: NDArray) -> NDArray:
|
11
36
|
"""
|
12
37
|
Adds a column to the linkage matrix link_arr that tracks the new id assigned
|
13
38
|
to each row
|
14
39
|
|
15
40
|
Parameters
|
16
41
|
----------
|
17
|
-
link_arr :
|
42
|
+
link_arr : NDArray
|
18
43
|
linkage matrix
|
19
44
|
|
20
45
|
Returns
|
21
46
|
-------
|
22
|
-
|
47
|
+
NDArray
|
23
48
|
linkage matrix with adjusted shape, new shape (link_arr.shape[0], link_arr.shape[1]+1)
|
24
49
|
"""
|
25
50
|
# Adjusting linkage matrix to accommodate renumbering
|
@@ -34,7 +59,7 @@ def extend_linkage(link_arr: np.ndarray) -> np.ndarray:
|
|
34
59
|
class Cluster:
|
35
60
|
__slots__ = "merged", "samples", "sample_dist", "is_copy", "count", "dist_avg", "dist_std", "out1", "out2"
|
36
61
|
|
37
|
-
def __init__(self, merged: int, samples:
|
62
|
+
def __init__(self, merged: int, samples: NDArray, sample_dist: Union[float, NDArray], is_copy: bool = False):
|
38
63
|
self.merged = merged
|
39
64
|
self.samples = np.array(samples, dtype=np.int32)
|
40
65
|
self.sample_dist = np.array([sample_dist] if np.isscalar(sample_dist) else sample_dist)
|
@@ -129,13 +154,13 @@ class Clusterer:
|
|
129
154
|
self._on_init(dataset)
|
130
155
|
|
131
156
|
def _on_init(self, dataset: ArrayLike):
|
132
|
-
self._data:
|
157
|
+
self._data: NDArray = flatten(to_numpy(dataset))
|
133
158
|
self._validate_data(self._data)
|
134
159
|
self._num_samples = len(self._data)
|
135
160
|
|
136
|
-
self._darr:
|
137
|
-
self._sqdmat:
|
138
|
-
self._larr:
|
161
|
+
self._darr: NDArray = pdist(self._data, metric="euclidean")
|
162
|
+
self._sqdmat: NDArray = squareform(self._darr)
|
163
|
+
self._larr: NDArray = extend_linkage(linkage(self._darr))
|
139
164
|
self._max_clusters: int = np.count_nonzero(self._larr[:, 3] == 2)
|
140
165
|
|
141
166
|
min_num = int(self._num_samples * 0.05)
|
@@ -145,7 +170,7 @@ class Clusterer:
|
|
145
170
|
self._last_good_merge_levels = None
|
146
171
|
|
147
172
|
@property
|
148
|
-
def data(self) ->
|
173
|
+
def data(self) -> NDArray:
|
149
174
|
return self._data
|
150
175
|
|
151
176
|
@data.setter
|
@@ -165,10 +190,10 @@ class Clusterer:
|
|
165
190
|
return self._last_good_merge_levels
|
166
191
|
|
167
192
|
@classmethod
|
168
|
-
def _validate_data(cls, x:
|
193
|
+
def _validate_data(cls, x: NDArray):
|
169
194
|
"""Checks that the data has the correct size, shape, and format"""
|
170
195
|
if not isinstance(x, np.ndarray):
|
171
|
-
raise TypeError(f"Data should be of type
|
196
|
+
raise TypeError(f"Data should be of type NDArray; got {type(x)}")
|
172
197
|
|
173
198
|
if x.ndim != 2:
|
174
199
|
raise ValueError(
|
@@ -239,7 +264,7 @@ class Clusterer:
|
|
239
264
|
clusters[level_id].setdefault(cid, cluster)
|
240
265
|
return clusters
|
241
266
|
|
242
|
-
def _get_cluster_distances(self) ->
|
267
|
+
def _get_cluster_distances(self) -> NDArray:
|
243
268
|
"""Calculates the minimum distances between clusters are each level"""
|
244
269
|
# Cluster distance matrix
|
245
270
|
max_level = self.clusters.max_level
|
@@ -260,7 +285,7 @@ class Clusterer:
|
|
260
285
|
|
261
286
|
return cluster_matrix
|
262
287
|
|
263
|
-
def _calc_merge_indices(self, merge_mean: List[
|
288
|
+
def _calc_merge_indices(self, merge_mean: List[NDArray], intra_max: List[float]) -> NDArray:
|
264
289
|
"""
|
265
290
|
Determine what clusters should be merged and return their indices
|
266
291
|
"""
|
@@ -283,7 +308,7 @@ class Clusterer:
|
|
283
308
|
mask2 = mask2_vals < one_std_check
|
284
309
|
return np.logical_or(desired_merge, mask2)
|
285
310
|
|
286
|
-
def _generate_merge_list(self, cluster_matrix:
|
311
|
+
def _generate_merge_list(self, cluster_matrix: NDArray) -> List[ClusterMergeEntry]:
|
287
312
|
"""
|
288
313
|
Runs through the clusters dictionary determining when clusters merge,
|
289
314
|
and how close are those clusters when they merge.
|
@@ -463,35 +488,23 @@ class Clusterer:
|
|
463
488
|
|
464
489
|
return exact_dupes, near_dupes
|
465
490
|
|
466
|
-
|
491
|
+
# TODO: Move data input to evaluate from class
|
492
|
+
@set_metadata("dataeval.detectors", ["data"])
|
493
|
+
def evaluate(self) -> ClustererOutput:
|
467
494
|
"""Finds and flags indices of the data for outliers and duplicates
|
468
495
|
|
469
496
|
Returns
|
470
497
|
-------
|
471
|
-
|
472
|
-
outliers
|
473
|
-
List of indices that do not fall within a cluster
|
474
|
-
potential_outliers :
|
475
|
-
List of indices which are near the border between belonging in the cluster and being an outlier
|
476
|
-
duplicates :
|
477
|
-
List of groups of indices that are exact duplicates
|
478
|
-
potential_duplicates :
|
479
|
-
List of groups of indices which are not exact but closely related data points
|
498
|
+
ClustererOutput
|
499
|
+
The outliers and duplicate indices found in the data
|
480
500
|
|
481
501
|
Example
|
482
502
|
-------
|
483
503
|
>>> cluster.evaluate()
|
484
|
-
|
504
|
+
ClustererOutput(outliers=[18, 21, 34, 35, 45], potential_outliers=[13, 15, 42], duplicates=[[9, 24], [23, 48]], potential_duplicates=[[1, 11]])
|
485
505
|
""" # noqa: E501
|
486
506
|
|
487
507
|
outliers, potential_outliers = self.find_outliers(self.last_good_merge_levels)
|
488
508
|
duplicates, potential_duplicates = self.find_duplicates(self.last_good_merge_levels)
|
489
509
|
|
490
|
-
|
491
|
-
"outliers": outliers,
|
492
|
-
"potential_outliers": potential_outliers,
|
493
|
-
"duplicates": duplicates,
|
494
|
-
"potential_duplicates": potential_duplicates,
|
495
|
-
}
|
496
|
-
|
497
|
-
return ret
|
510
|
+
return ClustererOutput(outliers, potential_outliers, duplicates, potential_duplicates)
|
@@ -7,12 +7,48 @@ Licensed under Apache Software License (Apache 2.0)
|
|
7
7
|
"""
|
8
8
|
|
9
9
|
from abc import ABC, abstractmethod
|
10
|
+
from dataclasses import dataclass
|
10
11
|
from functools import wraps
|
11
|
-
from typing import Callable,
|
12
|
+
from typing import Callable, Literal, Optional, Tuple
|
12
13
|
|
13
14
|
import numpy as np
|
15
|
+
from numpy.typing import ArrayLike, NDArray
|
14
16
|
|
15
|
-
from dataeval._internal.interop import
|
17
|
+
from dataeval._internal.interop import to_numpy
|
18
|
+
from dataeval._internal.output import OutputMetadata, set_metadata
|
19
|
+
|
20
|
+
|
21
|
+
@dataclass(frozen=True)
|
22
|
+
class DriftOutput(OutputMetadata):
|
23
|
+
is_drift: bool
|
24
|
+
threshold: float
|
25
|
+
|
26
|
+
|
27
|
+
@dataclass(frozen=True)
|
28
|
+
class DriftUnivariateOutput(DriftOutput):
|
29
|
+
"""
|
30
|
+
Attributes
|
31
|
+
----------
|
32
|
+
is_drift : bool
|
33
|
+
Drift prediction for the images
|
34
|
+
threshold : float
|
35
|
+
Threshold after multivariate correction if needed
|
36
|
+
feature_drift : NDArray[np.bool_]
|
37
|
+
Feature-level array of images detected to have drifted
|
38
|
+
feature_threshold : float
|
39
|
+
Feature-level threshold to determine drift
|
40
|
+
p_vals : NDArray[np.float32]
|
41
|
+
Feature-level p-values
|
42
|
+
distances : NDArray[np.float32]
|
43
|
+
Feature-level distances
|
44
|
+
"""
|
45
|
+
|
46
|
+
# is_drift: bool
|
47
|
+
# threshold: float
|
48
|
+
feature_drift: NDArray[np.bool_]
|
49
|
+
feature_threshold: float
|
50
|
+
p_vals: NDArray[np.float32]
|
51
|
+
distances: NDArray[np.float32]
|
16
52
|
|
17
53
|
|
18
54
|
def update_x_ref(fn):
|
@@ -51,7 +87,7 @@ class UpdateStrategy(ABC):
|
|
51
87
|
self.n = n
|
52
88
|
|
53
89
|
@abstractmethod
|
54
|
-
def __call__(self, x_ref:
|
90
|
+
def __call__(self, x_ref: NDArray, x: NDArray, count: int) -> NDArray:
|
55
91
|
"""Abstract implementation of update strategy"""
|
56
92
|
|
57
93
|
|
@@ -65,7 +101,7 @@ class LastSeenUpdate(UpdateStrategy):
|
|
65
101
|
Update with last n instances seen by the detector.
|
66
102
|
"""
|
67
103
|
|
68
|
-
def __call__(self, x_ref:
|
104
|
+
def __call__(self, x_ref: NDArray, x: NDArray, count: int) -> NDArray:
|
69
105
|
x_updated = np.concatenate([x_ref, x], axis=0)
|
70
106
|
return x_updated[-self.n :]
|
71
107
|
|
@@ -80,7 +116,7 @@ class ReservoirSamplingUpdate(UpdateStrategy):
|
|
80
116
|
Update with reservoir sampling of size n.
|
81
117
|
"""
|
82
118
|
|
83
|
-
def __call__(self, x_ref:
|
119
|
+
def __call__(self, x_ref: NDArray, x: NDArray, count: int) -> NDArray:
|
84
120
|
if x.shape[0] + count <= self.n:
|
85
121
|
return np.concatenate([x_ref, x], axis=0)
|
86
122
|
|
@@ -135,7 +171,7 @@ class BaseDrift:
|
|
135
171
|
self._x_refcount = 0
|
136
172
|
|
137
173
|
@property
|
138
|
-
def x_ref(self) ->
|
174
|
+
def x_ref(self) -> NDArray:
|
139
175
|
if not self.x_ref_preprocessed:
|
140
176
|
self.x_ref_preprocessed = True
|
141
177
|
if self.preprocess_fn is not None:
|
@@ -151,7 +187,7 @@ class BaseDrift:
|
|
151
187
|
return x
|
152
188
|
|
153
189
|
|
154
|
-
class
|
190
|
+
class BaseDriftUnivariate(BaseDrift):
|
155
191
|
"""
|
156
192
|
Generic drift detector component which serves as a base class for methods using
|
157
193
|
univariate tests. If n_features > 1, a multivariate correction is applied such
|
@@ -197,13 +233,13 @@ class BaseUnivariateDrift(BaseDrift):
|
|
197
233
|
|
198
234
|
@preprocess_x
|
199
235
|
@abstractmethod
|
200
|
-
def score(self, x: ArrayLike) -> Tuple[np.
|
236
|
+
def score(self, x: ArrayLike) -> Tuple[NDArray[np.float32], NDArray[np.float32]]:
|
201
237
|
"""Abstract method to calculate feature score after preprocessing"""
|
202
238
|
|
203
|
-
def _apply_correction(self, p_vals:
|
239
|
+
def _apply_correction(self, p_vals: NDArray) -> Tuple[bool, float]:
|
204
240
|
if self.correction == "bonferroni":
|
205
241
|
threshold = self.p_val / self.n_features
|
206
|
-
drift_pred =
|
242
|
+
drift_pred = bool((p_vals < threshold).any())
|
207
243
|
return drift_pred, threshold
|
208
244
|
elif self.correction == "fdr":
|
209
245
|
n = p_vals.shape[0]
|
@@ -214,18 +250,18 @@ class BaseUnivariateDrift(BaseDrift):
|
|
214
250
|
try:
|
215
251
|
idx_threshold = int(np.where(below_threshold)[0].max())
|
216
252
|
except ValueError: # sorted p-values not below thresholds
|
217
|
-
return
|
218
|
-
return
|
253
|
+
return bool(below_threshold.any()), q_threshold.min()
|
254
|
+
return bool(below_threshold.any()), q_threshold[idx_threshold]
|
219
255
|
else:
|
220
256
|
raise ValueError("`correction` needs to be either `bonferroni` or `fdr`.")
|
221
257
|
|
258
|
+
@set_metadata("dataeval.detectors")
|
222
259
|
@preprocess_x
|
223
260
|
@update_x_ref
|
224
261
|
def predict(
|
225
262
|
self,
|
226
263
|
x: ArrayLike,
|
227
|
-
|
228
|
-
) -> Dict[str, Union[int, float, np.ndarray]]:
|
264
|
+
) -> DriftUnivariateOutput:
|
229
265
|
"""
|
230
266
|
Predict whether a batch of data has drifted from the reference data and update
|
231
267
|
reference data using specified update strategy.
|
@@ -234,10 +270,6 @@ class BaseUnivariateDrift(BaseDrift):
|
|
234
270
|
----------
|
235
271
|
x : ArrayLike
|
236
272
|
Batch of instances.
|
237
|
-
drift_type : Literal["batch", "feature"], default "batch"
|
238
|
-
Predict drift at the 'feature' or 'batch' level. For 'batch', the test
|
239
|
-
statistics for each feature are aggregated using the Bonferroni or False
|
240
|
-
Discovery Rate correction (if n_features>1).
|
241
273
|
|
242
274
|
Returns
|
243
275
|
-------
|
@@ -248,20 +280,6 @@ class BaseUnivariateDrift(BaseDrift):
|
|
248
280
|
# compute drift scores
|
249
281
|
p_vals, dist = self.score(x)
|
250
282
|
|
251
|
-
|
252
|
-
|
253
|
-
|
254
|
-
drift_pred = (p_vals < self.p_val).astype(int)
|
255
|
-
threshold = self.p_val
|
256
|
-
elif drift_type == "batch":
|
257
|
-
drift_pred, threshold = self._apply_correction(p_vals)
|
258
|
-
else:
|
259
|
-
raise ValueError("`drift_type` needs to be either `feature` or `batch`.")
|
260
|
-
|
261
|
-
# populate drift dict
|
262
|
-
return {
|
263
|
-
"is_drift": drift_pred,
|
264
|
-
"p_val": p_vals,
|
265
|
-
"threshold": threshold,
|
266
|
-
"distance": dist,
|
267
|
-
}
|
283
|
+
feature_drift = (p_vals < self.p_val).astype(np.bool_)
|
284
|
+
drift_pred, threshold = self._apply_correction(p_vals)
|
285
|
+
return DriftUnivariateOutput(drift_pred, threshold, feature_drift, self.p_val, p_vals, dist)
|
@@ -9,14 +9,15 @@ Licensed under Apache Software License (Apache 2.0)
|
|
9
9
|
from typing import Callable, Literal, Optional, Tuple
|
10
10
|
|
11
11
|
import numpy as np
|
12
|
+
from numpy.typing import ArrayLike, NDArray
|
12
13
|
from scipy.stats import cramervonmises_2samp
|
13
14
|
|
14
|
-
from dataeval._internal.interop import
|
15
|
+
from dataeval._internal.interop import to_numpy
|
15
16
|
|
16
|
-
from .base import
|
17
|
+
from .base import BaseDriftUnivariate, UpdateStrategy, preprocess_x
|
17
18
|
|
18
19
|
|
19
|
-
class DriftCVM(
|
20
|
+
class DriftCVM(BaseDriftUnivariate):
|
20
21
|
"""
|
21
22
|
Cramér-von Mises (CVM) data drift detector, which tests for any change in the
|
22
23
|
distribution of continuous univariate data. For multivariate data, a separate
|
@@ -75,7 +76,7 @@ class DriftCVM(BaseUnivariateDrift):
|
|
75
76
|
)
|
76
77
|
|
77
78
|
@preprocess_x
|
78
|
-
def score(self, x: ArrayLike) -> Tuple[np.
|
79
|
+
def score(self, x: ArrayLike) -> Tuple[NDArray[np.float32], NDArray[np.float32]]:
|
79
80
|
"""
|
80
81
|
Performs the two-sample Cramér-von Mises test(s), computing the p-value and
|
81
82
|
test statistic per feature.
|
@@ -9,21 +9,22 @@ Licensed under Apache Software License (Apache 2.0)
|
|
9
9
|
from typing import Callable, Literal, Optional, Tuple
|
10
10
|
|
11
11
|
import numpy as np
|
12
|
+
from numpy.typing import ArrayLike, NDArray
|
12
13
|
from scipy.stats import ks_2samp
|
13
14
|
|
14
|
-
from dataeval._internal.interop import
|
15
|
+
from dataeval._internal.interop import to_numpy
|
15
16
|
|
16
|
-
from .base import
|
17
|
+
from .base import BaseDriftUnivariate, UpdateStrategy, preprocess_x
|
17
18
|
|
18
19
|
|
19
|
-
class DriftKS(
|
20
|
+
class DriftKS(BaseDriftUnivariate):
|
20
21
|
"""
|
21
22
|
Kolmogorov-Smirnov (K-S) data drift detector with Bonferroni or False Discovery
|
22
23
|
Rate (FDR) correction for multivariate data.
|
23
24
|
|
24
25
|
Parameters
|
25
26
|
----------
|
26
|
-
x_ref :
|
27
|
+
x_ref : NDArray
|
27
28
|
Data used as reference distribution.
|
28
29
|
p_val : float, default 0.05
|
29
30
|
p-value used for significance of the statistical test for each feature.
|
@@ -40,7 +41,7 @@ class DriftKS(BaseUnivariateDrift):
|
|
40
41
|
:py:class:`dataeval.detectors.LastSeenUpdateStrategy`
|
41
42
|
or via reservoir sampling with
|
42
43
|
:py:class:`dataeval.detectors.ReservoirSamplingUpdateStrategy`.
|
43
|
-
preprocess_fn : Optional[Callable[[
|
44
|
+
preprocess_fn : Optional[Callable[[NDArray], NDArray]], default None
|
44
45
|
Function to preprocess the data before computing the data drift metrics.
|
45
46
|
Typically a dimensionality reduction technique.
|
46
47
|
correction : Literal["bonferroni", "fdr"], default "bonferroni"
|
@@ -80,7 +81,7 @@ class DriftKS(BaseUnivariateDrift):
|
|
80
81
|
self.alternative = alternative
|
81
82
|
|
82
83
|
@preprocess_x
|
83
|
-
def score(self, x: ArrayLike) -> Tuple[np.
|
84
|
+
def score(self, x: ArrayLike) -> Tuple[NDArray[np.float32], NDArray[np.float32]]:
|
84
85
|
"""
|
85
86
|
Compute K-S scores and statistics per feature.
|
86
87
|
|
@@ -6,16 +6,43 @@ Original code Copyright (c) 2023 Seldon Technologies Ltd
|
|
6
6
|
Licensed under Apache Software License (Apache 2.0)
|
7
7
|
"""
|
8
8
|
|
9
|
-
from
|
9
|
+
from dataclasses import dataclass
|
10
|
+
from typing import Callable, Optional, Tuple
|
10
11
|
|
11
12
|
import torch
|
13
|
+
from numpy.typing import ArrayLike
|
12
14
|
|
13
|
-
from dataeval._internal.interop import
|
15
|
+
from dataeval._internal.interop import to_numpy
|
16
|
+
from dataeval._internal.output import set_metadata
|
14
17
|
|
15
|
-
from .base import BaseDrift, UpdateStrategy, preprocess_x, update_x_ref
|
18
|
+
from .base import BaseDrift, DriftOutput, UpdateStrategy, preprocess_x, update_x_ref
|
16
19
|
from .torch import GaussianRBF, get_device, mmd2_from_kernel_matrix
|
17
20
|
|
18
21
|
|
22
|
+
@dataclass(frozen=True)
|
23
|
+
class DriftMMDOutput(DriftOutput):
|
24
|
+
"""
|
25
|
+
Attributes
|
26
|
+
----------
|
27
|
+
is_drift : bool
|
28
|
+
Drift prediction for the images
|
29
|
+
threshold : float
|
30
|
+
P-value used for significance of the permutation test
|
31
|
+
p_val : float
|
32
|
+
P-value obtained from the permutation test
|
33
|
+
distance : float
|
34
|
+
MMD^2 between the reference and test set
|
35
|
+
distance_threshold : float
|
36
|
+
MMD^2 threshold above which drift is flagged
|
37
|
+
"""
|
38
|
+
|
39
|
+
# is_drift: bool
|
40
|
+
# threshold: float
|
41
|
+
p_val: float
|
42
|
+
distance: float
|
43
|
+
distance_threshold: float
|
44
|
+
|
45
|
+
|
19
46
|
class DriftMMD(BaseDrift):
|
20
47
|
"""
|
21
48
|
Maximum Mean Discrepancy (MMD) data drift detector using a permutation test.
|
@@ -74,7 +101,7 @@ class DriftMMD(BaseDrift):
|
|
74
101
|
super().__init__(x_ref, p_val, x_ref_preprocessed, update_x_ref, preprocess_fn)
|
75
102
|
|
76
103
|
self.infer_sigma = configure_kernel_from_x_ref
|
77
|
-
if configure_kernel_from_x_ref and
|
104
|
+
if configure_kernel_from_x_ref and sigma is not None:
|
78
105
|
self.infer_sigma = False
|
79
106
|
|
80
107
|
self.n_permutations = n_permutations # nb of iterations through permutation test
|
@@ -83,7 +110,7 @@ class DriftMMD(BaseDrift):
|
|
83
110
|
self.device = get_device(device)
|
84
111
|
|
85
112
|
# initialize kernel
|
86
|
-
sigma_tensor = torch.from_numpy(to_numpy(sigma)).to(self.device) if
|
113
|
+
sigma_tensor = torch.from_numpy(to_numpy(sigma)).to(self.device) if sigma is not None else None
|
87
114
|
self.kernel = kernel(sigma_tensor).to(self.device) if kernel == GaussianRBF else kernel
|
88
115
|
|
89
116
|
# compute kernel matrix for the reference data
|
@@ -128,19 +155,17 @@ class DriftMMD(BaseDrift):
|
|
128
155
|
mmd2_permuted = torch.Tensor(
|
129
156
|
[mmd2_from_kernel_matrix(kernel_mat, n, permute=True, zero_diag=False) for _ in range(self.n_permutations)]
|
130
157
|
)
|
131
|
-
mmd2, mmd2_permuted = mmd2.cpu(), mmd2_permuted.cpu()
|
158
|
+
mmd2, mmd2_permuted = mmd2.detach().cpu(), mmd2_permuted.detach().cpu()
|
132
159
|
p_val = (mmd2 <= mmd2_permuted).float().mean()
|
133
160
|
# compute distance threshold
|
134
161
|
idx_threshold = int(self.p_val * len(mmd2_permuted))
|
135
162
|
distance_threshold = torch.sort(mmd2_permuted, descending=True).values[idx_threshold]
|
136
163
|
return p_val.numpy().item(), mmd2.numpy().item(), distance_threshold.numpy()
|
137
164
|
|
165
|
+
@set_metadata("dataeval.detectors")
|
138
166
|
@preprocess_x
|
139
167
|
@update_x_ref
|
140
|
-
def predict(
|
141
|
-
self,
|
142
|
-
x: ArrayLike,
|
143
|
-
) -> Dict[str, Union[int, float]]:
|
168
|
+
def predict(self, x: ArrayLike) -> DriftMMDOutput:
|
144
169
|
"""
|
145
170
|
Predict whether a batch of data has drifted from the reference data and then
|
146
171
|
updates reference data using specified strategy.
|
@@ -152,17 +177,12 @@ class DriftMMD(BaseDrift):
|
|
152
177
|
|
153
178
|
Returns
|
154
179
|
-------
|
155
|
-
|
180
|
+
DriftMMDOutput
|
181
|
+
Output class containing the drift prediction, p-value, threshold and MMD metric.
|
156
182
|
"""
|
157
183
|
# compute drift scores
|
158
184
|
p_val, dist, distance_threshold = self.score(x)
|
159
|
-
drift_pred =
|
185
|
+
drift_pred = bool(p_val < self.p_val)
|
160
186
|
|
161
187
|
# populate drift dict
|
162
|
-
return
|
163
|
-
"is_drift": drift_pred,
|
164
|
-
"p_val": p_val,
|
165
|
-
"threshold": self.p_val,
|
166
|
-
"distance": dist,
|
167
|
-
"distance_threshold": distance_threshold,
|
168
|
-
}
|
188
|
+
return DriftMMDOutput(drift_pred, self.p_val, p_val, dist, distance_threshold)
|
@@ -12,6 +12,7 @@ from typing import Callable, Optional, Type, Union
|
|
12
12
|
import numpy as np
|
13
13
|
import torch
|
14
14
|
import torch.nn as nn
|
15
|
+
from numpy.typing import NDArray
|
15
16
|
|
16
17
|
|
17
18
|
def get_device(device: Optional[Union[str, torch.device]] = None) -> torch.device:
|
@@ -74,13 +75,13 @@ def mmd2_from_kernel_matrix(
|
|
74
75
|
|
75
76
|
|
76
77
|
def predict_batch(
|
77
|
-
x: Union[
|
78
|
+
x: Union[NDArray, torch.Tensor],
|
78
79
|
model: Union[Callable, nn.Module, nn.Sequential],
|
79
80
|
device: Optional[torch.device] = None,
|
80
81
|
batch_size: int = int(1e10),
|
81
82
|
preprocess_fn: Optional[Callable] = None,
|
82
83
|
dtype: Union[Type[np.generic], torch.dtype] = np.float32,
|
83
|
-
) -> Union[
|
84
|
+
) -> Union[NDArray, torch.Tensor, tuple]:
|
84
85
|
"""
|
85
86
|
Make batch predictions on a model.
|
86
87
|
|
@@ -138,7 +139,7 @@ def predict_batch(
|
|
138
139
|
else:
|
139
140
|
raise TypeError(
|
140
141
|
f"Model output type {type(preds_tmp)} not supported. The model \
|
141
|
-
output type needs to be one of list, tuple,
|
142
|
+
output type needs to be one of list, tuple, NDArray or \
|
142
143
|
torch.Tensor."
|
143
144
|
)
|
144
145
|
concat = partial(np.concatenate, axis=0) if return_np else partial(torch.cat, dim=0)
|
@@ -149,13 +150,13 @@ def predict_batch(
|
|
149
150
|
|
150
151
|
|
151
152
|
def preprocess_drift(
|
152
|
-
x:
|
153
|
+
x: NDArray,
|
153
154
|
model: nn.Module,
|
154
155
|
device: Optional[torch.device] = None,
|
155
156
|
preprocess_batch_fn: Optional[Callable] = None,
|
156
157
|
batch_size: int = int(1e10),
|
157
158
|
dtype: Union[Type[np.generic], torch.dtype] = np.float32,
|
158
|
-
) -> Union[
|
159
|
+
) -> Union[NDArray, torch.Tensor, tuple]:
|
159
160
|
"""
|
160
161
|
Prediction function used for preprocessing step of drift detector.
|
161
162
|
|
@@ -7,24 +7,23 @@ Licensed under Apache Software License (Apache 2.0)
|
|
7
7
|
"""
|
8
8
|
|
9
9
|
from functools import partial
|
10
|
-
from typing import Callable,
|
10
|
+
from typing import Callable, Literal, Optional
|
11
11
|
|
12
12
|
import numpy as np
|
13
|
+
from numpy.typing import ArrayLike, NDArray
|
13
14
|
from scipy.special import softmax
|
14
15
|
from scipy.stats import entropy
|
15
16
|
|
16
|
-
from
|
17
|
-
|
18
|
-
from .base import UpdateStrategy
|
17
|
+
from .base import DriftUnivariateOutput, UpdateStrategy
|
19
18
|
from .ks import DriftKS
|
20
19
|
from .torch import get_device, preprocess_drift
|
21
20
|
|
22
21
|
|
23
22
|
def classifier_uncertainty(
|
24
|
-
x:
|
23
|
+
x: NDArray,
|
25
24
|
model_fn: Callable,
|
26
25
|
preds_type: Literal["probs", "logits"] = "probs",
|
27
|
-
) ->
|
26
|
+
) -> NDArray:
|
28
27
|
"""
|
29
28
|
Evaluate model_fn on x and transform predictions to prediction uncertainties.
|
30
29
|
|
@@ -112,7 +111,7 @@ class DriftUncertainty:
|
|
112
111
|
preprocess_batch_fn: Optional[Callable] = None,
|
113
112
|
device: Optional[str] = None,
|
114
113
|
) -> None:
|
115
|
-
def model_fn(x:
|
114
|
+
def model_fn(x: NDArray) -> NDArray:
|
116
115
|
return preprocess_drift(
|
117
116
|
x,
|
118
117
|
model, # type: ignore
|
@@ -135,7 +134,7 @@ class DriftUncertainty:
|
|
135
134
|
preprocess_fn=preprocess_fn, # type: ignore
|
136
135
|
)
|
137
136
|
|
138
|
-
def predict(self, x: ArrayLike) ->
|
137
|
+
def predict(self, x: ArrayLike) -> DriftUnivariateOutput:
|
139
138
|
"""
|
140
139
|
Predict whether a batch of data has drifted from the reference data.
|
141
140
|
|