dataeval 0.82.0__py3-none-any.whl → 0.83.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dataeval/__init__.py +7 -2
- dataeval/config.py +78 -11
- dataeval/detectors/drift/_mmd.py +9 -9
- dataeval/detectors/drift/_torch.py +7 -7
- dataeval/detectors/drift/_uncertainty.py +4 -4
- dataeval/detectors/linters/duplicates.py +3 -3
- dataeval/detectors/linters/outliers.py +3 -3
- dataeval/detectors/ood/ae.py +5 -4
- dataeval/detectors/ood/base.py +2 -2
- dataeval/detectors/ood/mixin.py +1 -1
- dataeval/detectors/ood/vae.py +2 -1
- dataeval/metadata/__init__.py +2 -2
- dataeval/metadata/_distance.py +11 -44
- dataeval/metadata/_ood.py +152 -33
- dataeval/metrics/bias/_balance.py +9 -5
- dataeval/metrics/bias/_diversity.py +3 -0
- dataeval/metrics/bias/_parity.py +2 -0
- dataeval/metrics/estimators/_ber.py +2 -1
- dataeval/metrics/stats/_base.py +20 -21
- dataeval/metrics/stats/_boxratiostats.py +1 -1
- dataeval/metrics/stats/_dimensionstats.py +2 -2
- dataeval/metrics/stats/_hashstats.py +2 -2
- dataeval/metrics/stats/_imagestats.py +8 -8
- dataeval/metrics/stats/_pixelstats.py +2 -2
- dataeval/metrics/stats/_visualstats.py +2 -2
- dataeval/outputs/__init__.py +5 -0
- dataeval/outputs/_base.py +50 -21
- dataeval/outputs/_bias.py +1 -1
- dataeval/outputs/_linters.py +4 -2
- dataeval/outputs/_metadata.py +61 -0
- dataeval/outputs/_stats.py +12 -6
- dataeval/typing.py +40 -9
- dataeval/utils/_mst.py +1 -2
- dataeval/utils/data/_embeddings.py +23 -19
- dataeval/utils/data/_metadata.py +16 -7
- dataeval/utils/data/_selection.py +22 -15
- dataeval/utils/data/_split.py +3 -2
- dataeval/utils/data/datasets/_base.py +4 -2
- dataeval/utils/data/datasets/_cifar10.py +17 -9
- dataeval/utils/data/datasets/_milco.py +18 -12
- dataeval/utils/data/datasets/_mnist.py +24 -8
- dataeval/utils/data/datasets/_ships.py +18 -8
- dataeval/utils/data/datasets/_types.py +1 -5
- dataeval/utils/data/datasets/_voc.py +47 -24
- dataeval/utils/data/selections/__init__.py +2 -0
- dataeval/utils/data/selections/_classfilter.py +5 -3
- dataeval/utils/data/selections/_prioritize.py +296 -0
- dataeval/utils/data/selections/_shuffle.py +13 -4
- dataeval/utils/torch/_gmm.py +3 -2
- dataeval/utils/torch/_internal.py +5 -5
- dataeval/utils/torch/trainer.py +8 -8
- {dataeval-0.82.0.dist-info → dataeval-0.83.0.dist-info}/METADATA +4 -4
- dataeval-0.83.0.dist-info/RECORD +105 -0
- dataeval/detectors/ood/metadata_ood_mi.py +0 -93
- dataeval-0.82.0.dist-info/RECORD +0 -104
- {dataeval-0.82.0.dist-info → dataeval-0.83.0.dist-info}/LICENSE.txt +0 -0
- {dataeval-0.82.0.dist-info → dataeval-0.83.0.dist-info}/WHEEL +0 -0
@@ -0,0 +1,296 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
__all__ = []
|
4
|
+
|
5
|
+
import logging
|
6
|
+
import warnings
|
7
|
+
from abc import ABC, abstractmethod
|
8
|
+
from typing import Any, Literal, overload
|
9
|
+
|
10
|
+
import numpy as np
|
11
|
+
import torch
|
12
|
+
from numpy.typing import NDArray
|
13
|
+
from sklearn.cluster import KMeans
|
14
|
+
from sklearn.metrics import pairwise_distances
|
15
|
+
|
16
|
+
from dataeval.config import EPSILON, DeviceLike, get_seed
|
17
|
+
from dataeval.utils.data import Embeddings, Select
|
18
|
+
from dataeval.utils.data._selection import Selection, SelectionStage
|
19
|
+
|
20
|
+
_logger = logging.getLogger(__name__)
|
21
|
+
|
22
|
+
|
23
|
+
class _Clusters:
|
24
|
+
__slots__ = ["labels", "cluster_centers", "unique_labels"]
|
25
|
+
|
26
|
+
labels: NDArray[np.intp]
|
27
|
+
cluster_centers: NDArray[np.float64]
|
28
|
+
unique_labels: NDArray[np.intp]
|
29
|
+
|
30
|
+
def __init__(self, labels: NDArray[np.intp], cluster_centers: NDArray[np.float64]) -> None:
|
31
|
+
self.labels = labels
|
32
|
+
self.cluster_centers = cluster_centers
|
33
|
+
self.unique_labels = np.unique(labels)
|
34
|
+
|
35
|
+
def _dist2center(self, X: NDArray[np.float64]) -> NDArray[np.float64]:
|
36
|
+
dist = np.zeros(self.labels.shape)
|
37
|
+
for lab in self.unique_labels:
|
38
|
+
dist[self.labels == lab] = np.linalg.norm(X[self.labels == lab, :] - self.cluster_centers[lab, :], axis=1)
|
39
|
+
return dist
|
40
|
+
|
41
|
+
def _complexity(self, X: NDArray[np.float64]) -> NDArray[np.float64]:
|
42
|
+
num_clst_intra = int(np.maximum(np.minimum(int(self.unique_labels.shape[0] / 5), 20), 1))
|
43
|
+
d_intra = np.zeros(self.unique_labels.shape)
|
44
|
+
d_inter = np.zeros(self.unique_labels.shape)
|
45
|
+
for cdx, lab in enumerate(self.unique_labels):
|
46
|
+
d_intra[cdx] = np.mean(np.linalg.norm(X[self.labels == lab, :] - self.cluster_centers[cdx, :], axis=1))
|
47
|
+
d_inter[cdx] = np.mean(
|
48
|
+
np.linalg.norm(self.cluster_centers - self.cluster_centers[cdx, :], axis=1)[:num_clst_intra]
|
49
|
+
)
|
50
|
+
cj = d_intra * d_inter
|
51
|
+
tau = 0.1
|
52
|
+
exp = np.exp(cj / tau)
|
53
|
+
prob: NDArray[np.float64] = exp / np.sum(exp)
|
54
|
+
return prob
|
55
|
+
|
56
|
+
def _sort_by_weights(self, X: NDArray[np.float64]) -> NDArray[np.intp]:
|
57
|
+
pr = self._complexity(X)
|
58
|
+
d2c = self._dist2center(X)
|
59
|
+
inds_per_clst: list[NDArray[np.intp]] = []
|
60
|
+
for lab in zip(self.unique_labels):
|
61
|
+
inds = np.nonzero(self.labels == lab)[0]
|
62
|
+
# 'hardest' first
|
63
|
+
srt_inds = np.argsort(d2c[inds])[::-1]
|
64
|
+
inds_per_clst.append(inds[srt_inds])
|
65
|
+
glob_inds: list[NDArray[np.intp]] = []
|
66
|
+
while not bool(np.all([arr.size == 0 for arr in inds_per_clst])):
|
67
|
+
clst_ind = np.random.choice(self.unique_labels, 1, p=pr)[0]
|
68
|
+
if inds_per_clst[clst_ind].size > 0:
|
69
|
+
glob_inds.append(inds_per_clst[clst_ind][0])
|
70
|
+
else:
|
71
|
+
continue
|
72
|
+
inds_per_clst[clst_ind] = inds_per_clst[clst_ind][1:]
|
73
|
+
# sorted hardest first; reverse for consistency
|
74
|
+
return np.array(glob_inds[::-1])
|
75
|
+
|
76
|
+
|
77
|
+
class _Sorter(ABC):
|
78
|
+
@abstractmethod
|
79
|
+
def _sort(self, embeddings: NDArray[Any], reference: NDArray[Any] | None = None) -> NDArray[np.intp]: ...
|
80
|
+
|
81
|
+
|
82
|
+
class _KNNSorter(_Sorter):
|
83
|
+
def __init__(self, samples: int, k: int | None) -> None:
|
84
|
+
if k is None or k <= 0:
|
85
|
+
k = int(np.sqrt(samples))
|
86
|
+
_logger._log(logging.INFO, f"Setting k to default value of {k}", {"k": k, "samples": samples})
|
87
|
+
elif k >= samples:
|
88
|
+
raise ValueError(f"k={k} should be less than dataset size ({samples})")
|
89
|
+
elif k >= samples / 10 and k > np.sqrt(samples):
|
90
|
+
warnings.warn(
|
91
|
+
f"Variable k={k} is large with respect to dataset size but valid; "
|
92
|
+
+ f"a nominal recommendation is k={int(np.sqrt(samples))}"
|
93
|
+
)
|
94
|
+
self._k = k
|
95
|
+
|
96
|
+
def _sort(self, embeddings: NDArray[Any], reference: NDArray[Any] | None = None) -> NDArray[np.intp]:
|
97
|
+
if reference is None:
|
98
|
+
dists = pairwise_distances(embeddings, embeddings)
|
99
|
+
np.fill_diagonal(dists, np.inf)
|
100
|
+
else:
|
101
|
+
dists = pairwise_distances(embeddings, reference)
|
102
|
+
inds = np.argsort(np.sort(dists, axis=1)[:, self._k])
|
103
|
+
return inds
|
104
|
+
|
105
|
+
|
106
|
+
class _KMeansSorter(_Sorter):
|
107
|
+
def __init__(self, samples: int, c: int | None, n_init: int | Literal["auto", "warn"] = "auto") -> None:
|
108
|
+
if c is None or c <= 0:
|
109
|
+
c = int(np.sqrt(samples))
|
110
|
+
_logger._log(logging.INFO, f"Setting the value of num_clusters to a default value of {c}", {})
|
111
|
+
if c >= samples:
|
112
|
+
raise ValueError(f"c={c} should be less than dataset size ({samples})")
|
113
|
+
self._c = c
|
114
|
+
self._n_init = n_init
|
115
|
+
|
116
|
+
def _get_clusters(self, embeddings: NDArray[Any]) -> _Clusters:
|
117
|
+
clst = KMeans(n_clusters=self._c, init="k-means++", n_init=self._n_init, random_state=get_seed()) # type: ignore - n_init allows int but is typed as str
|
118
|
+
clst.fit(embeddings)
|
119
|
+
if clst.labels_ is None or clst.cluster_centers_ is None:
|
120
|
+
raise ValueError("Clustering failed to produce labels or cluster centers")
|
121
|
+
return _Clusters(clst.labels_, clst.cluster_centers_)
|
122
|
+
|
123
|
+
|
124
|
+
class _KMeansDistanceSorter(_KMeansSorter):
|
125
|
+
def _sort(self, embeddings: NDArray[Any], reference: NDArray[Any] | None = None) -> NDArray[np.intp]:
|
126
|
+
clst = self._get_clusters(embeddings if reference is None else reference)
|
127
|
+
inds = np.argsort(clst._dist2center(embeddings))
|
128
|
+
return inds
|
129
|
+
|
130
|
+
|
131
|
+
class _KMeansComplexitySorter(_KMeansSorter):
|
132
|
+
def _sort(self, embeddings: NDArray[Any], reference: NDArray[Any] | None = None) -> NDArray[np.intp]:
|
133
|
+
clst = self._get_clusters(embeddings if reference is None else reference)
|
134
|
+
inds = clst._sort_by_weights(embeddings)
|
135
|
+
return inds
|
136
|
+
|
137
|
+
|
138
|
+
class Prioritize(Selection[Any]):
|
139
|
+
"""
|
140
|
+
Prioritizes the dataset by sort order in the embedding space.
|
141
|
+
|
142
|
+
Parameters
|
143
|
+
----------
|
144
|
+
model : torch.nn.Module
|
145
|
+
Model to use for encoding images
|
146
|
+
batch_size : int
|
147
|
+
Batch size to use when encoding images
|
148
|
+
device : DeviceLike or None
|
149
|
+
Device to use for encoding images
|
150
|
+
method : Literal["knn", "kmeans_distance", "kmeans_complexity"]
|
151
|
+
Method to use for prioritization
|
152
|
+
k : int | None, default None
|
153
|
+
Number of nearest neighbors to use for prioritization (knn only)
|
154
|
+
c : int | None, default None
|
155
|
+
Number of clusters to use for prioritization (kmeans only)
|
156
|
+
"""
|
157
|
+
|
158
|
+
stage = SelectionStage.ORDER
|
159
|
+
|
160
|
+
@overload
|
161
|
+
def __init__(
|
162
|
+
self,
|
163
|
+
model: torch.nn.Module,
|
164
|
+
batch_size: int,
|
165
|
+
device: DeviceLike | None,
|
166
|
+
method: Literal["knn"],
|
167
|
+
*,
|
168
|
+
k: int | None = None,
|
169
|
+
) -> None: ...
|
170
|
+
|
171
|
+
@overload
|
172
|
+
def __init__(
|
173
|
+
self,
|
174
|
+
model: torch.nn.Module,
|
175
|
+
batch_size: int,
|
176
|
+
device: DeviceLike | None,
|
177
|
+
method: Literal["kmeans_distance", "kmeans_complexity"],
|
178
|
+
*,
|
179
|
+
c: int | None = None,
|
180
|
+
) -> None: ...
|
181
|
+
|
182
|
+
def __init__(
|
183
|
+
self,
|
184
|
+
model: torch.nn.Module,
|
185
|
+
batch_size: int,
|
186
|
+
device: DeviceLike | None,
|
187
|
+
method: Literal["knn", "kmeans_distance", "kmeans_complexity"],
|
188
|
+
*,
|
189
|
+
k: int | None = None,
|
190
|
+
c: int | None = None,
|
191
|
+
) -> None:
|
192
|
+
if method not in ("knn", "kmeans_distance", "kmeans_complexity"):
|
193
|
+
raise ValueError(f"Invalid prioritization method: {method}")
|
194
|
+
self._model = model
|
195
|
+
self._batch_size = batch_size
|
196
|
+
self._device = device
|
197
|
+
self._method = method
|
198
|
+
self._embeddings: Embeddings | None = None
|
199
|
+
self._reference: Embeddings | None = None
|
200
|
+
self._k = k
|
201
|
+
self._c = c
|
202
|
+
|
203
|
+
@overload
|
204
|
+
@classmethod
|
205
|
+
def using(
|
206
|
+
cls,
|
207
|
+
method: Literal["knn"],
|
208
|
+
*,
|
209
|
+
k: int | None = None,
|
210
|
+
embeddings: Embeddings | None = None,
|
211
|
+
reference: Embeddings | None = None,
|
212
|
+
) -> Prioritize: ...
|
213
|
+
|
214
|
+
@overload
|
215
|
+
@classmethod
|
216
|
+
def using(
|
217
|
+
cls,
|
218
|
+
method: Literal["kmeans_distance", "kmeans_complexity"],
|
219
|
+
*,
|
220
|
+
c: int | None = None,
|
221
|
+
embeddings: Embeddings | None = None,
|
222
|
+
reference: Embeddings | None = None,
|
223
|
+
) -> Prioritize: ...
|
224
|
+
|
225
|
+
@classmethod
|
226
|
+
def using(
|
227
|
+
cls,
|
228
|
+
method: Literal["knn", "kmeans_distance", "kmeans_complexity"],
|
229
|
+
*,
|
230
|
+
k: int | None = None,
|
231
|
+
c: int | None = None,
|
232
|
+
embeddings: Embeddings | None = None,
|
233
|
+
reference: Embeddings | None = None,
|
234
|
+
) -> Prioritize:
|
235
|
+
"""
|
236
|
+
Prioritizes the dataset by sort order in the embedding space using existing
|
237
|
+
embeddings and/or reference dataset embeddings.
|
238
|
+
|
239
|
+
Parameters
|
240
|
+
----------
|
241
|
+
method : Literal["knn", "kmeans_distance", "kmeans_complexity"]
|
242
|
+
Method to use for prioritization
|
243
|
+
embeddings : Embeddings or None, default None
|
244
|
+
Embeddings to use for prioritization
|
245
|
+
reference : Embeddings or None, default None
|
246
|
+
Reference embeddings to prioritize relative to
|
247
|
+
k : int or None, default None
|
248
|
+
Number of nearest neighbors to use for prioritization (knn only)
|
249
|
+
c : int or None, default None
|
250
|
+
Number of clusters to use for prioritization (kmeans, cluster only)
|
251
|
+
|
252
|
+
Notes
|
253
|
+
-----
|
254
|
+
At least one of `embeddings` or `reference` must be provided.
|
255
|
+
"""
|
256
|
+
emb_params: Embeddings | None = embeddings if embeddings is not None else reference
|
257
|
+
if emb_params is None:
|
258
|
+
raise ValueError("Must provide at least embeddings or reference embeddings.")
|
259
|
+
prioritize = Prioritize(emb_params._model, emb_params.batch_size, emb_params.device, method)
|
260
|
+
prioritize._k = k
|
261
|
+
prioritize._c = c
|
262
|
+
prioritize._embeddings = embeddings
|
263
|
+
prioritize._reference = reference
|
264
|
+
return prioritize
|
265
|
+
|
266
|
+
def _get_sorter(self, samples: int) -> _Sorter:
|
267
|
+
if self._method == "knn":
|
268
|
+
return _KNNSorter(samples, self._k)
|
269
|
+
elif self._method == "kmeans_distance":
|
270
|
+
return _KMeansDistanceSorter(samples, self._c)
|
271
|
+
else: # self._method == "kmeans_complexity"
|
272
|
+
return _KMeansComplexitySorter(samples, self._c)
|
273
|
+
|
274
|
+
def _to_normalized_ndarray(self, embeddings: Embeddings, selection: list[int] | None = None) -> NDArray[Any]:
|
275
|
+
emb: NDArray[Any] = embeddings.to_tensor(selection).cpu().numpy()
|
276
|
+
emb /= max(np.max(np.linalg.norm(emb, axis=1)), EPSILON)
|
277
|
+
return emb
|
278
|
+
|
279
|
+
def __call__(self, dataset: Select[Any]) -> None:
|
280
|
+
# Initialize sorter
|
281
|
+
self._sorter = self._get_sorter(len(dataset._selection))
|
282
|
+
# Extract and normalize embeddings
|
283
|
+
embeddings = (
|
284
|
+
Embeddings(dataset, batch_size=self._batch_size, model=self._model, device=self._device)
|
285
|
+
if self._embeddings is None
|
286
|
+
else self._embeddings
|
287
|
+
)
|
288
|
+
if len(dataset._selection) != len(embeddings):
|
289
|
+
raise ValueError(
|
290
|
+
"Size of embeddings do not match the size of the selection: "
|
291
|
+
+ f"embeddings={len(embeddings)}, selection={len(dataset._selection)}"
|
292
|
+
)
|
293
|
+
emb = self._to_normalized_ndarray(embeddings, dataset._selection)
|
294
|
+
ref = None if self._reference is None else self._to_normalized_ndarray(self._reference)
|
295
|
+
# Sort indices
|
296
|
+
dataset._selection = self._sorter._sort(emb, ref).tolist()
|
@@ -2,10 +2,14 @@ from __future__ import annotations
|
|
2
2
|
|
3
3
|
__all__ = []
|
4
4
|
|
5
|
-
from typing import Any
|
5
|
+
from typing import Any, Sequence
|
6
6
|
|
7
7
|
import numpy as np
|
8
|
+
from numpy.random import BitGenerator, Generator, SeedSequence
|
9
|
+
from numpy.typing import NDArray
|
8
10
|
|
11
|
+
from dataeval.typing import Array, ArrayLike
|
12
|
+
from dataeval.utils._array import as_numpy
|
9
13
|
from dataeval.utils.data._selection import Select, Selection, SelectionStage
|
10
14
|
|
11
15
|
|
@@ -15,14 +19,19 @@ class Shuffle(Selection[Any]):
|
|
15
19
|
|
16
20
|
Parameters
|
17
21
|
----------
|
18
|
-
seed
|
22
|
+
seed : int, ArrayLike, SeedSequence, BitGenerator, Generator or None, default None
|
19
23
|
Seed for the random number generator.
|
24
|
+
|
25
|
+
See Also
|
26
|
+
--------
|
27
|
+
`NumPy Random Generator <https://numpy.org/doc/stable/reference/random/generator.html>`_
|
20
28
|
"""
|
21
29
|
|
30
|
+
seed: int | NDArray[Any] | SeedSequence | BitGenerator | Generator | None
|
22
31
|
stage = SelectionStage.ORDER
|
23
32
|
|
24
|
-
def __init__(self, seed: int):
|
25
|
-
self.seed = seed
|
33
|
+
def __init__(self, seed: int | ArrayLike | SeedSequence | BitGenerator | Generator | None = None):
|
34
|
+
self.seed = as_numpy(seed) if isinstance(seed, (Sequence, Array)) else seed
|
26
35
|
|
27
36
|
def __call__(self, dataset: Select[Any]) -> None:
|
28
37
|
rng = np.random.default_rng(self.seed)
|
dataeval/utils/torch/_gmm.py
CHANGED
@@ -16,6 +16,8 @@ from typing import TypeVar
|
|
16
16
|
import numpy as np
|
17
17
|
import torch
|
18
18
|
|
19
|
+
from dataeval.config import EPSILON
|
20
|
+
|
19
21
|
TGMMData = TypeVar("TGMMData")
|
20
22
|
|
21
23
|
|
@@ -74,8 +76,7 @@ def gmm_params(z: torch.Tensor, gamma: torch.Tensor) -> GaussianMixtureModelPara
|
|
74
76
|
|
75
77
|
# cholesky decomposition of covariance and determinant derivation
|
76
78
|
D = cov.shape[1]
|
77
|
-
|
78
|
-
L = torch.linalg.cholesky(cov + torch.eye(D) * eps) # K x D x D
|
79
|
+
L = torch.linalg.cholesky(cov + torch.eye(D) * EPSILON) # K x D x D
|
79
80
|
log_det_cov = 2.0 * torch.sum(torch.log(torch.diagonal(L, dim1=-2, dim2=-1)), 1) # K
|
80
81
|
|
81
82
|
return GaussianMixtureModelParams(phi, mu, cov, L, log_det_cov)
|
@@ -11,13 +11,13 @@ from numpy.typing import NDArray
|
|
11
11
|
from torch.utils.data import DataLoader, TensorDataset
|
12
12
|
from tqdm import tqdm
|
13
13
|
|
14
|
-
from dataeval.config import get_device
|
14
|
+
from dataeval.config import DeviceLike, get_device
|
15
15
|
|
16
16
|
|
17
17
|
def predict_batch(
|
18
18
|
x: NDArray[Any] | torch.Tensor,
|
19
19
|
model: Callable | torch.nn.Module | torch.nn.Sequential,
|
20
|
-
device:
|
20
|
+
device: DeviceLike | None = None,
|
21
21
|
batch_size: int = int(1e10),
|
22
22
|
preprocess_fn: Callable[[torch.Tensor], torch.Tensor] | None = None,
|
23
23
|
dtype: type[np.generic] | torch.dtype = np.float32,
|
@@ -31,9 +31,9 @@ def predict_batch(
|
|
31
31
|
Batch of instances.
|
32
32
|
model : Callable | nn.Module | nn.Sequential
|
33
33
|
PyTorch model.
|
34
|
-
device :
|
35
|
-
|
36
|
-
|
34
|
+
device : DeviceLike or None, default None
|
35
|
+
The hardware device to use if specified, otherwise uses the DataEval
|
36
|
+
default or torch default.
|
37
37
|
batch_size : int, default 1e10
|
38
38
|
Batch size used during prediction.
|
39
39
|
preprocess_fn : Callable | None, default None
|
dataeval/utils/torch/trainer.py
CHANGED
@@ -2,6 +2,8 @@
|
|
2
2
|
|
3
3
|
from __future__ import annotations
|
4
4
|
|
5
|
+
from dataeval.config import DeviceLike, get_device
|
6
|
+
|
5
7
|
__all__ = ["AETrainer"]
|
6
8
|
|
7
9
|
from typing import Any
|
@@ -25,9 +27,9 @@ class AETrainer:
|
|
25
27
|
----------
|
26
28
|
model : nn.Module
|
27
29
|
The model to be trained.
|
28
|
-
device :
|
29
|
-
The hardware device to use
|
30
|
-
|
30
|
+
device : DeviceLike or None, default None
|
31
|
+
The hardware device to use if specified, otherwise uses the DataEval
|
32
|
+
default or torch default.
|
31
33
|
batch_size : int, default 8
|
32
34
|
The number of images to process in a batch.
|
33
35
|
"""
|
@@ -35,13 +37,11 @@ class AETrainer:
|
|
35
37
|
def __init__(
|
36
38
|
self,
|
37
39
|
model: nn.Module,
|
38
|
-
device:
|
40
|
+
device: DeviceLike | None = None,
|
39
41
|
batch_size: int = 8,
|
40
42
|
):
|
41
|
-
|
42
|
-
|
43
|
-
self.device: torch.device = torch.device(device)
|
44
|
-
self.model: nn.Module = model.to(device)
|
43
|
+
self.device: torch.device = get_device(device)
|
44
|
+
self.model: nn.Module = model.to(self.device)
|
45
45
|
self.batch_size = batch_size
|
46
46
|
|
47
47
|
def train(self, dataset: Dataset[Any], epochs: int = 25) -> list[float]:
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: dataeval
|
3
|
-
Version: 0.
|
3
|
+
Version: 0.83.0
|
4
4
|
Summary: DataEval provides a simple interface to characterize image data and its impact on model performance across classification and object-detection tasks
|
5
5
|
Home-page: https://dataeval.ai/
|
6
6
|
License: MIT
|
@@ -50,9 +50,9 @@ and reference material, please visit our documentation on
|
|
50
50
|
|
51
51
|
<!-- start tagline -->
|
52
52
|
|
53
|
-
DataEval
|
54
|
-
reliable AI models and
|
55
|
-
deployed models.
|
53
|
+
DataEval analyzes datasets and models to give users the ability to train and
|
54
|
+
test performant, unbiased, and reliable AI models and monitor data for
|
55
|
+
impactful shifts to deployed models.
|
56
56
|
|
57
57
|
<!-- end tagline -->
|
58
58
|
|
@@ -0,0 +1,105 @@
|
|
1
|
+
dataeval/__init__.py,sha256=uL-JSd_dKVJpGx4H8f6aOiQVpli46zeTLFqjb4Pa69c,1636
|
2
|
+
dataeval/_log.py,sha256=Mn5bRWO0cgtAYd5VGYSFiPgu57ta3zoktrtHAZ1m3dU,357
|
3
|
+
dataeval/config.py,sha256=oQ0XQsgIF4_z4n1j0Di6B-JCRUFzzPgJgpQUm3ZlYhs,3539
|
4
|
+
dataeval/detectors/__init__.py,sha256=3Sg-XWlwr75zEEH3hZKA4nWMtGvaRlnfzTWvZG_Ak6U,189
|
5
|
+
dataeval/detectors/drift/__init__.py,sha256=6is_XBtG1d-vUbhHvqXGOdnAwxJ7NA5yRfURn7pCeIw,651
|
6
|
+
dataeval/detectors/drift/_base.py,sha256=mJdKvyROgWvz-p1VlAIJqUI6BAj9ss8riUvR5An5wIw,13459
|
7
|
+
dataeval/detectors/drift/_cvm.py,sha256=H2w-I0eMD7yP-CSmpdodeJ0-TYznJT7w_H7JuobESow,3859
|
8
|
+
dataeval/detectors/drift/_ks.py,sha256=-5k3RBPA3kadX7oD14Wc52rAqQf1udwFeW7Qf3Sv4Tw,4058
|
9
|
+
dataeval/detectors/drift/_mmd.py,sha256=NEXowx9UHIvmEKS8sqssw6PMLJMh0BZPhRNX1hYlkz4,7239
|
10
|
+
dataeval/detectors/drift/_torch.py,sha256=VrFCyTaRrUslFPy_mYZ4UL70LZ8faH4eHwLurZ9qqNE,7628
|
11
|
+
dataeval/detectors/drift/_uncertainty.py,sha256=O5h6_bJbeQEE660SLLP8k-EHqImmKegIcxzcnUKI7X4,5714
|
12
|
+
dataeval/detectors/drift/updates.py,sha256=Btu2iaZW7fbO59G1w5v3ykFot0YPzy2U6VjF0d440VE,2195
|
13
|
+
dataeval/detectors/linters/__init__.py,sha256=xn2zPwUcmsuf-Jd9uw6AVI11C9z1b1Y9fYtuFnXenZ0,404
|
14
|
+
dataeval/detectors/linters/duplicates.py,sha256=tcxniL8rRZkDdQqfuS502UmfKxS3a7iRA22Dtt_vQIk,4935
|
15
|
+
dataeval/detectors/linters/outliers.py,sha256=Hln2dPQZjF_uV2QYptA_o6ZF3ugyCImVT-XLDB2-q3A,9042
|
16
|
+
dataeval/detectors/ood/__init__.py,sha256=juCYBDs7CQEAtMhnEpPqF6uTrOIH9kTBSuQ_GRw6a8o,283
|
17
|
+
dataeval/detectors/ood/ae.py,sha256=YQfhB1ShQLjM1V4uCz9Oo2tCZpOfAZ_-SBCAl4Ac67Y,2921
|
18
|
+
dataeval/detectors/ood/base.py,sha256=9b-Ljznf0lB1SXF4F_Aj3eJ4Y3ijGEDPMjucUsWOGJM,3051
|
19
|
+
dataeval/detectors/ood/mixin.py,sha256=0_o-1HPvgf3-Lf1MSOIfjj5UB8LTLEBGYtJJfyCCzwc,5431
|
20
|
+
dataeval/detectors/ood/vae.py,sha256=Fcq0-WbLhzYCgYOAJPBklHm7yuXmFJuEpBkhgwM5kiA,2291
|
21
|
+
dataeval/metadata/__init__.py,sha256=XDDmJbOZBNM6pL0r6Nbu6oMRoyAh22IDkPYGndNlkZU,316
|
22
|
+
dataeval/metadata/_distance.py,sha256=xsXMMg1pJkHcEZ-KIlqv9YOGYVID3ELjt3-fr1QVnOs,4082
|
23
|
+
dataeval/metadata/_ood.py,sha256=HbS5MusWl62hjixUAd-xaaT0KXkYY1M-MlnUaAI_-8M,12751
|
24
|
+
dataeval/metadata/_utils.py,sha256=r8qBJT83RblobD5W5zyTVi6vYi51Dwkqswizdbzss-M,1169
|
25
|
+
dataeval/metrics/__init__.py,sha256=8VC8q3HuJN3o_WN51Ae2_wXznl3RMXIvA5GYVcy7vr8,225
|
26
|
+
dataeval/metrics/bias/__init__.py,sha256=1yTLmgiu1kwT_7ZWcjOUbj8R0NJ0DjGoCuWdA0_T7kc,683
|
27
|
+
dataeval/metrics/bias/_balance.py,sha256=UnUgbPk2ybFfS5qxv8e_uim7RxamWj0UQP71x3omGs0,6158
|
28
|
+
dataeval/metrics/bias/_coverage.py,sha256=PeUoOiaghUEdn6Ov8z2-am7-fnBVIPcFbJK7Ty5JObA,3647
|
29
|
+
dataeval/metrics/bias/_diversity.py,sha256=U_l4oYjH39rON2Io0BdCIwJxxob0cKTW8bZNufG0CWs,5820
|
30
|
+
dataeval/metrics/bias/_parity.py,sha256=8JRZv4wLpxN9zTvMDlcpKgz-2nO-9eVjqccODcf2nbw,11535
|
31
|
+
dataeval/metrics/estimators/__init__.py,sha256=Pnds8uIyAovt2fKqZjiHCIP_kVoBWlVllekYuK5UmmU,568
|
32
|
+
dataeval/metrics/estimators/_ber.py,sha256=C30E5LiGGTAfo31zWFYDptDg0R7CTJGJ-a60YgzSkYY,5382
|
33
|
+
dataeval/metrics/estimators/_clusterer.py,sha256=1HrpihGTJ63IkNSOy4Ibw633Gllkm1RxKmoKT5MOgt0,1434
|
34
|
+
dataeval/metrics/estimators/_divergence.py,sha256=QDWl1lyAYoO9D3Ho7qOHSk6ud8Gi2MGuXEsYwO1HxvA,4043
|
35
|
+
dataeval/metrics/estimators/_uap.py,sha256=BULEBbJ9BQ1IcTeZf0x7iI60QHAWCccBOM97FIu9VXA,1928
|
36
|
+
dataeval/metrics/stats/__init__.py,sha256=6tA_9nbbM5ObJ6cds8Y1VBtTQiTOxrpGQSFLu_lWGGA,1098
|
37
|
+
dataeval/metrics/stats/_base.py,sha256=rn0CrRCvVh3QLDEi_JlOFVUoQ-xtclnOoHt_o1E26J4,10656
|
38
|
+
dataeval/metrics/stats/_boxratiostats.py,sha256=8Kd2FTZ5PLNYZfdAjU_R385gb0Z16JY0L9H_d5ZhgQs,6341
|
39
|
+
dataeval/metrics/stats/_dimensionstats.py,sha256=h2wCLn4UuW7-GV6tM5E1SqSeGa_-4ie9oaEXpSC7EKI,2690
|
40
|
+
dataeval/metrics/stats/_hashstats.py,sha256=yD6cXMvOo10-xtwUr7ftBRbCqMhReNfQJMInEWV_8Mk,4757
|
41
|
+
dataeval/metrics/stats/_imagestats.py,sha256=hyjijPXAfUIJ1lwWiIyYK9VSLiq7Vg2-YhJ5Q8s1rkY,2979
|
42
|
+
dataeval/metrics/stats/_labelstats.py,sha256=PtGyqj4RHw0cyLAWAR9FzZGqgA81AtxLGHZiuMAL2h0,4100
|
43
|
+
dataeval/metrics/stats/_pixelstats.py,sha256=Q0-ldG-znDYBP_qTqm6S4qYm0ZV5FTTHf8MlyGHSYEc,3235
|
44
|
+
dataeval/metrics/stats/_visualstats.py,sha256=ZxBDTerZ8ixibY2pGl7mwwcIz3DWl-k_Jb4YwBjHLNw,3686
|
45
|
+
dataeval/outputs/__init__.py,sha256=uxTAr1Kn0QNwC7zn1U_5WBAgwZxupM3JGgD25DyO6yI,1655
|
46
|
+
dataeval/outputs/_base.py,sha256=aZFbgybnZSQ3ws7QYRLTbDFqUfBFRVtIwX2LZfeGFUA,5703
|
47
|
+
dataeval/outputs/_bias.py,sha256=O5RHbTUJDwkwJfz2-YoOfRb4eDl5Tg1UFVtvs025wfA,12173
|
48
|
+
dataeval/outputs/_drift.py,sha256=gOiu2C-ERTWiRqlP0auMYxPBGdm9HecWPqWfg7I4tZg,2015
|
49
|
+
dataeval/outputs/_estimators.py,sha256=a2oAIxxEDZ9WLGfMWH8KD-BVUS_SnULRPR-iI9hFPoQ,3047
|
50
|
+
dataeval/outputs/_linters.py,sha256=YOdjrfm8ypdRrqYOaPM9nc6wVJI3-ita3Haj7LHDNaw,6416
|
51
|
+
dataeval/outputs/_metadata.py,sha256=ffZgpX8KWURPHXpOWjbvJ2KRqWQkS2nWuIjKUzoHhMI,1710
|
52
|
+
dataeval/outputs/_ood.py,sha256=suLKVXULGtXH0rq9eXHI1d3d2jhGmItJtz4QiQd47A4,1718
|
53
|
+
dataeval/outputs/_stats.py,sha256=PhRdyWWZxewzenFx0MxK9y9ZLE2MnMA-a4-JeSJ_Bs8,13180
|
54
|
+
dataeval/outputs/_utils.py,sha256=HHlGC7sk416m_3Bgn075Qdblz_aPup_UOafJpB0RuXY,893
|
55
|
+
dataeval/outputs/_workflows.py,sha256=MkRD6ubI4NCBXb9v3kjXy64cUGs3G-JKkBdOpRD9XVE,10750
|
56
|
+
dataeval/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
57
|
+
dataeval/typing.py,sha256=YQ1KteeK1zf2mcWwngWwQP8EC3pI4WsvAzp_x179b4g,6568
|
58
|
+
dataeval/utils/__init__.py,sha256=T8F8zJh4ZAeu0wDzfpld92I2zJg9mWBmkGCHrDPU7gk,264
|
59
|
+
dataeval/utils/_array.py,sha256=fc04sYShIdsRS4qtG1UCnlGGk-yVRxlOHTNAmW7NpDY,4990
|
60
|
+
dataeval/utils/_bin.py,sha256=nylthmsC3vzLHLhlUMACvZs--h7xvAh9Pt75InaQJW8,7322
|
61
|
+
dataeval/utils/_clusterer.py,sha256=fw5x-2QN0TIbiodDKHZxRgxKHINedpPcOklzce0Rbjg,5436
|
62
|
+
dataeval/utils/_fast_mst.py,sha256=4_7ykVihCL5jWtxcGnrecIsDQo65kUml9SZ1JxgBZYY,7172
|
63
|
+
dataeval/utils/_image.py,sha256=capzF_X5H0jy0PmTP3Hf52GFgLqrnfU6gS4tiwck9jo,1939
|
64
|
+
dataeval/utils/_method.py,sha256=9B9JQbgqWJBRhQJb7glajUtWaQzUTIUuvrZ9_bisxsM,394
|
65
|
+
dataeval/utils/_mst.py,sha256=f0vXytTUjlOS6AyL7c6PkXmaHuuGUK-vMLpq-5xMgxk,2183
|
66
|
+
dataeval/utils/_plot.py,sha256=mTRQNbJsA42QMiOwZbJaH8sNYgP996QFDEGVVE9HSgY,7076
|
67
|
+
dataeval/utils/data/__init__.py,sha256=vldQ2ZXl8gnI3s4vAGqUUVi6dc_R58F3JMSpbCOyFRI,820
|
68
|
+
dataeval/utils/data/_dataset.py,sha256=tjZUJnxj9IY71GKqdKltrwufkn0EC0S3a6ylrW5Bc2s,7756
|
69
|
+
dataeval/utils/data/_embeddings.py,sha256=fKGFJXhb4ajnBE3jrKxIvBAhBQ6HpcYYkpO_sAk3jTE,3669
|
70
|
+
dataeval/utils/data/_images.py,sha256=pv_vvpH8hWxPgLvjeVC2mZiyZivZFNLARNIOXam5ceY,1984
|
71
|
+
dataeval/utils/data/_metadata.py,sha256=VqeePp7NtoFFWzmIhH4fn-cjrnATpgzgzs-d73cnBXM,14370
|
72
|
+
dataeval/utils/data/_selection.py,sha256=nlslafwAfoZ5d5K_v9bIIvij-UP0NcalKqH4Nw7A-S4,4553
|
73
|
+
dataeval/utils/data/_split.py,sha256=YdsqTRjKbdSfg8w0f4XgX7j0uOSdtfzvvyObAzyqgI0,18433
|
74
|
+
dataeval/utils/data/_targets.py,sha256=ws5d9wRiDkIuOV7GSAKNxzgSm6AWTgb0BFroQK5nAmM,3057
|
75
|
+
dataeval/utils/data/collate.py,sha256=Z5nmBnWV_IoJzMp_tj8RCKjMJA9sSCY_zZITqISGixc,3865
|
76
|
+
dataeval/utils/data/datasets/__init__.py,sha256=jBrswiERrvBx4pJQJZIq_B5UE-Wy8a2_SBfM2crG8R8,511
|
77
|
+
dataeval/utils/data/datasets/_base.py,sha256=CZ-hb-yWPLdnTQ3pURJMcityQ42ZNYj_Lbb1P5Junn4,8793
|
78
|
+
dataeval/utils/data/datasets/_cifar10.py,sha256=I6HKksE2escos1aTdiZJObtiVXChBlez5BDa0eBfJ_Y,5449
|
79
|
+
dataeval/utils/data/datasets/_fileio.py,sha256=SixIk5nIlIwJdX9zjNXS10vHA3hL8aaYbqHsDg1xSpY,6447
|
80
|
+
dataeval/utils/data/datasets/_milco.py,sha256=ScBe7Ux-J9Kxs33jeKffhWKeSb8GCrWznTyEUt95Vt4,6369
|
81
|
+
dataeval/utils/data/datasets/_mixin.py,sha256=FJgZP_cpJkgAHA3j3ai_j3Wt7aFSEjIMVmt9NpvVXzg,1757
|
82
|
+
dataeval/utils/data/datasets/_mnist.py,sha256=iWWI9mq6TbZm7eTL9btzqjCNMhgXrLHQeMKENr7USsk,7988
|
83
|
+
dataeval/utils/data/datasets/_ships.py,sha256=p3fScYLW2f1wUEPOroCX5nOFti0vMOSjeYltj6ox53U,4777
|
84
|
+
dataeval/utils/data/datasets/_types.py,sha256=iSKyHXRlGuomXs0FHK6md8lXLQrQQ4fxgVOwr4o81bo,1089
|
85
|
+
dataeval/utils/data/datasets/_voc.py,sha256=4poEer_G_mUBcz6eAro0Tc29CjdgjEAlms0Eu0tLBzE,14842
|
86
|
+
dataeval/utils/data/selections/__init__.py,sha256=k86OpqGPkjT1MrOir5fOZ3AIq5UR81Az9ek7l1-GdIM,565
|
87
|
+
dataeval/utils/data/selections/_classfilter.py,sha256=opSF8CGv4x1hUMe-GTQOu3UwJK80DzT0nJOV0l2uaW4,2404
|
88
|
+
dataeval/utils/data/selections/_indices.py,sha256=QdLgXN7GABCvGPYe28PV1RAc_RSP_nZOyCvEpKRBdWg,636
|
89
|
+
dataeval/utils/data/selections/_limit.py,sha256=ECvHRsp7OF4LZw2tE4sGqqJ085kjC-hd2c7QDMfvXr8,518
|
90
|
+
dataeval/utils/data/selections/_prioritize.py,sha256=EAA4_uFVV7MmemhhufGmP7eunnbtyTc-TzgcnvRK5OE,11333
|
91
|
+
dataeval/utils/data/selections/_reverse.py,sha256=6SWpELC9Wgx-kPqzhDrPNn4NKU6FqDJveLrxV4D2Ypk,374
|
92
|
+
dataeval/utils/data/selections/_shuffle.py,sha256=kY3xJvVbBArdrJu_u6mXmxk1HdNmmDE4w7MmxbevUmU,1178
|
93
|
+
dataeval/utils/metadata.py,sha256=X8Hu4LdCzAaE9uk1hI4BflmFve_VOQCqK9lXq0sk9ow,14196
|
94
|
+
dataeval/utils/torch/__init__.py,sha256=dn5mjCrFp0b1aL_UEURhONU0Ag0cmXoTOBSGagpkTiA,325
|
95
|
+
dataeval/utils/torch/_blocks.py,sha256=HVhBTMMD5NA4qheMUgyol1KWiKZDIuc8k5j4RcMKmhk,1466
|
96
|
+
dataeval/utils/torch/_gmm.py,sha256=XM68GNEP97EjaB1U49-ZXRb81d0CEFnPS910alrcB3g,3740
|
97
|
+
dataeval/utils/torch/_internal.py,sha256=23DCnF7C7N3tZgZUpT2nyH7mMb8Pi4GcnQyjK0BKHpg,5735
|
98
|
+
dataeval/utils/torch/models.py,sha256=hmroEs6C6jQ5tAoZa71RFeIvXLxfXrTJSFH_jG2LGQU,9749
|
99
|
+
dataeval/utils/torch/trainer.py,sha256=iUotX4OdirH8-ZtjdpU8gbJavkYW9YY9qpA2mAlFy1Y,5520
|
100
|
+
dataeval/workflows/__init__.py,sha256=ou8y0KO-d6W5lgmcyLjKlf-J_ckP3vilW7wHkgiDlZ4,255
|
101
|
+
dataeval/workflows/sufficiency.py,sha256=mjKmfRrAjShLUFIARv5o8yT5fnFvDsS5Qu6ujIPUgQg,8497
|
102
|
+
dataeval-0.83.0.dist-info/LICENSE.txt,sha256=uAooygKWvX6NbU9Ran9oG2msttoG8aeTeHSTe5JeCnY,1061
|
103
|
+
dataeval-0.83.0.dist-info/METADATA,sha256=lVRLNQcl2DYQDo7GHpFv_z133aD5hn-uOCkXXltGK5s,5320
|
104
|
+
dataeval-0.83.0.dist-info/WHEEL,sha256=Nq82e9rUAnEjt98J6MlVmMCZb-t9cYE2Ir1kpBmnWfs,88
|
105
|
+
dataeval-0.83.0.dist-info/RECORD,,
|
@@ -1,93 +0,0 @@
|
|
1
|
-
from __future__ import annotations
|
2
|
-
|
3
|
-
__all__ = []
|
4
|
-
|
5
|
-
import numbers
|
6
|
-
import warnings
|
7
|
-
from typing import Any
|
8
|
-
|
9
|
-
import numpy as np
|
10
|
-
from numpy.typing import NDArray
|
11
|
-
from sklearn.feature_selection import mutual_info_classif
|
12
|
-
|
13
|
-
# NATS2BITS is the reciprocal of natural log of 2. If you have an information/entropy-type quantity measured in nats,
|
14
|
-
# which is what many library functions return, multiply it by NATS2BITS to get it in bits.
|
15
|
-
NATS2BITS = 1.442695
|
16
|
-
|
17
|
-
|
18
|
-
def get_metadata_ood_mi(
|
19
|
-
metadata: dict[str, list[Any] | NDArray[Any]],
|
20
|
-
is_ood: NDArray[np.bool_],
|
21
|
-
discrete_features: str | bool | NDArray[np.bool_] = False,
|
22
|
-
random_state: int | None = None,
|
23
|
-
) -> dict[str, float]:
|
24
|
-
"""Computes mutual information between a set of metadata features and an out-of-distribution flag.
|
25
|
-
|
26
|
-
Given a metadata dictionary `metadata` (where each key maps to one scalar metadata feature per example), and a
|
27
|
-
corresponding boolean flag `is_ood` indicating whether each example falls out-of-distribution (OOD) relative to a
|
28
|
-
reference dataset, this function finds the strength of association between each metadata feature and `is_ood` by
|
29
|
-
computing their mutual information. Metadata features may be either discrete or continuous; set the
|
30
|
-
`discrete_features` keyword to a bool array set to True for each feature that is discrete, or pass one bool to apply
|
31
|
-
to all features. Returns a dict indicating the strength of association between each individual feature and the OOD
|
32
|
-
flag, measured in bits.
|
33
|
-
|
34
|
-
Parameters
|
35
|
-
----------
|
36
|
-
metadata : dict[str, list[Any] | NDArray[Any]]
|
37
|
-
A set of arrays of values, indexed by metadata feature names, with one value per data example per feature.
|
38
|
-
is_ood : NDArray[np.bool_]
|
39
|
-
A boolean array, with one value per example, that indicates which examples are OOD.
|
40
|
-
discrete_features : str | bool | NDArray[np.bool_]
|
41
|
-
Either a boolean array or a single boolean value, indicate which features take on discrete values.
|
42
|
-
random_state : int, optional - default None
|
43
|
-
Determines random number generation for small noise added to continuous variables. Set to a value for
|
44
|
-
reproducible results.
|
45
|
-
|
46
|
-
Returns
|
47
|
-
-------
|
48
|
-
dict[str, float]
|
49
|
-
A dictionary with keys corresponding to metadata feature names, and values indicating the strength of
|
50
|
-
association between each named feature and the OOD flag, as mutual information measured in bits.
|
51
|
-
|
52
|
-
Examples
|
53
|
-
--------
|
54
|
-
Imagine we have 3 data examples, and that the corresponding metadata contains 2 features called time and altitude.
|
55
|
-
|
56
|
-
>>> metadata = {"time": np.linspace(0, 10, 100), "altitude": np.linspace(0, 16, 100) ** 2}
|
57
|
-
>>> is_ood = metadata["altitude"] > 100
|
58
|
-
>>> get_metadata_ood_mi(metadata, is_ood, discrete_features=False, random_state=0)
|
59
|
-
{'time': 0.9359596758173668, 'altitude': 0.9407686591507002}
|
60
|
-
"""
|
61
|
-
numerical_keys = [k for k, v in metadata.items() if all(isinstance(vi, numbers.Number) for vi in v)]
|
62
|
-
if len(numerical_keys) < len(metadata):
|
63
|
-
warnings.warn(
|
64
|
-
f"Processing {numerical_keys}, others are non-numerical and will be skipped.",
|
65
|
-
UserWarning,
|
66
|
-
)
|
67
|
-
|
68
|
-
md_lengths = {len(np.atleast_1d(v)) for v in metadata.values()}
|
69
|
-
if len(md_lengths) > 1:
|
70
|
-
raise ValueError(f"Metadata features have differing sizes: {md_lengths}")
|
71
|
-
|
72
|
-
if len(is_ood) != (mdl := md_lengths.pop()):
|
73
|
-
raise ValueError(
|
74
|
-
f"OOD flag and metadata features need to be same size, but are different sizes: {len(is_ood)} and {mdl}."
|
75
|
-
)
|
76
|
-
|
77
|
-
X = np.array([metadata[k] for k in numerical_keys]).T
|
78
|
-
|
79
|
-
X0, dX = np.mean(X, axis=0), np.std(X, axis=0, ddof=1)
|
80
|
-
Xscl = (X - X0) / dX
|
81
|
-
|
82
|
-
mutual_info_values = (
|
83
|
-
mutual_info_classif(
|
84
|
-
Xscl,
|
85
|
-
is_ood,
|
86
|
-
discrete_features=discrete_features, # type: ignore
|
87
|
-
random_state=random_state,
|
88
|
-
)
|
89
|
-
* NATS2BITS
|
90
|
-
)
|
91
|
-
|
92
|
-
mi_dict = {k: mutual_info_values[i] for i, k in enumerate(numerical_keys)}
|
93
|
-
return mi_dict
|