dataeval 0.72.1__py3-none-any.whl → 0.73.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dataeval/__init__.py +4 -4
- dataeval/detectors/__init__.py +4 -3
- dataeval/detectors/drift/__init__.py +9 -10
- dataeval/{_internal/detectors → detectors}/drift/base.py +39 -91
- dataeval/{_internal/detectors → detectors}/drift/cvm.py +4 -3
- dataeval/{_internal/detectors → detectors}/drift/ks.py +4 -3
- dataeval/{_internal/detectors → detectors}/drift/mmd.py +23 -25
- dataeval/{_internal/detectors → detectors}/drift/torch.py +13 -11
- dataeval/{_internal/detectors → detectors}/drift/uncertainty.py +7 -5
- dataeval/detectors/drift/updates.py +61 -0
- dataeval/detectors/linters/__init__.py +3 -3
- dataeval/{_internal/detectors → detectors/linters}/clusterer.py +41 -39
- dataeval/{_internal/detectors → detectors/linters}/duplicates.py +19 -9
- dataeval/{_internal/detectors → detectors/linters}/merged_stats.py +3 -1
- dataeval/{_internal/detectors → detectors/linters}/outliers.py +14 -21
- dataeval/detectors/ood/__init__.py +6 -6
- dataeval/{_internal/detectors → detectors}/ood/ae.py +20 -12
- dataeval/detectors/ood/aegmm.py +66 -0
- dataeval/{_internal/detectors → detectors}/ood/base.py +33 -21
- dataeval/{_internal/detectors → detectors}/ood/llr.py +43 -33
- dataeval/detectors/ood/metadata_ks_compare.py +99 -0
- dataeval/detectors/ood/metadata_least_likely.py +119 -0
- dataeval/detectors/ood/metadata_ood_mi.py +92 -0
- dataeval/{_internal/detectors → detectors}/ood/vae.py +23 -17
- dataeval/detectors/ood/vaegmm.py +75 -0
- dataeval/interop.py +56 -0
- dataeval/metrics/__init__.py +1 -1
- dataeval/metrics/bias/__init__.py +4 -4
- dataeval/{_internal/metrics → metrics/bias}/balance.py +75 -13
- dataeval/{_internal/metrics → metrics/bias}/coverage.py +41 -7
- dataeval/{_internal/metrics → metrics/bias}/diversity.py +75 -18
- dataeval/metrics/bias/metadata.py +358 -0
- dataeval/{_internal/metrics → metrics/bias}/parity.py +54 -44
- dataeval/metrics/estimators/__init__.py +3 -3
- dataeval/{_internal/metrics → metrics/estimators}/ber.py +25 -22
- dataeval/{_internal/metrics → metrics/estimators}/divergence.py +11 -12
- dataeval/{_internal/metrics → metrics/estimators}/uap.py +5 -3
- dataeval/metrics/stats/__init__.py +7 -7
- dataeval/{_internal/metrics → metrics}/stats/base.py +59 -35
- dataeval/{_internal/metrics → metrics}/stats/boxratiostats.py +18 -14
- dataeval/{_internal/metrics → metrics}/stats/datasetstats.py +18 -16
- dataeval/{_internal/metrics → metrics}/stats/dimensionstats.py +9 -7
- dataeval/metrics/stats/hashstats.py +156 -0
- dataeval/{_internal/metrics → metrics}/stats/labelstats.py +5 -3
- dataeval/{_internal/metrics → metrics}/stats/pixelstats.py +9 -8
- dataeval/{_internal/metrics → metrics}/stats/visualstats.py +10 -9
- dataeval/{_internal/output.py → output.py} +26 -6
- dataeval/utils/__init__.py +8 -3
- dataeval/utils/image.py +71 -0
- dataeval/utils/lazy.py +26 -0
- dataeval/utils/metadata.py +258 -0
- dataeval/utils/shared.py +151 -0
- dataeval/{_internal → utils}/split_dataset.py +98 -33
- dataeval/utils/tensorflow/__init__.py +7 -6
- dataeval/{_internal/models/tensorflow → utils/tensorflow/_internal}/gmm.py +8 -2
- dataeval/{_internal/models/tensorflow/losses.py → utils/tensorflow/_internal/loss.py} +28 -18
- dataeval/{_internal/models/tensorflow/pixelcnn.py → utils/tensorflow/_internal/models.py} +387 -97
- dataeval/{_internal/models/tensorflow → utils/tensorflow/_internal}/trainer.py +15 -6
- dataeval/{_internal/models/tensorflow → utils/tensorflow/_internal}/utils.py +84 -85
- dataeval/utils/tensorflow/loss/__init__.py +6 -2
- dataeval/utils/torch/__init__.py +7 -3
- dataeval/{_internal/models/pytorch → utils/torch}/blocks.py +19 -14
- dataeval/{_internal → utils/torch}/datasets.py +48 -42
- dataeval/utils/torch/models.py +138 -0
- dataeval/{_internal/models/pytorch/autoencoder.py → utils/torch/trainer.py} +7 -136
- dataeval/{_internal → utils/torch}/utils.py +3 -1
- dataeval/workflows/__init__.py +1 -1
- dataeval/{_internal/workflows → workflows}/sufficiency.py +39 -34
- {dataeval-0.72.1.dist-info → dataeval-0.73.0.dist-info}/METADATA +4 -3
- dataeval-0.73.0.dist-info/RECORD +73 -0
- dataeval/_internal/detectors/__init__.py +0 -0
- dataeval/_internal/detectors/drift/__init__.py +0 -0
- dataeval/_internal/detectors/ood/__init__.py +0 -0
- dataeval/_internal/detectors/ood/aegmm.py +0 -78
- dataeval/_internal/detectors/ood/vaegmm.py +0 -89
- dataeval/_internal/interop.py +0 -49
- dataeval/_internal/metrics/__init__.py +0 -0
- dataeval/_internal/metrics/stats/hashstats.py +0 -75
- dataeval/_internal/metrics/utils.py +0 -447
- dataeval/_internal/models/__init__.py +0 -0
- dataeval/_internal/models/pytorch/__init__.py +0 -0
- dataeval/_internal/models/pytorch/utils.py +0 -67
- dataeval/_internal/models/tensorflow/__init__.py +0 -0
- dataeval/_internal/models/tensorflow/autoencoder.py +0 -320
- dataeval/_internal/workflows/__init__.py +0 -0
- dataeval/detectors/drift/kernels/__init__.py +0 -10
- dataeval/detectors/drift/updates/__init__.py +0 -8
- dataeval/utils/tensorflow/models/__init__.py +0 -9
- dataeval/utils/tensorflow/recon/__init__.py +0 -3
- dataeval/utils/torch/datasets/__init__.py +0 -12
- dataeval/utils/torch/models/__init__.py +0 -11
- dataeval/utils/torch/trainer/__init__.py +0 -7
- dataeval-0.72.1.dist-info/RECORD +0 -81
- {dataeval-0.72.1.dist-info → dataeval-0.73.0.dist-info}/LICENSE.txt +0 -0
- {dataeval-0.72.1.dist-info → dataeval-0.73.0.dist-info}/WHEEL +0 -0
dataeval/__init__.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1
|
-
__version__ = "0.
|
1
|
+
__version__ = "0.73.0"
|
2
2
|
|
3
3
|
from importlib.util import find_spec
|
4
4
|
|
@@ -8,16 +8,16 @@ _IS_TENSORFLOW_AVAILABLE = find_spec("tensorflow") is not None and find_spec("te
|
|
8
8
|
|
9
9
|
del find_spec
|
10
10
|
|
11
|
-
from
|
11
|
+
from dataeval import detectors, metrics # noqa: E402
|
12
12
|
|
13
13
|
__all__ = ["detectors", "metrics"]
|
14
14
|
|
15
15
|
if _IS_TORCH_AVAILABLE: # pragma: no cover
|
16
|
-
from
|
16
|
+
from dataeval import workflows
|
17
17
|
|
18
18
|
__all__ += ["workflows"]
|
19
19
|
|
20
20
|
if _IS_TENSORFLOW_AVAILABLE or _IS_TORCH_AVAILABLE: # pragma: no cover
|
21
|
-
from
|
21
|
+
from dataeval import utils
|
22
22
|
|
23
23
|
__all__ += ["utils"]
|
dataeval/detectors/__init__.py
CHANGED
@@ -3,12 +3,13 @@ Detectors can determine if a dataset or individual images in a dataset are indic
|
|
3
3
|
"""
|
4
4
|
|
5
5
|
from dataeval import _IS_TENSORFLOW_AVAILABLE
|
6
|
-
|
7
|
-
from . import drift, linters
|
6
|
+
from dataeval.detectors import drift, linters
|
8
7
|
|
9
8
|
__all__ = ["drift", "linters"]
|
10
9
|
|
11
10
|
if _IS_TENSORFLOW_AVAILABLE: # pragma: no cover
|
12
|
-
from . import ood
|
11
|
+
from dataeval.detectors import ood
|
13
12
|
|
14
13
|
__all__ += ["ood"]
|
14
|
+
|
15
|
+
del _IS_TENSORFLOW_AVAILABLE
|
@@ -3,19 +3,18 @@
|
|
3
3
|
"""
|
4
4
|
|
5
5
|
from dataeval import _IS_TORCH_AVAILABLE
|
6
|
-
from dataeval.
|
7
|
-
from dataeval.
|
8
|
-
from dataeval.
|
9
|
-
|
10
|
-
from . import updates
|
6
|
+
from dataeval.detectors.drift import updates
|
7
|
+
from dataeval.detectors.drift.base import DriftOutput
|
8
|
+
from dataeval.detectors.drift.cvm import DriftCVM
|
9
|
+
from dataeval.detectors.drift.ks import DriftKS
|
11
10
|
|
12
11
|
__all__ = ["DriftCVM", "DriftKS", "DriftOutput", "updates"]
|
13
12
|
|
14
13
|
if _IS_TORCH_AVAILABLE: # pragma: no cover
|
15
|
-
from dataeval.
|
16
|
-
from dataeval.
|
17
|
-
from dataeval.
|
14
|
+
from dataeval.detectors.drift.mmd import DriftMMD, DriftMMDOutput
|
15
|
+
from dataeval.detectors.drift.torch import preprocess_drift
|
16
|
+
from dataeval.detectors.drift.uncertainty import DriftUncertainty
|
18
17
|
|
19
|
-
|
18
|
+
__all__ += ["DriftMMD", "DriftMMDOutput", "DriftUncertainty", "preprocess_drift"]
|
20
19
|
|
21
|
-
|
20
|
+
del _IS_TORCH_AVAILABLE
|
@@ -8,16 +8,38 @@ Licensed under Apache Software License (Apache 2.0)
|
|
8
8
|
|
9
9
|
from __future__ import annotations
|
10
10
|
|
11
|
+
__all__ = ["DriftOutput"]
|
12
|
+
|
11
13
|
from abc import ABC, abstractmethod
|
12
14
|
from dataclasses import dataclass
|
13
15
|
from functools import wraps
|
14
|
-
from typing import Callable, Literal
|
16
|
+
from typing import Any, Callable, Literal, TypeVar
|
15
17
|
|
16
18
|
import numpy as np
|
17
19
|
from numpy.typing import ArrayLike, NDArray
|
18
20
|
|
19
|
-
from dataeval.
|
20
|
-
from dataeval.
|
21
|
+
from dataeval.interop import as_numpy, to_numpy
|
22
|
+
from dataeval.output import OutputMetadata, set_metadata
|
23
|
+
|
24
|
+
R = TypeVar("R")
|
25
|
+
|
26
|
+
|
27
|
+
class UpdateStrategy(ABC):
|
28
|
+
"""
|
29
|
+
Updates reference dataset for drift detector
|
30
|
+
|
31
|
+
Parameters
|
32
|
+
----------
|
33
|
+
n : int
|
34
|
+
Update with last n instances seen by the detector.
|
35
|
+
"""
|
36
|
+
|
37
|
+
def __init__(self, n: int) -> None:
|
38
|
+
self.n = n
|
39
|
+
|
40
|
+
@abstractmethod
|
41
|
+
def __call__(self, x_ref: NDArray[Any], x: NDArray[Any], count: int) -> NDArray[Any]:
|
42
|
+
"""Abstract implementation of update strategy"""
|
21
43
|
|
22
44
|
|
23
45
|
@dataclass(frozen=True)
|
@@ -70,9 +92,11 @@ class DriftOutput(DriftBaseOutput):
|
|
70
92
|
distances: NDArray[np.float32]
|
71
93
|
|
72
94
|
|
73
|
-
def update_x_ref(fn):
|
95
|
+
def update_x_ref(fn: Callable[..., R]) -> Callable[..., R]:
|
96
|
+
"""Decorator to update x_ref with x using selected update methodology"""
|
97
|
+
|
74
98
|
@wraps(fn)
|
75
|
-
def _(self, x, *args, **kwargs):
|
99
|
+
def _(self, x, *args, **kwargs) -> R:
|
76
100
|
output = fn(self, x, *args, **kwargs)
|
77
101
|
|
78
102
|
# update reference dataset
|
@@ -86,9 +110,11 @@ def update_x_ref(fn):
|
|
86
110
|
return _
|
87
111
|
|
88
112
|
|
89
|
-
def preprocess_x(fn):
|
113
|
+
def preprocess_x(fn: Callable[..., R]) -> Callable[..., R]:
|
114
|
+
"""Decorator to run preprocess_fn on x before calling wrapped function"""
|
115
|
+
|
90
116
|
@wraps(fn)
|
91
|
-
def _(self, x, *args, **kwargs):
|
117
|
+
def _(self, x, *args, **kwargs) -> R:
|
92
118
|
if self._x_refcount == 0:
|
93
119
|
self._x = self._preprocess(x)
|
94
120
|
self._x_refcount += 1
|
@@ -101,70 +127,6 @@ def preprocess_x(fn):
|
|
101
127
|
return _
|
102
128
|
|
103
129
|
|
104
|
-
class UpdateStrategy(ABC):
|
105
|
-
"""
|
106
|
-
Updates reference dataset for :term:`drift<Drift>` detector
|
107
|
-
|
108
|
-
Parameters
|
109
|
-
----------
|
110
|
-
n : int
|
111
|
-
Update with last n instances seen by the detector.
|
112
|
-
"""
|
113
|
-
|
114
|
-
def __init__(self, n: int):
|
115
|
-
self.n = n
|
116
|
-
|
117
|
-
@abstractmethod
|
118
|
-
def __call__(self, x_ref: NDArray, x: NDArray, count: int) -> NDArray:
|
119
|
-
"""Abstract implementation of update strategy"""
|
120
|
-
|
121
|
-
|
122
|
-
class LastSeenUpdate(UpdateStrategy):
|
123
|
-
"""
|
124
|
-
Updates reference dataset for :term:`drift<Drift>` detector using last seen method.
|
125
|
-
|
126
|
-
Parameters
|
127
|
-
----------
|
128
|
-
n : int
|
129
|
-
Update with last n instances seen by the detector.
|
130
|
-
"""
|
131
|
-
|
132
|
-
def __call__(self, x_ref: NDArray, x: NDArray, count: int) -> NDArray:
|
133
|
-
x_updated = np.concatenate([x_ref, x], axis=0)
|
134
|
-
return x_updated[-self.n :]
|
135
|
-
|
136
|
-
|
137
|
-
class ReservoirSamplingUpdate(UpdateStrategy):
|
138
|
-
"""
|
139
|
-
Updates reference dataset for :term:`drift<Drift>` detector using reservoir sampling method.
|
140
|
-
|
141
|
-
Parameters
|
142
|
-
----------
|
143
|
-
n : int
|
144
|
-
Update with last n instances seen by the detector.
|
145
|
-
"""
|
146
|
-
|
147
|
-
def __call__(self, x_ref: NDArray, x: NDArray, count: int) -> NDArray:
|
148
|
-
if x.shape[0] + count <= self.n:
|
149
|
-
return np.concatenate([x_ref, x], axis=0)
|
150
|
-
|
151
|
-
n_ref = x_ref.shape[0]
|
152
|
-
output_size = min(self.n, n_ref + x.shape[0])
|
153
|
-
shape = (output_size,) + x.shape[1:]
|
154
|
-
x_reservoir = np.zeros(shape, dtype=x_ref.dtype)
|
155
|
-
x_reservoir[:n_ref] = x_ref
|
156
|
-
for item in x:
|
157
|
-
count += 1
|
158
|
-
if n_ref < self.n:
|
159
|
-
x_reservoir[n_ref, :] = item
|
160
|
-
n_ref += 1
|
161
|
-
else:
|
162
|
-
r = np.random.randint(0, count)
|
163
|
-
if r < self.n:
|
164
|
-
x_reservoir[r, :] = item
|
165
|
-
return x_reservoir
|
166
|
-
|
167
|
-
|
168
130
|
class BaseDrift:
|
169
131
|
"""
|
170
132
|
A generic :term:`drift<Drift>` detection component for preprocessing data and applying statistical correction.
|
@@ -223,7 +185,7 @@ class BaseDrift:
|
|
223
185
|
p_val: float = 0.05,
|
224
186
|
x_ref_preprocessed: bool = False,
|
225
187
|
update_x_ref: UpdateStrategy | None = None,
|
226
|
-
preprocess_fn: Callable[
|
188
|
+
preprocess_fn: Callable[..., ArrayLike] | None = None,
|
227
189
|
correction: Literal["bonferroni", "fdr"] = "bonferroni",
|
228
190
|
) -> None:
|
229
191
|
# Type checking
|
@@ -235,20 +197,20 @@ class BaseDrift:
|
|
235
197
|
raise ValueError("`correction` must be `bonferroni` or `fdr`.")
|
236
198
|
|
237
199
|
self._x_ref = to_numpy(x_ref)
|
238
|
-
self.x_ref_preprocessed = x_ref_preprocessed
|
200
|
+
self.x_ref_preprocessed: bool = x_ref_preprocessed
|
239
201
|
|
240
202
|
# Other attributes
|
241
203
|
self.p_val = p_val
|
242
204
|
self.update_x_ref = update_x_ref
|
243
205
|
self.preprocess_fn = preprocess_fn
|
244
206
|
self.correction = correction
|
245
|
-
self.n = len(self._x_ref)
|
207
|
+
self.n: int = len(self._x_ref)
|
246
208
|
|
247
209
|
# Ref counter for preprocessed x
|
248
210
|
self._x_refcount = 0
|
249
211
|
|
250
212
|
@property
|
251
|
-
def x_ref(self) -> NDArray:
|
213
|
+
def x_ref(self) -> NDArray[Any]:
|
252
214
|
"""
|
253
215
|
Retrieve the reference data, applying preprocessing if not already done.
|
254
216
|
|
@@ -313,9 +275,6 @@ class BaseDriftUnivariate(BaseDrift):
|
|
313
275
|
|
314
276
|
Attributes
|
315
277
|
----------
|
316
|
-
_n_features : int | None
|
317
|
-
Number of features in the data. If not provided, it is lazily inferred from the
|
318
|
-
input data and any preprocessing function.
|
319
278
|
p_val : float
|
320
279
|
The significance level for drift detection.
|
321
280
|
correction : str
|
@@ -324,17 +283,6 @@ class BaseDriftUnivariate(BaseDrift):
|
|
324
283
|
Strategy for updating the reference data if applicable.
|
325
284
|
preprocess_fn : Callable | None
|
326
285
|
Function used for preprocessing input data before drift detection.
|
327
|
-
|
328
|
-
Methods
|
329
|
-
-------
|
330
|
-
n_features:
|
331
|
-
Property that returns the number of features, inferring it if necessary.
|
332
|
-
score(x):
|
333
|
-
Abstract method to compute univariate feature scores after preprocessing.
|
334
|
-
_apply_correction(p_vals):
|
335
|
-
Apply a statistical correction to p-values to account for multiple testing.
|
336
|
-
predict(x):
|
337
|
-
Predict whether drift has occurred on a batch of data, applying multivariate correction if needed.
|
338
286
|
"""
|
339
287
|
|
340
288
|
def __init__(
|
@@ -427,7 +375,7 @@ class BaseDriftUnivariate(BaseDrift):
|
|
427
375
|
return drift_pred, threshold
|
428
376
|
elif self.correction == "fdr":
|
429
377
|
n = p_vals.shape[0]
|
430
|
-
i = np.arange(n) + 1
|
378
|
+
i = np.arange(n) + np.int_(1)
|
431
379
|
p_sorted = np.sort(p_vals)
|
432
380
|
q_threshold = self.p_val * i / n
|
433
381
|
below_threshold = p_sorted < q_threshold
|
@@ -439,7 +387,7 @@ class BaseDriftUnivariate(BaseDrift):
|
|
439
387
|
else:
|
440
388
|
raise ValueError("`correction` needs to be either `bonferroni` or `fdr`.")
|
441
389
|
|
442
|
-
@set_metadata(
|
390
|
+
@set_metadata()
|
443
391
|
@preprocess_x
|
444
392
|
@update_x_ref
|
445
393
|
def predict(
|
@@ -8,15 +8,16 @@ Licensed under Apache Software License (Apache 2.0)
|
|
8
8
|
|
9
9
|
from __future__ import annotations
|
10
10
|
|
11
|
+
__all__ = ["DriftCVM"]
|
12
|
+
|
11
13
|
from typing import Callable, Literal
|
12
14
|
|
13
15
|
import numpy as np
|
14
16
|
from numpy.typing import ArrayLike, NDArray
|
15
17
|
from scipy.stats import cramervonmises_2samp
|
16
18
|
|
17
|
-
from dataeval.
|
18
|
-
|
19
|
-
from .base import BaseDriftUnivariate, UpdateStrategy, preprocess_x
|
19
|
+
from dataeval.detectors.drift.base import BaseDriftUnivariate, UpdateStrategy, preprocess_x
|
20
|
+
from dataeval.interop import to_numpy
|
20
21
|
|
21
22
|
|
22
23
|
class DriftCVM(BaseDriftUnivariate):
|
@@ -8,15 +8,16 @@ Licensed under Apache Software License (Apache 2.0)
|
|
8
8
|
|
9
9
|
from __future__ import annotations
|
10
10
|
|
11
|
+
__all__ = ["DriftKS"]
|
12
|
+
|
11
13
|
from typing import Callable, Literal
|
12
14
|
|
13
15
|
import numpy as np
|
14
16
|
from numpy.typing import ArrayLike, NDArray
|
15
17
|
from scipy.stats import ks_2samp
|
16
18
|
|
17
|
-
from dataeval.
|
18
|
-
|
19
|
-
from .base import BaseDriftUnivariate, UpdateStrategy, preprocess_x
|
19
|
+
from dataeval.detectors.drift.base import BaseDriftUnivariate, UpdateStrategy, preprocess_x
|
20
|
+
from dataeval.interop import to_numpy
|
20
21
|
|
21
22
|
|
22
23
|
class DriftKS(BaseDriftUnivariate):
|
@@ -8,17 +8,18 @@ Licensed under Apache Software License (Apache 2.0)
|
|
8
8
|
|
9
9
|
from __future__ import annotations
|
10
10
|
|
11
|
+
__all__ = ["DriftMMD", "DriftMMDOutput"]
|
12
|
+
|
11
13
|
from dataclasses import dataclass
|
12
14
|
from typing import Callable
|
13
15
|
|
14
16
|
import torch
|
15
17
|
from numpy.typing import ArrayLike
|
16
18
|
|
17
|
-
from dataeval.
|
18
|
-
from dataeval.
|
19
|
-
|
20
|
-
from .
|
21
|
-
from .torch import GaussianRBF, get_device, mmd2_from_kernel_matrix
|
19
|
+
from dataeval.detectors.drift.base import BaseDrift, DriftBaseOutput, UpdateStrategy, preprocess_x, update_x_ref
|
20
|
+
from dataeval.detectors.drift.torch import _GaussianRBF, _mmd2_from_kernel_matrix, get_device
|
21
|
+
from dataeval.interop import as_numpy
|
22
|
+
from dataeval.output import set_metadata
|
22
23
|
|
23
24
|
|
24
25
|
@dataclass(frozen=True)
|
@@ -70,10 +71,8 @@ class DriftMMD(BaseDrift):
|
|
70
71
|
preprocess_fn : Callable | None, default None
|
71
72
|
Function to preprocess the data before computing the data drift metrics.
|
72
73
|
Typically a :term:`dimensionality reduction<Dimensionality Reduction>` technique.
|
73
|
-
kernel : Callable, default GaussianRBF
|
74
|
-
Kernel used for the MMD computation, defaults to Gaussian RBF kernel.
|
75
74
|
sigma : ArrayLike | None, default None
|
76
|
-
Optionally set the GaussianRBF kernel bandwidth. Can also pass multiple
|
75
|
+
Optionally set the internal GaussianRBF kernel bandwidth. Can also pass multiple
|
77
76
|
bandwidth values as an array. The kernel evaluation is then averaged over
|
78
77
|
those bandwidths.
|
79
78
|
configure_kernel_from_x_ref : bool, default True
|
@@ -91,41 +90,40 @@ class DriftMMD(BaseDrift):
|
|
91
90
|
p_val: float = 0.05,
|
92
91
|
x_ref_preprocessed: bool = False,
|
93
92
|
update_x_ref: UpdateStrategy | None = None,
|
94
|
-
preprocess_fn: Callable[
|
95
|
-
kernel: Callable = GaussianRBF,
|
93
|
+
preprocess_fn: Callable[..., ArrayLike] | None = None,
|
96
94
|
sigma: ArrayLike | None = None,
|
97
95
|
configure_kernel_from_x_ref: bool = True,
|
98
96
|
n_permutations: int = 100,
|
99
|
-
device: str | None = None,
|
97
|
+
device: str | torch.device | None = None,
|
100
98
|
) -> None:
|
101
99
|
super().__init__(x_ref, p_val, x_ref_preprocessed, update_x_ref, preprocess_fn)
|
102
100
|
|
103
|
-
self.
|
101
|
+
self._infer_sigma = configure_kernel_from_x_ref
|
104
102
|
if configure_kernel_from_x_ref and sigma is not None:
|
105
|
-
self.
|
103
|
+
self._infer_sigma = False
|
106
104
|
|
107
105
|
self.n_permutations = n_permutations # nb of iterations through permutation test
|
108
106
|
|
109
107
|
# set device
|
110
|
-
self.device = get_device(device)
|
108
|
+
self.device: torch.device = get_device(device)
|
111
109
|
|
112
110
|
# initialize kernel
|
113
111
|
sigma_tensor = torch.from_numpy(as_numpy(sigma)).to(self.device) if sigma is not None else None
|
114
|
-
self.
|
112
|
+
self._kernel = _GaussianRBF(sigma_tensor).to(self.device)
|
115
113
|
|
116
114
|
# compute kernel matrix for the reference data
|
117
|
-
if self.
|
115
|
+
if self._infer_sigma or isinstance(sigma_tensor, torch.Tensor):
|
118
116
|
x = torch.from_numpy(self.x_ref).to(self.device)
|
119
|
-
self.
|
120
|
-
self.
|
117
|
+
self._k_xx = self._kernel(x, x, infer_sigma=self._infer_sigma)
|
118
|
+
self._infer_sigma = False
|
121
119
|
else:
|
122
|
-
self.
|
120
|
+
self._k_xx, self._infer_sigma = None, True
|
123
121
|
|
124
122
|
def _kernel_matrix(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
125
123
|
"""Compute and return full kernel matrix between arrays x and y."""
|
126
|
-
k_xy = self.
|
127
|
-
k_xx = self.
|
128
|
-
k_yy = self.
|
124
|
+
k_xy = self._kernel(x, y, self._infer_sigma)
|
125
|
+
k_xx = self._k_xx if self._k_xx is not None and self.update_x_ref is None else self._kernel(x, x)
|
126
|
+
k_yy = self._kernel(y, y)
|
129
127
|
kernel_mat = torch.cat([torch.cat([k_xx, k_xy], 1), torch.cat([k_xy.T, k_yy], 1)], 0)
|
130
128
|
return kernel_mat
|
131
129
|
|
@@ -152,9 +150,9 @@ class DriftMMD(BaseDrift):
|
|
152
150
|
n = x.shape[0]
|
153
151
|
kernel_mat = self._kernel_matrix(x_ref, torch.from_numpy(x).to(self.device))
|
154
152
|
kernel_mat = kernel_mat - torch.diag(kernel_mat.diag()) # zero diagonal
|
155
|
-
mmd2 =
|
153
|
+
mmd2 = _mmd2_from_kernel_matrix(kernel_mat, n, permute=False, zero_diag=False)
|
156
154
|
mmd2_permuted = torch.Tensor(
|
157
|
-
[
|
155
|
+
[_mmd2_from_kernel_matrix(kernel_mat, n, permute=True, zero_diag=False) for _ in range(self.n_permutations)]
|
158
156
|
)
|
159
157
|
mmd2, mmd2_permuted = mmd2.detach().cpu(), mmd2_permuted.detach().cpu()
|
160
158
|
p_val = (mmd2 <= mmd2_permuted).float().mean()
|
@@ -163,7 +161,7 @@ class DriftMMD(BaseDrift):
|
|
163
161
|
distance_threshold = torch.sort(mmd2_permuted, descending=True).values[idx_threshold]
|
164
162
|
return p_val.numpy().item(), mmd2.numpy().item(), distance_threshold.numpy().item()
|
165
163
|
|
166
|
-
@set_metadata(
|
164
|
+
@set_metadata()
|
167
165
|
@preprocess_x
|
168
166
|
@update_x_ref
|
169
167
|
def predict(self, x: ArrayLike) -> DriftMMDOutput:
|
@@ -8,8 +8,10 @@ Licensed under Apache Software License (Apache 2.0)
|
|
8
8
|
|
9
9
|
from __future__ import annotations
|
10
10
|
|
11
|
+
__all__ = []
|
12
|
+
|
11
13
|
from functools import partial
|
12
|
-
from typing import Callable
|
14
|
+
from typing import Any, Callable
|
13
15
|
|
14
16
|
import numpy as np
|
15
17
|
import torch
|
@@ -42,7 +44,7 @@ def get_device(device: str | torch.device | None = None) -> torch.device:
|
|
42
44
|
return torch_device
|
43
45
|
|
44
46
|
|
45
|
-
def
|
47
|
+
def _mmd2_from_kernel_matrix(
|
46
48
|
kernel_mat: torch.Tensor, m: int, permute: bool = False, zero_diag: bool = True
|
47
49
|
) -> torch.Tensor:
|
48
50
|
"""
|
@@ -78,13 +80,13 @@ def mmd2_from_kernel_matrix(
|
|
78
80
|
|
79
81
|
|
80
82
|
def predict_batch(
|
81
|
-
x: NDArray | torch.Tensor,
|
83
|
+
x: NDArray[Any] | torch.Tensor,
|
82
84
|
model: Callable | nn.Module | nn.Sequential,
|
83
85
|
device: torch.device | None = None,
|
84
86
|
batch_size: int = int(1e10),
|
85
87
|
preprocess_fn: Callable | None = None,
|
86
88
|
dtype: type[np.generic] | torch.dtype = np.float32,
|
87
|
-
) -> NDArray | torch.Tensor | tuple:
|
89
|
+
) -> NDArray[Any] | torch.Tensor | tuple[Any, ...]:
|
88
90
|
"""
|
89
91
|
Make batch predictions on a model.
|
90
92
|
|
@@ -154,13 +156,13 @@ def predict_batch(
|
|
154
156
|
|
155
157
|
|
156
158
|
def preprocess_drift(
|
157
|
-
x: NDArray,
|
159
|
+
x: NDArray[Any],
|
158
160
|
model: nn.Module,
|
159
|
-
device: torch.device | None = None,
|
161
|
+
device: str | torch.device | None = None,
|
160
162
|
preprocess_batch_fn: Callable | None = None,
|
161
163
|
batch_size: int = int(1e10),
|
162
164
|
dtype: type[np.generic] | torch.dtype = np.float32,
|
163
|
-
) -> NDArray | torch.Tensor | tuple:
|
165
|
+
) -> NDArray[Any] | torch.Tensor | tuple[Any, ...]:
|
164
166
|
"""
|
165
167
|
Prediction function used for preprocessing step of drift detector.
|
166
168
|
|
@@ -189,7 +191,7 @@ def preprocess_drift(
|
|
189
191
|
return predict_batch(
|
190
192
|
x,
|
191
193
|
model,
|
192
|
-
device=device,
|
194
|
+
device=get_device(device),
|
193
195
|
batch_size=batch_size,
|
194
196
|
preprocess_fn=preprocess_batch_fn,
|
195
197
|
dtype=dtype,
|
@@ -197,7 +199,7 @@ def preprocess_drift(
|
|
197
199
|
|
198
200
|
|
199
201
|
@torch.jit.script
|
200
|
-
def
|
202
|
+
def _squared_pairwise_distance(
|
201
203
|
x: torch.Tensor, y: torch.Tensor, a_min: float = 1e-30
|
202
204
|
) -> torch.Tensor: # pragma: no cover - torch.jit.script code is compiled and copied
|
203
205
|
"""
|
@@ -249,7 +251,7 @@ def sigma_median(x: torch.Tensor, y: torch.Tensor, dist: torch.Tensor) -> torch.
|
|
249
251
|
return sigma
|
250
252
|
|
251
253
|
|
252
|
-
class
|
254
|
+
class _GaussianRBF(nn.Module):
|
253
255
|
"""
|
254
256
|
Gaussian RBF kernel: k(x,y) = exp(-(1/(2*sigma^2)||x-y||^2).
|
255
257
|
|
@@ -303,7 +305,7 @@ class GaussianRBF(nn.Module):
|
|
303
305
|
infer_sigma: bool = False,
|
304
306
|
) -> torch.Tensor:
|
305
307
|
x, y = torch.as_tensor(x), torch.as_tensor(y)
|
306
|
-
dist =
|
308
|
+
dist = _squared_pairwise_distance(x.flatten(1), y.flatten(1)) # [Nx, Ny]
|
307
309
|
|
308
310
|
if infer_sigma or self.init_required:
|
309
311
|
if self.trainable and infer_sigma:
|
@@ -8,6 +8,8 @@ Licensed under Apache Software License (Apache 2.0)
|
|
8
8
|
|
9
9
|
from __future__ import annotations
|
10
10
|
|
11
|
+
__all__ = ["DriftUncertainty"]
|
12
|
+
|
11
13
|
from functools import partial
|
12
14
|
from typing import Callable, Literal
|
13
15
|
|
@@ -16,16 +18,16 @@ from numpy.typing import ArrayLike, NDArray
|
|
16
18
|
from scipy.special import softmax
|
17
19
|
from scipy.stats import entropy
|
18
20
|
|
19
|
-
from .base import DriftOutput, UpdateStrategy
|
20
|
-
from .ks import DriftKS
|
21
|
-
from .torch import get_device, preprocess_drift
|
21
|
+
from dataeval.detectors.drift.base import DriftOutput, UpdateStrategy
|
22
|
+
from dataeval.detectors.drift.ks import DriftKS
|
23
|
+
from dataeval.detectors.drift.torch import get_device, preprocess_drift
|
22
24
|
|
23
25
|
|
24
26
|
def classifier_uncertainty(
|
25
|
-
x: NDArray,
|
27
|
+
x: NDArray[np.float64],
|
26
28
|
model_fn: Callable,
|
27
29
|
preds_type: Literal["probs", "logits"] = "probs",
|
28
|
-
) -> NDArray:
|
30
|
+
) -> NDArray[np.float64]:
|
29
31
|
"""
|
30
32
|
Evaluate model_fn on x and transform predictions to prediction uncertainties.
|
31
33
|
|
@@ -0,0 +1,61 @@
|
|
1
|
+
"""
|
2
|
+
Update strategies inform how the :term:`drift<Drift>` detector classes update the reference data when monitoring
|
3
|
+
for drift.
|
4
|
+
"""
|
5
|
+
|
6
|
+
from __future__ import annotations
|
7
|
+
|
8
|
+
__all__ = ["LastSeenUpdate", "ReservoirSamplingUpdate"]
|
9
|
+
|
10
|
+
from typing import Any
|
11
|
+
|
12
|
+
import numpy as np
|
13
|
+
from numpy.typing import NDArray
|
14
|
+
|
15
|
+
from dataeval.detectors.drift.base import UpdateStrategy
|
16
|
+
|
17
|
+
|
18
|
+
class LastSeenUpdate(UpdateStrategy):
|
19
|
+
"""
|
20
|
+
Updates reference dataset for :term:`drift<Drift>` detector using last seen method.
|
21
|
+
|
22
|
+
Parameters
|
23
|
+
----------
|
24
|
+
n : int
|
25
|
+
Update with last n instances seen by the detector.
|
26
|
+
"""
|
27
|
+
|
28
|
+
def __call__(self, x_ref: NDArray[Any], x: NDArray[Any], count: int) -> NDArray[Any]:
|
29
|
+
x_updated = np.concatenate([x_ref, x], axis=0)
|
30
|
+
return x_updated[-self.n :]
|
31
|
+
|
32
|
+
|
33
|
+
class ReservoirSamplingUpdate(UpdateStrategy):
|
34
|
+
"""
|
35
|
+
Updates reference dataset for :term:`drift<Drift>` detector using reservoir sampling method.
|
36
|
+
|
37
|
+
Parameters
|
38
|
+
----------
|
39
|
+
n : int
|
40
|
+
Update with last n instances seen by the detector.
|
41
|
+
"""
|
42
|
+
|
43
|
+
def __call__(self, x_ref: NDArray[Any], x: NDArray[Any], count: int) -> NDArray[Any]:
|
44
|
+
if x.shape[0] + count <= self.n:
|
45
|
+
return np.concatenate([x_ref, x], axis=0)
|
46
|
+
|
47
|
+
n_ref = x_ref.shape[0]
|
48
|
+
output_size = min(self.n, n_ref + x.shape[0])
|
49
|
+
shape = (output_size,) + x.shape[1:]
|
50
|
+
x_reservoir = np.zeros(shape, dtype=x_ref.dtype)
|
51
|
+
x_reservoir[:n_ref] = x_ref
|
52
|
+
for item in x:
|
53
|
+
count += 1
|
54
|
+
if n_ref < self.n:
|
55
|
+
x_reservoir[n_ref, :] = item
|
56
|
+
n_ref += 1
|
57
|
+
else:
|
58
|
+
r = np.random.randint(0, count)
|
59
|
+
if r < self.n:
|
60
|
+
x_reservoir[r, :] = item
|
61
|
+
return x_reservoir
|
@@ -2,9 +2,9 @@
|
|
2
2
|
Linters help identify potential issues in training and test data and are an important aspect of data cleaning.
|
3
3
|
"""
|
4
4
|
|
5
|
-
from dataeval.
|
6
|
-
from dataeval.
|
7
|
-
from dataeval.
|
5
|
+
from dataeval.detectors.linters.clusterer import Clusterer, ClustererOutput
|
6
|
+
from dataeval.detectors.linters.duplicates import Duplicates, DuplicatesOutput
|
7
|
+
from dataeval.detectors.linters.outliers import Outliers, OutliersOutput
|
8
8
|
|
9
9
|
__all__ = [
|
10
10
|
"Clusterer",
|