dataeval 0.72.0__py3-none-any.whl → 0.72.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (90) hide show
  1. dataeval/__init__.py +4 -4
  2. dataeval/detectors/__init__.py +4 -3
  3. dataeval/detectors/drift/__init__.py +10 -11
  4. dataeval/{_internal/detectors → detectors}/drift/base.py +51 -102
  5. dataeval/{_internal/detectors → detectors}/drift/cvm.py +9 -8
  6. dataeval/{_internal/detectors → detectors}/drift/ks.py +11 -10
  7. dataeval/{_internal/detectors → detectors}/drift/mmd.py +33 -34
  8. dataeval/{_internal/detectors → detectors}/drift/torch.py +15 -13
  9. dataeval/{_internal/detectors → detectors}/drift/uncertainty.py +12 -9
  10. dataeval/detectors/drift/updates.py +61 -0
  11. dataeval/detectors/linters/__init__.py +3 -3
  12. dataeval/{_internal/detectors → detectors/linters}/clusterer.py +47 -45
  13. dataeval/{_internal/detectors → detectors/linters}/duplicates.py +20 -10
  14. dataeval/{_internal/detectors → detectors/linters}/merged_stats.py +3 -1
  15. dataeval/{_internal/detectors → detectors/linters}/outliers.py +19 -26
  16. dataeval/detectors/ood/__init__.py +8 -16
  17. dataeval/{_internal/detectors → detectors}/ood/ae.py +9 -9
  18. dataeval/{_internal/detectors → detectors}/ood/aegmm.py +10 -30
  19. dataeval/{_internal/detectors → detectors}/ood/base.py +27 -21
  20. dataeval/{_internal/detectors → detectors}/ood/llr.py +27 -23
  21. dataeval/detectors/ood/metadata_ks_compare.py +99 -0
  22. dataeval/detectors/ood/metadata_least_likely.py +119 -0
  23. dataeval/detectors/ood/metadata_ood_mi.py +92 -0
  24. dataeval/{_internal/detectors → detectors}/ood/vae.py +11 -13
  25. dataeval/{_internal/detectors → detectors}/ood/vaegmm.py +10 -32
  26. dataeval/{_internal/interop.py → interop.py} +12 -7
  27. dataeval/metrics/__init__.py +1 -1
  28. dataeval/metrics/bias/__init__.py +4 -4
  29. dataeval/{_internal/metrics → metrics/bias}/balance.py +70 -4
  30. dataeval/{_internal/metrics → metrics/bias}/coverage.py +10 -8
  31. dataeval/{_internal/metrics → metrics/bias}/diversity.py +54 -20
  32. dataeval/metrics/bias/metadata.py +275 -0
  33. dataeval/{_internal/metrics → metrics/bias}/parity.py +21 -17
  34. dataeval/metrics/estimators/__init__.py +3 -3
  35. dataeval/{_internal/metrics → metrics/estimators}/ber.py +31 -28
  36. dataeval/{_internal/metrics → metrics/estimators}/divergence.py +15 -16
  37. dataeval/{_internal/metrics → metrics/estimators}/uap.py +8 -6
  38. dataeval/metrics/stats/__init__.py +7 -7
  39. dataeval/{_internal/metrics → metrics}/stats/base.py +66 -40
  40. dataeval/{_internal/metrics → metrics}/stats/boxratiostats.py +19 -15
  41. dataeval/{_internal/metrics → metrics}/stats/datasetstats.py +19 -17
  42. dataeval/{_internal/metrics → metrics}/stats/dimensionstats.py +12 -10
  43. dataeval/metrics/stats/hashstats.py +156 -0
  44. dataeval/{_internal/metrics → metrics}/stats/labelstats.py +8 -6
  45. dataeval/{_internal/metrics → metrics}/stats/pixelstats.py +12 -11
  46. dataeval/{_internal/metrics → metrics}/stats/visualstats.py +14 -13
  47. dataeval/{_internal/output.py → output.py} +26 -6
  48. dataeval/utils/__init__.py +8 -4
  49. dataeval/utils/image.py +71 -0
  50. dataeval/utils/shared.py +151 -0
  51. dataeval/utils/split_dataset.py +486 -0
  52. dataeval/utils/tensorflow/__init__.py +9 -7
  53. dataeval/{_internal/models/tensorflow → utils/tensorflow/_internal}/autoencoder.py +64 -68
  54. dataeval/{_internal/models/tensorflow/losses.py → utils/tensorflow/_internal/loss.py} +10 -9
  55. dataeval/{_internal/models/tensorflow → utils/tensorflow/_internal}/pixelcnn.py +18 -22
  56. dataeval/{_internal/models/tensorflow → utils/tensorflow/_internal}/trainer.py +3 -1
  57. dataeval/{_internal/models/tensorflow → utils/tensorflow/_internal}/utils.py +18 -18
  58. dataeval/utils/tensorflow/loss/__init__.py +6 -2
  59. dataeval/utils/torch/__init__.py +7 -3
  60. dataeval/{_internal/models/pytorch → utils/torch}/blocks.py +19 -14
  61. dataeval/{_internal → utils/torch}/datasets.py +49 -43
  62. dataeval/utils/torch/models.py +138 -0
  63. dataeval/{_internal/models/pytorch/autoencoder.py → utils/torch/trainer.py} +12 -141
  64. dataeval/{_internal → utils/torch}/utils.py +3 -1
  65. dataeval/workflows/__init__.py +1 -1
  66. dataeval/{_internal/workflows → workflows}/sufficiency.py +42 -37
  67. {dataeval-0.72.0.dist-info → dataeval-0.72.2.dist-info}/METADATA +7 -5
  68. dataeval-0.72.2.dist-info/RECORD +72 -0
  69. dataeval/_internal/detectors/__init__.py +0 -0
  70. dataeval/_internal/detectors/drift/__init__.py +0 -0
  71. dataeval/_internal/detectors/ood/__init__.py +0 -0
  72. dataeval/_internal/metrics/__init__.py +0 -0
  73. dataeval/_internal/metrics/stats/hashstats.py +0 -75
  74. dataeval/_internal/metrics/utils.py +0 -447
  75. dataeval/_internal/models/__init__.py +0 -0
  76. dataeval/_internal/models/pytorch/__init__.py +0 -0
  77. dataeval/_internal/models/pytorch/utils.py +0 -67
  78. dataeval/_internal/models/tensorflow/__init__.py +0 -0
  79. dataeval/_internal/workflows/__init__.py +0 -0
  80. dataeval/detectors/drift/kernels/__init__.py +0 -10
  81. dataeval/detectors/drift/updates/__init__.py +0 -7
  82. dataeval/utils/tensorflow/models/__init__.py +0 -9
  83. dataeval/utils/tensorflow/recon/__init__.py +0 -3
  84. dataeval/utils/torch/datasets/__init__.py +0 -12
  85. dataeval/utils/torch/models/__init__.py +0 -11
  86. dataeval/utils/torch/trainer/__init__.py +0 -7
  87. dataeval-0.72.0.dist-info/RECORD +0 -80
  88. /dataeval/{_internal/models/tensorflow → utils/tensorflow/_internal}/gmm.py +0 -0
  89. {dataeval-0.72.0.dist-info → dataeval-0.72.2.dist-info}/LICENSE.txt +0 -0
  90. {dataeval-0.72.0.dist-info → dataeval-0.72.2.dist-info}/WHEEL +0 -0
@@ -1,3 +1,7 @@
1
+ from typing import Any
2
+
3
+ __all__ = []
4
+
1
5
  import torch.nn as nn
2
6
 
3
7
 
@@ -8,21 +12,22 @@ class Conv(nn.Module):
8
12
 
9
13
  def __init__(
10
14
  self,
11
- in_channels,
12
- out_channels,
13
- k=1,
14
- s=1,
15
- p=0,
16
- activation="relu",
17
- norm="instance",
18
- ):
15
+ in_channels: int,
16
+ out_channels: int,
17
+ k: int = 1,
18
+ s: int = 1,
19
+ p: int = 0,
20
+ activation: str = "relu",
21
+ norm: str = "instance",
22
+ ) -> None:
19
23
  super().__init__()
20
- conv = nn.Conv2d(in_channels, out_channels, kernel_size=k, stride=s, padding=p)
21
- norm = self.get_norm_func(norm=norm, out_channels=out_channels)
22
- act = self.get_activation_func(activation=activation)
23
- self.module = nn.Sequential(conv, norm, act)
24
+ self.module: nn.Sequential = nn.Sequential(
25
+ nn.Conv2d(in_channels, out_channels, kernel_size=k, stride=s, padding=p),
26
+ self.get_norm_func(norm=norm, out_channels=out_channels),
27
+ self.get_activation_func(activation=activation),
28
+ )
24
29
 
25
- def get_norm_func(self, norm: str, out_channels) -> nn.Module:
30
+ def get_norm_func(self, norm: str, out_channels: int) -> nn.Module:
26
31
  if norm == "batch":
27
32
  return nn.BatchNorm2d(out_channels)
28
33
  if norm == "instance":
@@ -42,5 +47,5 @@ class Conv(nn.Module):
42
47
  return nn.Tanh()
43
48
  return nn.Identity()
44
49
 
45
- def forward(self, x):
50
+ def forward(self, x: Any) -> Any:
46
51
  return self.module(x)
@@ -1,5 +1,7 @@
1
1
  from __future__ import annotations
2
2
 
3
+ __all__ = ["MNIST", "CIFAR10", "VOCDetection"]
4
+
3
5
  import hashlib
4
6
  import os
5
7
  import zipfile
@@ -11,7 +13,7 @@ import numpy as np
11
13
  import requests
12
14
  from numpy.typing import NDArray
13
15
  from torch.utils.data import Dataset
14
- from torchvision.datasets import CIFAR10, VOCDetection # noqa: F401
16
+ from torchvision.datasets import CIFAR10, VOCDetection
15
17
 
16
18
  ClassStringMap = Literal["zero", "one", "two", "three", "four", "five", "six", "seven", "eight", "nine"]
17
19
  TClassMap = TypeVar("TClassMap", ClassStringMap, int, list[ClassStringMap], list[int])
@@ -50,6 +52,7 @@ def _get_file(
50
52
  file_hash: str | None = None,
51
53
  verbose: bool = True,
52
54
  md5: bool = False,
55
+ timeout: int = 60,
53
56
  ):
54
57
  fpath = os.path.join(root, fname)
55
58
  download = True
@@ -64,16 +67,16 @@ def _get_file(
64
67
  try:
65
68
  error_msg = "URL fetch failure on {}: {} -- {}"
66
69
  try:
67
- with requests.get(origin, stream=True, timeout=60) as r:
70
+ with requests.get(origin, stream=True, timeout=timeout) as r:
68
71
  r.raise_for_status()
69
72
  with open(fpath, "wb") as f:
70
73
  for chunk in r.iter_content(chunk_size=8192):
71
74
  if chunk:
72
75
  f.write(chunk)
73
76
  except requests.exceptions.HTTPError as e:
74
- raise Exception(f"{error_msg.format(origin, e.response.status_code)} -- {e.response.reason}") from e
77
+ raise RuntimeError(f"{error_msg.format(origin, e.response.status_code, e.response.reason)}") from e
75
78
  except requests.exceptions.RequestException as e:
76
- raise Exception(f"{error_msg.format(origin, 'Unknown error')} -- {str(e)}") from e
79
+ raise ValueError(f"{error_msg.format(origin, 'Unknown error', str(e))}") from e
77
80
  except (Exception, KeyboardInterrupt):
78
81
  if os.path.exists(fpath):
79
82
  os.remove(fpath)
@@ -89,7 +92,7 @@ def _get_file(
89
92
  return fpath
90
93
 
91
94
 
92
- def check_exists(
95
+ def _check_exists(
93
96
  folder: str | Path,
94
97
  url: str,
95
98
  root: str | Path,
@@ -103,7 +106,7 @@ def check_exists(
103
106
  location = str(folder)
104
107
  if not os.path.exists(folder):
105
108
  if download:
106
- location = download_dataset(url, root, fname, file_hash, verbose, md5)
109
+ location = _download_dataset(url, root, fname, file_hash, verbose, md5)
107
110
  else:
108
111
  raise RuntimeError("Dataset not found. You can use download=True to download it")
109
112
  else:
@@ -112,7 +115,7 @@ def check_exists(
112
115
  return location
113
116
 
114
117
 
115
- def download_dataset(
118
+ def _download_dataset(
116
119
  url: str, root: str | Path, fname: str, file_hash: str, verbose: bool = True, md5: bool = False
117
120
  ) -> str:
118
121
  """Code to download mnist and corruptions, originates from tensorflow_datasets (tfds):
@@ -131,11 +134,11 @@ def download_dataset(
131
134
  md5=md5,
132
135
  )
133
136
  if md5:
134
- folder = extract_archive(fpath, root, remove_finished=True)
137
+ folder = _extract_archive(fpath, root, remove_finished=True)
135
138
  return folder
136
139
 
137
140
 
138
- def extract_archive(
141
+ def _extract_archive(
139
142
  from_path: str | Path,
140
143
  to_path: str | Path | None = None,
141
144
  remove_finished: bool = False,
@@ -163,13 +166,13 @@ def extract_archive(
163
166
  return str(to_path)
164
167
 
165
168
 
166
- def subselect(arr: NDArray, count: int, from_back: bool = False):
169
+ def _subselect(arr: NDArray, count: int, from_back: bool = False):
167
170
  if from_back:
168
171
  return arr[-count:]
169
172
  return arr[:count]
170
173
 
171
174
 
172
- class MNIST(Dataset):
175
+ class MNIST(Dataset[tuple[NDArray[np.float64], int]]):
173
176
  """MNIST Dataset and Corruptions.
174
177
 
175
178
  Args:
@@ -185,7 +188,7 @@ class MNIST(Dataset):
185
188
  unit_interval : bool, default False
186
189
  Shift the data values to the unit interval [0-1].
187
190
  dtype : type | None, default None
188
- Change the numpy dtype - data is loaded as np.uint8
191
+ Change the :term:`NumPy` dtype - data is loaded as np.uint8
189
192
  channels : Literal['channels_first' | 'channels_last'] | None, default None
190
193
  Location of channel axis if desired, default has no channels (N, 28, 28)
191
194
  flatten : bool, default False
@@ -211,17 +214,17 @@ class MNIST(Dataset):
211
214
  If True, outputs print statements.
212
215
  """
213
216
 
214
- mirror = [
217
+ _mirrors: tuple[str, ...] = (
215
218
  "https://storage.googleapis.com/tensorflow/tf-keras-datasets/",
216
219
  "https://zenodo.org/record/3239543/files/",
217
- ]
220
+ )
218
221
 
219
- resources = [
222
+ _resources: tuple[tuple[str, str], ...] = (
220
223
  ("mnist.npz", "731c5ac602752760c8e48fbffcf8c3b850d9dc2a2aedcf2cc48468fc17b673d1"),
221
224
  ("mnist_c.zip", "4b34b33045869ee6d424616cd3a65da3"),
222
- ]
225
+ )
223
226
 
224
- class_dict = {
227
+ class_dict: dict[str, int] = {
225
228
  "zero": 0,
226
229
  "one": 1,
227
230
  "two": 2,
@@ -267,43 +270,46 @@ class MNIST(Dataset):
267
270
  self.randomize = randomize
268
271
  self.from_back = slice_back
269
272
  self.verbose = verbose
273
+ self.data: NDArray[np.float64]
274
+ self.targets: NDArray[np.int_]
275
+ self.size: int
270
276
 
271
- self.class_set = []
277
+ self._class_set = []
272
278
  if classes is not None:
273
279
  if not isinstance(classes, list):
274
280
  classes = [classes] # type: ignore
275
281
 
276
282
  for val in classes: # type: ignore
277
283
  if isinstance(val, int) and 0 <= val < 10:
278
- self.class_set.append(val)
284
+ self._class_set.append(val)
279
285
  elif isinstance(val, str):
280
- self.class_set.append(self.class_dict[val])
281
- self.class_set = set(self.class_set)
286
+ self._class_set.append(self.class_dict[val])
287
+ self._class_set = set(self._class_set)
282
288
 
283
- if not self.class_set:
284
- self.class_set = set(self.class_dict.values())
289
+ if not self._class_set:
290
+ self._class_set = set(self.class_dict.values())
285
291
 
286
- self.num_classes = len(self.class_set)
292
+ self._num_classes = len(self._class_set)
287
293
 
288
294
  if self.corruption is None:
289
- file_resource = self.resources[0]
290
- mirror = self.mirror[0]
295
+ file_resource = self._resources[0]
296
+ mirror = self._mirrors[0]
291
297
  md5 = False
292
298
  else:
293
299
  if self.corruption == "identity" and verbose:
294
300
  print("Identity is not a corrupted dataset but the original MNIST dataset.")
295
- file_resource = self.resources[1]
296
- mirror = self.mirror[1]
301
+ file_resource = self._resources[1]
302
+ mirror = self._mirrors[1]
297
303
  md5 = True
298
- check_exists(self.mnist_folder, mirror, self.root, file_resource[0], file_resource[1], download, verbose, md5)
304
+ _check_exists(self.mnist_folder, mirror, self.root, file_resource[0], file_resource[1], download, verbose, md5)
299
305
 
300
306
  self.data, self.targets = self._load_data()
301
307
 
302
308
  self._augmentations()
303
309
 
304
- def _load_data(self):
310
+ def _load_data(self) -> tuple[NDArray[np.float64], NDArray[np.int64]]:
305
311
  if self.corruption is None:
306
- image_file = self.resources[0][0]
312
+ image_file = self._resources[0][0]
307
313
  data, targets = self._read_normal_file(os.path.join(self.mnist_folder, image_file))
308
314
  else:
309
315
  image_file = f"{'train' if self.train else 'test'}_images.npy"
@@ -329,27 +335,27 @@ class MNIST(Dataset):
329
335
  self.data = self.data[shuffled_indices]
330
336
  self.targets = self.targets[shuffled_indices]
331
337
 
332
- if not self.balance and self.num_classes > self.size:
338
+ if not self.balance and self._num_classes > self.size:
333
339
  if self.size > 0:
334
- self.data = subselect(self.data, self.size, self.from_back)
335
- self.targets = subselect(self.targets, self.size, self.from_back)
340
+ self.data = _subselect(self.data, self.size, self.from_back)
341
+ self.targets = _subselect(self.targets, self.size, self.from_back)
336
342
  else:
337
- label_dict = {label: np.where(self.targets == label)[0] for label in self.class_set}
343
+ label_dict = {label: np.where(self.targets == label)[0] for label in self._class_set}
338
344
  min_label_count = min(len(indices) for indices in label_dict.values())
339
345
 
340
- self.per_class_count = int(np.ceil(self.size / self.num_classes)) if self.size > 0 else min_label_count
346
+ self._per_class_count = int(np.ceil(self.size / self._num_classes)) if self.size > 0 else min_label_count
341
347
 
342
- if self.per_class_count > min_label_count:
343
- self.per_class_count = min_label_count
348
+ if self._per_class_count > min_label_count:
349
+ self._per_class_count = min_label_count
344
350
  if not self.balance and self.verbose:
345
351
  warn(
346
- f"Because of dataset limitations, only {min_label_count*self.num_classes} samples "
352
+ f"Because of dataset limitations, only {min_label_count*self._num_classes} samples "
347
353
  f"will be returned, instead of the desired {self.size}."
348
354
  )
349
355
 
350
- all_indices = np.empty(shape=(self.num_classes, self.per_class_count), dtype=int)
351
- for i, label in enumerate(self.class_set):
352
- all_indices[i] = subselect(label_dict[label], self.per_class_count, self.from_back)
356
+ all_indices: NDArray[np.int_] = np.empty(shape=(self._num_classes, self._per_class_count), dtype=np.int_)
357
+ for i, label in enumerate(self._class_set):
358
+ all_indices[i] = _subselect(label_dict[label], self._per_class_count, self.from_back)
353
359
  self.data = np.vstack(self.data[all_indices.T]) # type: ignore
354
360
  self.targets = np.hstack(self.targets[all_indices.T]) # type: ignore
355
361
 
@@ -370,7 +376,7 @@ class MNIST(Dataset):
370
376
  if self.flatten and self.channels is None:
371
377
  self.data = self.data.reshape(self.data.shape[0], -1)
372
378
 
373
- def __getitem__(self, index: int) -> tuple[NDArray, int]:
379
+ def __getitem__(self, index: int) -> tuple[NDArray[np.float64], int]:
374
380
  """
375
381
  Args:
376
382
  index (int): Index
@@ -0,0 +1,138 @@
1
+ from __future__ import annotations
2
+
3
+ __all__ = ["AriaAutoencoder", "Encoder", "Decoder"]
4
+
5
+ from typing import Any
6
+
7
+ import torch.nn as nn
8
+
9
+
10
+ class AriaAutoencoder(nn.Module):
11
+ """
12
+ An autoencoder model with a separate encoder and decoder.
13
+
14
+ Parameters
15
+ ----------
16
+ channels : int, default 3
17
+ Number of input channels
18
+ """
19
+
20
+ def __init__(self, channels: int = 3) -> None:
21
+ super().__init__()
22
+ self.encoder: Encoder = Encoder(channels)
23
+ self.decoder: Decoder = Decoder(channels)
24
+
25
+ def forward(self, x: Any) -> Any:
26
+ """
27
+ Perform a forward pass through the encoder and decoder.
28
+
29
+ Parameters
30
+ ----------
31
+ x : torch.Tensor
32
+ Input tensor
33
+
34
+ Returns
35
+ -------
36
+ torch.Tensor
37
+ The reconstructed output tensor.
38
+ """
39
+ x = self.encoder(x)
40
+ x = self.decoder(x)
41
+ return x
42
+
43
+ def encode(self, x: Any) -> Any:
44
+ """
45
+ Encode the input tensor using the encoder.
46
+
47
+ Parameters
48
+ ----------
49
+ x : torch.Tensor
50
+ Input tensor
51
+
52
+ Returns
53
+ -------
54
+ torch.Tensor
55
+ The encoded representation of the input tensor.
56
+ """
57
+ return self.encoder(x)
58
+
59
+
60
+ class Encoder(nn.Module):
61
+ """
62
+ A simple encoder to be used in an autoencoder model.
63
+
64
+ This is the encoder used by the AriaAutoencoder model.
65
+
66
+ Parameters
67
+ ----------
68
+ channels : int, default 3
69
+ Number of input channels
70
+ """
71
+
72
+ def __init__(self, channels: int = 3) -> None:
73
+ super().__init__()
74
+ self.encoder: nn.Sequential = nn.Sequential(
75
+ nn.Conv2d(channels, 256, 2, stride=1, padding=1),
76
+ nn.ReLU(),
77
+ nn.MaxPool2d(2),
78
+ nn.Conv2d(256, 128, 2, stride=1, padding=1),
79
+ nn.ReLU(),
80
+ nn.MaxPool2d(2),
81
+ nn.Conv2d(128, 64, 2, stride=1),
82
+ )
83
+
84
+ def forward(self, x: Any) -> Any:
85
+ """
86
+ Perform a forward pass through the encoder.
87
+
88
+ Parameters
89
+ ----------
90
+ x : torch.Tensor
91
+ Input tensor
92
+
93
+ Returns
94
+ -------
95
+ torch.Tensor
96
+ The encoded representation of the input tensor.
97
+ """
98
+ return self.encoder(x)
99
+
100
+
101
+ class Decoder(nn.Module):
102
+ """
103
+ A simple decoder to be used in an autoencoder model.
104
+
105
+ This is the decoder used by the AriaAutoencoder model.
106
+
107
+ Parameters
108
+ ----------
109
+ channels : int
110
+ Number of output channels
111
+ """
112
+
113
+ def __init__(self, channels: int) -> None:
114
+ super().__init__()
115
+ self.decoder: nn.Sequential = nn.Sequential(
116
+ nn.ConvTranspose2d(64, 128, 2, stride=1),
117
+ nn.ReLU(),
118
+ nn.ConvTranspose2d(128, 256, 2, stride=2),
119
+ nn.ReLU(),
120
+ nn.ConvTranspose2d(256, channels, 2, stride=2),
121
+ nn.Sigmoid(),
122
+ )
123
+
124
+ def forward(self, x: Any) -> Any:
125
+ """
126
+ Perform a forward pass through the decoder.
127
+
128
+ Parameters
129
+ ----------
130
+ x : torch.Tensor
131
+ The encoded tensor.
132
+
133
+ Returns
134
+ -------
135
+ torch.Tensor
136
+ The reconstructed output tensor.
137
+ """
138
+ return self.decoder(x)
@@ -1,5 +1,7 @@
1
1
  from __future__ import annotations
2
2
 
3
+ __all__ = ["AETrainer"]
4
+
3
5
  from typing import Any
4
6
 
5
7
  import torch
@@ -17,7 +19,7 @@ def get_images_from_batch(batch: Any) -> Any:
17
19
 
18
20
  class AETrainer:
19
21
  """
20
- A class to train and evaluate an autoencoder model.
22
+ A class to train and evaluate an autoencoder<Autoencoder>` model.
21
23
 
22
24
  Parameters
23
25
  ----------
@@ -38,13 +40,13 @@ class AETrainer:
38
40
  ):
39
41
  if device == "auto":
40
42
  device = "cuda" if torch.cuda.is_available() else "cpu"
41
- self.device = device
42
- self.model = model.to(device)
43
+ self.device: torch.device = torch.device(device)
44
+ self.model: nn.Module = model.to(device)
43
45
  self.batch_size = batch_size
44
46
 
45
- def train(self, dataset: Dataset, epochs: int = 25) -> list[float]:
47
+ def train(self, dataset: Dataset[Any], epochs: int = 25) -> list[float]:
46
48
  """
47
- Basic image reconstruction training function for Autoencoder models
49
+ Basic image reconstruction training function for :term:`Autoencoder` models
48
50
 
49
51
  Uses `torch.optim.Adam` and `torch.nn.MSELoss` as default hyperparameters
50
52
 
@@ -59,7 +61,7 @@ class AETrainer:
59
61
  Returns
60
62
  -------
61
63
  List[float]
62
- A list of average loss values for each epoch.
64
+ A list of average loss values for each :term:`epoch<Epoch>`.
63
65
 
64
66
  Note
65
67
  ----
@@ -101,9 +103,9 @@ class AETrainer:
101
103
  return loss_history
102
104
 
103
105
  @torch.no_grad
104
- def eval(self, dataset: Dataset) -> float:
106
+ def eval(self, dataset: Dataset[Any]) -> float:
105
107
  """
106
- Basic image reconstruction evaluation function for Autoencoder models
108
+ Basic image reconstruction evaluation function for :term:`autoencoder<Autoencoder>` models
107
109
 
108
110
  Uses `torch.nn.MSELoss` as default loss function.
109
111
 
@@ -137,9 +139,9 @@ class AETrainer:
137
139
  return total_loss / len(dataloader)
138
140
 
139
141
  @torch.no_grad
140
- def encode(self, dataset: Dataset) -> torch.Tensor:
142
+ def encode(self, dataset: Dataset[Any]) -> torch.Tensor:
141
143
  """
142
- Create image embeddings for the dataset using the model's encoder.
144
+ Create image :term:`embeddings<Embeddings>` for the dataset using the model's encoder.
143
145
 
144
146
  If the model has an `encode` method, it will be used; otherwise,
145
147
  `model.forward` will be used.
@@ -174,134 +176,3 @@ class AETrainer:
174
176
  encodings = torch.vstack((encodings, embeddings)) if len(encodings) else embeddings
175
177
 
176
178
  return encodings
177
-
178
-
179
- class AriaAutoencoder(nn.Module):
180
- """
181
- An autoencoder model with a separate encoder and decoder.
182
-
183
- Parameters
184
- ----------
185
- channels : int, default 3
186
- Number of input channels
187
- """
188
-
189
- def __init__(self, channels=3):
190
- super().__init__()
191
- self.encoder = Encoder(channels)
192
- self.decoder = Decoder(channels)
193
-
194
- def forward(self, x):
195
- """
196
- Perform a forward pass through the encoder and decoder.
197
-
198
- Parameters
199
- ----------
200
- x : torch.Tensor
201
- Input tensor
202
-
203
- Returns
204
- -------
205
- torch.Tensor
206
- The reconstructed output tensor.
207
- """
208
- x = self.encoder(x)
209
- x = self.decoder(x)
210
- return x
211
-
212
- def encode(self, x):
213
- """
214
- Encode the input tensor using the encoder.
215
-
216
- Parameters
217
- ----------
218
- x : torch.Tensor
219
- Input tensor
220
-
221
- Returns
222
- -------
223
- torch.Tensor
224
- The encoded representation of the input tensor.
225
- """
226
- return self.encoder(x)
227
-
228
-
229
- class Encoder(nn.Module):
230
- """
231
- A simple encoder to be used in an autoencoder model.
232
-
233
- This is the encoder used by the AriaAutoencoder model.
234
-
235
- Parameters
236
- ----------
237
- channels : int, default 3
238
- Number of input channels
239
- """
240
-
241
- def __init__(self, channels=3):
242
- super().__init__()
243
- self.encoder = nn.Sequential(
244
- nn.Conv2d(channels, 256, 2, stride=1, padding=1),
245
- nn.ReLU(),
246
- nn.MaxPool2d(2),
247
- nn.Conv2d(256, 128, 2, stride=1, padding=1),
248
- nn.ReLU(),
249
- nn.MaxPool2d(2),
250
- nn.Conv2d(128, 64, 2, stride=1),
251
- )
252
-
253
- def forward(self, x):
254
- """
255
- Perform a forward pass through the encoder.
256
-
257
- Parameters
258
- ----------
259
- x : torch.Tensor
260
- Input tensor
261
-
262
- Returns
263
- -------
264
- torch.Tensor
265
- The encoded representation of the input tensor.
266
- """
267
- return self.encoder(x)
268
-
269
-
270
- class Decoder(nn.Module):
271
- """
272
- A simple decoder to be used in an autoencoder model.
273
-
274
- This is the decoder used by the AriaAutoencoder model.
275
-
276
- Parameters
277
- ----------
278
- channels : int
279
- Number of output channels
280
- """
281
-
282
- def __init__(self, channels):
283
- super().__init__()
284
- self.decoder = nn.Sequential(
285
- nn.ConvTranspose2d(64, 128, 2, stride=1),
286
- nn.ReLU(),
287
- nn.ConvTranspose2d(128, 256, 2, stride=2),
288
- nn.ReLU(),
289
- nn.ConvTranspose2d(256, channels, 2, stride=2),
290
- nn.Sigmoid(),
291
- )
292
-
293
- def forward(self, x):
294
- """
295
- Perform a forward pass through the decoder.
296
-
297
- Parameters
298
- ----------
299
- x : torch.Tensor
300
- The encoded tensor.
301
-
302
- Returns
303
- -------
304
- torch.Tensor
305
- The reconstructed output tensor.
306
- """
307
- return self.decoder(x)
@@ -1,12 +1,14 @@
1
1
  from __future__ import annotations
2
2
 
3
+ __all__ = ["read_dataset"]
4
+
3
5
  from collections import defaultdict
4
6
  from typing import Any
5
7
 
6
8
  from torch.utils.data import Dataset
7
9
 
8
10
 
9
- def read_dataset(dataset: Dataset) -> list[list[Any]]:
11
+ def read_dataset(dataset: Dataset[Any]) -> list[list[Any]]:
10
12
  """
11
13
  Extract information from a dataset at each index into individual lists of each information position
12
14
 
@@ -5,6 +5,6 @@ Workflows perform a sequence of actions to analyze the dataset and make predicti
5
5
  from dataeval import _IS_TORCH_AVAILABLE
6
6
 
7
7
  if _IS_TORCH_AVAILABLE: # pragma: no cover
8
- from dataeval._internal.workflows.sufficiency import Sufficiency, SufficiencyOutput
8
+ from dataeval.workflows.sufficiency import Sufficiency, SufficiencyOutput
9
9
 
10
10
  __all__ = ["Sufficiency", "SufficiencyOutput"]