dataeval 0.84.0__py3-none-any.whl → 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dataeval/__init__.py +1 -1
- dataeval/data/__init__.py +19 -0
- dataeval/data/_embeddings.py +345 -0
- dataeval/{utils/data → data}/_images.py +2 -2
- dataeval/{utils/data → data}/_metadata.py +8 -7
- dataeval/{utils/data → data}/_selection.py +22 -9
- dataeval/{utils/data → data}/_split.py +1 -1
- dataeval/data/selections/__init__.py +19 -0
- dataeval/data/selections/_classbalance.py +37 -0
- dataeval/data/selections/_classfilter.py +109 -0
- dataeval/{utils/data → data}/selections/_indices.py +1 -1
- dataeval/{utils/data → data}/selections/_limit.py +1 -1
- dataeval/{utils/data → data}/selections/_prioritize.py +3 -3
- dataeval/{utils/data → data}/selections/_reverse.py +1 -1
- dataeval/{utils/data → data}/selections/_shuffle.py +3 -3
- dataeval/detectors/drift/__init__.py +2 -2
- dataeval/detectors/drift/_base.py +55 -203
- dataeval/detectors/drift/_cvm.py +19 -30
- dataeval/detectors/drift/_ks.py +18 -30
- dataeval/detectors/drift/_mmd.py +189 -53
- dataeval/detectors/drift/_uncertainty.py +52 -56
- dataeval/detectors/drift/updates.py +13 -12
- dataeval/detectors/linters/duplicates.py +6 -4
- dataeval/detectors/linters/outliers.py +3 -3
- dataeval/detectors/ood/ae.py +1 -1
- dataeval/metadata/_distance.py +1 -1
- dataeval/metadata/_ood.py +4 -4
- dataeval/metrics/bias/_balance.py +1 -1
- dataeval/metrics/bias/_diversity.py +1 -1
- dataeval/metrics/bias/_parity.py +1 -1
- dataeval/metrics/stats/_base.py +7 -7
- dataeval/metrics/stats/_dimensionstats.py +2 -2
- dataeval/metrics/stats/_hashstats.py +2 -2
- dataeval/metrics/stats/_imagestats.py +4 -4
- dataeval/metrics/stats/_labelstats.py +2 -2
- dataeval/metrics/stats/_pixelstats.py +2 -2
- dataeval/metrics/stats/_visualstats.py +2 -2
- dataeval/outputs/_bias.py +1 -1
- dataeval/typing.py +53 -19
- dataeval/utils/__init__.py +2 -2
- dataeval/utils/_array.py +18 -7
- dataeval/utils/data/__init__.py +5 -20
- dataeval/utils/data/_dataset.py +6 -4
- dataeval/utils/data/collate.py +2 -0
- dataeval/utils/datasets/__init__.py +17 -0
- dataeval/utils/{data/datasets → datasets}/_base.py +10 -7
- dataeval/utils/{data/datasets → datasets}/_cifar10.py +11 -11
- dataeval/utils/{data/datasets → datasets}/_milco.py +44 -16
- dataeval/utils/{data/datasets → datasets}/_mnist.py +11 -7
- dataeval/utils/{data/datasets → datasets}/_ships.py +10 -6
- dataeval/utils/{data/datasets → datasets}/_voc.py +43 -22
- dataeval/utils/torch/_internal.py +12 -35
- {dataeval-0.84.0.dist-info → dataeval-1.0.0.dist-info}/METADATA +2 -3
- dataeval-1.0.0.dist-info/RECORD +107 -0
- dataeval/detectors/drift/_torch.py +0 -222
- dataeval/utils/data/_embeddings.py +0 -186
- dataeval/utils/data/datasets/__init__.py +0 -17
- dataeval/utils/data/selections/__init__.py +0 -17
- dataeval/utils/data/selections/_classfilter.py +0 -59
- dataeval-0.84.0.dist-info/RECORD +0 -106
- /dataeval/{utils/data → data}/_targets.py +0 -0
- /dataeval/utils/{metadata.py → data/metadata.py} +0 -0
- /dataeval/utils/{data/datasets → datasets}/_fileio.py +0 -0
- /dataeval/utils/{data/datasets → datasets}/_mixin.py +0 -0
- /dataeval/utils/{data/datasets → datasets}/_types.py +0 -0
- {dataeval-0.84.0.dist-info → dataeval-1.0.0.dist-info}/LICENSE.txt +0 -0
- {dataeval-0.84.0.dist-info → dataeval-1.0.0.dist-info}/WHEEL +0 -0
dataeval/detectors/drift/_ks.py
CHANGED
@@ -10,14 +10,15 @@ from __future__ import annotations
|
|
10
10
|
|
11
11
|
__all__ = []
|
12
12
|
|
13
|
-
from typing import
|
13
|
+
from typing import Literal
|
14
14
|
|
15
15
|
import numpy as np
|
16
16
|
from numpy.typing import NDArray
|
17
17
|
from scipy.stats import ks_2samp
|
18
18
|
|
19
|
+
from dataeval.data._embeddings import Embeddings
|
19
20
|
from dataeval.detectors.drift._base import BaseDriftUnivariate, UpdateStrategy
|
20
|
-
from dataeval.typing import
|
21
|
+
from dataeval.typing import Array
|
21
22
|
|
22
23
|
|
23
24
|
class DriftKS(BaseDriftUnivariate):
|
@@ -31,43 +32,34 @@ class DriftKS(BaseDriftUnivariate):
|
|
31
32
|
|
32
33
|
Parameters
|
33
34
|
----------
|
34
|
-
|
35
|
+
data : Embeddings or Array
|
35
36
|
Data used as reference distribution.
|
36
|
-
p_val : float
|
37
|
+
p_val : float or None, default 0.05
|
37
38
|
:term:`p-value<P-Value>` used for significance of the statistical test for each feature.
|
38
39
|
If the FDR correction method is used, this corresponds to the acceptable
|
39
40
|
q-value.
|
40
|
-
|
41
|
-
Whether the given reference data ``x_ref`` has been preprocessed yet.
|
42
|
-
If ``True``, only the test data ``x`` will be preprocessed at prediction time.
|
43
|
-
If ``False``, the reference data will also be preprocessed.
|
44
|
-
update_x_ref : UpdateStrategy | None, default None
|
41
|
+
update_strategy : UpdateStrategy or None, default None
|
45
42
|
Reference data can optionally be updated using an UpdateStrategy class. Update
|
46
43
|
using the last n instances seen by the detector with LastSeenUpdateStrategy
|
47
44
|
or via reservoir sampling with ReservoirSamplingUpdateStrategy.
|
48
|
-
|
49
|
-
Function to preprocess the data before computing the data :term:`drift<Drift>` metrics.
|
50
|
-
Typically a :term:`dimensionality reduction<Dimensionality Reduction>` technique.
|
51
|
-
correction : "bonferroni" | "fdr", default "bonferroni"
|
45
|
+
correction : "bonferroni" or "fdr", default "bonferroni"
|
52
46
|
Correction type for multivariate data. Either 'bonferroni' or 'fdr' (False
|
53
47
|
Discovery Rate).
|
54
|
-
alternative : "two-sided"
|
48
|
+
alternative : "two-sided", "less" or "greater", default "two-sided"
|
55
49
|
Defines the alternative hypothesis. Options are 'two-sided', 'less' or
|
56
50
|
'greater'.
|
57
51
|
n_features : int | None, default None
|
58
|
-
Number of features used in the
|
59
|
-
|
60
|
-
be inferred automatically but could be more expensive to compute.
|
52
|
+
Number of features used in the univariate drift tests. If not provided, it will
|
53
|
+
be inferred from the data.
|
61
54
|
|
62
55
|
Example
|
63
56
|
-------
|
64
|
-
>>> from
|
65
|
-
>>> from dataeval.detectors.drift import preprocess_drift
|
57
|
+
>>> from dataeval.data import Embeddings
|
66
58
|
|
67
|
-
Use
|
59
|
+
Use Embeddings to encode images before testing for drift
|
68
60
|
|
69
|
-
>>>
|
70
|
-
>>> drift = DriftKS(
|
61
|
+
>>> train_emb = Embeddings(train_images, model=encoder, batch_size=64)
|
62
|
+
>>> drift = DriftKS(train_emb)
|
71
63
|
|
72
64
|
Test incoming images for drift
|
73
65
|
|
@@ -77,21 +69,17 @@ class DriftKS(BaseDriftUnivariate):
|
|
77
69
|
|
78
70
|
def __init__(
|
79
71
|
self,
|
80
|
-
|
72
|
+
data: Embeddings | Array,
|
81
73
|
p_val: float = 0.05,
|
82
|
-
|
83
|
-
update_x_ref: UpdateStrategy | None = None,
|
84
|
-
preprocess_fn: Callable[[ArrayLike], ArrayLike] | None = None,
|
74
|
+
update_strategy: UpdateStrategy | None = None,
|
85
75
|
correction: Literal["bonferroni", "fdr"] = "bonferroni",
|
86
76
|
alternative: Literal["two-sided", "less", "greater"] = "two-sided",
|
87
77
|
n_features: int | None = None,
|
88
78
|
) -> None:
|
89
79
|
super().__init__(
|
90
|
-
|
80
|
+
data=data,
|
91
81
|
p_val=p_val,
|
92
|
-
|
93
|
-
update_x_ref=update_x_ref,
|
94
|
-
preprocess_fn=preprocess_fn,
|
82
|
+
update_strategy=update_strategy,
|
95
83
|
correction=correction,
|
96
84
|
n_features=n_features,
|
97
85
|
)
|
dataeval/detectors/drift/_mmd.py
CHANGED
@@ -10,16 +10,16 @@ from __future__ import annotations
|
|
10
10
|
|
11
11
|
__all__ = []
|
12
12
|
|
13
|
-
from typing import Callable
|
13
|
+
from typing import Any, Callable
|
14
14
|
|
15
15
|
import torch
|
16
16
|
|
17
17
|
from dataeval.config import DeviceLike, get_device
|
18
|
-
from dataeval.
|
19
|
-
from dataeval.detectors.drift.
|
18
|
+
from dataeval.data._embeddings import Embeddings
|
19
|
+
from dataeval.detectors.drift._base import BaseDrift, UpdateStrategy, update_strategy
|
20
20
|
from dataeval.outputs import DriftMMDOutput
|
21
21
|
from dataeval.outputs._base import set_metadata
|
22
|
-
from dataeval.typing import
|
22
|
+
from dataeval.typing import Array
|
23
23
|
|
24
24
|
|
25
25
|
class DriftMMD(BaseDrift):
|
@@ -29,29 +29,20 @@ class DriftMMD(BaseDrift):
|
|
29
29
|
|
30
30
|
Parameters
|
31
31
|
----------
|
32
|
-
|
32
|
+
data : Embeddings or Array
|
33
33
|
Data used as reference distribution.
|
34
34
|
p_val : float or None, default 0.05
|
35
35
|
:term:`P-value` used for significance of the statistical test for each feature.
|
36
36
|
If the FDR correction method is used, this corresponds to the acceptable
|
37
37
|
q-value.
|
38
|
-
|
39
|
-
Whether the given reference data ``x_ref`` has been preprocessed yet.
|
40
|
-
If ``True``, only the test data ``x`` will be preprocessed at prediction time.
|
41
|
-
If ``False``, the reference data will also be preprocessed.
|
42
|
-
update_x_ref : UpdateStrategy or None, default None
|
38
|
+
update_strategy : UpdateStrategy or None, default None
|
43
39
|
Reference data can optionally be updated using an UpdateStrategy class. Update
|
44
40
|
using the last n instances seen by the detector with LastSeenUpdateStrategy
|
45
41
|
or via reservoir sampling with ReservoirSamplingUpdateStrategy.
|
46
|
-
|
47
|
-
Function to preprocess the data before computing the data drift metrics.
|
48
|
-
Typically a :term:`dimensionality reduction<Dimensionality Reduction>` technique.
|
49
|
-
sigma : ArrayLike or None, default None
|
42
|
+
sigma : Array or None, default None
|
50
43
|
Optionally set the internal GaussianRBF kernel bandwidth. Can also pass multiple
|
51
44
|
bandwidth values as an array. The kernel evaluation is then averaged over
|
52
45
|
those bandwidths.
|
53
|
-
configure_kernel_from_x_ref : bool, default True
|
54
|
-
Whether to already configure the kernel bandwidth from the reference data.
|
55
46
|
n_permutations : int, default 100
|
56
47
|
Number of permutations used in the permutation test.
|
57
48
|
device : DeviceLike or None, default None
|
@@ -60,13 +51,12 @@ class DriftMMD(BaseDrift):
|
|
60
51
|
|
61
52
|
Example
|
62
53
|
-------
|
63
|
-
>>> from
|
64
|
-
>>> from dataeval.detectors.drift import preprocess_drift
|
54
|
+
>>> from dataeval.data import Embeddings
|
65
55
|
|
66
|
-
Use
|
56
|
+
Use Embeddings to encode images before testing for drift
|
67
57
|
|
68
|
-
>>>
|
69
|
-
>>> drift = DriftMMD(
|
58
|
+
>>> train_emb = Embeddings(train_images, model=encoder, batch_size=64)
|
59
|
+
>>> drift = DriftMMD(train_emb)
|
70
60
|
|
71
61
|
Test incoming images for drift
|
72
62
|
|
@@ -76,21 +66,14 @@ class DriftMMD(BaseDrift):
|
|
76
66
|
|
77
67
|
def __init__(
|
78
68
|
self,
|
79
|
-
|
69
|
+
data: Embeddings | Array,
|
80
70
|
p_val: float = 0.05,
|
81
|
-
|
82
|
-
|
83
|
-
preprocess_fn: Callable[..., ArrayLike] | None = None,
|
84
|
-
sigma: ArrayLike | None = None,
|
85
|
-
configure_kernel_from_x_ref: bool = True,
|
71
|
+
update_strategy: UpdateStrategy | None = None,
|
72
|
+
sigma: Array | None = None,
|
86
73
|
n_permutations: int = 100,
|
87
74
|
device: DeviceLike | None = None,
|
88
75
|
) -> None:
|
89
|
-
super().__init__(
|
90
|
-
|
91
|
-
self._infer_sigma = configure_kernel_from_x_ref
|
92
|
-
if configure_kernel_from_x_ref and sigma is not None:
|
93
|
-
self._infer_sigma = False
|
76
|
+
super().__init__(data, p_val, update_strategy)
|
94
77
|
|
95
78
|
self.n_permutations = n_permutations # nb of iterations through permutation test
|
96
79
|
|
@@ -102,23 +85,20 @@ class DriftMMD(BaseDrift):
|
|
102
85
|
self._kernel = GaussianRBF(sigma_tensor).to(self.device)
|
103
86
|
|
104
87
|
# compute kernel matrix for the reference data
|
105
|
-
if
|
106
|
-
|
107
|
-
self._k_xx = self._kernel(x, x, infer_sigma=self._infer_sigma)
|
108
|
-
self._infer_sigma = False
|
88
|
+
if isinstance(sigma_tensor, torch.Tensor):
|
89
|
+
self._k_xx = self._kernel(self.x_ref, self.x_ref)
|
109
90
|
else:
|
110
|
-
self._k_xx
|
91
|
+
self._k_xx = None
|
111
92
|
|
112
|
-
def _kernel_matrix(self, x:
|
93
|
+
def _kernel_matrix(self, x: Array, y: Array) -> torch.Tensor:
|
113
94
|
"""Compute and return full kernel matrix between arrays x and y."""
|
114
|
-
k_xy = self._kernel(x, y
|
115
|
-
k_xx = self._k_xx if self._k_xx is not None and self.
|
95
|
+
k_xy = self._kernel(x, y)
|
96
|
+
k_xx = self._k_xx if self._k_xx is not None and self.update_strategy is None else self._kernel(x, x)
|
116
97
|
k_yy = self._kernel(y, y)
|
117
98
|
kernel_mat = torch.cat([torch.cat([k_xx, k_xy], 1), torch.cat([k_xy.T, k_yy], 1)], 0)
|
118
99
|
return kernel_mat
|
119
100
|
|
120
|
-
|
121
|
-
def score(self, x: ArrayLike) -> tuple[float, float, float]:
|
101
|
+
def score(self, data: Embeddings | Array) -> tuple[float, float, float]:
|
122
102
|
"""
|
123
103
|
Compute the :term:`p-value<P-Value>` resulting from a permutation test using the maximum mean
|
124
104
|
discrepancy as a distance measure between the reference data and the data to
|
@@ -126,8 +106,8 @@ class DriftMMD(BaseDrift):
|
|
126
106
|
|
127
107
|
Parameters
|
128
108
|
----------
|
129
|
-
|
130
|
-
Batch of instances.
|
109
|
+
data : Embeddings or Array
|
110
|
+
Batch of instances to score.
|
131
111
|
|
132
112
|
Returns
|
133
113
|
-------
|
@@ -135,10 +115,9 @@ class DriftMMD(BaseDrift):
|
|
135
115
|
p-value obtained from the permutation test, MMD^2 between the reference and test set,
|
136
116
|
and MMD^2 threshold above which :term:`drift<Drift>` is flagged
|
137
117
|
"""
|
138
|
-
|
139
|
-
x_test = torch.as_tensor(x, device=self.device)
|
118
|
+
x_test = self._encode(data)
|
140
119
|
n = x_test.shape[0]
|
141
|
-
kernel_mat = self._kernel_matrix(x_ref, x_test)
|
120
|
+
kernel_mat = self._kernel_matrix(self.x_ref, x_test)
|
142
121
|
kernel_mat = kernel_mat - torch.diag(kernel_mat.diag()) # zero diagonal
|
143
122
|
mmd2 = mmd2_from_kernel_matrix(kernel_mat, n, permute=False, zero_diag=False)
|
144
123
|
mmd2_permuted = torch.tensor(
|
@@ -152,17 +131,16 @@ class DriftMMD(BaseDrift):
|
|
152
131
|
return float(p_val.item()), float(mmd2.item()), float(distance_threshold.item())
|
153
132
|
|
154
133
|
@set_metadata
|
155
|
-
@
|
156
|
-
|
157
|
-
def predict(self, x: ArrayLike) -> DriftMMDOutput:
|
134
|
+
@update_strategy
|
135
|
+
def predict(self, data: Embeddings | Array) -> DriftMMDOutput:
|
158
136
|
"""
|
159
137
|
Predict whether a batch of data has drifted from the reference data and then
|
160
138
|
updates reference data using specified strategy.
|
161
139
|
|
162
140
|
Parameters
|
163
141
|
----------
|
164
|
-
|
165
|
-
Batch of instances.
|
142
|
+
data : Embeddings or Array
|
143
|
+
Batch of instances to predict drift on.
|
166
144
|
|
167
145
|
Returns
|
168
146
|
-------
|
@@ -171,8 +149,166 @@ class DriftMMD(BaseDrift):
|
|
171
149
|
threshold and MMD metric.
|
172
150
|
"""
|
173
151
|
# compute drift scores
|
174
|
-
p_val, dist, distance_threshold = self.score(
|
152
|
+
p_val, dist, distance_threshold = self.score(data)
|
175
153
|
drift_pred = bool(p_val < self.p_val)
|
176
154
|
|
177
155
|
# populate drift dict
|
178
156
|
return DriftMMDOutput(drift_pred, self.p_val, p_val, dist, distance_threshold)
|
157
|
+
|
158
|
+
|
159
|
+
@torch.jit.script
|
160
|
+
def _squared_pairwise_distance(
|
161
|
+
x: torch.Tensor, y: torch.Tensor, a_min: float = 1e-30
|
162
|
+
) -> torch.Tensor: # pragma: no cover - torch.jit.script code is compiled and copied
|
163
|
+
"""
|
164
|
+
PyTorch pairwise squared Euclidean distance between samples x and y.
|
165
|
+
|
166
|
+
Parameters
|
167
|
+
----------
|
168
|
+
x : torch.Tensor
|
169
|
+
Batch of instances of shape [Nx, features].
|
170
|
+
y : torch.Tensor
|
171
|
+
Batch of instances of shape [Ny, features].
|
172
|
+
a_min : float
|
173
|
+
Lower bound to clip distance values.
|
174
|
+
|
175
|
+
Returns
|
176
|
+
-------
|
177
|
+
torch.Tensor
|
178
|
+
Pairwise squared Euclidean distance [Nx, Ny].
|
179
|
+
"""
|
180
|
+
x2 = x.pow(2).sum(dim=-1, keepdim=True)
|
181
|
+
y2 = y.pow(2).sum(dim=-1, keepdim=True)
|
182
|
+
dist = torch.addmm(y2.transpose(-2, -1), x, y.transpose(-2, -1), alpha=-2).add_(x2)
|
183
|
+
return dist.clamp_min_(a_min)
|
184
|
+
|
185
|
+
|
186
|
+
def sigma_median(x: torch.Tensor, y: torch.Tensor, dist: torch.Tensor) -> torch.Tensor:
|
187
|
+
"""
|
188
|
+
Bandwidth estimation using the median heuristic `Gretton2012`
|
189
|
+
|
190
|
+
Parameters
|
191
|
+
----------
|
192
|
+
x : torch.Tensor
|
193
|
+
Tensor of instances with dimension [Nx, features].
|
194
|
+
y : torch.Tensor
|
195
|
+
Tensor of instances with dimension [Ny, features].
|
196
|
+
dist : torch.Tensor
|
197
|
+
Tensor with dimensions [Nx, Ny], containing the pairwise distances
|
198
|
+
between `x` and `y`.
|
199
|
+
|
200
|
+
Returns
|
201
|
+
-------
|
202
|
+
torch.Tensor
|
203
|
+
The computed bandwidth, `sigma`.
|
204
|
+
"""
|
205
|
+
n = min(x.shape[0], y.shape[0])
|
206
|
+
n = n if (x[:n] == y[:n]).all() and x.shape == y.shape else 0
|
207
|
+
n_median = n + (torch.prod(torch.as_tensor(dist.shape)) - n) // 2 - 1
|
208
|
+
sigma = (0.5 * dist.flatten().sort().values[int(n_median)].unsqueeze(dim=-1)) ** 0.5
|
209
|
+
return sigma
|
210
|
+
|
211
|
+
|
212
|
+
class GaussianRBF(torch.nn.Module):
|
213
|
+
"""
|
214
|
+
Gaussian RBF kernel: k(x,y) = exp(-(1/(2*sigma^2)||x-y||^2).
|
215
|
+
|
216
|
+
A forward pass takes a batch of instances x [Nx, features] and
|
217
|
+
y [Ny, features] and returns the kernel matrix [Nx, Ny].
|
218
|
+
|
219
|
+
Parameters
|
220
|
+
----------
|
221
|
+
sigma : torch.Tensor | None, default None
|
222
|
+
Bandwidth used for the kernel. Needn't be specified if being inferred or
|
223
|
+
trained. Can pass multiple values to eval kernel with and then average.
|
224
|
+
init_sigma_fn : Callable | None, default None
|
225
|
+
Function used to compute the bandwidth ``sigma``. Used when ``sigma`` is to be
|
226
|
+
inferred. The function's signature should take in the tensors ``x``, ``y`` and
|
227
|
+
``dist`` and return ``sigma``. If ``None``, it is set to ``sigma_median``.
|
228
|
+
trainable : bool, default False
|
229
|
+
Whether or not to track gradients w.r.t. `sigma` to allow it to be trained.
|
230
|
+
"""
|
231
|
+
|
232
|
+
def __init__(
|
233
|
+
self,
|
234
|
+
sigma: torch.Tensor | None = None,
|
235
|
+
init_sigma_fn: Callable | None = None,
|
236
|
+
trainable: bool = False,
|
237
|
+
) -> None:
|
238
|
+
super().__init__()
|
239
|
+
init_sigma_fn = sigma_median if init_sigma_fn is None else init_sigma_fn
|
240
|
+
self.config: dict[str, Any] = {
|
241
|
+
"sigma": sigma,
|
242
|
+
"trainable": trainable,
|
243
|
+
"init_sigma_fn": init_sigma_fn,
|
244
|
+
}
|
245
|
+
if sigma is None:
|
246
|
+
self.log_sigma: torch.nn.Parameter = torch.nn.Parameter(torch.empty(1), requires_grad=trainable)
|
247
|
+
self.init_required: bool = True
|
248
|
+
else:
|
249
|
+
sigma = sigma.reshape(-1) # [Ns,]
|
250
|
+
self.log_sigma: torch.nn.Parameter = torch.nn.Parameter(sigma.log(), requires_grad=trainable)
|
251
|
+
self.init_required: bool = False
|
252
|
+
self.init_sigma_fn = init_sigma_fn
|
253
|
+
self.trainable = trainable
|
254
|
+
|
255
|
+
@property
|
256
|
+
def sigma(self) -> torch.Tensor:
|
257
|
+
return self.log_sigma.exp()
|
258
|
+
|
259
|
+
def forward(
|
260
|
+
self,
|
261
|
+
x: Array,
|
262
|
+
y: Array,
|
263
|
+
infer_sigma: bool = False,
|
264
|
+
) -> torch.Tensor:
|
265
|
+
x, y = torch.as_tensor(x), torch.as_tensor(y)
|
266
|
+
dist = _squared_pairwise_distance(x.flatten(1), y.flatten(1)) # [Nx, Ny]
|
267
|
+
|
268
|
+
if infer_sigma or self.init_required:
|
269
|
+
if self.trainable and infer_sigma:
|
270
|
+
raise ValueError("Gradients cannot be computed w.r.t. an inferred sigma value")
|
271
|
+
sigma = self.init_sigma_fn(x, y, dist)
|
272
|
+
with torch.no_grad():
|
273
|
+
self.log_sigma.copy_(sigma.log().clone())
|
274
|
+
self.init_required: bool = False
|
275
|
+
|
276
|
+
gamma = 1.0 / (2.0 * self.sigma**2) # [Ns,]
|
277
|
+
# TODO: do matrix multiplication after all?
|
278
|
+
kernel_mat = torch.exp(-torch.cat([(g * dist)[None, :, :] for g in gamma], dim=0)) # [Ns, Nx, Ny]
|
279
|
+
return kernel_mat.mean(dim=0) # [Nx, Ny]
|
280
|
+
|
281
|
+
|
282
|
+
def mmd2_from_kernel_matrix(
|
283
|
+
kernel_mat: torch.Tensor, m: int, permute: bool = False, zero_diag: bool = True
|
284
|
+
) -> torch.Tensor:
|
285
|
+
"""
|
286
|
+
Compute maximum mean discrepancy (MMD^2) between 2 samples x and y from the
|
287
|
+
full kernel matrix between the samples.
|
288
|
+
|
289
|
+
Parameters
|
290
|
+
----------
|
291
|
+
kernel_mat : torch.Tensor
|
292
|
+
Kernel matrix between samples x and y.
|
293
|
+
m : int
|
294
|
+
Number of instances in y.
|
295
|
+
permute : bool, default False
|
296
|
+
Whether to permute the row indices. Used for permutation tests.
|
297
|
+
zero_diag : bool, default True
|
298
|
+
Whether to zero out the diagonal of the kernel matrix.
|
299
|
+
|
300
|
+
Returns
|
301
|
+
-------
|
302
|
+
torch.Tensor
|
303
|
+
MMD^2 between the samples from the kernel matrix.
|
304
|
+
"""
|
305
|
+
n = kernel_mat.shape[0] - m
|
306
|
+
if zero_diag:
|
307
|
+
kernel_mat = kernel_mat - torch.diag(kernel_mat.diag())
|
308
|
+
if permute:
|
309
|
+
idx = torch.randperm(kernel_mat.shape[0])
|
310
|
+
kernel_mat = kernel_mat[idx][:, idx]
|
311
|
+
k_xx, k_yy, k_xy = kernel_mat[:-m, :-m], kernel_mat[-m:, -m:], kernel_mat[-m:, :-m]
|
312
|
+
c_xx, c_yy = 1 / (n * (n - 1)), 1 / (m * (m - 1))
|
313
|
+
mmd2 = c_xx * k_xx.sum() + c_yy * k_yy.sum() - 2.0 * k_xy.mean()
|
314
|
+
return mmd2
|
@@ -10,33 +10,32 @@ from __future__ import annotations
|
|
10
10
|
|
11
11
|
__all__ = []
|
12
12
|
|
13
|
-
from
|
14
|
-
from typing import Callable, Literal
|
13
|
+
from typing import Literal, Sequence, cast
|
15
14
|
|
16
15
|
import numpy as np
|
17
|
-
|
16
|
+
import torch
|
18
17
|
from scipy.special import softmax
|
19
18
|
from scipy.stats import entropy
|
20
19
|
|
21
|
-
from dataeval.config import get_device
|
22
|
-
from dataeval.detectors.drift._base import UpdateStrategy
|
20
|
+
from dataeval.config import DeviceLike, get_device
|
21
|
+
from dataeval.detectors.drift._base import BaseDrift, UpdateStrategy
|
23
22
|
from dataeval.detectors.drift._ks import DriftKS
|
24
|
-
from dataeval.detectors.drift._torch import preprocess_drift
|
25
23
|
from dataeval.outputs import DriftOutput
|
26
|
-
from dataeval.typing import
|
24
|
+
from dataeval.typing import Array, Transform
|
25
|
+
from dataeval.utils._array import as_numpy
|
26
|
+
from dataeval.utils.torch._internal import predict_batch
|
27
27
|
|
28
28
|
|
29
29
|
def classifier_uncertainty(
|
30
|
-
|
31
|
-
model_fn: Callable,
|
30
|
+
preds: Array,
|
32
31
|
preds_type: Literal["probs", "logits"] = "probs",
|
33
|
-
) ->
|
32
|
+
) -> torch.Tensor:
|
34
33
|
"""
|
35
34
|
Evaluate model_fn on x and transform predictions to prediction uncertainties.
|
36
35
|
|
37
36
|
Parameters
|
38
37
|
----------
|
39
|
-
x :
|
38
|
+
x : Array
|
40
39
|
Batch of instances.
|
41
40
|
model_fn : Callable
|
42
41
|
Function that evaluates a :term:`classification<Classification>` model on x in a single call (contains
|
@@ -50,23 +49,21 @@ def classifier_uncertainty(
|
|
50
49
|
NDArray
|
51
50
|
A scalar indication of uncertainty of the model on each instance in x.
|
52
51
|
"""
|
53
|
-
|
54
|
-
preds = model_fn(x)
|
55
|
-
|
52
|
+
preds_np = as_numpy(preds)
|
56
53
|
if preds_type == "probs":
|
57
|
-
if np.abs(1 - np.sum(
|
54
|
+
if np.abs(1 - np.sum(preds_np, axis=-1)).mean() > 1e-6:
|
58
55
|
raise ValueError("Probabilities across labels should sum to 1")
|
59
|
-
probs =
|
56
|
+
probs = preds_np
|
60
57
|
elif preds_type == "logits":
|
61
|
-
probs = softmax(
|
58
|
+
probs = softmax(preds_np, axis=-1)
|
62
59
|
else:
|
63
60
|
raise NotImplementedError("Only prediction types 'probs' and 'logits' supported.")
|
64
61
|
|
65
|
-
uncertainties = entropy(probs, axis=-1)
|
66
|
-
return uncertainties[:, None]
|
62
|
+
uncertainties = cast(np.ndarray, entropy(probs, axis=-1))
|
63
|
+
return torch.as_tensor(uncertainties[:, None])
|
67
64
|
|
68
65
|
|
69
|
-
class DriftUncertainty:
|
66
|
+
class DriftUncertainty(BaseDrift):
|
70
67
|
"""
|
71
68
|
Test for a change in the number of instances falling into regions on which \
|
72
69
|
the model is uncertain.
|
@@ -75,29 +72,27 @@ class DriftUncertainty:
|
|
75
72
|
|
76
73
|
Parameters
|
77
74
|
----------
|
78
|
-
|
75
|
+
data : Array
|
79
76
|
Data used as reference distribution.
|
80
77
|
model : Callable
|
81
78
|
:term:`Classification` model outputting class probabilities (or logits)
|
82
79
|
p_val : float, default 0.05
|
83
80
|
:term:`P-Value` used for the significance of the test.
|
84
|
-
|
85
|
-
Whether the given reference data ``x_ref`` has been preprocessed yet.
|
86
|
-
If ``True``, only the test data ``x`` will be preprocessed at prediction time.
|
87
|
-
If ``False``, the reference data will also be preprocessed.
|
88
|
-
update_x_ref : UpdateStrategy or None, default None
|
81
|
+
update_strategy : UpdateStrategy or None, default None
|
89
82
|
Reference data can optionally be updated using an UpdateStrategy class. Update
|
90
83
|
using the last n instances seen by the detector with LastSeenUpdateStrategy
|
91
84
|
or via reservoir sampling with ReservoirSamplingUpdateStrategy.
|
85
|
+
correction : "bonferroni" or "fdr", default "bonferroni"
|
86
|
+
Correction type for multivariate data. Either 'bonferroni' or 'fdr' (False
|
87
|
+
Discovery Rate).
|
92
88
|
preds_type : "probs" or "logits", default "probs"
|
93
89
|
Type of prediction output by the model. Options are 'probs' (in [0,1]) or
|
94
90
|
'logits' (in [-inf,inf]).
|
95
91
|
batch_size : int, default 32
|
96
92
|
Batch size used to evaluate model. Only relevant when backend has been
|
97
93
|
specified for batch prediction.
|
98
|
-
|
99
|
-
|
100
|
-
objects to a batch which can be processed by the model.
|
94
|
+
transforms : Transform, Sequence[Transform] or None, default None
|
95
|
+
Transform(s) to apply to the data.
|
101
96
|
device : DeviceLike or None, default None
|
102
97
|
Device type used. The default None tries to use the GPU and falls back on
|
103
98
|
CPU if needed. Can be specified by passing either 'cuda' or 'cpu'.
|
@@ -120,46 +115,47 @@ class DriftUncertainty:
|
|
120
115
|
|
121
116
|
def __init__(
|
122
117
|
self,
|
123
|
-
|
124
|
-
model:
|
118
|
+
data: Array,
|
119
|
+
model: torch.nn.Module,
|
125
120
|
p_val: float = 0.05,
|
126
|
-
|
127
|
-
|
121
|
+
update_strategy: UpdateStrategy | None = None,
|
122
|
+
correction: Literal["bonferroni", "fdr"] = "bonferroni",
|
128
123
|
preds_type: Literal["probs", "logits"] = "probs",
|
129
124
|
batch_size: int = 32,
|
130
|
-
|
131
|
-
device:
|
125
|
+
transforms: Transform[torch.Tensor] | Sequence[Transform[torch.Tensor]] | None = None,
|
126
|
+
device: DeviceLike | None = None,
|
132
127
|
) -> None:
|
133
|
-
|
134
|
-
|
135
|
-
|
136
|
-
|
137
|
-
batch_size=batch_size,
|
138
|
-
preprocess_batch_fn=preprocess_batch_fn,
|
139
|
-
device=get_device(device),
|
140
|
-
)
|
141
|
-
|
142
|
-
preprocess_fn = partial(
|
143
|
-
classifier_uncertainty,
|
144
|
-
model_fn=model_fn,
|
145
|
-
preds_type=preds_type,
|
146
|
-
)
|
128
|
+
self.model: torch.nn.Module = model
|
129
|
+
self.device: torch.device = get_device(device)
|
130
|
+
self.batch_size: int = batch_size
|
131
|
+
self.preds_type: Literal["probs", "logits"] = preds_type
|
147
132
|
|
133
|
+
self._transforms = (
|
134
|
+
[] if transforms is None else [transforms] if isinstance(transforms, Transform) else transforms
|
135
|
+
)
|
148
136
|
self._detector = DriftKS(
|
149
|
-
|
137
|
+
data=self._preprocess(data).cpu().numpy(),
|
150
138
|
p_val=p_val,
|
151
|
-
|
152
|
-
|
153
|
-
preprocess_fn=preprocess_fn, # type: ignore
|
139
|
+
update_strategy=update_strategy,
|
140
|
+
correction=correction,
|
154
141
|
)
|
155
142
|
|
156
|
-
def
|
143
|
+
def _transform(self, x: torch.Tensor) -> torch.Tensor:
|
144
|
+
for transform in self._transforms:
|
145
|
+
x = transform(x)
|
146
|
+
return x
|
147
|
+
|
148
|
+
def _preprocess(self, x: Array) -> torch.Tensor:
|
149
|
+
preds = predict_batch(x, self.model, self.device, self.batch_size, self._transform)
|
150
|
+
return classifier_uncertainty(preds, self.preds_type)
|
151
|
+
|
152
|
+
def predict(self, x: Array) -> DriftOutput:
|
157
153
|
"""
|
158
154
|
Predict whether a batch of data has drifted from the reference data.
|
159
155
|
|
160
156
|
Parameters
|
161
157
|
----------
|
162
|
-
x :
|
158
|
+
x : Array
|
163
159
|
Batch of instances.
|
164
160
|
|
165
161
|
Returns
|
@@ -168,4 +164,4 @@ class DriftUncertainty:
|
|
168
164
|
Dictionary containing the drift prediction, :term:`p-value<P-Value>`, and threshold
|
169
165
|
statistics.
|
170
166
|
"""
|
171
|
-
return self._detector.predict(x)
|
167
|
+
return self._detector.predict(self._preprocess(x).cpu().numpy())
|