dataeval 0.72.0__py3-none-any.whl → 0.72.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dataeval/__init__.py +4 -4
- dataeval/detectors/__init__.py +4 -3
- dataeval/detectors/drift/__init__.py +10 -11
- dataeval/{_internal/detectors → detectors}/drift/base.py +51 -102
- dataeval/{_internal/detectors → detectors}/drift/cvm.py +9 -8
- dataeval/{_internal/detectors → detectors}/drift/ks.py +11 -10
- dataeval/{_internal/detectors → detectors}/drift/mmd.py +33 -34
- dataeval/{_internal/detectors → detectors}/drift/torch.py +15 -13
- dataeval/{_internal/detectors → detectors}/drift/uncertainty.py +12 -9
- dataeval/detectors/drift/updates.py +61 -0
- dataeval/detectors/linters/__init__.py +3 -3
- dataeval/{_internal/detectors → detectors/linters}/clusterer.py +47 -45
- dataeval/{_internal/detectors → detectors/linters}/duplicates.py +20 -10
- dataeval/{_internal/detectors → detectors/linters}/merged_stats.py +3 -1
- dataeval/{_internal/detectors → detectors/linters}/outliers.py +19 -26
- dataeval/detectors/ood/__init__.py +8 -16
- dataeval/{_internal/detectors → detectors}/ood/ae.py +9 -9
- dataeval/{_internal/detectors → detectors}/ood/aegmm.py +10 -30
- dataeval/{_internal/detectors → detectors}/ood/base.py +27 -21
- dataeval/{_internal/detectors → detectors}/ood/llr.py +27 -23
- dataeval/detectors/ood/metadata_ks_compare.py +99 -0
- dataeval/detectors/ood/metadata_least_likely.py +119 -0
- dataeval/detectors/ood/metadata_ood_mi.py +92 -0
- dataeval/{_internal/detectors → detectors}/ood/vae.py +11 -13
- dataeval/{_internal/detectors → detectors}/ood/vaegmm.py +10 -32
- dataeval/{_internal/interop.py → interop.py} +12 -7
- dataeval/metrics/__init__.py +1 -1
- dataeval/metrics/bias/__init__.py +4 -4
- dataeval/{_internal/metrics → metrics/bias}/balance.py +70 -4
- dataeval/{_internal/metrics → metrics/bias}/coverage.py +10 -8
- dataeval/{_internal/metrics → metrics/bias}/diversity.py +54 -20
- dataeval/metrics/bias/metadata.py +275 -0
- dataeval/{_internal/metrics → metrics/bias}/parity.py +21 -17
- dataeval/metrics/estimators/__init__.py +3 -3
- dataeval/{_internal/metrics → metrics/estimators}/ber.py +31 -28
- dataeval/{_internal/metrics → metrics/estimators}/divergence.py +15 -16
- dataeval/{_internal/metrics → metrics/estimators}/uap.py +8 -6
- dataeval/metrics/stats/__init__.py +7 -7
- dataeval/{_internal/metrics → metrics}/stats/base.py +66 -40
- dataeval/{_internal/metrics → metrics}/stats/boxratiostats.py +19 -15
- dataeval/{_internal/metrics → metrics}/stats/datasetstats.py +19 -17
- dataeval/{_internal/metrics → metrics}/stats/dimensionstats.py +12 -10
- dataeval/metrics/stats/hashstats.py +156 -0
- dataeval/{_internal/metrics → metrics}/stats/labelstats.py +8 -6
- dataeval/{_internal/metrics → metrics}/stats/pixelstats.py +12 -11
- dataeval/{_internal/metrics → metrics}/stats/visualstats.py +14 -13
- dataeval/{_internal/output.py → output.py} +26 -6
- dataeval/utils/__init__.py +8 -4
- dataeval/utils/image.py +71 -0
- dataeval/utils/shared.py +151 -0
- dataeval/utils/split_dataset.py +486 -0
- dataeval/utils/tensorflow/__init__.py +9 -7
- dataeval/{_internal/models/tensorflow → utils/tensorflow/_internal}/autoencoder.py +64 -68
- dataeval/{_internal/models/tensorflow/losses.py → utils/tensorflow/_internal/loss.py} +10 -9
- dataeval/{_internal/models/tensorflow → utils/tensorflow/_internal}/pixelcnn.py +18 -22
- dataeval/{_internal/models/tensorflow → utils/tensorflow/_internal}/trainer.py +3 -1
- dataeval/{_internal/models/tensorflow → utils/tensorflow/_internal}/utils.py +18 -18
- dataeval/utils/tensorflow/loss/__init__.py +6 -2
- dataeval/utils/torch/__init__.py +7 -3
- dataeval/{_internal/models/pytorch → utils/torch}/blocks.py +19 -14
- dataeval/{_internal → utils/torch}/datasets.py +49 -43
- dataeval/utils/torch/models.py +138 -0
- dataeval/{_internal/models/pytorch/autoencoder.py → utils/torch/trainer.py} +12 -141
- dataeval/{_internal → utils/torch}/utils.py +3 -1
- dataeval/workflows/__init__.py +1 -1
- dataeval/{_internal/workflows → workflows}/sufficiency.py +42 -37
- {dataeval-0.72.0.dist-info → dataeval-0.72.2.dist-info}/METADATA +7 -5
- dataeval-0.72.2.dist-info/RECORD +72 -0
- dataeval/_internal/detectors/__init__.py +0 -0
- dataeval/_internal/detectors/drift/__init__.py +0 -0
- dataeval/_internal/detectors/ood/__init__.py +0 -0
- dataeval/_internal/metrics/__init__.py +0 -0
- dataeval/_internal/metrics/stats/hashstats.py +0 -75
- dataeval/_internal/metrics/utils.py +0 -447
- dataeval/_internal/models/__init__.py +0 -0
- dataeval/_internal/models/pytorch/__init__.py +0 -0
- dataeval/_internal/models/pytorch/utils.py +0 -67
- dataeval/_internal/models/tensorflow/__init__.py +0 -0
- dataeval/_internal/workflows/__init__.py +0 -0
- dataeval/detectors/drift/kernels/__init__.py +0 -10
- dataeval/detectors/drift/updates/__init__.py +0 -7
- dataeval/utils/tensorflow/models/__init__.py +0 -9
- dataeval/utils/tensorflow/recon/__init__.py +0 -3
- dataeval/utils/torch/datasets/__init__.py +0 -12
- dataeval/utils/torch/models/__init__.py +0 -11
- dataeval/utils/torch/trainer/__init__.py +0 -7
- dataeval-0.72.0.dist-info/RECORD +0 -80
- /dataeval/{_internal/models/tensorflow → utils/tensorflow/_internal}/gmm.py +0 -0
- {dataeval-0.72.0.dist-info → dataeval-0.72.2.dist-info}/LICENSE.txt +0 -0
- {dataeval-0.72.0.dist-info → dataeval-0.72.2.dist-info}/WHEEL +0 -0
@@ -1,3 +1,7 @@
|
|
1
|
+
from typing import Any
|
2
|
+
|
3
|
+
__all__ = []
|
4
|
+
|
1
5
|
import torch.nn as nn
|
2
6
|
|
3
7
|
|
@@ -8,21 +12,22 @@ class Conv(nn.Module):
|
|
8
12
|
|
9
13
|
def __init__(
|
10
14
|
self,
|
11
|
-
in_channels,
|
12
|
-
out_channels,
|
13
|
-
k=1,
|
14
|
-
s=1,
|
15
|
-
p=0,
|
16
|
-
activation="relu",
|
17
|
-
norm="instance",
|
18
|
-
):
|
15
|
+
in_channels: int,
|
16
|
+
out_channels: int,
|
17
|
+
k: int = 1,
|
18
|
+
s: int = 1,
|
19
|
+
p: int = 0,
|
20
|
+
activation: str = "relu",
|
21
|
+
norm: str = "instance",
|
22
|
+
) -> None:
|
19
23
|
super().__init__()
|
20
|
-
|
21
|
-
|
22
|
-
|
23
|
-
|
24
|
+
self.module: nn.Sequential = nn.Sequential(
|
25
|
+
nn.Conv2d(in_channels, out_channels, kernel_size=k, stride=s, padding=p),
|
26
|
+
self.get_norm_func(norm=norm, out_channels=out_channels),
|
27
|
+
self.get_activation_func(activation=activation),
|
28
|
+
)
|
24
29
|
|
25
|
-
def get_norm_func(self, norm: str, out_channels) -> nn.Module:
|
30
|
+
def get_norm_func(self, norm: str, out_channels: int) -> nn.Module:
|
26
31
|
if norm == "batch":
|
27
32
|
return nn.BatchNorm2d(out_channels)
|
28
33
|
if norm == "instance":
|
@@ -42,5 +47,5 @@ class Conv(nn.Module):
|
|
42
47
|
return nn.Tanh()
|
43
48
|
return nn.Identity()
|
44
49
|
|
45
|
-
def forward(self, x):
|
50
|
+
def forward(self, x: Any) -> Any:
|
46
51
|
return self.module(x)
|
@@ -1,5 +1,7 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
|
+
__all__ = ["MNIST", "CIFAR10", "VOCDetection"]
|
4
|
+
|
3
5
|
import hashlib
|
4
6
|
import os
|
5
7
|
import zipfile
|
@@ -11,7 +13,7 @@ import numpy as np
|
|
11
13
|
import requests
|
12
14
|
from numpy.typing import NDArray
|
13
15
|
from torch.utils.data import Dataset
|
14
|
-
from torchvision.datasets import CIFAR10, VOCDetection
|
16
|
+
from torchvision.datasets import CIFAR10, VOCDetection
|
15
17
|
|
16
18
|
ClassStringMap = Literal["zero", "one", "two", "three", "four", "five", "six", "seven", "eight", "nine"]
|
17
19
|
TClassMap = TypeVar("TClassMap", ClassStringMap, int, list[ClassStringMap], list[int])
|
@@ -50,6 +52,7 @@ def _get_file(
|
|
50
52
|
file_hash: str | None = None,
|
51
53
|
verbose: bool = True,
|
52
54
|
md5: bool = False,
|
55
|
+
timeout: int = 60,
|
53
56
|
):
|
54
57
|
fpath = os.path.join(root, fname)
|
55
58
|
download = True
|
@@ -64,16 +67,16 @@ def _get_file(
|
|
64
67
|
try:
|
65
68
|
error_msg = "URL fetch failure on {}: {} -- {}"
|
66
69
|
try:
|
67
|
-
with requests.get(origin, stream=True, timeout=
|
70
|
+
with requests.get(origin, stream=True, timeout=timeout) as r:
|
68
71
|
r.raise_for_status()
|
69
72
|
with open(fpath, "wb") as f:
|
70
73
|
for chunk in r.iter_content(chunk_size=8192):
|
71
74
|
if chunk:
|
72
75
|
f.write(chunk)
|
73
76
|
except requests.exceptions.HTTPError as e:
|
74
|
-
raise
|
77
|
+
raise RuntimeError(f"{error_msg.format(origin, e.response.status_code, e.response.reason)}") from e
|
75
78
|
except requests.exceptions.RequestException as e:
|
76
|
-
raise
|
79
|
+
raise ValueError(f"{error_msg.format(origin, 'Unknown error', str(e))}") from e
|
77
80
|
except (Exception, KeyboardInterrupt):
|
78
81
|
if os.path.exists(fpath):
|
79
82
|
os.remove(fpath)
|
@@ -89,7 +92,7 @@ def _get_file(
|
|
89
92
|
return fpath
|
90
93
|
|
91
94
|
|
92
|
-
def
|
95
|
+
def _check_exists(
|
93
96
|
folder: str | Path,
|
94
97
|
url: str,
|
95
98
|
root: str | Path,
|
@@ -103,7 +106,7 @@ def check_exists(
|
|
103
106
|
location = str(folder)
|
104
107
|
if not os.path.exists(folder):
|
105
108
|
if download:
|
106
|
-
location =
|
109
|
+
location = _download_dataset(url, root, fname, file_hash, verbose, md5)
|
107
110
|
else:
|
108
111
|
raise RuntimeError("Dataset not found. You can use download=True to download it")
|
109
112
|
else:
|
@@ -112,7 +115,7 @@ def check_exists(
|
|
112
115
|
return location
|
113
116
|
|
114
117
|
|
115
|
-
def
|
118
|
+
def _download_dataset(
|
116
119
|
url: str, root: str | Path, fname: str, file_hash: str, verbose: bool = True, md5: bool = False
|
117
120
|
) -> str:
|
118
121
|
"""Code to download mnist and corruptions, originates from tensorflow_datasets (tfds):
|
@@ -131,11 +134,11 @@ def download_dataset(
|
|
131
134
|
md5=md5,
|
132
135
|
)
|
133
136
|
if md5:
|
134
|
-
folder =
|
137
|
+
folder = _extract_archive(fpath, root, remove_finished=True)
|
135
138
|
return folder
|
136
139
|
|
137
140
|
|
138
|
-
def
|
141
|
+
def _extract_archive(
|
139
142
|
from_path: str | Path,
|
140
143
|
to_path: str | Path | None = None,
|
141
144
|
remove_finished: bool = False,
|
@@ -163,13 +166,13 @@ def extract_archive(
|
|
163
166
|
return str(to_path)
|
164
167
|
|
165
168
|
|
166
|
-
def
|
169
|
+
def _subselect(arr: NDArray, count: int, from_back: bool = False):
|
167
170
|
if from_back:
|
168
171
|
return arr[-count:]
|
169
172
|
return arr[:count]
|
170
173
|
|
171
174
|
|
172
|
-
class MNIST(Dataset):
|
175
|
+
class MNIST(Dataset[tuple[NDArray[np.float64], int]]):
|
173
176
|
"""MNIST Dataset and Corruptions.
|
174
177
|
|
175
178
|
Args:
|
@@ -185,7 +188,7 @@ class MNIST(Dataset):
|
|
185
188
|
unit_interval : bool, default False
|
186
189
|
Shift the data values to the unit interval [0-1].
|
187
190
|
dtype : type | None, default None
|
188
|
-
Change the
|
191
|
+
Change the :term:`NumPy` dtype - data is loaded as np.uint8
|
189
192
|
channels : Literal['channels_first' | 'channels_last'] | None, default None
|
190
193
|
Location of channel axis if desired, default has no channels (N, 28, 28)
|
191
194
|
flatten : bool, default False
|
@@ -211,17 +214,17 @@ class MNIST(Dataset):
|
|
211
214
|
If True, outputs print statements.
|
212
215
|
"""
|
213
216
|
|
214
|
-
|
217
|
+
_mirrors: tuple[str, ...] = (
|
215
218
|
"https://storage.googleapis.com/tensorflow/tf-keras-datasets/",
|
216
219
|
"https://zenodo.org/record/3239543/files/",
|
217
|
-
|
220
|
+
)
|
218
221
|
|
219
|
-
|
222
|
+
_resources: tuple[tuple[str, str], ...] = (
|
220
223
|
("mnist.npz", "731c5ac602752760c8e48fbffcf8c3b850d9dc2a2aedcf2cc48468fc17b673d1"),
|
221
224
|
("mnist_c.zip", "4b34b33045869ee6d424616cd3a65da3"),
|
222
|
-
|
225
|
+
)
|
223
226
|
|
224
|
-
class_dict = {
|
227
|
+
class_dict: dict[str, int] = {
|
225
228
|
"zero": 0,
|
226
229
|
"one": 1,
|
227
230
|
"two": 2,
|
@@ -267,43 +270,46 @@ class MNIST(Dataset):
|
|
267
270
|
self.randomize = randomize
|
268
271
|
self.from_back = slice_back
|
269
272
|
self.verbose = verbose
|
273
|
+
self.data: NDArray[np.float64]
|
274
|
+
self.targets: NDArray[np.int_]
|
275
|
+
self.size: int
|
270
276
|
|
271
|
-
self.
|
277
|
+
self._class_set = []
|
272
278
|
if classes is not None:
|
273
279
|
if not isinstance(classes, list):
|
274
280
|
classes = [classes] # type: ignore
|
275
281
|
|
276
282
|
for val in classes: # type: ignore
|
277
283
|
if isinstance(val, int) and 0 <= val < 10:
|
278
|
-
self.
|
284
|
+
self._class_set.append(val)
|
279
285
|
elif isinstance(val, str):
|
280
|
-
self.
|
281
|
-
self.
|
286
|
+
self._class_set.append(self.class_dict[val])
|
287
|
+
self._class_set = set(self._class_set)
|
282
288
|
|
283
|
-
if not self.
|
284
|
-
self.
|
289
|
+
if not self._class_set:
|
290
|
+
self._class_set = set(self.class_dict.values())
|
285
291
|
|
286
|
-
self.
|
292
|
+
self._num_classes = len(self._class_set)
|
287
293
|
|
288
294
|
if self.corruption is None:
|
289
|
-
file_resource = self.
|
290
|
-
mirror = self.
|
295
|
+
file_resource = self._resources[0]
|
296
|
+
mirror = self._mirrors[0]
|
291
297
|
md5 = False
|
292
298
|
else:
|
293
299
|
if self.corruption == "identity" and verbose:
|
294
300
|
print("Identity is not a corrupted dataset but the original MNIST dataset.")
|
295
|
-
file_resource = self.
|
296
|
-
mirror = self.
|
301
|
+
file_resource = self._resources[1]
|
302
|
+
mirror = self._mirrors[1]
|
297
303
|
md5 = True
|
298
|
-
|
304
|
+
_check_exists(self.mnist_folder, mirror, self.root, file_resource[0], file_resource[1], download, verbose, md5)
|
299
305
|
|
300
306
|
self.data, self.targets = self._load_data()
|
301
307
|
|
302
308
|
self._augmentations()
|
303
309
|
|
304
|
-
def _load_data(self):
|
310
|
+
def _load_data(self) -> tuple[NDArray[np.float64], NDArray[np.int64]]:
|
305
311
|
if self.corruption is None:
|
306
|
-
image_file = self.
|
312
|
+
image_file = self._resources[0][0]
|
307
313
|
data, targets = self._read_normal_file(os.path.join(self.mnist_folder, image_file))
|
308
314
|
else:
|
309
315
|
image_file = f"{'train' if self.train else 'test'}_images.npy"
|
@@ -329,27 +335,27 @@ class MNIST(Dataset):
|
|
329
335
|
self.data = self.data[shuffled_indices]
|
330
336
|
self.targets = self.targets[shuffled_indices]
|
331
337
|
|
332
|
-
if not self.balance and self.
|
338
|
+
if not self.balance and self._num_classes > self.size:
|
333
339
|
if self.size > 0:
|
334
|
-
self.data =
|
335
|
-
self.targets =
|
340
|
+
self.data = _subselect(self.data, self.size, self.from_back)
|
341
|
+
self.targets = _subselect(self.targets, self.size, self.from_back)
|
336
342
|
else:
|
337
|
-
label_dict = {label: np.where(self.targets == label)[0] for label in self.
|
343
|
+
label_dict = {label: np.where(self.targets == label)[0] for label in self._class_set}
|
338
344
|
min_label_count = min(len(indices) for indices in label_dict.values())
|
339
345
|
|
340
|
-
self.
|
346
|
+
self._per_class_count = int(np.ceil(self.size / self._num_classes)) if self.size > 0 else min_label_count
|
341
347
|
|
342
|
-
if self.
|
343
|
-
self.
|
348
|
+
if self._per_class_count > min_label_count:
|
349
|
+
self._per_class_count = min_label_count
|
344
350
|
if not self.balance and self.verbose:
|
345
351
|
warn(
|
346
|
-
f"Because of dataset limitations, only {min_label_count*self.
|
352
|
+
f"Because of dataset limitations, only {min_label_count*self._num_classes} samples "
|
347
353
|
f"will be returned, instead of the desired {self.size}."
|
348
354
|
)
|
349
355
|
|
350
|
-
all_indices = np.empty(shape=(self.
|
351
|
-
for i, label in enumerate(self.
|
352
|
-
all_indices[i] =
|
356
|
+
all_indices: NDArray[np.int_] = np.empty(shape=(self._num_classes, self._per_class_count), dtype=np.int_)
|
357
|
+
for i, label in enumerate(self._class_set):
|
358
|
+
all_indices[i] = _subselect(label_dict[label], self._per_class_count, self.from_back)
|
353
359
|
self.data = np.vstack(self.data[all_indices.T]) # type: ignore
|
354
360
|
self.targets = np.hstack(self.targets[all_indices.T]) # type: ignore
|
355
361
|
|
@@ -370,7 +376,7 @@ class MNIST(Dataset):
|
|
370
376
|
if self.flatten and self.channels is None:
|
371
377
|
self.data = self.data.reshape(self.data.shape[0], -1)
|
372
378
|
|
373
|
-
def __getitem__(self, index: int) -> tuple[NDArray, int]:
|
379
|
+
def __getitem__(self, index: int) -> tuple[NDArray[np.float64], int]:
|
374
380
|
"""
|
375
381
|
Args:
|
376
382
|
index (int): Index
|
@@ -0,0 +1,138 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
__all__ = ["AriaAutoencoder", "Encoder", "Decoder"]
|
4
|
+
|
5
|
+
from typing import Any
|
6
|
+
|
7
|
+
import torch.nn as nn
|
8
|
+
|
9
|
+
|
10
|
+
class AriaAutoencoder(nn.Module):
|
11
|
+
"""
|
12
|
+
An autoencoder model with a separate encoder and decoder.
|
13
|
+
|
14
|
+
Parameters
|
15
|
+
----------
|
16
|
+
channels : int, default 3
|
17
|
+
Number of input channels
|
18
|
+
"""
|
19
|
+
|
20
|
+
def __init__(self, channels: int = 3) -> None:
|
21
|
+
super().__init__()
|
22
|
+
self.encoder: Encoder = Encoder(channels)
|
23
|
+
self.decoder: Decoder = Decoder(channels)
|
24
|
+
|
25
|
+
def forward(self, x: Any) -> Any:
|
26
|
+
"""
|
27
|
+
Perform a forward pass through the encoder and decoder.
|
28
|
+
|
29
|
+
Parameters
|
30
|
+
----------
|
31
|
+
x : torch.Tensor
|
32
|
+
Input tensor
|
33
|
+
|
34
|
+
Returns
|
35
|
+
-------
|
36
|
+
torch.Tensor
|
37
|
+
The reconstructed output tensor.
|
38
|
+
"""
|
39
|
+
x = self.encoder(x)
|
40
|
+
x = self.decoder(x)
|
41
|
+
return x
|
42
|
+
|
43
|
+
def encode(self, x: Any) -> Any:
|
44
|
+
"""
|
45
|
+
Encode the input tensor using the encoder.
|
46
|
+
|
47
|
+
Parameters
|
48
|
+
----------
|
49
|
+
x : torch.Tensor
|
50
|
+
Input tensor
|
51
|
+
|
52
|
+
Returns
|
53
|
+
-------
|
54
|
+
torch.Tensor
|
55
|
+
The encoded representation of the input tensor.
|
56
|
+
"""
|
57
|
+
return self.encoder(x)
|
58
|
+
|
59
|
+
|
60
|
+
class Encoder(nn.Module):
|
61
|
+
"""
|
62
|
+
A simple encoder to be used in an autoencoder model.
|
63
|
+
|
64
|
+
This is the encoder used by the AriaAutoencoder model.
|
65
|
+
|
66
|
+
Parameters
|
67
|
+
----------
|
68
|
+
channels : int, default 3
|
69
|
+
Number of input channels
|
70
|
+
"""
|
71
|
+
|
72
|
+
def __init__(self, channels: int = 3) -> None:
|
73
|
+
super().__init__()
|
74
|
+
self.encoder: nn.Sequential = nn.Sequential(
|
75
|
+
nn.Conv2d(channels, 256, 2, stride=1, padding=1),
|
76
|
+
nn.ReLU(),
|
77
|
+
nn.MaxPool2d(2),
|
78
|
+
nn.Conv2d(256, 128, 2, stride=1, padding=1),
|
79
|
+
nn.ReLU(),
|
80
|
+
nn.MaxPool2d(2),
|
81
|
+
nn.Conv2d(128, 64, 2, stride=1),
|
82
|
+
)
|
83
|
+
|
84
|
+
def forward(self, x: Any) -> Any:
|
85
|
+
"""
|
86
|
+
Perform a forward pass through the encoder.
|
87
|
+
|
88
|
+
Parameters
|
89
|
+
----------
|
90
|
+
x : torch.Tensor
|
91
|
+
Input tensor
|
92
|
+
|
93
|
+
Returns
|
94
|
+
-------
|
95
|
+
torch.Tensor
|
96
|
+
The encoded representation of the input tensor.
|
97
|
+
"""
|
98
|
+
return self.encoder(x)
|
99
|
+
|
100
|
+
|
101
|
+
class Decoder(nn.Module):
|
102
|
+
"""
|
103
|
+
A simple decoder to be used in an autoencoder model.
|
104
|
+
|
105
|
+
This is the decoder used by the AriaAutoencoder model.
|
106
|
+
|
107
|
+
Parameters
|
108
|
+
----------
|
109
|
+
channels : int
|
110
|
+
Number of output channels
|
111
|
+
"""
|
112
|
+
|
113
|
+
def __init__(self, channels: int) -> None:
|
114
|
+
super().__init__()
|
115
|
+
self.decoder: nn.Sequential = nn.Sequential(
|
116
|
+
nn.ConvTranspose2d(64, 128, 2, stride=1),
|
117
|
+
nn.ReLU(),
|
118
|
+
nn.ConvTranspose2d(128, 256, 2, stride=2),
|
119
|
+
nn.ReLU(),
|
120
|
+
nn.ConvTranspose2d(256, channels, 2, stride=2),
|
121
|
+
nn.Sigmoid(),
|
122
|
+
)
|
123
|
+
|
124
|
+
def forward(self, x: Any) -> Any:
|
125
|
+
"""
|
126
|
+
Perform a forward pass through the decoder.
|
127
|
+
|
128
|
+
Parameters
|
129
|
+
----------
|
130
|
+
x : torch.Tensor
|
131
|
+
The encoded tensor.
|
132
|
+
|
133
|
+
Returns
|
134
|
+
-------
|
135
|
+
torch.Tensor
|
136
|
+
The reconstructed output tensor.
|
137
|
+
"""
|
138
|
+
return self.decoder(x)
|
@@ -1,5 +1,7 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
|
+
__all__ = ["AETrainer"]
|
4
|
+
|
3
5
|
from typing import Any
|
4
6
|
|
5
7
|
import torch
|
@@ -17,7 +19,7 @@ def get_images_from_batch(batch: Any) -> Any:
|
|
17
19
|
|
18
20
|
class AETrainer:
|
19
21
|
"""
|
20
|
-
A class to train and evaluate an autoencoder model.
|
22
|
+
A class to train and evaluate an autoencoder<Autoencoder>` model.
|
21
23
|
|
22
24
|
Parameters
|
23
25
|
----------
|
@@ -38,13 +40,13 @@ class AETrainer:
|
|
38
40
|
):
|
39
41
|
if device == "auto":
|
40
42
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
41
|
-
self.device = device
|
42
|
-
self.model = model.to(device)
|
43
|
+
self.device: torch.device = torch.device(device)
|
44
|
+
self.model: nn.Module = model.to(device)
|
43
45
|
self.batch_size = batch_size
|
44
46
|
|
45
|
-
def train(self, dataset: Dataset, epochs: int = 25) -> list[float]:
|
47
|
+
def train(self, dataset: Dataset[Any], epochs: int = 25) -> list[float]:
|
46
48
|
"""
|
47
|
-
Basic image reconstruction training function for Autoencoder models
|
49
|
+
Basic image reconstruction training function for :term:`Autoencoder` models
|
48
50
|
|
49
51
|
Uses `torch.optim.Adam` and `torch.nn.MSELoss` as default hyperparameters
|
50
52
|
|
@@ -59,7 +61,7 @@ class AETrainer:
|
|
59
61
|
Returns
|
60
62
|
-------
|
61
63
|
List[float]
|
62
|
-
A list of average loss values for each epoch
|
64
|
+
A list of average loss values for each :term:`epoch<Epoch>`.
|
63
65
|
|
64
66
|
Note
|
65
67
|
----
|
@@ -101,9 +103,9 @@ class AETrainer:
|
|
101
103
|
return loss_history
|
102
104
|
|
103
105
|
@torch.no_grad
|
104
|
-
def eval(self, dataset: Dataset) -> float:
|
106
|
+
def eval(self, dataset: Dataset[Any]) -> float:
|
105
107
|
"""
|
106
|
-
Basic image reconstruction evaluation function for Autoencoder models
|
108
|
+
Basic image reconstruction evaluation function for :term:`autoencoder<Autoencoder>` models
|
107
109
|
|
108
110
|
Uses `torch.nn.MSELoss` as default loss function.
|
109
111
|
|
@@ -137,9 +139,9 @@ class AETrainer:
|
|
137
139
|
return total_loss / len(dataloader)
|
138
140
|
|
139
141
|
@torch.no_grad
|
140
|
-
def encode(self, dataset: Dataset) -> torch.Tensor:
|
142
|
+
def encode(self, dataset: Dataset[Any]) -> torch.Tensor:
|
141
143
|
"""
|
142
|
-
Create image embeddings for the dataset using the model's encoder.
|
144
|
+
Create image :term:`embeddings<Embeddings>` for the dataset using the model's encoder.
|
143
145
|
|
144
146
|
If the model has an `encode` method, it will be used; otherwise,
|
145
147
|
`model.forward` will be used.
|
@@ -174,134 +176,3 @@ class AETrainer:
|
|
174
176
|
encodings = torch.vstack((encodings, embeddings)) if len(encodings) else embeddings
|
175
177
|
|
176
178
|
return encodings
|
177
|
-
|
178
|
-
|
179
|
-
class AriaAutoencoder(nn.Module):
|
180
|
-
"""
|
181
|
-
An autoencoder model with a separate encoder and decoder.
|
182
|
-
|
183
|
-
Parameters
|
184
|
-
----------
|
185
|
-
channels : int, default 3
|
186
|
-
Number of input channels
|
187
|
-
"""
|
188
|
-
|
189
|
-
def __init__(self, channels=3):
|
190
|
-
super().__init__()
|
191
|
-
self.encoder = Encoder(channels)
|
192
|
-
self.decoder = Decoder(channels)
|
193
|
-
|
194
|
-
def forward(self, x):
|
195
|
-
"""
|
196
|
-
Perform a forward pass through the encoder and decoder.
|
197
|
-
|
198
|
-
Parameters
|
199
|
-
----------
|
200
|
-
x : torch.Tensor
|
201
|
-
Input tensor
|
202
|
-
|
203
|
-
Returns
|
204
|
-
-------
|
205
|
-
torch.Tensor
|
206
|
-
The reconstructed output tensor.
|
207
|
-
"""
|
208
|
-
x = self.encoder(x)
|
209
|
-
x = self.decoder(x)
|
210
|
-
return x
|
211
|
-
|
212
|
-
def encode(self, x):
|
213
|
-
"""
|
214
|
-
Encode the input tensor using the encoder.
|
215
|
-
|
216
|
-
Parameters
|
217
|
-
----------
|
218
|
-
x : torch.Tensor
|
219
|
-
Input tensor
|
220
|
-
|
221
|
-
Returns
|
222
|
-
-------
|
223
|
-
torch.Tensor
|
224
|
-
The encoded representation of the input tensor.
|
225
|
-
"""
|
226
|
-
return self.encoder(x)
|
227
|
-
|
228
|
-
|
229
|
-
class Encoder(nn.Module):
|
230
|
-
"""
|
231
|
-
A simple encoder to be used in an autoencoder model.
|
232
|
-
|
233
|
-
This is the encoder used by the AriaAutoencoder model.
|
234
|
-
|
235
|
-
Parameters
|
236
|
-
----------
|
237
|
-
channels : int, default 3
|
238
|
-
Number of input channels
|
239
|
-
"""
|
240
|
-
|
241
|
-
def __init__(self, channels=3):
|
242
|
-
super().__init__()
|
243
|
-
self.encoder = nn.Sequential(
|
244
|
-
nn.Conv2d(channels, 256, 2, stride=1, padding=1),
|
245
|
-
nn.ReLU(),
|
246
|
-
nn.MaxPool2d(2),
|
247
|
-
nn.Conv2d(256, 128, 2, stride=1, padding=1),
|
248
|
-
nn.ReLU(),
|
249
|
-
nn.MaxPool2d(2),
|
250
|
-
nn.Conv2d(128, 64, 2, stride=1),
|
251
|
-
)
|
252
|
-
|
253
|
-
def forward(self, x):
|
254
|
-
"""
|
255
|
-
Perform a forward pass through the encoder.
|
256
|
-
|
257
|
-
Parameters
|
258
|
-
----------
|
259
|
-
x : torch.Tensor
|
260
|
-
Input tensor
|
261
|
-
|
262
|
-
Returns
|
263
|
-
-------
|
264
|
-
torch.Tensor
|
265
|
-
The encoded representation of the input tensor.
|
266
|
-
"""
|
267
|
-
return self.encoder(x)
|
268
|
-
|
269
|
-
|
270
|
-
class Decoder(nn.Module):
|
271
|
-
"""
|
272
|
-
A simple decoder to be used in an autoencoder model.
|
273
|
-
|
274
|
-
This is the decoder used by the AriaAutoencoder model.
|
275
|
-
|
276
|
-
Parameters
|
277
|
-
----------
|
278
|
-
channels : int
|
279
|
-
Number of output channels
|
280
|
-
"""
|
281
|
-
|
282
|
-
def __init__(self, channels):
|
283
|
-
super().__init__()
|
284
|
-
self.decoder = nn.Sequential(
|
285
|
-
nn.ConvTranspose2d(64, 128, 2, stride=1),
|
286
|
-
nn.ReLU(),
|
287
|
-
nn.ConvTranspose2d(128, 256, 2, stride=2),
|
288
|
-
nn.ReLU(),
|
289
|
-
nn.ConvTranspose2d(256, channels, 2, stride=2),
|
290
|
-
nn.Sigmoid(),
|
291
|
-
)
|
292
|
-
|
293
|
-
def forward(self, x):
|
294
|
-
"""
|
295
|
-
Perform a forward pass through the decoder.
|
296
|
-
|
297
|
-
Parameters
|
298
|
-
----------
|
299
|
-
x : torch.Tensor
|
300
|
-
The encoded tensor.
|
301
|
-
|
302
|
-
Returns
|
303
|
-
-------
|
304
|
-
torch.Tensor
|
305
|
-
The reconstructed output tensor.
|
306
|
-
"""
|
307
|
-
return self.decoder(x)
|
@@ -1,12 +1,14 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
|
+
__all__ = ["read_dataset"]
|
4
|
+
|
3
5
|
from collections import defaultdict
|
4
6
|
from typing import Any
|
5
7
|
|
6
8
|
from torch.utils.data import Dataset
|
7
9
|
|
8
10
|
|
9
|
-
def read_dataset(dataset: Dataset) -> list[list[Any]]:
|
11
|
+
def read_dataset(dataset: Dataset[Any]) -> list[list[Any]]:
|
10
12
|
"""
|
11
13
|
Extract information from a dataset at each index into individual lists of each information position
|
12
14
|
|
dataeval/workflows/__init__.py
CHANGED
@@ -5,6 +5,6 @@ Workflows perform a sequence of actions to analyze the dataset and make predicti
|
|
5
5
|
from dataeval import _IS_TORCH_AVAILABLE
|
6
6
|
|
7
7
|
if _IS_TORCH_AVAILABLE: # pragma: no cover
|
8
|
-
from dataeval.
|
8
|
+
from dataeval.workflows.sufficiency import Sufficiency, SufficiencyOutput
|
9
9
|
|
10
10
|
__all__ = ["Sufficiency", "SufficiencyOutput"]
|