dataeval 0.76.1__py3-none-any.whl → 0.82.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dataeval/__init__.py +3 -3
- dataeval/config.py +77 -0
- dataeval/detectors/__init__.py +1 -1
- dataeval/detectors/drift/__init__.py +6 -6
- dataeval/detectors/drift/{base.py → _base.py} +40 -85
- dataeval/detectors/drift/{cvm.py → _cvm.py} +21 -28
- dataeval/detectors/drift/{ks.py → _ks.py} +20 -26
- dataeval/detectors/drift/{mmd.py → _mmd.py} +31 -43
- dataeval/detectors/drift/{torch.py → _torch.py} +2 -1
- dataeval/detectors/drift/{uncertainty.py → _uncertainty.py} +24 -7
- dataeval/detectors/drift/updates.py +20 -3
- dataeval/detectors/linters/__init__.py +3 -5
- dataeval/detectors/linters/duplicates.py +13 -36
- dataeval/detectors/linters/outliers.py +23 -148
- dataeval/detectors/ood/__init__.py +1 -1
- dataeval/detectors/ood/ae.py +30 -9
- dataeval/detectors/ood/base.py +5 -4
- dataeval/detectors/ood/mixin.py +21 -7
- dataeval/detectors/ood/vae.py +73 -0
- dataeval/metadata/__init__.py +6 -0
- dataeval/metadata/_distance.py +167 -0
- dataeval/metadata/_ood.py +217 -0
- dataeval/metadata/_utils.py +44 -0
- dataeval/metrics/__init__.py +1 -1
- dataeval/metrics/bias/__init__.py +6 -4
- dataeval/metrics/bias/{balance.py → _balance.py} +15 -101
- dataeval/metrics/bias/_coverage.py +98 -0
- dataeval/metrics/bias/{diversity.py → _diversity.py} +18 -111
- dataeval/metrics/bias/{parity.py → _parity.py} +39 -77
- dataeval/metrics/estimators/__init__.py +15 -4
- dataeval/metrics/estimators/{ber.py → _ber.py} +42 -29
- dataeval/metrics/estimators/_clusterer.py +44 -0
- dataeval/metrics/estimators/{divergence.py → _divergence.py} +18 -30
- dataeval/metrics/estimators/{uap.py → _uap.py} +4 -18
- dataeval/metrics/stats/__init__.py +16 -13
- dataeval/metrics/stats/{base.py → _base.py} +82 -133
- dataeval/metrics/stats/{boxratiostats.py → _boxratiostats.py} +15 -18
- dataeval/metrics/stats/_dimensionstats.py +75 -0
- dataeval/metrics/stats/{hashstats.py → _hashstats.py} +21 -37
- dataeval/metrics/stats/_imagestats.py +94 -0
- dataeval/metrics/stats/_labelstats.py +131 -0
- dataeval/metrics/stats/{pixelstats.py → _pixelstats.py} +19 -50
- dataeval/metrics/stats/{visualstats.py → _visualstats.py} +23 -54
- dataeval/outputs/__init__.py +53 -0
- dataeval/{output.py → outputs/_base.py} +55 -25
- dataeval/outputs/_bias.py +381 -0
- dataeval/outputs/_drift.py +83 -0
- dataeval/outputs/_estimators.py +114 -0
- dataeval/outputs/_linters.py +184 -0
- dataeval/{detectors/ood/output.py → outputs/_ood.py} +22 -22
- dataeval/outputs/_stats.py +387 -0
- dataeval/outputs/_utils.py +44 -0
- dataeval/outputs/_workflows.py +364 -0
- dataeval/typing.py +234 -0
- dataeval/utils/__init__.py +2 -2
- dataeval/utils/_array.py +169 -0
- dataeval/utils/_bin.py +199 -0
- dataeval/utils/_clusterer.py +144 -0
- dataeval/utils/_fast_mst.py +189 -0
- dataeval/utils/{image.py → _image.py} +6 -4
- dataeval/utils/_method.py +14 -0
- dataeval/utils/{shared.py → _mst.py} +3 -65
- dataeval/utils/{plot.py → _plot.py} +6 -6
- dataeval/utils/data/__init__.py +26 -0
- dataeval/utils/data/_dataset.py +217 -0
- dataeval/utils/data/_embeddings.py +104 -0
- dataeval/utils/data/_images.py +68 -0
- dataeval/utils/data/_metadata.py +360 -0
- dataeval/utils/data/_selection.py +126 -0
- dataeval/utils/{dataset/split.py → data/_split.py} +12 -38
- dataeval/utils/data/_targets.py +85 -0
- dataeval/utils/data/collate.py +103 -0
- dataeval/utils/data/datasets/__init__.py +17 -0
- dataeval/utils/data/datasets/_base.py +254 -0
- dataeval/utils/data/datasets/_cifar10.py +134 -0
- dataeval/utils/data/datasets/_fileio.py +168 -0
- dataeval/utils/data/datasets/_milco.py +153 -0
- dataeval/utils/data/datasets/_mixin.py +56 -0
- dataeval/utils/data/datasets/_mnist.py +183 -0
- dataeval/utils/data/datasets/_ships.py +123 -0
- dataeval/utils/data/datasets/_types.py +52 -0
- dataeval/utils/data/datasets/_voc.py +352 -0
- dataeval/utils/data/selections/__init__.py +15 -0
- dataeval/utils/data/selections/_classfilter.py +57 -0
- dataeval/utils/data/selections/_indices.py +26 -0
- dataeval/utils/data/selections/_limit.py +26 -0
- dataeval/utils/data/selections/_reverse.py +18 -0
- dataeval/utils/data/selections/_shuffle.py +29 -0
- dataeval/utils/metadata.py +51 -376
- dataeval/utils/torch/{gmm.py → _gmm.py} +4 -2
- dataeval/utils/torch/{internal.py → _internal.py} +21 -51
- dataeval/utils/torch/models.py +43 -2
- dataeval/workflows/__init__.py +2 -1
- dataeval/workflows/sufficiency.py +11 -346
- {dataeval-0.76.1.dist-info → dataeval-0.82.0.dist-info}/METADATA +5 -2
- dataeval-0.82.0.dist-info/RECORD +104 -0
- dataeval/detectors/linters/clusterer.py +0 -512
- dataeval/detectors/linters/merged_stats.py +0 -49
- dataeval/detectors/ood/metadata_ks_compare.py +0 -129
- dataeval/detectors/ood/metadata_least_likely.py +0 -119
- dataeval/interop.py +0 -69
- dataeval/metrics/bias/coverage.py +0 -194
- dataeval/metrics/stats/datasetstats.py +0 -202
- dataeval/metrics/stats/dimensionstats.py +0 -115
- dataeval/metrics/stats/labelstats.py +0 -210
- dataeval/utils/dataset/__init__.py +0 -7
- dataeval/utils/dataset/datasets.py +0 -412
- dataeval/utils/dataset/read.py +0 -63
- dataeval-0.76.1.dist-info/RECORD +0 -67
- /dataeval/{log.py → _log.py} +0 -0
- /dataeval/utils/torch/{blocks.py → _blocks.py} +0 -0
- {dataeval-0.76.1.dist-info → dataeval-0.82.0.dist-info}/LICENSE.txt +0 -0
- {dataeval-0.76.1.dist-info → dataeval-0.82.0.dist-info}/WHEEL +0 -0
dataeval/__init__.py
CHANGED
@@ -7,12 +7,12 @@ shifts that impact performance of deployed models.
|
|
7
7
|
|
8
8
|
from __future__ import annotations
|
9
9
|
|
10
|
-
__all__ = ["detectors", "log", "metrics", "utils", "workflows"]
|
11
|
-
__version__ = "0.
|
10
|
+
__all__ = ["config", "detectors", "log", "metrics", "typing", "utils", "workflows"]
|
11
|
+
__version__ = "0.82.0"
|
12
12
|
|
13
13
|
import logging
|
14
14
|
|
15
|
-
from dataeval import detectors, metrics, utils, workflows
|
15
|
+
from dataeval import config, detectors, metrics, typing, utils, workflows
|
16
16
|
|
17
17
|
logging.getLogger(__name__).addHandler(logging.NullHandler())
|
18
18
|
|
dataeval/config.py
ADDED
@@ -0,0 +1,77 @@
|
|
1
|
+
"""
|
2
|
+
Global configuration settings for DataEval.
|
3
|
+
"""
|
4
|
+
|
5
|
+
from __future__ import annotations
|
6
|
+
|
7
|
+
__all__ = ["get_device", "set_device", "get_max_processes", "set_max_processes"]
|
8
|
+
|
9
|
+
import torch
|
10
|
+
from torch import device
|
11
|
+
|
12
|
+
_device: device | None = None
|
13
|
+
_processes: int | None = None
|
14
|
+
|
15
|
+
|
16
|
+
def set_device(device: str | device | int) -> None:
|
17
|
+
"""
|
18
|
+
Sets the default device to use when executing against a PyTorch backend.
|
19
|
+
|
20
|
+
Parameters
|
21
|
+
----------
|
22
|
+
device : str or int or `torch.device`
|
23
|
+
The default device to use. See `torch.device <https://pytorch.org/docs/stable/tensor_attributes.html#torch.device>`_
|
24
|
+
documentation for more information.
|
25
|
+
"""
|
26
|
+
global _device
|
27
|
+
_device = torch.device(device)
|
28
|
+
|
29
|
+
|
30
|
+
def get_device(override: str | device | int | None = None) -> torch.device:
|
31
|
+
"""
|
32
|
+
Returns the PyTorch device to use.
|
33
|
+
|
34
|
+
Parameters
|
35
|
+
----------
|
36
|
+
override : str or int or `torch.device` or None, default None
|
37
|
+
The user specified override if provided, otherwise returns the default device.
|
38
|
+
|
39
|
+
Returns
|
40
|
+
-------
|
41
|
+
`torch.device`
|
42
|
+
"""
|
43
|
+
if override is None:
|
44
|
+
global _device
|
45
|
+
return torch.get_default_device() if _device is None else _device
|
46
|
+
else:
|
47
|
+
return torch.device(override)
|
48
|
+
|
49
|
+
|
50
|
+
def set_max_processes(processes: int | None) -> None:
|
51
|
+
"""
|
52
|
+
Sets the maximum number of worker processes to use when running tasks that support parallel processing.
|
53
|
+
|
54
|
+
Parameters
|
55
|
+
----------
|
56
|
+
processes : int or None
|
57
|
+
The maximum number of worker processes to use, or None to use
|
58
|
+
`os.process_cpu_count <https://docs.python.org/3/library/os.html#os.process_cpu_count>`_
|
59
|
+
to determine the number of worker processes.
|
60
|
+
"""
|
61
|
+
global _processes
|
62
|
+
_processes = processes
|
63
|
+
|
64
|
+
|
65
|
+
def get_max_processes() -> int | None:
|
66
|
+
"""
|
67
|
+
Returns the maximum number of worker processes to use when running tasks that support parallel processing.
|
68
|
+
|
69
|
+
Returns
|
70
|
+
-------
|
71
|
+
int or None
|
72
|
+
The maximum number of worker processes to use, or None to use
|
73
|
+
`os.process_cpu_count <https://docs.python.org/3/library/os.html#os.process_cpu_count>`_
|
74
|
+
to determine the number of worker processes.
|
75
|
+
"""
|
76
|
+
global _processes
|
77
|
+
return _processes
|
dataeval/detectors/__init__.py
CHANGED
@@ -14,9 +14,9 @@ __all__ = [
|
|
14
14
|
]
|
15
15
|
|
16
16
|
from dataeval.detectors.drift import updates
|
17
|
-
from dataeval.detectors.drift.
|
18
|
-
from dataeval.detectors.drift.
|
19
|
-
from dataeval.detectors.drift.
|
20
|
-
from dataeval.detectors.drift.
|
21
|
-
from dataeval.detectors.drift.
|
22
|
-
from dataeval.
|
17
|
+
from dataeval.detectors.drift._cvm import DriftCVM
|
18
|
+
from dataeval.detectors.drift._ks import DriftKS
|
19
|
+
from dataeval.detectors.drift._mmd import DriftMMD
|
20
|
+
from dataeval.detectors.drift._torch import preprocess_drift
|
21
|
+
from dataeval.detectors.drift._uncertainty import DriftUncertainty
|
22
|
+
from dataeval.outputs._drift import DriftMMDOutput, DriftOutput
|
@@ -10,86 +10,29 @@ from __future__ import annotations
|
|
10
10
|
|
11
11
|
__all__ = []
|
12
12
|
|
13
|
-
|
14
|
-
from
|
13
|
+
import math
|
14
|
+
from abc import abstractmethod
|
15
15
|
from functools import wraps
|
16
|
-
from typing import Any, Callable, Literal, TypeVar
|
16
|
+
from typing import Any, Callable, Literal, Protocol, TypeVar, runtime_checkable
|
17
17
|
|
18
18
|
import numpy as np
|
19
|
-
from numpy.typing import
|
19
|
+
from numpy.typing import NDArray
|
20
20
|
|
21
|
-
from dataeval.
|
22
|
-
from dataeval.
|
21
|
+
from dataeval.outputs import DriftOutput
|
22
|
+
from dataeval.outputs._base import set_metadata
|
23
|
+
from dataeval.typing import Array, ArrayLike
|
24
|
+
from dataeval.utils._array import as_numpy, to_numpy
|
23
25
|
|
24
26
|
R = TypeVar("R")
|
25
27
|
|
26
28
|
|
27
|
-
|
29
|
+
@runtime_checkable
|
30
|
+
class UpdateStrategy(Protocol):
|
28
31
|
"""
|
29
|
-
|
30
|
-
|
31
|
-
Parameters
|
32
|
-
----------
|
33
|
-
n : int
|
34
|
-
Update with last n instances seen by the detector.
|
35
|
-
"""
|
36
|
-
|
37
|
-
def __init__(self, n: int) -> None:
|
38
|
-
self.n = n
|
39
|
-
|
40
|
-
@abstractmethod
|
41
|
-
def __call__(self, x_ref: NDArray[Any], x: NDArray[Any], count: int) -> NDArray[Any]:
|
42
|
-
"""Abstract implementation of update strategy"""
|
43
|
-
|
44
|
-
|
45
|
-
@dataclass(frozen=True)
|
46
|
-
class DriftBaseOutput(Output):
|
47
|
-
"""
|
48
|
-
Base output class for Drift Detector classes
|
49
|
-
|
50
|
-
Attributes
|
51
|
-
----------
|
52
|
-
is_drift : bool
|
53
|
-
Drift prediction for the images
|
54
|
-
threshold : float
|
55
|
-
Threshold after multivariate correction if needed
|
32
|
+
Protocol for reference dataset update strategy for drift detectors
|
56
33
|
"""
|
57
34
|
|
58
|
-
|
59
|
-
threshold: float
|
60
|
-
p_val: float
|
61
|
-
distance: float
|
62
|
-
|
63
|
-
|
64
|
-
@dataclass(frozen=True)
|
65
|
-
class DriftOutput(DriftBaseOutput):
|
66
|
-
"""
|
67
|
-
Output class for :class:`DriftCVM`, :class:`DriftKS`, and :class:`DriftUncertainty` drift detectors.
|
68
|
-
|
69
|
-
Attributes
|
70
|
-
----------
|
71
|
-
is_drift : bool
|
72
|
-
:term:`Drift` prediction for the images
|
73
|
-
threshold : float
|
74
|
-
Threshold after multivariate correction if needed
|
75
|
-
feature_drift : NDArray
|
76
|
-
Feature-level array of images detected to have drifted
|
77
|
-
feature_threshold : float
|
78
|
-
Feature-level threshold to determine drift
|
79
|
-
p_vals : NDArray
|
80
|
-
Feature-level p-values
|
81
|
-
distances : NDArray
|
82
|
-
Feature-level distances
|
83
|
-
"""
|
84
|
-
|
85
|
-
# is_drift: bool
|
86
|
-
# threshold: float
|
87
|
-
# p_val: float
|
88
|
-
# distance: float
|
89
|
-
feature_drift: NDArray[np.bool_]
|
90
|
-
feature_threshold: float
|
91
|
-
p_vals: NDArray[np.float32]
|
92
|
-
distances: NDArray[np.float32]
|
35
|
+
def __call__(self, x_ref: NDArray[Any], x: NDArray[Any], count: int) -> NDArray[Any]: ...
|
93
36
|
|
94
37
|
|
95
38
|
def update_x_ref(fn: Callable[..., R]) -> Callable[..., R]:
|
@@ -196,7 +139,7 @@ class BaseDrift:
|
|
196
139
|
if correction not in ["bonferroni", "fdr"]:
|
197
140
|
raise ValueError("`correction` must be `bonferroni` or `fdr`.")
|
198
141
|
|
199
|
-
self._x_ref =
|
142
|
+
self._x_ref = x_ref
|
200
143
|
self.x_ref_preprocessed: bool = x_ref_preprocessed
|
201
144
|
|
202
145
|
# Other attributes
|
@@ -204,25 +147,25 @@ class BaseDrift:
|
|
204
147
|
self.update_x_ref = update_x_ref
|
205
148
|
self.preprocess_fn = preprocess_fn
|
206
149
|
self.correction = correction
|
207
|
-
self.n: int = len(
|
150
|
+
self.n: int = len(x_ref)
|
208
151
|
|
209
152
|
# Ref counter for preprocessed x
|
210
153
|
self._x_refcount = 0
|
211
154
|
|
212
155
|
@property
|
213
|
-
def x_ref(self) ->
|
156
|
+
def x_ref(self) -> ArrayLike:
|
214
157
|
"""
|
215
158
|
Retrieve the reference data, applying preprocessing if not already done.
|
216
159
|
|
217
160
|
Returns
|
218
161
|
-------
|
219
|
-
|
162
|
+
ArrayLike
|
220
163
|
The reference dataset (`x_ref`), preprocessed if needed.
|
221
164
|
"""
|
222
165
|
if not self.x_ref_preprocessed:
|
223
166
|
self.x_ref_preprocessed = True
|
224
167
|
if self.preprocess_fn is not None:
|
225
|
-
self._x_ref =
|
168
|
+
self._x_ref = self.preprocess_fn(self._x_ref)
|
226
169
|
|
227
170
|
return self._x_ref
|
228
171
|
|
@@ -323,32 +266,44 @@ class BaseDriftUnivariate(BaseDrift):
|
|
323
266
|
# lazy process n_features as needed
|
324
267
|
if not isinstance(self._n_features, int):
|
325
268
|
# compute number of features for the univariate tests
|
326
|
-
|
327
|
-
|
328
|
-
self.
|
329
|
-
|
330
|
-
|
331
|
-
|
332
|
-
|
269
|
+
x_ref = (
|
270
|
+
self.x_ref
|
271
|
+
if self.preprocess_fn is None or self.x_ref_preprocessed
|
272
|
+
else self.preprocess_fn(self._x_ref[0:1])
|
273
|
+
)
|
274
|
+
# infer features from preprocessed reference data
|
275
|
+
shape = x_ref.shape if isinstance(x_ref, Array) else as_numpy(x_ref).shape
|
276
|
+
self._n_features = int(math.prod(shape[1:])) # Multiplies all channel sizes after first
|
333
277
|
|
334
278
|
return self._n_features
|
335
279
|
|
336
280
|
@preprocess_x
|
337
|
-
@abstractmethod
|
338
281
|
def score(self, x: ArrayLike) -> tuple[NDArray[np.float32], NDArray[np.float32]]:
|
339
282
|
"""
|
340
|
-
|
283
|
+
Calculates p-values and test statistics per feature.
|
341
284
|
|
342
285
|
Parameters
|
343
286
|
----------
|
344
287
|
x : ArrayLike
|
345
|
-
|
288
|
+
Batch of instances
|
346
289
|
|
347
290
|
Returns
|
348
291
|
-------
|
349
292
|
tuple[NDArray, NDArray]
|
350
|
-
|
293
|
+
Feature level p-values and test statistics
|
351
294
|
"""
|
295
|
+
x_np = to_numpy(x)
|
296
|
+
x_np = x_np.reshape(x_np.shape[0], -1)
|
297
|
+
x_ref_np = as_numpy(self.x_ref)
|
298
|
+
x_ref_np = x_ref_np.reshape(x_ref_np.shape[0], -1)
|
299
|
+
p_val = np.zeros(self.n_features, dtype=np.float32)
|
300
|
+
dist = np.zeros_like(p_val)
|
301
|
+
for f in range(self.n_features):
|
302
|
+
dist[f], p_val[f] = self._score_fn(x_ref_np[:, f], x_np[:, f])
|
303
|
+
return p_val, dist
|
304
|
+
|
305
|
+
@abstractmethod
|
306
|
+
def _score_fn(self, x: NDArray[np.float32], y: NDArray[np.float32]) -> tuple[np.float32, np.float32]: ...
|
352
307
|
|
353
308
|
def _apply_correction(self, p_vals: NDArray) -> tuple[bool, float]:
|
354
309
|
"""
|
@@ -13,11 +13,11 @@ __all__ = []
|
|
13
13
|
from typing import Callable, Literal
|
14
14
|
|
15
15
|
import numpy as np
|
16
|
-
from numpy.typing import
|
16
|
+
from numpy.typing import NDArray
|
17
17
|
from scipy.stats import cramervonmises_2samp
|
18
18
|
|
19
|
-
from dataeval.detectors.drift.
|
20
|
-
from dataeval.
|
19
|
+
from dataeval.detectors.drift._base import BaseDriftUnivariate, UpdateStrategy
|
20
|
+
from dataeval.typing import ArrayLike
|
21
21
|
|
22
22
|
|
23
23
|
class DriftCVM(BaseDriftUnivariate):
|
@@ -55,6 +55,21 @@ class DriftCVM(BaseDriftUnivariate):
|
|
55
55
|
Number of features used in the statistical test. No need to pass it if no
|
56
56
|
preprocessing takes place. In case of a preprocessing step, this can also
|
57
57
|
be inferred automatically but could be more expensive to compute.
|
58
|
+
|
59
|
+
Example
|
60
|
+
-------
|
61
|
+
>>> from functools import partial
|
62
|
+
>>> from dataeval.detectors.drift import preprocess_drift
|
63
|
+
|
64
|
+
Use a preprocess function to encode images before testing for drift
|
65
|
+
|
66
|
+
>>> preprocess_fn = partial(preprocess_drift, model=encoder, batch_size=64)
|
67
|
+
>>> drift = DriftCVM(train_images, preprocess_fn=preprocess_fn)
|
68
|
+
|
69
|
+
Test incoming images for drift
|
70
|
+
|
71
|
+
>>> drift.predict(test_images).drifted
|
72
|
+
True
|
58
73
|
"""
|
59
74
|
|
60
75
|
def __init__(
|
@@ -77,28 +92,6 @@ class DriftCVM(BaseDriftUnivariate):
|
|
77
92
|
n_features=n_features,
|
78
93
|
)
|
79
94
|
|
80
|
-
|
81
|
-
|
82
|
-
|
83
|
-
Performs the two-sample Cramér-von Mises test(s), computing the :term:`p-value<P-value>` and
|
84
|
-
test statistic per feature.
|
85
|
-
|
86
|
-
Parameters
|
87
|
-
----------
|
88
|
-
x : ArrayLike
|
89
|
-
Batch of instances.
|
90
|
-
|
91
|
-
Returns
|
92
|
-
-------
|
93
|
-
tuple[NDArray, NDArray]
|
94
|
-
Feature level p-values and CVM statistic
|
95
|
-
"""
|
96
|
-
x_np = to_numpy(x)
|
97
|
-
x_np = x_np.reshape(x_np.shape[0], -1)
|
98
|
-
x_ref = self.x_ref.reshape(self.x_ref.shape[0], -1)
|
99
|
-
p_val = np.zeros(self.n_features, dtype=np.float32)
|
100
|
-
dist = np.zeros_like(p_val)
|
101
|
-
for f in range(self.n_features):
|
102
|
-
result = cramervonmises_2samp(x_ref[:, f], x_np[:, f], method="auto")
|
103
|
-
p_val[f], dist[f] = result.pvalue, result.statistic
|
104
|
-
return p_val, dist
|
95
|
+
def _score_fn(self, x: NDArray[np.float32], y: NDArray[np.float32]) -> tuple[np.float32, np.float32]:
|
96
|
+
result = cramervonmises_2samp(x, y, method="auto")
|
97
|
+
return np.float32(result.statistic), np.float32(result.pvalue)
|
@@ -13,11 +13,11 @@ __all__ = []
|
|
13
13
|
from typing import Callable, Literal
|
14
14
|
|
15
15
|
import numpy as np
|
16
|
-
from numpy.typing import
|
16
|
+
from numpy.typing import NDArray
|
17
17
|
from scipy.stats import ks_2samp
|
18
18
|
|
19
|
-
from dataeval.detectors.drift.
|
20
|
-
from dataeval.
|
19
|
+
from dataeval.detectors.drift._base import BaseDriftUnivariate, UpdateStrategy
|
20
|
+
from dataeval.typing import ArrayLike
|
21
21
|
|
22
22
|
|
23
23
|
class DriftKS(BaseDriftUnivariate):
|
@@ -58,6 +58,21 @@ class DriftKS(BaseDriftUnivariate):
|
|
58
58
|
Number of features used in the statistical test. No need to pass it if no
|
59
59
|
preprocessing takes place. In case of a preprocessing step, this can also
|
60
60
|
be inferred automatically but could be more expensive to compute.
|
61
|
+
|
62
|
+
Example
|
63
|
+
-------
|
64
|
+
>>> from functools import partial
|
65
|
+
>>> from dataeval.detectors.drift import preprocess_drift
|
66
|
+
|
67
|
+
Use a preprocess function to encode images before testing for drift
|
68
|
+
|
69
|
+
>>> preprocess_fn = partial(preprocess_drift, model=encoder, batch_size=64)
|
70
|
+
>>> drift = DriftKS(train_images, preprocess_fn=preprocess_fn)
|
71
|
+
|
72
|
+
Test incoming images for drift
|
73
|
+
|
74
|
+
>>> drift.predict(test_images).drifted
|
75
|
+
True
|
61
76
|
"""
|
62
77
|
|
63
78
|
def __init__(
|
@@ -84,26 +99,5 @@ class DriftKS(BaseDriftUnivariate):
|
|
84
99
|
# Other attributes
|
85
100
|
self.alternative = alternative
|
86
101
|
|
87
|
-
|
88
|
-
|
89
|
-
"""
|
90
|
-
Compute KS scores and :term:Statistics` per feature.
|
91
|
-
|
92
|
-
Parameters
|
93
|
-
----------
|
94
|
-
x : ArrayLike
|
95
|
-
Batch of instances.
|
96
|
-
|
97
|
-
Returns
|
98
|
-
-------
|
99
|
-
tuple[NDArray, NDArray]
|
100
|
-
Feature level :term:p-values and KS statistic
|
101
|
-
"""
|
102
|
-
x = to_numpy(x)
|
103
|
-
x = x.reshape(x.shape[0], -1)
|
104
|
-
x_ref = self.x_ref.reshape(self.x_ref.shape[0], -1)
|
105
|
-
p_val = np.zeros(self.n_features, dtype=np.float32)
|
106
|
-
dist = np.zeros_like(p_val)
|
107
|
-
for f in range(self.n_features):
|
108
|
-
dist[f], p_val[f] = ks_2samp(x_ref[:, f], x[:, f], alternative=self.alternative, method="exact")
|
109
|
-
return p_val, dist
|
102
|
+
def _score_fn(self, x: NDArray[np.float32], y: NDArray[np.float32]) -> tuple[np.float32, np.float32]:
|
103
|
+
return ks_2samp(x, y, alternative=self.alternative, method="exact")
|
@@ -10,43 +10,16 @@ from __future__ import annotations
|
|
10
10
|
|
11
11
|
__all__ = []
|
12
12
|
|
13
|
-
from dataclasses import dataclass
|
14
13
|
from typing import Callable
|
15
14
|
|
16
15
|
import torch
|
17
|
-
from numpy.typing import ArrayLike
|
18
16
|
|
19
|
-
from dataeval.
|
20
|
-
from dataeval.detectors.drift.
|
21
|
-
from dataeval.
|
22
|
-
from dataeval.
|
23
|
-
from dataeval.
|
24
|
-
|
25
|
-
|
26
|
-
@dataclass(frozen=True)
|
27
|
-
class DriftMMDOutput(DriftBaseOutput):
|
28
|
-
"""
|
29
|
-
Output class for :class:`DriftMMD` :term:`drift<Drift>` detector.
|
30
|
-
|
31
|
-
Attributes
|
32
|
-
----------
|
33
|
-
is_drift : bool
|
34
|
-
Drift prediction for the images
|
35
|
-
threshold : float
|
36
|
-
:term:`P-Value` used for significance of the permutation test
|
37
|
-
p_val : float
|
38
|
-
P-value obtained from the permutation test
|
39
|
-
distance : float
|
40
|
-
MMD^2 between the reference and test set
|
41
|
-
distance_threshold : float
|
42
|
-
MMD^2 threshold above which drift is flagged
|
43
|
-
"""
|
44
|
-
|
45
|
-
# is_drift: bool
|
46
|
-
# threshold: float
|
47
|
-
# p_val: float
|
48
|
-
# distance: float
|
49
|
-
distance_threshold: float
|
17
|
+
from dataeval.config import get_device
|
18
|
+
from dataeval.detectors.drift._base import BaseDrift, UpdateStrategy, preprocess_x, update_x_ref
|
19
|
+
from dataeval.detectors.drift._torch import GaussianRBF, mmd2_from_kernel_matrix
|
20
|
+
from dataeval.outputs import DriftMMDOutput
|
21
|
+
from dataeval.outputs._base import set_metadata
|
22
|
+
from dataeval.typing import ArrayLike
|
50
23
|
|
51
24
|
|
52
25
|
class DriftMMD(BaseDrift):
|
@@ -84,6 +57,21 @@ class DriftMMD(BaseDrift):
|
|
84
57
|
device : str | None, default None
|
85
58
|
Device type used. The default None uses the GPU and falls back on CPU.
|
86
59
|
Can be specified by passing either 'cuda', 'gpu' or 'cpu'.
|
60
|
+
|
61
|
+
Example
|
62
|
+
-------
|
63
|
+
>>> from functools import partial
|
64
|
+
>>> from dataeval.detectors.drift import preprocess_drift
|
65
|
+
|
66
|
+
Use a preprocess function to encode images before testing for drift
|
67
|
+
|
68
|
+
>>> preprocess_fn = partial(preprocess_drift, model=encoder, batch_size=64)
|
69
|
+
>>> drift = DriftMMD(train_images, preprocess_fn=preprocess_fn)
|
70
|
+
|
71
|
+
Test incoming images for drift
|
72
|
+
|
73
|
+
>>> drift.predict(test_images).drifted
|
74
|
+
True
|
87
75
|
"""
|
88
76
|
|
89
77
|
def __init__(
|
@@ -110,12 +98,12 @@ class DriftMMD(BaseDrift):
|
|
110
98
|
self.device: torch.device = get_device(device)
|
111
99
|
|
112
100
|
# initialize kernel
|
113
|
-
sigma_tensor = torch.
|
101
|
+
sigma_tensor = torch.as_tensor(sigma, device=self.device) if sigma is not None else None
|
114
102
|
self._kernel = GaussianRBF(sigma_tensor).to(self.device)
|
115
103
|
|
116
104
|
# compute kernel matrix for the reference data
|
117
105
|
if self._infer_sigma or isinstance(sigma_tensor, torch.Tensor):
|
118
|
-
x = torch.
|
106
|
+
x = torch.as_tensor(self.x_ref, device=self.device)
|
119
107
|
self._k_xx = self._kernel(x, x, infer_sigma=self._infer_sigma)
|
120
108
|
self._infer_sigma = False
|
121
109
|
else:
|
@@ -147,21 +135,21 @@ class DriftMMD(BaseDrift):
|
|
147
135
|
p-value obtained from the permutation test, MMD^2 between the reference and test set,
|
148
136
|
and MMD^2 threshold above which :term:`drift<Drift>` is flagged
|
149
137
|
"""
|
150
|
-
|
151
|
-
|
152
|
-
n =
|
153
|
-
kernel_mat = self._kernel_matrix(x_ref,
|
138
|
+
x_ref = torch.as_tensor(self.x_ref, device=self.device)
|
139
|
+
x_test = torch.as_tensor(x, device=self.device)
|
140
|
+
n = x_test.shape[0]
|
141
|
+
kernel_mat = self._kernel_matrix(x_ref, x_test)
|
154
142
|
kernel_mat = kernel_mat - torch.diag(kernel_mat.diag()) # zero diagonal
|
155
143
|
mmd2 = mmd2_from_kernel_matrix(kernel_mat, n, permute=False, zero_diag=False)
|
156
|
-
mmd2_permuted = torch.
|
157
|
-
[mmd2_from_kernel_matrix(kernel_mat, n, permute=True, zero_diag=False)
|
144
|
+
mmd2_permuted = torch.tensor(
|
145
|
+
[mmd2_from_kernel_matrix(kernel_mat, n, permute=True, zero_diag=False)] * self.n_permutations,
|
146
|
+
device=self.device,
|
158
147
|
)
|
159
|
-
mmd2, mmd2_permuted = mmd2.detach().cpu(), mmd2_permuted.detach().cpu()
|
160
148
|
p_val = (mmd2 <= mmd2_permuted).float().mean()
|
161
149
|
# compute distance threshold
|
162
150
|
idx_threshold = int(self.p_val * len(mmd2_permuted))
|
163
151
|
distance_threshold = torch.sort(mmd2_permuted, descending=True).values[idx_threshold]
|
164
|
-
return p_val.
|
152
|
+
return float(p_val.item()), float(mmd2.item()), float(distance_threshold.item())
|
165
153
|
|
166
154
|
@set_metadata
|
167
155
|
@preprocess_x
|
@@ -17,7 +17,8 @@ import torch
|
|
17
17
|
import torch.nn as nn
|
18
18
|
from numpy.typing import NDArray
|
19
19
|
|
20
|
-
from dataeval.
|
20
|
+
from dataeval.config import get_device
|
21
|
+
from dataeval.utils.torch._internal import predict_batch
|
21
22
|
|
22
23
|
|
23
24
|
def mmd2_from_kernel_matrix(
|
@@ -14,14 +14,16 @@ from functools import partial
|
|
14
14
|
from typing import Callable, Literal
|
15
15
|
|
16
16
|
import numpy as np
|
17
|
-
from numpy.typing import
|
17
|
+
from numpy.typing import NDArray
|
18
18
|
from scipy.special import softmax
|
19
19
|
from scipy.stats import entropy
|
20
20
|
|
21
|
-
from dataeval.
|
22
|
-
from dataeval.detectors.drift.
|
23
|
-
from dataeval.detectors.drift.
|
24
|
-
from dataeval.
|
21
|
+
from dataeval.config import get_device
|
22
|
+
from dataeval.detectors.drift._base import UpdateStrategy
|
23
|
+
from dataeval.detectors.drift._ks import DriftKS
|
24
|
+
from dataeval.detectors.drift._torch import preprocess_drift
|
25
|
+
from dataeval.outputs import DriftOutput
|
26
|
+
from dataeval.typing import ArrayLike
|
25
27
|
|
26
28
|
|
27
29
|
def classifier_uncertainty(
|
@@ -87,7 +89,7 @@ class DriftUncertainty:
|
|
87
89
|
Reference data can optionally be updated using an UpdateStrategy class. Update
|
88
90
|
using the last n instances seen by the detector with LastSeenUpdateStrategy
|
89
91
|
or via reservoir sampling with ReservoirSamplingUpdateStrategy.
|
90
|
-
preds_type : "probs" | "logits", default "
|
92
|
+
preds_type : "probs" | "logits", default "probs"
|
91
93
|
Type of prediction output by the model. Options are 'probs' (in [0,1]) or
|
92
94
|
'logits' (in [-inf,inf]).
|
93
95
|
batch_size : int, default 32
|
@@ -98,7 +100,22 @@ class DriftUncertainty:
|
|
98
100
|
objects to a batch which can be processed by the model.
|
99
101
|
device : str | None, default None
|
100
102
|
Device type used. The default None tries to use the GPU and falls back on
|
101
|
-
CPU if needed. Can be specified by passing either 'cuda'
|
103
|
+
CPU if needed. Can be specified by passing either 'cuda' or 'cpu'.
|
104
|
+
|
105
|
+
Example
|
106
|
+
-------
|
107
|
+
>>> model = ClassificationModel()
|
108
|
+
>>> drift = DriftUncertainty(x_ref, model=model, batch_size=20)
|
109
|
+
|
110
|
+
Verify reference images have not drifted
|
111
|
+
|
112
|
+
>>> drift.predict(x_ref.copy()).drifted
|
113
|
+
False
|
114
|
+
|
115
|
+
Test incoming images for drift
|
116
|
+
|
117
|
+
>>> drift.predict(x_test).drifted
|
118
|
+
True
|
102
119
|
"""
|
103
120
|
|
104
121
|
def __init__(
|
@@ -7,15 +7,32 @@ from __future__ import annotations
|
|
7
7
|
|
8
8
|
__all__ = ["LastSeenUpdate", "ReservoirSamplingUpdate"]
|
9
9
|
|
10
|
+
from abc import ABC, abstractmethod
|
10
11
|
from typing import Any
|
11
12
|
|
12
13
|
import numpy as np
|
13
14
|
from numpy.typing import NDArray
|
14
15
|
|
15
|
-
|
16
|
+
|
17
|
+
class BaseUpdateStrategy(ABC):
|
18
|
+
"""
|
19
|
+
Updates reference dataset for drift detector
|
20
|
+
|
21
|
+
Parameters
|
22
|
+
----------
|
23
|
+
n : int
|
24
|
+
Update with last n instances seen by the detector.
|
25
|
+
"""
|
26
|
+
|
27
|
+
def __init__(self, n: int) -> None:
|
28
|
+
self.n = n
|
29
|
+
|
30
|
+
@abstractmethod
|
31
|
+
def __call__(self, x_ref: NDArray[Any], x: NDArray[Any], count: int) -> NDArray[Any]:
|
32
|
+
"""Abstract implementation of update strategy"""
|
16
33
|
|
17
34
|
|
18
|
-
class LastSeenUpdate(
|
35
|
+
class LastSeenUpdate(BaseUpdateStrategy):
|
19
36
|
"""
|
20
37
|
Updates reference dataset for :term:`drift<Drift>` detector using last seen method.
|
21
38
|
|
@@ -30,7 +47,7 @@ class LastSeenUpdate(UpdateStrategy):
|
|
30
47
|
return x_updated[-self.n :]
|
31
48
|
|
32
49
|
|
33
|
-
class ReservoirSamplingUpdate(
|
50
|
+
class ReservoirSamplingUpdate(BaseUpdateStrategy):
|
34
51
|
"""
|
35
52
|
Updates reference dataset for :term:`drift<Drift>` detector using reservoir sampling method.
|
36
53
|
|