dataeval 0.72.1__py3-none-any.whl → 0.73.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (95) hide show
  1. dataeval/__init__.py +4 -4
  2. dataeval/detectors/__init__.py +4 -3
  3. dataeval/detectors/drift/__init__.py +9 -10
  4. dataeval/{_internal/detectors → detectors}/drift/base.py +39 -91
  5. dataeval/{_internal/detectors → detectors}/drift/cvm.py +4 -3
  6. dataeval/{_internal/detectors → detectors}/drift/ks.py +4 -3
  7. dataeval/{_internal/detectors → detectors}/drift/mmd.py +23 -25
  8. dataeval/{_internal/detectors → detectors}/drift/torch.py +13 -11
  9. dataeval/{_internal/detectors → detectors}/drift/uncertainty.py +7 -5
  10. dataeval/detectors/drift/updates.py +61 -0
  11. dataeval/detectors/linters/__init__.py +3 -3
  12. dataeval/{_internal/detectors → detectors/linters}/clusterer.py +41 -39
  13. dataeval/{_internal/detectors → detectors/linters}/duplicates.py +19 -9
  14. dataeval/{_internal/detectors → detectors/linters}/merged_stats.py +3 -1
  15. dataeval/{_internal/detectors → detectors/linters}/outliers.py +14 -21
  16. dataeval/detectors/ood/__init__.py +6 -6
  17. dataeval/{_internal/detectors → detectors}/ood/ae.py +20 -12
  18. dataeval/detectors/ood/aegmm.py +66 -0
  19. dataeval/{_internal/detectors → detectors}/ood/base.py +33 -21
  20. dataeval/{_internal/detectors → detectors}/ood/llr.py +43 -33
  21. dataeval/detectors/ood/metadata_ks_compare.py +99 -0
  22. dataeval/detectors/ood/metadata_least_likely.py +119 -0
  23. dataeval/detectors/ood/metadata_ood_mi.py +92 -0
  24. dataeval/{_internal/detectors → detectors}/ood/vae.py +23 -17
  25. dataeval/detectors/ood/vaegmm.py +75 -0
  26. dataeval/interop.py +56 -0
  27. dataeval/metrics/__init__.py +1 -1
  28. dataeval/metrics/bias/__init__.py +4 -4
  29. dataeval/{_internal/metrics → metrics/bias}/balance.py +75 -13
  30. dataeval/{_internal/metrics → metrics/bias}/coverage.py +41 -7
  31. dataeval/{_internal/metrics → metrics/bias}/diversity.py +75 -18
  32. dataeval/metrics/bias/metadata.py +358 -0
  33. dataeval/{_internal/metrics → metrics/bias}/parity.py +54 -44
  34. dataeval/metrics/estimators/__init__.py +3 -3
  35. dataeval/{_internal/metrics → metrics/estimators}/ber.py +25 -22
  36. dataeval/{_internal/metrics → metrics/estimators}/divergence.py +11 -12
  37. dataeval/{_internal/metrics → metrics/estimators}/uap.py +5 -3
  38. dataeval/metrics/stats/__init__.py +7 -7
  39. dataeval/{_internal/metrics → metrics}/stats/base.py +59 -35
  40. dataeval/{_internal/metrics → metrics}/stats/boxratiostats.py +18 -14
  41. dataeval/{_internal/metrics → metrics}/stats/datasetstats.py +18 -16
  42. dataeval/{_internal/metrics → metrics}/stats/dimensionstats.py +9 -7
  43. dataeval/metrics/stats/hashstats.py +156 -0
  44. dataeval/{_internal/metrics → metrics}/stats/labelstats.py +5 -3
  45. dataeval/{_internal/metrics → metrics}/stats/pixelstats.py +9 -8
  46. dataeval/{_internal/metrics → metrics}/stats/visualstats.py +10 -9
  47. dataeval/{_internal/output.py → output.py} +26 -6
  48. dataeval/utils/__init__.py +8 -3
  49. dataeval/utils/image.py +71 -0
  50. dataeval/utils/lazy.py +26 -0
  51. dataeval/utils/metadata.py +258 -0
  52. dataeval/utils/shared.py +151 -0
  53. dataeval/{_internal → utils}/split_dataset.py +98 -33
  54. dataeval/utils/tensorflow/__init__.py +7 -6
  55. dataeval/{_internal/models/tensorflow → utils/tensorflow/_internal}/gmm.py +8 -2
  56. dataeval/{_internal/models/tensorflow/losses.py → utils/tensorflow/_internal/loss.py} +28 -18
  57. dataeval/{_internal/models/tensorflow/pixelcnn.py → utils/tensorflow/_internal/models.py} +387 -97
  58. dataeval/{_internal/models/tensorflow → utils/tensorflow/_internal}/trainer.py +15 -6
  59. dataeval/{_internal/models/tensorflow → utils/tensorflow/_internal}/utils.py +84 -85
  60. dataeval/utils/tensorflow/loss/__init__.py +6 -2
  61. dataeval/utils/torch/__init__.py +7 -3
  62. dataeval/{_internal/models/pytorch → utils/torch}/blocks.py +19 -14
  63. dataeval/{_internal → utils/torch}/datasets.py +48 -42
  64. dataeval/utils/torch/models.py +138 -0
  65. dataeval/{_internal/models/pytorch/autoencoder.py → utils/torch/trainer.py} +7 -136
  66. dataeval/{_internal → utils/torch}/utils.py +3 -1
  67. dataeval/workflows/__init__.py +1 -1
  68. dataeval/{_internal/workflows → workflows}/sufficiency.py +39 -34
  69. {dataeval-0.72.1.dist-info → dataeval-0.73.0.dist-info}/METADATA +4 -3
  70. dataeval-0.73.0.dist-info/RECORD +73 -0
  71. dataeval/_internal/detectors/__init__.py +0 -0
  72. dataeval/_internal/detectors/drift/__init__.py +0 -0
  73. dataeval/_internal/detectors/ood/__init__.py +0 -0
  74. dataeval/_internal/detectors/ood/aegmm.py +0 -78
  75. dataeval/_internal/detectors/ood/vaegmm.py +0 -89
  76. dataeval/_internal/interop.py +0 -49
  77. dataeval/_internal/metrics/__init__.py +0 -0
  78. dataeval/_internal/metrics/stats/hashstats.py +0 -75
  79. dataeval/_internal/metrics/utils.py +0 -447
  80. dataeval/_internal/models/__init__.py +0 -0
  81. dataeval/_internal/models/pytorch/__init__.py +0 -0
  82. dataeval/_internal/models/pytorch/utils.py +0 -67
  83. dataeval/_internal/models/tensorflow/__init__.py +0 -0
  84. dataeval/_internal/models/tensorflow/autoencoder.py +0 -320
  85. dataeval/_internal/workflows/__init__.py +0 -0
  86. dataeval/detectors/drift/kernels/__init__.py +0 -10
  87. dataeval/detectors/drift/updates/__init__.py +0 -8
  88. dataeval/utils/tensorflow/models/__init__.py +0 -9
  89. dataeval/utils/tensorflow/recon/__init__.py +0 -3
  90. dataeval/utils/torch/datasets/__init__.py +0 -12
  91. dataeval/utils/torch/models/__init__.py +0 -11
  92. dataeval/utils/torch/trainer/__init__.py +0 -7
  93. dataeval-0.72.1.dist-info/RECORD +0 -81
  94. {dataeval-0.72.1.dist-info → dataeval-0.73.0.dist-info}/LICENSE.txt +0 -0
  95. {dataeval-0.72.1.dist-info → dataeval-0.73.0.dist-info}/WHEEL +0 -0
@@ -8,20 +8,27 @@ Licensed under Apache Software License (Apache 2.0)
8
8
 
9
9
  from __future__ import annotations
10
10
 
11
- from typing import Callable, Iterable, cast
11
+ from typing import TYPE_CHECKING, Callable, Iterable, cast
12
12
 
13
13
  import numpy as np
14
- import tensorflow as tf
15
- import tf_keras as keras
16
14
  from numpy.typing import NDArray
17
15
 
16
+ from dataeval.utils.lazy import lazyload
17
+
18
+ if TYPE_CHECKING:
19
+ import tensorflow as tf
20
+ import tf_keras as keras
21
+ else:
22
+ tf = lazyload("tensorflow")
23
+ keras = lazyload("tf_keras")
24
+
18
25
 
19
26
  def trainer(
20
27
  model: keras.Model,
21
28
  x_train: NDArray,
22
29
  y_train: NDArray | None = None,
23
30
  loss_fn: Callable[..., tf.Tensor] | None = None,
24
- optimizer: keras.optimizers.Optimizer = keras.optimizers.Adam,
31
+ optimizer: keras.optimizers.Optimizer | None = None,
25
32
  preprocess_fn: Callable[[tf.Tensor], tf.Tensor] | None = None,
26
33
  epochs: int = 20,
27
34
  reg_loss_fn: Callable[[keras.Model], tf.Tensor] = (lambda _: cast(tf.Tensor, tf.Variable(0, dtype=tf.float32))),
@@ -58,9 +65,11 @@ def trainer(
58
65
  Whether to print training progress.
59
66
  """
60
67
  loss_fn = loss_fn() if isinstance(loss_fn, type) else loss_fn
61
- optimizer = optimizer() if isinstance(optimizer, type) else optimizer
68
+ optimizer = keras.optimizers.Adam() if optimizer is None else optimizer
62
69
 
63
- train_data = x_train if y_train is None else (x_train, y_train)
70
+ train_data = (
71
+ x_train.astype(np.float32) if y_train is None else (x_train.astype(np.float32), y_train.astype(np.float32))
72
+ )
64
73
  dataset = tf.data.Dataset.from_tensor_slices(train_data)
65
74
  dataset = dataset.shuffle(buffer_size=buffer_size).batch(batch_size)
66
75
  n_minibatch = len(dataset)
@@ -9,25 +9,24 @@ Licensed under Apache Software License (Apache 2.0)
9
9
  from __future__ import annotations
10
10
 
11
11
  import math
12
- from typing import Callable, Union, cast
12
+ from typing import TYPE_CHECKING, Any, Callable, Literal, Union, cast
13
13
 
14
14
  import numpy as np
15
- import tensorflow as tf
16
- import tf_keras as keras
17
15
  from numpy.typing import NDArray
18
- from tensorflow._api.v2.nn import relu, softmax, tanh
19
- from tf_keras import Sequential
20
- from tf_keras.layers import (
21
- Conv2D,
22
- Conv2DTranspose,
23
- Dense,
24
- Flatten,
25
- InputLayer,
26
- Reshape,
27
- )
28
16
 
29
- from dataeval._internal.models.tensorflow.autoencoder import AE, AEGMM, VAE, VAEGMM
30
- from dataeval._internal.models.tensorflow.pixelcnn import PixelCNN
17
+ from dataeval.utils.lazy import lazyload
18
+
19
+ if TYPE_CHECKING:
20
+ import tensorflow as tf
21
+ import tensorflow._api.v2.nn as nn
22
+ import tf_keras as keras
23
+
24
+ import dataeval.utils.tensorflow._internal.models as tf_models
25
+ else:
26
+ tf = lazyload("tensorflow")
27
+ nn = lazyload("tensorflow._api.v2.nn")
28
+ keras = lazyload("tf_keras")
29
+ tf_models = lazyload("dataeval.utils.tensorflow._internal.models")
31
30
 
32
31
 
33
32
  def predict_batch(
@@ -95,47 +94,47 @@ def predict_batch(
95
94
  return out
96
95
 
97
96
 
98
- def _get_default_encoder_net(input_shape: tuple[int, int, int], encoding_dim: int):
99
- return Sequential(
97
+ def get_default_encoder_net(input_shape: tuple[int, int, int], encoding_dim: int):
98
+ return keras.Sequential(
100
99
  [
101
- InputLayer(input_shape=input_shape),
102
- Conv2D(64, 4, strides=2, padding="same", activation=relu),
103
- Conv2D(128, 4, strides=2, padding="same", activation=relu),
104
- Conv2D(512, 4, strides=2, padding="same", activation=relu),
105
- Flatten(),
106
- Dense(encoding_dim),
100
+ keras.layers.InputLayer(input_shape=input_shape),
101
+ keras.layers.Conv2D(64, 4, strides=2, padding="same", activation=nn.relu),
102
+ keras.layers.Conv2D(128, 4, strides=2, padding="same", activation=nn.relu),
103
+ keras.layers.Conv2D(512, 4, strides=2, padding="same", activation=nn.relu),
104
+ keras.layers.Flatten(),
105
+ keras.layers.Dense(encoding_dim),
107
106
  ]
108
107
  )
109
108
 
110
109
 
111
- def _get_default_decoder_net(input_shape: tuple[int, int, int], encoding_dim: int):
112
- return Sequential(
110
+ def get_default_decoder_net(input_shape: tuple[int, int, int], encoding_dim: int):
111
+ return keras.Sequential(
113
112
  [
114
- InputLayer(input_shape=(encoding_dim,)),
115
- Dense(4 * 4 * 128),
116
- Reshape(target_shape=(4, 4, 128)),
117
- Conv2DTranspose(256, 4, strides=2, padding="same", activation=relu),
118
- Conv2DTranspose(64, 4, strides=2, padding="same", activation=relu),
119
- Flatten(),
120
- Dense(math.prod(input_shape)),
121
- Reshape(target_shape=input_shape),
113
+ keras.layers.InputLayer(input_shape=(encoding_dim,)),
114
+ keras.layers.Dense(4 * 4 * 128),
115
+ keras.layers.Reshape(target_shape=(4, 4, 128)),
116
+ keras.layers.Conv2DTranspose(256, 4, strides=2, padding="same", activation=nn.relu),
117
+ keras.layers.Conv2DTranspose(64, 4, strides=2, padding="same", activation=nn.relu),
118
+ keras.layers.Flatten(),
119
+ keras.layers.Dense(math.prod(input_shape)),
120
+ keras.layers.Reshape(target_shape=input_shape),
122
121
  ]
123
122
  )
124
123
 
125
124
 
126
125
  def create_model(
127
- model_type: AE | AEGMM | PixelCNN | VAE | VAEGMM,
126
+ model_type: Literal["AE", "AEGMM", "PixelCNN", "VAE", "VAEGMM"],
128
127
  input_shape: tuple[int, int, int],
129
128
  encoding_dim: int | None = None,
130
129
  n_gmm: int | None = None,
131
130
  gmm_latent_dim: int | None = None,
132
- ):
131
+ ) -> Any:
133
132
  """
134
133
  Create a default model for the specified model type.
135
134
 
136
135
  Parameters
137
136
  ----------
138
- model_type : Union[AE, AEGMM, PixelCNN, VAE, VAEGMM]
137
+ model_type : Literal["AE", "AEGMM", "PixelCNN", "VAE", "VAEGMM"]
139
138
  The model type to create.
140
139
  input_shape : Tuple[int, int, int]
141
140
  The input shape of the data used.
@@ -148,93 +147,93 @@ def create_model(
148
147
  """
149
148
  input_dim = math.prod(input_shape)
150
149
  encoding_dim = int(math.pow(2, int(input_dim.bit_length() * 0.8)) if encoding_dim is None else encoding_dim)
151
- if model_type == AE:
152
- return AE(
153
- _get_default_encoder_net(input_shape, encoding_dim),
154
- _get_default_decoder_net(input_shape, encoding_dim),
150
+ if model_type == "AE":
151
+ return tf_models.AE(
152
+ get_default_encoder_net(input_shape, encoding_dim),
153
+ get_default_decoder_net(input_shape, encoding_dim),
155
154
  )
156
155
 
157
- if model_type == VAE:
158
- return VAE(
159
- _get_default_encoder_net(input_shape, encoding_dim),
160
- _get_default_decoder_net(input_shape, encoding_dim),
156
+ if model_type == "VAE":
157
+ return tf_models.VAE(
158
+ get_default_encoder_net(input_shape, encoding_dim),
159
+ get_default_decoder_net(input_shape, encoding_dim),
161
160
  encoding_dim,
162
161
  )
163
162
 
164
- if model_type == AEGMM:
163
+ if model_type == "AEGMM":
165
164
  n_gmm = 2 if n_gmm is None else n_gmm
166
165
  gmm_latent_dim = 1 if gmm_latent_dim is None else gmm_latent_dim
167
166
  # The outlier detector is an encoder/decoder architecture
168
- encoder_net = Sequential(
167
+ encoder_net = keras.Sequential(
169
168
  [
170
- Flatten(),
171
- InputLayer(input_shape=(input_dim,)),
172
- Dense(60, activation=tanh),
173
- Dense(30, activation=tanh),
174
- Dense(10, activation=tanh),
175
- Dense(gmm_latent_dim, activation=None),
169
+ keras.layers.Flatten(),
170
+ keras.layers.InputLayer(input_shape=(input_dim,)),
171
+ keras.layers.Dense(60, activation=nn.tanh),
172
+ keras.layers.Dense(30, activation=nn.tanh),
173
+ keras.layers.Dense(10, activation=nn.tanh),
174
+ keras.layers.Dense(gmm_latent_dim, activation=None),
176
175
  ]
177
176
  )
178
177
  # Here we define the decoder
179
- decoder_net = Sequential(
178
+ decoder_net = keras.Sequential(
180
179
  [
181
- InputLayer(input_shape=(gmm_latent_dim,)),
182
- Dense(10, activation=tanh),
183
- Dense(30, activation=tanh),
184
- Dense(60, activation=tanh),
185
- Dense(input_dim, activation=None),
186
- Reshape(target_shape=input_shape),
180
+ keras.layers.InputLayer(input_shape=(gmm_latent_dim,)),
181
+ keras.layers.Dense(10, activation=nn.tanh),
182
+ keras.layers.Dense(30, activation=nn.tanh),
183
+ keras.layers.Dense(60, activation=nn.tanh),
184
+ keras.layers.Dense(input_dim, activation=None),
185
+ keras.layers.Reshape(target_shape=input_shape),
187
186
  ]
188
187
  )
189
188
  # GMM autoencoders have a density network too
190
- gmm_density_net = Sequential(
189
+ gmm_density_net = keras.Sequential(
191
190
  [
192
- InputLayer(input_shape=(gmm_latent_dim + 2,)),
193
- Dense(10, activation=tanh),
194
- Dense(n_gmm, activation=softmax),
191
+ keras.layers.InputLayer(input_shape=(gmm_latent_dim + 2,)),
192
+ keras.layers.Dense(10, activation=nn.tanh),
193
+ keras.layers.Dense(n_gmm, activation=nn.softmax),
195
194
  ]
196
195
  )
197
- return AEGMM(
196
+ return tf_models.AEGMM(
198
197
  encoder_net=encoder_net,
199
198
  decoder_net=decoder_net,
200
199
  gmm_density_net=gmm_density_net,
201
200
  n_gmm=n_gmm,
202
201
  )
203
202
 
204
- if model_type == VAEGMM:
203
+ if model_type == "VAEGMM":
205
204
  n_gmm = 2 if n_gmm is None else n_gmm
206
205
  gmm_latent_dim = 2 if gmm_latent_dim is None else gmm_latent_dim
207
206
  # The outlier detector is an encoder/decoder architecture
208
207
  # Here we define the encoder
209
- encoder_net = Sequential(
208
+ encoder_net = keras.Sequential(
210
209
  [
211
- Flatten(),
212
- InputLayer(input_shape=(input_dim,)),
213
- Dense(20, activation=relu),
214
- Dense(15, activation=relu),
215
- Dense(7, activation=relu),
210
+ keras.layers.Flatten(),
211
+ keras.layers.InputLayer(input_shape=(input_dim,)),
212
+ keras.layers.Dense(20, activation=nn.relu),
213
+ keras.layers.Dense(15, activation=nn.relu),
214
+ keras.layers.Dense(7, activation=nn.relu),
216
215
  ]
217
216
  )
218
217
  # Here we define the decoder
219
- decoder_net = Sequential(
218
+ decoder_net = keras.Sequential(
220
219
  [
221
- InputLayer(input_shape=(gmm_latent_dim,)),
222
- Dense(7, activation=relu),
223
- Dense(15, activation=relu),
224
- Dense(20, activation=relu),
225
- Dense(input_dim, activation=None),
226
- Reshape(target_shape=input_shape),
220
+ keras.layers.InputLayer(input_shape=(gmm_latent_dim,)),
221
+ keras.layers.Dense(7, activation=nn.relu),
222
+ keras.layers.Dense(15, activation=nn.relu),
223
+ keras.layers.Dense(20, activation=nn.relu),
224
+ keras.layers.Dense(input_dim, activation=None),
225
+ keras.layers.Reshape(target_shape=input_shape),
227
226
  ]
228
227
  )
229
228
  # GMM autoencoders have a density network too
230
- gmm_density_net = Sequential(
229
+ gmm_density_net = keras.Sequential(
231
230
  [
232
- InputLayer(input_shape=(gmm_latent_dim + 2,)),
233
- Dense(10, activation=relu),
234
- Dense(n_gmm, activation=softmax),
231
+ keras.layers.InputLayer(input_shape=(gmm_latent_dim + 2,)),
232
+ keras.layers.Dense(10, activation=nn.relu),
233
+ keras.layers.Dense(n_gmm, activation=nn.softmax),
235
234
  ]
236
235
  )
237
- return VAEGMM(
236
+ return tf_models.VAEGMM(
238
237
  encoder_net=encoder_net,
239
238
  decoder_net=decoder_net,
240
239
  gmm_density_net=gmm_density_net,
@@ -242,8 +241,8 @@ def create_model(
242
241
  latent_dim=gmm_latent_dim,
243
242
  )
244
243
 
245
- if model_type == PixelCNN:
246
- return PixelCNN(
244
+ if model_type == "PixelCNN":
245
+ return tf_models.PixelCNN(
247
246
  image_shape=input_shape,
248
247
  num_resnet=5,
249
248
  num_hierarchies=2,
@@ -1,7 +1,11 @@
1
1
  from dataeval import _IS_TENSORFLOW_AVAILABLE
2
- from dataeval._internal.models.tensorflow.losses import Elbo, LossGMM
3
2
 
4
3
  __all__ = []
5
4
 
5
+
6
6
  if _IS_TENSORFLOW_AVAILABLE:
7
- __all__ += ["Elbo", "LossGMM"]
7
+ from dataeval.utils.tensorflow._internal.loss import Elbo, LossGMM
8
+
9
+ __all__ = ["Elbo", "LossGMM"]
10
+
11
+ del _IS_TENSORFLOW_AVAILABLE
@@ -6,16 +6,20 @@ to create a seamless integration between custom models and DataEval's metrics.
6
6
  """
7
7
 
8
8
  from dataeval import _IS_TORCH_AVAILABLE, _IS_TORCHVISION_AVAILABLE
9
- from dataeval._internal.utils import read_dataset
10
9
 
11
10
  __all__ = []
12
11
 
13
12
  if _IS_TORCH_AVAILABLE:
14
- from . import models, trainer
13
+ from dataeval.utils.torch import models, trainer
14
+ from dataeval.utils.torch.utils import read_dataset
15
15
 
16
16
  __all__ += ["read_dataset", "models", "trainer"]
17
17
 
18
18
  if _IS_TORCHVISION_AVAILABLE:
19
- from . import datasets
19
+ from dataeval.utils.torch import datasets
20
20
 
21
21
  __all__ += ["datasets"]
22
+
23
+
24
+ del _IS_TORCH_AVAILABLE
25
+ del _IS_TORCHVISION_AVAILABLE
@@ -1,3 +1,7 @@
1
+ from typing import Any
2
+
3
+ __all__ = []
4
+
1
5
  import torch.nn as nn
2
6
 
3
7
 
@@ -8,21 +12,22 @@ class Conv(nn.Module):
8
12
 
9
13
  def __init__(
10
14
  self,
11
- in_channels,
12
- out_channels,
13
- k=1,
14
- s=1,
15
- p=0,
16
- activation="relu",
17
- norm="instance",
18
- ):
15
+ in_channels: int,
16
+ out_channels: int,
17
+ k: int = 1,
18
+ s: int = 1,
19
+ p: int = 0,
20
+ activation: str = "relu",
21
+ norm: str = "instance",
22
+ ) -> None:
19
23
  super().__init__()
20
- conv = nn.Conv2d(in_channels, out_channels, kernel_size=k, stride=s, padding=p)
21
- norm = self.get_norm_func(norm=norm, out_channels=out_channels)
22
- act = self.get_activation_func(activation=activation)
23
- self.module = nn.Sequential(conv, norm, act)
24
+ self.module: nn.Sequential = nn.Sequential(
25
+ nn.Conv2d(in_channels, out_channels, kernel_size=k, stride=s, padding=p),
26
+ self.get_norm_func(norm=norm, out_channels=out_channels),
27
+ self.get_activation_func(activation=activation),
28
+ )
24
29
 
25
- def get_norm_func(self, norm: str, out_channels) -> nn.Module:
30
+ def get_norm_func(self, norm: str, out_channels: int) -> nn.Module:
26
31
  if norm == "batch":
27
32
  return nn.BatchNorm2d(out_channels)
28
33
  if norm == "instance":
@@ -42,5 +47,5 @@ class Conv(nn.Module):
42
47
  return nn.Tanh()
43
48
  return nn.Identity()
44
49
 
45
- def forward(self, x):
50
+ def forward(self, x: Any) -> Any:
46
51
  return self.module(x)
@@ -1,5 +1,7 @@
1
1
  from __future__ import annotations
2
2
 
3
+ __all__ = ["MNIST", "CIFAR10", "VOCDetection"]
4
+
3
5
  import hashlib
4
6
  import os
5
7
  import zipfile
@@ -11,7 +13,7 @@ import numpy as np
11
13
  import requests
12
14
  from numpy.typing import NDArray
13
15
  from torch.utils.data import Dataset
14
- from torchvision.datasets import CIFAR10, VOCDetection # noqa: F401
16
+ from torchvision.datasets import CIFAR10, VOCDetection
15
17
 
16
18
  ClassStringMap = Literal["zero", "one", "two", "three", "four", "five", "six", "seven", "eight", "nine"]
17
19
  TClassMap = TypeVar("TClassMap", ClassStringMap, int, list[ClassStringMap], list[int])
@@ -50,6 +52,7 @@ def _get_file(
50
52
  file_hash: str | None = None,
51
53
  verbose: bool = True,
52
54
  md5: bool = False,
55
+ timeout: int = 60,
53
56
  ):
54
57
  fpath = os.path.join(root, fname)
55
58
  download = True
@@ -64,16 +67,16 @@ def _get_file(
64
67
  try:
65
68
  error_msg = "URL fetch failure on {}: {} -- {}"
66
69
  try:
67
- with requests.get(origin, stream=True, timeout=60) as r:
70
+ with requests.get(origin, stream=True, timeout=timeout) as r:
68
71
  r.raise_for_status()
69
72
  with open(fpath, "wb") as f:
70
73
  for chunk in r.iter_content(chunk_size=8192):
71
74
  if chunk:
72
75
  f.write(chunk)
73
76
  except requests.exceptions.HTTPError as e:
74
- raise Exception(f"{error_msg.format(origin, e.response.status_code)} -- {e.response.reason}") from e
77
+ raise RuntimeError(f"{error_msg.format(origin, e.response.status_code, e.response.reason)}") from e
75
78
  except requests.exceptions.RequestException as e:
76
- raise Exception(f"{error_msg.format(origin, 'Unknown error')} -- {str(e)}") from e
79
+ raise ValueError(f"{error_msg.format(origin, 'Unknown error', str(e))}") from e
77
80
  except (Exception, KeyboardInterrupt):
78
81
  if os.path.exists(fpath):
79
82
  os.remove(fpath)
@@ -89,7 +92,7 @@ def _get_file(
89
92
  return fpath
90
93
 
91
94
 
92
- def check_exists(
95
+ def _check_exists(
93
96
  folder: str | Path,
94
97
  url: str,
95
98
  root: str | Path,
@@ -103,7 +106,7 @@ def check_exists(
103
106
  location = str(folder)
104
107
  if not os.path.exists(folder):
105
108
  if download:
106
- location = download_dataset(url, root, fname, file_hash, verbose, md5)
109
+ location = _download_dataset(url, root, fname, file_hash, verbose, md5)
107
110
  else:
108
111
  raise RuntimeError("Dataset not found. You can use download=True to download it")
109
112
  else:
@@ -112,7 +115,7 @@ def check_exists(
112
115
  return location
113
116
 
114
117
 
115
- def download_dataset(
118
+ def _download_dataset(
116
119
  url: str, root: str | Path, fname: str, file_hash: str, verbose: bool = True, md5: bool = False
117
120
  ) -> str:
118
121
  """Code to download mnist and corruptions, originates from tensorflow_datasets (tfds):
@@ -131,11 +134,11 @@ def download_dataset(
131
134
  md5=md5,
132
135
  )
133
136
  if md5:
134
- folder = extract_archive(fpath, root, remove_finished=True)
137
+ folder = _extract_archive(fpath, root, remove_finished=True)
135
138
  return folder
136
139
 
137
140
 
138
- def extract_archive(
141
+ def _extract_archive(
139
142
  from_path: str | Path,
140
143
  to_path: str | Path | None = None,
141
144
  remove_finished: bool = False,
@@ -163,13 +166,13 @@ def extract_archive(
163
166
  return str(to_path)
164
167
 
165
168
 
166
- def subselect(arr: NDArray, count: int, from_back: bool = False):
169
+ def _subselect(arr: NDArray, count: int, from_back: bool = False):
167
170
  if from_back:
168
171
  return arr[-count:]
169
172
  return arr[:count]
170
173
 
171
174
 
172
- class MNIST(Dataset):
175
+ class MNIST(Dataset[tuple[NDArray[np.float64], int]]):
173
176
  """MNIST Dataset and Corruptions.
174
177
 
175
178
  Args:
@@ -211,17 +214,17 @@ class MNIST(Dataset):
211
214
  If True, outputs print statements.
212
215
  """
213
216
 
214
- mirror = [
217
+ _mirrors: tuple[str, ...] = (
215
218
  "https://storage.googleapis.com/tensorflow/tf-keras-datasets/",
216
219
  "https://zenodo.org/record/3239543/files/",
217
- ]
220
+ )
218
221
 
219
- resources = [
222
+ _resources: tuple[tuple[str, str], ...] = (
220
223
  ("mnist.npz", "731c5ac602752760c8e48fbffcf8c3b850d9dc2a2aedcf2cc48468fc17b673d1"),
221
224
  ("mnist_c.zip", "4b34b33045869ee6d424616cd3a65da3"),
222
- ]
225
+ )
223
226
 
224
- class_dict = {
227
+ class_dict: dict[str, int] = {
225
228
  "zero": 0,
226
229
  "one": 1,
227
230
  "two": 2,
@@ -267,43 +270,46 @@ class MNIST(Dataset):
267
270
  self.randomize = randomize
268
271
  self.from_back = slice_back
269
272
  self.verbose = verbose
273
+ self.data: NDArray[np.float64]
274
+ self.targets: NDArray[np.int_]
275
+ self.size: int
270
276
 
271
- self.class_set = []
277
+ self._class_set = []
272
278
  if classes is not None:
273
279
  if not isinstance(classes, list):
274
280
  classes = [classes] # type: ignore
275
281
 
276
282
  for val in classes: # type: ignore
277
283
  if isinstance(val, int) and 0 <= val < 10:
278
- self.class_set.append(val)
284
+ self._class_set.append(val)
279
285
  elif isinstance(val, str):
280
- self.class_set.append(self.class_dict[val])
281
- self.class_set = set(self.class_set)
286
+ self._class_set.append(self.class_dict[val])
287
+ self._class_set = set(self._class_set)
282
288
 
283
- if not self.class_set:
284
- self.class_set = set(self.class_dict.values())
289
+ if not self._class_set:
290
+ self._class_set = set(self.class_dict.values())
285
291
 
286
- self.num_classes = len(self.class_set)
292
+ self._num_classes = len(self._class_set)
287
293
 
288
294
  if self.corruption is None:
289
- file_resource = self.resources[0]
290
- mirror = self.mirror[0]
295
+ file_resource = self._resources[0]
296
+ mirror = self._mirrors[0]
291
297
  md5 = False
292
298
  else:
293
299
  if self.corruption == "identity" and verbose:
294
300
  print("Identity is not a corrupted dataset but the original MNIST dataset.")
295
- file_resource = self.resources[1]
296
- mirror = self.mirror[1]
301
+ file_resource = self._resources[1]
302
+ mirror = self._mirrors[1]
297
303
  md5 = True
298
- check_exists(self.mnist_folder, mirror, self.root, file_resource[0], file_resource[1], download, verbose, md5)
304
+ _check_exists(self.mnist_folder, mirror, self.root, file_resource[0], file_resource[1], download, verbose, md5)
299
305
 
300
306
  self.data, self.targets = self._load_data()
301
307
 
302
308
  self._augmentations()
303
309
 
304
- def _load_data(self):
310
+ def _load_data(self) -> tuple[NDArray[np.float64], NDArray[np.int64]]:
305
311
  if self.corruption is None:
306
- image_file = self.resources[0][0]
312
+ image_file = self._resources[0][0]
307
313
  data, targets = self._read_normal_file(os.path.join(self.mnist_folder, image_file))
308
314
  else:
309
315
  image_file = f"{'train' if self.train else 'test'}_images.npy"
@@ -329,27 +335,27 @@ class MNIST(Dataset):
329
335
  self.data = self.data[shuffled_indices]
330
336
  self.targets = self.targets[shuffled_indices]
331
337
 
332
- if not self.balance and self.num_classes > self.size:
338
+ if not self.balance and self._num_classes > self.size:
333
339
  if self.size > 0:
334
- self.data = subselect(self.data, self.size, self.from_back)
335
- self.targets = subselect(self.targets, self.size, self.from_back)
340
+ self.data = _subselect(self.data, self.size, self.from_back)
341
+ self.targets = _subselect(self.targets, self.size, self.from_back)
336
342
  else:
337
- label_dict = {label: np.where(self.targets == label)[0] for label in self.class_set}
343
+ label_dict = {label: np.where(self.targets == label)[0] for label in self._class_set}
338
344
  min_label_count = min(len(indices) for indices in label_dict.values())
339
345
 
340
- self.per_class_count = int(np.ceil(self.size / self.num_classes)) if self.size > 0 else min_label_count
346
+ self._per_class_count = int(np.ceil(self.size / self._num_classes)) if self.size > 0 else min_label_count
341
347
 
342
- if self.per_class_count > min_label_count:
343
- self.per_class_count = min_label_count
348
+ if self._per_class_count > min_label_count:
349
+ self._per_class_count = min_label_count
344
350
  if not self.balance and self.verbose:
345
351
  warn(
346
- f"Because of dataset limitations, only {min_label_count*self.num_classes} samples "
352
+ f"Because of dataset limitations, only {min_label_count*self._num_classes} samples "
347
353
  f"will be returned, instead of the desired {self.size}."
348
354
  )
349
355
 
350
- all_indices = np.empty(shape=(self.num_classes, self.per_class_count), dtype=int)
351
- for i, label in enumerate(self.class_set):
352
- all_indices[i] = subselect(label_dict[label], self.per_class_count, self.from_back)
356
+ all_indices: NDArray[np.int_] = np.empty(shape=(self._num_classes, self._per_class_count), dtype=np.int_)
357
+ for i, label in enumerate(self._class_set):
358
+ all_indices[i] = _subselect(label_dict[label], self._per_class_count, self.from_back)
353
359
  self.data = np.vstack(self.data[all_indices.T]) # type: ignore
354
360
  self.targets = np.hstack(self.targets[all_indices.T]) # type: ignore
355
361
 
@@ -370,7 +376,7 @@ class MNIST(Dataset):
370
376
  if self.flatten and self.channels is None:
371
377
  self.data = self.data.reshape(self.data.shape[0], -1)
372
378
 
373
- def __getitem__(self, index: int) -> tuple[NDArray, int]:
379
+ def __getitem__(self, index: int) -> tuple[NDArray[np.float64], int]:
374
380
  """
375
381
  Args:
376
382
  index (int): Index