dataeval 0.70.0__py3-none-any.whl → 0.70.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dataeval/__init__.py +6 -6
- dataeval/_internal/datasets.py +235 -131
- dataeval/_internal/detectors/clusterer.py +2 -0
- dataeval/_internal/detectors/drift/base.py +2 -2
- dataeval/_internal/detectors/drift/mmd.py +1 -1
- dataeval/_internal/detectors/duplicates.py +2 -0
- dataeval/_internal/detectors/ood/ae.py +5 -3
- dataeval/_internal/detectors/ood/aegmm.py +6 -4
- dataeval/_internal/detectors/ood/base.py +12 -7
- dataeval/_internal/detectors/ood/llr.py +6 -4
- dataeval/_internal/detectors/ood/vae.py +5 -3
- dataeval/_internal/detectors/ood/vaegmm.py +6 -4
- dataeval/_internal/detectors/outliers.py +4 -2
- dataeval/_internal/metrics/balance.py +4 -2
- dataeval/_internal/metrics/ber.py +2 -0
- dataeval/_internal/metrics/coverage.py +4 -0
- dataeval/_internal/metrics/divergence.py +6 -2
- dataeval/_internal/metrics/diversity.py +8 -6
- dataeval/_internal/metrics/parity.py +8 -6
- dataeval/_internal/metrics/stats/base.py +2 -2
- dataeval/_internal/metrics/stats/datasetstats.py +2 -0
- dataeval/_internal/metrics/stats/dimensionstats.py +2 -0
- dataeval/_internal/metrics/stats/hashstats.py +2 -0
- dataeval/_internal/metrics/stats/labelstats.py +1 -1
- dataeval/_internal/metrics/stats/pixelstats.py +4 -2
- dataeval/_internal/metrics/stats/visualstats.py +4 -2
- dataeval/_internal/metrics/uap.py +6 -2
- dataeval/_internal/metrics/utils.py +2 -2
- dataeval/_internal/models/pytorch/autoencoder.py +5 -5
- dataeval/_internal/models/tensorflow/pixelcnn.py +1 -4
- dataeval/_internal/utils.py +11 -16
- dataeval/_internal/workflows/sufficiency.py +44 -33
- dataeval/detectors/__init__.py +4 -0
- dataeval/detectors/drift/__init__.py +8 -3
- dataeval/detectors/drift/kernels/__init__.py +4 -0
- dataeval/detectors/drift/updates/__init__.py +4 -0
- dataeval/detectors/linters/__init__.py +15 -4
- dataeval/detectors/ood/__init__.py +14 -2
- dataeval/metrics/__init__.py +5 -0
- dataeval/metrics/bias/__init__.py +13 -4
- dataeval/metrics/estimators/__init__.py +8 -8
- dataeval/metrics/stats/__init__.py +17 -6
- dataeval/utils/__init__.py +16 -3
- dataeval/utils/tensorflow/__init__.py +11 -0
- dataeval/utils/torch/__init__.py +12 -0
- dataeval/utils/torch/datasets/__init__.py +7 -0
- dataeval/workflows/__init__.py +4 -0
- {dataeval-0.70.0.dist-info → dataeval-0.70.1.dist-info}/METADATA +10 -2
- dataeval-0.70.1.dist-info/RECORD +80 -0
- dataeval/tensorflow/__init__.py +0 -3
- dataeval/torch/__init__.py +0 -3
- dataeval-0.70.0.dist-info/RECORD +0 -79
- /dataeval/{tensorflow → utils/tensorflow}/loss/__init__.py +0 -0
- /dataeval/{tensorflow → utils/tensorflow}/models/__init__.py +0 -0
- /dataeval/{tensorflow → utils/tensorflow}/recon/__init__.py +0 -0
- /dataeval/{torch → utils/torch}/models/__init__.py +0 -0
- /dataeval/{torch → utils/torch}/trainer/__init__.py +0 -0
- {dataeval-0.70.0.dist-info → dataeval-0.70.1.dist-info}/LICENSE.txt +0 -0
- {dataeval-0.70.0.dist-info → dataeval-0.70.1.dist-info}/WHEEL +0 -0
dataeval/__init__.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1
|
-
__version__ = "0.70.
|
1
|
+
__version__ = "0.70.1"
|
2
2
|
|
3
3
|
from importlib.util import find_spec
|
4
4
|
|
@@ -12,11 +12,11 @@ from . import detectors, metrics # noqa: E402
|
|
12
12
|
__all__ = ["detectors", "metrics"]
|
13
13
|
|
14
14
|
if _IS_TORCH_AVAILABLE: # pragma: no cover
|
15
|
-
from . import
|
15
|
+
from . import workflows
|
16
16
|
|
17
|
-
__all__ += ["
|
17
|
+
__all__ += ["workflows"]
|
18
18
|
|
19
|
-
if _IS_TENSORFLOW_AVAILABLE: # pragma: no cover
|
20
|
-
from . import
|
19
|
+
if _IS_TENSORFLOW_AVAILABLE or _IS_TORCH_AVAILABLE: # pragma: no cover
|
20
|
+
from . import utils
|
21
21
|
|
22
|
-
__all__ += ["
|
22
|
+
__all__ += ["utils"]
|
dataeval/_internal/datasets.py
CHANGED
@@ -4,18 +4,39 @@ import hashlib
|
|
4
4
|
import os
|
5
5
|
import zipfile
|
6
6
|
from pathlib import Path
|
7
|
-
from typing import Literal
|
8
|
-
from
|
9
|
-
from urllib.request import urlretrieve
|
7
|
+
from typing import Literal, TypeVar
|
8
|
+
from warnings import warn
|
10
9
|
|
11
10
|
import numpy as np
|
11
|
+
import requests
|
12
12
|
from numpy.typing import NDArray
|
13
13
|
from torch.utils.data import Dataset
|
14
14
|
from torchvision.datasets import CIFAR10, VOCDetection # noqa: F401
|
15
15
|
|
16
|
-
|
17
|
-
|
18
|
-
|
16
|
+
ClassStringMap = Literal["zero", "one", "two", "three", "four", "five", "six", "seven", "eight", "nine"]
|
17
|
+
TClassMap = TypeVar("TClassMap", ClassStringMap, int, list[ClassStringMap], list[int])
|
18
|
+
CorruptionStringMap = Literal[
|
19
|
+
"identity",
|
20
|
+
"shot_noise",
|
21
|
+
"impulse_noise",
|
22
|
+
"glass_blur",
|
23
|
+
"motion_blur",
|
24
|
+
"shear",
|
25
|
+
"scale",
|
26
|
+
"rotate",
|
27
|
+
"brightness",
|
28
|
+
"translate",
|
29
|
+
"stripe",
|
30
|
+
"fog",
|
31
|
+
"spatter",
|
32
|
+
"dotted_line",
|
33
|
+
"zigzag",
|
34
|
+
"canny_edges",
|
35
|
+
]
|
36
|
+
|
37
|
+
|
38
|
+
def _validate_file(fpath, file_md5, md5=False, chunk_size=65535):
|
39
|
+
hasher = hashlib.md5() if md5 else hashlib.sha256()
|
19
40
|
with open(fpath, "rb") as fpath_file:
|
20
41
|
while chunk := fpath_file.read(chunk_size):
|
21
42
|
hasher.update(chunk)
|
@@ -26,44 +47,74 @@ def _get_file(
|
|
26
47
|
root: str | Path,
|
27
48
|
fname: str,
|
28
49
|
origin: str,
|
29
|
-
|
50
|
+
file_hash: str | None = None,
|
51
|
+
verbose: bool = True,
|
52
|
+
md5: bool = False,
|
30
53
|
):
|
31
|
-
fname = os.fspath(fname) if isinstance(fname, os.PathLike) else fname
|
32
54
|
fpath = os.path.join(root, fname)
|
33
|
-
|
34
|
-
|
35
|
-
|
36
|
-
if
|
37
|
-
|
38
|
-
|
39
|
-
|
40
|
-
else:
|
41
|
-
download = True
|
55
|
+
download = True
|
56
|
+
if os.path.exists(fpath) and file_hash is not None and _validate_file(fpath, file_hash, md5):
|
57
|
+
download = False
|
58
|
+
if verbose:
|
59
|
+
print("File already downloaded and verified.")
|
60
|
+
if md5:
|
61
|
+
print("Extracting zip file...")
|
42
62
|
|
43
63
|
if download:
|
44
64
|
try:
|
45
65
|
error_msg = "URL fetch failure on {}: {} -- {}"
|
46
66
|
try:
|
47
|
-
|
48
|
-
|
49
|
-
|
50
|
-
|
51
|
-
|
67
|
+
with requests.get(origin, stream=True, timeout=60) as r:
|
68
|
+
r.raise_for_status()
|
69
|
+
with open(fpath, "wb") as f:
|
70
|
+
for chunk in r.iter_content(chunk_size=8192):
|
71
|
+
if chunk:
|
72
|
+
f.write(chunk)
|
73
|
+
except requests.exceptions.HTTPError as e:
|
74
|
+
raise Exception(f"{error_msg.format(origin, e.response.status_code)} -- {e.response.reason}") from e
|
75
|
+
except requests.exceptions.RequestException as e:
|
76
|
+
raise Exception(f"{error_msg.format(origin, 'Unknown error')} -- {str(e)}") from e
|
52
77
|
except (Exception, KeyboardInterrupt):
|
53
78
|
if os.path.exists(fpath):
|
54
79
|
os.remove(fpath)
|
55
80
|
raise
|
56
81
|
|
57
|
-
if os.path.exists(fpath) and
|
82
|
+
if os.path.exists(fpath) and file_hash is not None and not _validate_file(fpath, file_hash, md5):
|
58
83
|
raise ValueError(
|
59
84
|
"Incomplete or corrupted file detected. "
|
60
|
-
f"The
|
61
|
-
f"of {
|
85
|
+
f"The file hash does not match the provided value "
|
86
|
+
f"of {file_hash}.",
|
62
87
|
)
|
88
|
+
|
63
89
|
return fpath
|
64
90
|
|
65
91
|
|
66
|
-
def
|
92
|
+
def check_exists(
|
93
|
+
folder: str | Path,
|
94
|
+
url: str,
|
95
|
+
root: str | Path,
|
96
|
+
fname: str,
|
97
|
+
file_hash: str,
|
98
|
+
download: bool = True,
|
99
|
+
verbose: bool = True,
|
100
|
+
md5: bool = False,
|
101
|
+
):
|
102
|
+
"""Determine if the dataset has already been downloaded."""
|
103
|
+
location = str(folder)
|
104
|
+
if not os.path.exists(folder):
|
105
|
+
if download:
|
106
|
+
location = download_dataset(url, root, fname, file_hash, verbose, md5)
|
107
|
+
else:
|
108
|
+
raise RuntimeError("Dataset not found. You can use download=True to download it")
|
109
|
+
else:
|
110
|
+
if verbose:
|
111
|
+
print("Files already downloaded and verified")
|
112
|
+
return location
|
113
|
+
|
114
|
+
|
115
|
+
def download_dataset(
|
116
|
+
url: str, root: str | Path, fname: str, file_hash: str, verbose: bool = True, md5: bool = False
|
117
|
+
) -> str:
|
67
118
|
"""Code to download mnist and corruptions, originates from tensorflow_datasets (tfds):
|
68
119
|
https://github.com/tensorflow/datasets/blob/master/tensorflow_datasets/image_classification/mnist_corrupted.py
|
69
120
|
"""
|
@@ -71,21 +122,24 @@ def download_dataset(url: str, root: str | Path, fname: str, md5: str) -> str:
|
|
71
122
|
folder = os.path.join(root, name)
|
72
123
|
os.makedirs(folder, exist_ok=True)
|
73
124
|
|
74
|
-
|
75
|
-
|
125
|
+
fpath = _get_file(
|
126
|
+
folder,
|
76
127
|
fname,
|
77
128
|
origin=url + fname,
|
78
|
-
|
129
|
+
file_hash=file_hash,
|
130
|
+
verbose=verbose,
|
131
|
+
md5=md5,
|
79
132
|
)
|
80
|
-
|
81
|
-
|
133
|
+
if md5:
|
134
|
+
folder = extract_archive(fpath, root, remove_finished=True)
|
135
|
+
return folder
|
82
136
|
|
83
137
|
|
84
138
|
def extract_archive(
|
85
139
|
from_path: str | Path,
|
86
140
|
to_path: str | Path | None = None,
|
87
141
|
remove_finished: bool = False,
|
88
|
-
):
|
142
|
+
) -> str:
|
89
143
|
"""Extract an archive.
|
90
144
|
|
91
145
|
The archive type and a possible compression is automatically detected from the file name.
|
@@ -94,8 +148,11 @@ def extract_archive(
|
|
94
148
|
if not from_path.is_absolute():
|
95
149
|
from_path = from_path.resolve()
|
96
150
|
|
97
|
-
if to_path is None:
|
151
|
+
if to_path is None or not os.path.exists(to_path):
|
98
152
|
to_path = os.path.dirname(from_path)
|
153
|
+
to_path = Path(to_path)
|
154
|
+
if not to_path.is_absolute():
|
155
|
+
to_path = to_path.resolve()
|
99
156
|
|
100
157
|
# Extracting zip
|
101
158
|
with zipfile.ZipFile(from_path, "r", compression=zipfile.ZIP_STORED) as zzip:
|
@@ -103,6 +160,13 @@ def extract_archive(
|
|
103
160
|
|
104
161
|
if remove_finished:
|
105
162
|
os.remove(from_path)
|
163
|
+
return str(to_path)
|
164
|
+
|
165
|
+
|
166
|
+
def subselect(arr: NDArray, count: int, from_back: bool = False):
|
167
|
+
if from_back:
|
168
|
+
return arr[-count:]
|
169
|
+
return arr[:count]
|
106
170
|
|
107
171
|
|
108
172
|
class MNIST(Dataset):
|
@@ -133,40 +197,42 @@ class MNIST(Dataset):
|
|
133
197
|
'motion_blur' | 'shear' | 'scale' | 'rotate' | 'brightness' | 'translate' | 'stripe' |
|
134
198
|
'fog' | 'spatter' | 'dotted_line' | 'zigzag' | 'canny_edges'] | None, default None
|
135
199
|
The desired corruption style or None.
|
200
|
+
classes : Literal["zero", "one", "two", "three", "four", "five", "six", "seven", "eight", "nine"]
|
201
|
+
| int | list[int] | list[Literal["zero", "one", "two", "three", "four", "five", "six", "seven",
|
202
|
+
"eight", "nine"]] | None, default None
|
203
|
+
Option to select specific classes from dataset.
|
204
|
+
balance : bool, default True
|
205
|
+
If True, returns equal number of samples for each class.
|
206
|
+
randomize : bool, default False
|
207
|
+
If True, shuffles the data prior to selection - uses a set seed for reproducibility.
|
208
|
+
slice_back : bool, default False
|
209
|
+
If True and size has a value greater than 0, then grabs selection starting at the last image.
|
210
|
+
verbose : bool, default True
|
211
|
+
If True, outputs print statements.
|
136
212
|
"""
|
137
213
|
|
138
|
-
mirror =
|
139
|
-
|
140
|
-
|
141
|
-
|
142
|
-
classes = [
|
143
|
-
"0 - zero",
|
144
|
-
"1 - one",
|
145
|
-
"2 - two",
|
146
|
-
"3 - three",
|
147
|
-
"4 - four",
|
148
|
-
"5 - five",
|
149
|
-
"6 - six",
|
150
|
-
"7 - seven",
|
151
|
-
"8 - eight",
|
152
|
-
"9 - nine",
|
214
|
+
mirror = [
|
215
|
+
"https://storage.googleapis.com/tensorflow/tf-keras-datasets/",
|
216
|
+
"https://zenodo.org/record/3239543/files/",
|
153
217
|
]
|
154
218
|
|
155
|
-
|
156
|
-
|
157
|
-
|
158
|
-
|
159
|
-
@property
|
160
|
-
def test_labels(self):
|
161
|
-
return self.targets
|
162
|
-
|
163
|
-
@property
|
164
|
-
def train_data(self):
|
165
|
-
return self.data
|
219
|
+
resources = [
|
220
|
+
("mnist.npz", "731c5ac602752760c8e48fbffcf8c3b850d9dc2a2aedcf2cc48468fc17b673d1"),
|
221
|
+
("mnist_c.zip", "4b34b33045869ee6d424616cd3a65da3"),
|
222
|
+
]
|
166
223
|
|
167
|
-
|
168
|
-
|
169
|
-
|
224
|
+
class_dict = {
|
225
|
+
"zero": 0,
|
226
|
+
"one": 1,
|
227
|
+
"two": 2,
|
228
|
+
"three": 3,
|
229
|
+
"four": 4,
|
230
|
+
"five": 5,
|
231
|
+
"six": 6,
|
232
|
+
"seven": 7,
|
233
|
+
"eight": 8,
|
234
|
+
"nine": 9,
|
235
|
+
}
|
170
236
|
|
171
237
|
def __init__(
|
172
238
|
self,
|
@@ -179,25 +245,12 @@ class MNIST(Dataset):
|
|
179
245
|
channels: Literal["channels_first", "channels_last"] | None = None,
|
180
246
|
flatten: bool = False,
|
181
247
|
normalize: tuple[float, float] | None = None,
|
182
|
-
corruption:
|
183
|
-
|
184
|
-
|
185
|
-
|
186
|
-
|
187
|
-
|
188
|
-
"shear",
|
189
|
-
"scale",
|
190
|
-
"rotate",
|
191
|
-
"brightness",
|
192
|
-
"translate",
|
193
|
-
"stripe",
|
194
|
-
"fog",
|
195
|
-
"spatter",
|
196
|
-
"dotted_line",
|
197
|
-
"zigzag",
|
198
|
-
"canny_edges",
|
199
|
-
]
|
200
|
-
| None = None,
|
248
|
+
corruption: CorruptionStringMap | None = None,
|
249
|
+
classes: TClassMap | None = None,
|
250
|
+
balance: bool = True,
|
251
|
+
randomize: bool = False,
|
252
|
+
slice_back: bool = False,
|
253
|
+
verbose: bool = True,
|
201
254
|
) -> None:
|
202
255
|
if isinstance(root, str):
|
203
256
|
root = os.path.expanduser(root)
|
@@ -209,64 +262,113 @@ class MNIST(Dataset):
|
|
209
262
|
self.channels = channels
|
210
263
|
self.flatten = flatten
|
211
264
|
self.normalize = normalize
|
212
|
-
|
213
|
-
if corruption is None:
|
214
|
-
corruption = "identity"
|
215
|
-
elif corruption == "identity":
|
216
|
-
print("Identity is not a corrupted dataset but the original MNIST dataset")
|
217
265
|
self.corruption = corruption
|
218
|
-
|
219
|
-
|
220
|
-
|
221
|
-
|
222
|
-
|
266
|
+
self.balance = balance
|
267
|
+
self.randomize = randomize
|
268
|
+
self.from_back = slice_back
|
269
|
+
self.verbose = verbose
|
270
|
+
|
271
|
+
self.class_set = []
|
272
|
+
if classes is not None:
|
273
|
+
if not isinstance(classes, list):
|
274
|
+
classes = [classes] # type: ignore
|
275
|
+
|
276
|
+
for val in classes: # type: ignore
|
277
|
+
if isinstance(val, int) and 0 <= val < 10:
|
278
|
+
self.class_set.append(val)
|
279
|
+
elif isinstance(val, str):
|
280
|
+
self.class_set.append(self.class_dict[val])
|
281
|
+
self.class_set = set(self.class_set)
|
282
|
+
|
283
|
+
if not self.class_set:
|
284
|
+
self.class_set = set(self.class_dict.values())
|
285
|
+
|
286
|
+
self.num_classes = len(self.class_set)
|
287
|
+
|
288
|
+
if self.corruption is None:
|
289
|
+
file_resource = self.resources[0]
|
290
|
+
mirror = self.mirror[0]
|
291
|
+
md5 = False
|
223
292
|
else:
|
224
|
-
|
293
|
+
if self.corruption == "identity" and verbose:
|
294
|
+
print("Identity is not a corrupted dataset but the original MNIST dataset.")
|
295
|
+
file_resource = self.resources[1]
|
296
|
+
mirror = self.mirror[1]
|
297
|
+
md5 = True
|
298
|
+
check_exists(self.mnist_folder, mirror, self.root, file_resource[0], file_resource[1], download, verbose, md5)
|
225
299
|
|
226
300
|
self.data, self.targets = self._load_data()
|
227
301
|
|
302
|
+
self._augmentations()
|
303
|
+
|
228
304
|
def _load_data(self):
|
229
|
-
|
230
|
-
|
231
|
-
|
232
|
-
|
233
|
-
|
234
|
-
|
235
|
-
|
236
|
-
|
237
|
-
|
238
|
-
|
239
|
-
|
240
|
-
|
241
|
-
|
242
|
-
|
243
|
-
|
244
|
-
|
245
|
-
|
246
|
-
|
247
|
-
|
248
|
-
|
249
|
-
|
250
|
-
|
305
|
+
if self.corruption is None:
|
306
|
+
image_file = self.resources[0][0]
|
307
|
+
data, targets = self._read_normal_file(os.path.join(self.mnist_folder, image_file))
|
308
|
+
else:
|
309
|
+
image_file = f"{'train' if self.train else 'test'}_images.npy"
|
310
|
+
data = self._read_corrupt_file(os.path.join(self.mnist_folder, image_file))
|
311
|
+
data = data.squeeze()
|
312
|
+
|
313
|
+
label_file = f"{'train' if self.train else 'test'}_labels.npy"
|
314
|
+
targets = self._read_corrupt_file(os.path.join(self.mnist_folder, label_file))
|
315
|
+
|
316
|
+
return data, targets
|
317
|
+
|
318
|
+
def _augmentations(self):
|
319
|
+
if self.size > self.targets.shape[0] and self.verbose:
|
320
|
+
warn(
|
321
|
+
f"Asked for more samples, {self.size}, than the raw dataset contains, {self.targets.shape[0]}. "
|
322
|
+
"Adjusting down to raw dataset size."
|
323
|
+
)
|
324
|
+
self.size = -1
|
325
|
+
|
326
|
+
if self.randomize:
|
327
|
+
rdm_seed = np.random.default_rng(2023)
|
328
|
+
shuffled_indices = rdm_seed.permutation(self.data.shape[0])
|
329
|
+
self.data = self.data[shuffled_indices]
|
330
|
+
self.targets = self.targets[shuffled_indices]
|
331
|
+
|
332
|
+
if not self.balance and self.num_classes > self.size:
|
333
|
+
if self.size > 0:
|
334
|
+
self.data = subselect(self.data, self.size, self.from_back)
|
335
|
+
self.targets = subselect(self.targets, self.size, self.from_back)
|
336
|
+
else:
|
337
|
+
label_dict = {label: np.where(self.targets == label)[0] for label in self.class_set}
|
338
|
+
min_label_count = min(len(indices) for indices in label_dict.values())
|
339
|
+
|
340
|
+
self.per_class_count = int(np.ceil(self.size / self.num_classes)) if self.size > 0 else min_label_count
|
341
|
+
|
342
|
+
if self.per_class_count > min_label_count:
|
343
|
+
self.per_class_count = min_label_count
|
344
|
+
if not self.balance and self.verbose:
|
345
|
+
warn(
|
346
|
+
f"Because of dataset limitations, only {min_label_count*self.num_classes} samples "
|
347
|
+
f"will be returned, instead of the desired {self.size}."
|
348
|
+
)
|
349
|
+
|
350
|
+
all_indices = np.empty(shape=(self.num_classes, self.per_class_count), dtype=int)
|
351
|
+
for i, label in enumerate(self.class_set):
|
352
|
+
all_indices[i] = subselect(label_dict[label], self.per_class_count, self.from_back)
|
353
|
+
self.data = np.vstack(self.data[all_indices.T]) # type: ignore
|
354
|
+
self.targets = np.hstack(self.targets[all_indices.T]) # type: ignore
|
251
355
|
|
252
356
|
if self.unit_interval:
|
253
|
-
data = data / 255
|
357
|
+
self.data = self.data / 255
|
254
358
|
|
255
359
|
if self.normalize:
|
256
|
-
data = (data - self.normalize[0]) / self.normalize[1]
|
360
|
+
self.data = (self.data - self.normalize[0]) / self.normalize[1]
|
257
361
|
|
258
362
|
if self.dtype:
|
259
|
-
data = data.astype(self.dtype)
|
363
|
+
self.data = self.data.astype(self.dtype)
|
260
364
|
|
261
365
|
if self.channels == "channels_first":
|
262
|
-
data =
|
263
|
-
elif self.channels
|
264
|
-
data = data[:, :, :,
|
366
|
+
self.data = self.data[:, np.newaxis, :, :]
|
367
|
+
elif self.channels == "channels_last":
|
368
|
+
self.data = self.data[:, :, :, np.newaxis]
|
265
369
|
|
266
370
|
if self.flatten and self.channels is None:
|
267
|
-
data = data.reshape(data.shape[0], -1)
|
268
|
-
|
269
|
-
return data, targets
|
371
|
+
self.data = self.data.reshape(self.data.shape[0], -1)
|
270
372
|
|
271
373
|
def __getitem__(self, index: int) -> tuple[NDArray, int]:
|
272
374
|
"""
|
@@ -285,16 +387,18 @@ class MNIST(Dataset):
|
|
285
387
|
|
286
388
|
@property
|
287
389
|
def mnist_folder(self) -> str:
|
390
|
+
if self.corruption is None:
|
391
|
+
return os.path.join(self.root, "mnist")
|
288
392
|
return os.path.join(self.root, "mnist_c", self.corruption)
|
289
393
|
|
290
|
-
|
291
|
-
|
292
|
-
|
293
|
-
|
294
|
-
|
295
|
-
|
296
|
-
return x
|
394
|
+
def _read_normal_file(self, path: str) -> tuple[NDArray, NDArray]:
|
395
|
+
with np.load(path, allow_pickle=True) as f:
|
396
|
+
if self.train:
|
397
|
+
x, y = f["x_train"], f["y_train"]
|
398
|
+
else:
|
399
|
+
x, y = f["x_test"], f["y_test"]
|
400
|
+
return x, y
|
297
401
|
|
298
|
-
def
|
402
|
+
def _read_corrupt_file(self, path: str) -> NDArray:
|
299
403
|
x = np.load(path, allow_pickle=False)
|
300
404
|
return x
|
@@ -23,7 +23,7 @@ from dataeval._internal.output import OutputMetadata, set_metadata
|
|
23
23
|
@dataclass(frozen=True)
|
24
24
|
class DriftBaseOutput(OutputMetadata):
|
25
25
|
"""
|
26
|
-
|
26
|
+
Base output class for Drift detector classes
|
27
27
|
|
28
28
|
Attributes
|
29
29
|
----------
|
@@ -42,7 +42,7 @@ class DriftBaseOutput(OutputMetadata):
|
|
42
42
|
@dataclass(frozen=True)
|
43
43
|
class DriftOutput(DriftBaseOutput):
|
44
44
|
"""
|
45
|
-
Output class for DriftCVM and
|
45
|
+
Output class for :class:`DriftCVM`, :class:`DriftKS`, and :class:`DriftUncertainty` drift detectors
|
46
46
|
|
47
47
|
Attributes
|
48
48
|
----------
|
@@ -17,6 +17,8 @@ TIndexCollection = TypeVar("TIndexCollection", DuplicateGroup, DatasetDuplicateG
|
|
17
17
|
@dataclass(frozen=True)
|
18
18
|
class DuplicatesOutput(Generic[TIndexCollection], OutputMetadata):
|
19
19
|
"""
|
20
|
+
Output class for :class:`Duplicates` lint detector
|
21
|
+
|
20
22
|
Attributes
|
21
23
|
----------
|
22
24
|
exact : list[list[int] | dict[int, list[int]]]
|
@@ -15,10 +15,11 @@ import numpy as np
|
|
15
15
|
import tensorflow as tf
|
16
16
|
from numpy.typing import ArrayLike
|
17
17
|
|
18
|
-
from dataeval._internal.detectors.ood.base import OODBase,
|
18
|
+
from dataeval._internal.detectors.ood.base import OODBase, OODScoreOutput
|
19
19
|
from dataeval._internal.interop import as_numpy
|
20
20
|
from dataeval._internal.models.tensorflow.autoencoder import AE
|
21
21
|
from dataeval._internal.models.tensorflow.utils import predict_batch
|
22
|
+
from dataeval._internal.output import set_metadata
|
22
23
|
|
23
24
|
|
24
25
|
class OOD_AE(OODBase):
|
@@ -48,7 +49,8 @@ class OOD_AE(OODBase):
|
|
48
49
|
loss_fn = keras.losses.MeanSquaredError()
|
49
50
|
super().fit(as_numpy(x_ref), threshold_perc, loss_fn, optimizer, epochs, batch_size, verbose)
|
50
51
|
|
51
|
-
|
52
|
+
@set_metadata("dataeval.detectors")
|
53
|
+
def score(self, X: ArrayLike, batch_size: int = int(1e10)) -> OODScoreOutput:
|
52
54
|
self._validate(X := as_numpy(X))
|
53
55
|
|
54
56
|
# reconstruct instances
|
@@ -62,4 +64,4 @@ class OOD_AE(OODBase):
|
|
62
64
|
sorted_fscore_perc = sorted_fscore[:, -n_score_features:]
|
63
65
|
iscore = np.mean(sorted_fscore_perc, axis=1)
|
64
66
|
|
65
|
-
return
|
67
|
+
return OODScoreOutput(iscore, fscore)
|
@@ -14,12 +14,13 @@ import keras
|
|
14
14
|
import tensorflow as tf
|
15
15
|
from numpy.typing import ArrayLike
|
16
16
|
|
17
|
-
from dataeval._internal.detectors.ood.base import OODGMMBase,
|
17
|
+
from dataeval._internal.detectors.ood.base import OODGMMBase, OODScoreOutput
|
18
18
|
from dataeval._internal.interop import to_numpy
|
19
19
|
from dataeval._internal.models.tensorflow.autoencoder import AEGMM
|
20
20
|
from dataeval._internal.models.tensorflow.gmm import gmm_energy
|
21
21
|
from dataeval._internal.models.tensorflow.losses import LossGMM
|
22
22
|
from dataeval._internal.models.tensorflow.utils import predict_batch
|
23
|
+
from dataeval._internal.output import set_metadata
|
23
24
|
|
24
25
|
|
25
26
|
class OOD_AEGMM(OODGMMBase):
|
@@ -49,7 +50,8 @@ class OOD_AEGMM(OODGMMBase):
|
|
49
50
|
loss_fn = LossGMM()
|
50
51
|
super().fit(x_ref, threshold_perc, loss_fn, optimizer, epochs, batch_size, verbose)
|
51
52
|
|
52
|
-
|
53
|
+
@set_metadata("dataeval.detectors")
|
54
|
+
def score(self, X: ArrayLike, batch_size: int = int(1e10)) -> OODScoreOutput:
|
53
55
|
"""
|
54
56
|
Compute the out-of-distribution (OOD) score for a given dataset.
|
55
57
|
|
@@ -63,7 +65,7 @@ class OOD_AEGMM(OODGMMBase):
|
|
63
65
|
|
64
66
|
Returns
|
65
67
|
-------
|
66
|
-
|
68
|
+
OODScoreOutput
|
67
69
|
An object containing the instance-level OOD score.
|
68
70
|
|
69
71
|
Note
|
@@ -73,4 +75,4 @@ class OOD_AEGMM(OODGMMBase):
|
|
73
75
|
self._validate(X := to_numpy(X))
|
74
76
|
_, z, _ = predict_batch(X, self.model, batch_size=batch_size)
|
75
77
|
energy, _ = gmm_energy(z, self.gmm_params, return_mean=False)
|
76
|
-
return
|
78
|
+
return OODScoreOutput(energy.numpy()) # type: ignore
|