dataeval 0.72.1__py3-none-any.whl → 0.72.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (90) hide show
  1. dataeval/__init__.py +4 -4
  2. dataeval/detectors/__init__.py +4 -3
  3. dataeval/detectors/drift/__init__.py +9 -10
  4. dataeval/{_internal/detectors → detectors}/drift/base.py +39 -91
  5. dataeval/{_internal/detectors → detectors}/drift/cvm.py +4 -3
  6. dataeval/{_internal/detectors → detectors}/drift/ks.py +4 -3
  7. dataeval/{_internal/detectors → detectors}/drift/mmd.py +23 -25
  8. dataeval/{_internal/detectors → detectors}/drift/torch.py +13 -11
  9. dataeval/{_internal/detectors → detectors}/drift/uncertainty.py +7 -5
  10. dataeval/detectors/drift/updates.py +61 -0
  11. dataeval/detectors/linters/__init__.py +3 -3
  12. dataeval/{_internal/detectors → detectors/linters}/clusterer.py +41 -39
  13. dataeval/{_internal/detectors → detectors/linters}/duplicates.py +19 -9
  14. dataeval/{_internal/detectors → detectors/linters}/merged_stats.py +3 -1
  15. dataeval/{_internal/detectors → detectors/linters}/outliers.py +14 -21
  16. dataeval/detectors/ood/__init__.py +6 -6
  17. dataeval/{_internal/detectors → detectors}/ood/ae.py +7 -7
  18. dataeval/{_internal/detectors → detectors}/ood/aegmm.py +9 -29
  19. dataeval/{_internal/detectors → detectors}/ood/base.py +24 -18
  20. dataeval/{_internal/detectors → detectors}/ood/llr.py +24 -20
  21. dataeval/detectors/ood/metadata_ks_compare.py +99 -0
  22. dataeval/detectors/ood/metadata_least_likely.py +119 -0
  23. dataeval/detectors/ood/metadata_ood_mi.py +92 -0
  24. dataeval/{_internal/detectors → detectors}/ood/vae.py +10 -12
  25. dataeval/{_internal/detectors → detectors}/ood/vaegmm.py +10 -32
  26. dataeval/{_internal/interop.py → interop.py} +12 -7
  27. dataeval/metrics/__init__.py +1 -1
  28. dataeval/metrics/bias/__init__.py +4 -4
  29. dataeval/{_internal/metrics → metrics/bias}/balance.py +75 -9
  30. dataeval/{_internal/metrics → metrics/bias}/coverage.py +6 -4
  31. dataeval/{_internal/metrics → metrics/bias}/diversity.py +48 -14
  32. dataeval/metrics/bias/metadata.py +275 -0
  33. dataeval/{_internal/metrics → metrics/bias}/parity.py +12 -10
  34. dataeval/metrics/estimators/__init__.py +3 -3
  35. dataeval/{_internal/metrics → metrics/estimators}/ber.py +25 -22
  36. dataeval/{_internal/metrics → metrics/estimators}/divergence.py +11 -12
  37. dataeval/{_internal/metrics → metrics/estimators}/uap.py +5 -3
  38. dataeval/metrics/stats/__init__.py +7 -7
  39. dataeval/{_internal/metrics → metrics}/stats/base.py +59 -35
  40. dataeval/{_internal/metrics → metrics}/stats/boxratiostats.py +18 -14
  41. dataeval/{_internal/metrics → metrics}/stats/datasetstats.py +18 -16
  42. dataeval/{_internal/metrics → metrics}/stats/dimensionstats.py +9 -7
  43. dataeval/metrics/stats/hashstats.py +156 -0
  44. dataeval/{_internal/metrics → metrics}/stats/labelstats.py +5 -3
  45. dataeval/{_internal/metrics → metrics}/stats/pixelstats.py +9 -8
  46. dataeval/{_internal/metrics → metrics}/stats/visualstats.py +10 -9
  47. dataeval/{_internal/output.py → output.py} +26 -6
  48. dataeval/utils/__init__.py +7 -3
  49. dataeval/utils/image.py +71 -0
  50. dataeval/utils/shared.py +151 -0
  51. dataeval/{_internal → utils}/split_dataset.py +98 -33
  52. dataeval/utils/tensorflow/__init__.py +7 -6
  53. dataeval/{_internal/models/tensorflow → utils/tensorflow/_internal}/autoencoder.py +60 -64
  54. dataeval/{_internal/models/tensorflow/losses.py → utils/tensorflow/_internal/loss.py} +9 -8
  55. dataeval/{_internal/models/tensorflow → utils/tensorflow/_internal}/pixelcnn.py +16 -20
  56. dataeval/{_internal/models/tensorflow → utils/tensorflow/_internal}/trainer.py +3 -1
  57. dataeval/{_internal/models/tensorflow → utils/tensorflow/_internal}/utils.py +17 -17
  58. dataeval/utils/tensorflow/loss/__init__.py +6 -2
  59. dataeval/utils/torch/__init__.py +7 -3
  60. dataeval/{_internal/models/pytorch → utils/torch}/blocks.py +19 -14
  61. dataeval/{_internal → utils/torch}/datasets.py +48 -42
  62. dataeval/utils/torch/models.py +138 -0
  63. dataeval/{_internal/models/pytorch/autoencoder.py → utils/torch/trainer.py} +7 -136
  64. dataeval/{_internal → utils/torch}/utils.py +3 -1
  65. dataeval/workflows/__init__.py +1 -1
  66. dataeval/{_internal/workflows → workflows}/sufficiency.py +39 -34
  67. {dataeval-0.72.1.dist-info → dataeval-0.72.2.dist-info}/METADATA +2 -1
  68. dataeval-0.72.2.dist-info/RECORD +72 -0
  69. dataeval/_internal/detectors/__init__.py +0 -0
  70. dataeval/_internal/detectors/drift/__init__.py +0 -0
  71. dataeval/_internal/detectors/ood/__init__.py +0 -0
  72. dataeval/_internal/metrics/__init__.py +0 -0
  73. dataeval/_internal/metrics/stats/hashstats.py +0 -75
  74. dataeval/_internal/metrics/utils.py +0 -447
  75. dataeval/_internal/models/__init__.py +0 -0
  76. dataeval/_internal/models/pytorch/__init__.py +0 -0
  77. dataeval/_internal/models/pytorch/utils.py +0 -67
  78. dataeval/_internal/models/tensorflow/__init__.py +0 -0
  79. dataeval/_internal/workflows/__init__.py +0 -0
  80. dataeval/detectors/drift/kernels/__init__.py +0 -10
  81. dataeval/detectors/drift/updates/__init__.py +0 -8
  82. dataeval/utils/tensorflow/models/__init__.py +0 -9
  83. dataeval/utils/tensorflow/recon/__init__.py +0 -3
  84. dataeval/utils/torch/datasets/__init__.py +0 -12
  85. dataeval/utils/torch/models/__init__.py +0 -11
  86. dataeval/utils/torch/trainer/__init__.py +0 -7
  87. dataeval-0.72.1.dist-info/RECORD +0 -81
  88. /dataeval/{_internal/models/tensorflow → utils/tensorflow/_internal}/gmm.py +0 -0
  89. {dataeval-0.72.1.dist-info → dataeval-0.72.2.dist-info}/LICENSE.txt +0 -0
  90. {dataeval-0.72.1.dist-info → dataeval-0.72.2.dist-info}/WHEEL +0 -0
@@ -34,13 +34,9 @@ from tensorflow_probability.python.internal import (
34
34
  tensorshape_util,
35
35
  )
36
36
 
37
- __all__ = [
38
- "Shift",
39
- ]
40
-
41
37
 
42
38
  class WeightNorm(keras.layers.Wrapper):
43
- def __init__(self, layer, data_init: bool = True, **kwargs):
39
+ def __init__(self, layer, data_init: bool = True, **kwargs) -> None:
44
40
  """Layer wrapper to decouple magnitude and direction of the layer's weights.
45
41
 
46
42
  This wrapper reparameterizes a layer by decoupling the weight's
@@ -187,7 +183,7 @@ class WeightNorm(keras.layers.Wrapper):
187
183
 
188
184
 
189
185
  class Shift(bijector.Bijector):
190
- def __init__(self, shift, validate_args=False, name="shift"):
186
+ def __init__(self, shift, validate_args=False, name="shift") -> None:
191
187
  """Instantiates the `Shift` bijector which computes `Y = g(X; shift) = X + shift`
192
188
  where `shift` is a numeric `Tensor`.
193
189
 
@@ -276,13 +272,13 @@ class PixelCNN(distribution.Distribution):
276
272
 
277
273
  def __init__(
278
274
  self,
279
- image_shape: tuple,
280
- conditional_shape: tuple | None = None,
275
+ image_shape: tuple[int, int, int],
276
+ conditional_shape: tuple[int, ...] | None = None,
281
277
  num_resnet: int = 5,
282
278
  num_hierarchies: int = 3,
283
279
  num_filters: int = 160,
284
280
  num_logistic_mix: int = 10,
285
- receptive_field_dims: tuple = (3, 3),
281
+ receptive_field_dims: tuple[int, int] = (3, 3),
286
282
  dropout_p: float = 0.5,
287
283
  resnet_activation: str = "concat_elu",
288
284
  l2_weight: float = 0.0,
@@ -290,7 +286,7 @@ class PixelCNN(distribution.Distribution):
290
286
  use_data_init: bool = True,
291
287
  high: int = 255,
292
288
  low: int = 0,
293
- dtype=tf.float32,
289
+ dtype: tf.DType = tf.float32,
294
290
  ) -> None:
295
291
  parameters = dict(locals())
296
292
  with tf.name_scope("PixelCNN") as name:
@@ -315,7 +311,7 @@ class PixelCNN(distribution.Distribution):
315
311
  self._high = tf.cast(high, self.dtype)
316
312
  self._low = tf.cast(low, self.dtype)
317
313
  self._num_logistic_mix = num_logistic_mix
318
- self.network = _PixelCNNNetwork(
314
+ self._network = PixelCNNNetwork(
319
315
  dropout_p=dropout_p,
320
316
  num_resnet=num_resnet,
321
317
  num_hierarchies=num_hierarchies,
@@ -338,7 +334,7 @@ class PixelCNN(distribution.Distribution):
338
334
 
339
335
  self.image_shape = image_shape
340
336
  self.conditional_shape = conditional_shape
341
- self.network.build(input_shape)
337
+ self._network.build(input_shape)
342
338
 
343
339
  def _make_mixture_dist(self, component_logits, locs, scales, return_per_feature: bool = False):
344
340
  """Builds a mixture of quantized logistic distributions.
@@ -455,7 +451,7 @@ class PixelCNN(distribution.Distribution):
455
451
  transformed_value = (2.0 * (value - self._low) / (self._high - self._low)) - 1.0
456
452
  inputs = transformed_value if conditional_input is None else [transformed_value, conditional_input]
457
453
 
458
- params = self.network(inputs, training=training)
454
+ params = self._network(inputs, training=training)
459
455
 
460
456
  num_channels = self.event_shape[-1]
461
457
  if num_channels == 1:
@@ -554,7 +550,7 @@ class PixelCNN(distribution.Distribution):
554
550
  seed=seed,
555
551
  )
556
552
  inputs = samples_0 if conditional_input is None else [samples_0, h]
557
- params_0 = self.network(inputs, training=training)
553
+ params_0 = self._network(inputs, training=training)
558
554
  samples_0 = self._sample_channels(*params_0, seed=seed)
559
555
 
560
556
  image_height, image_width, _ = tensorshape_util.as_list(self.event_shape)
@@ -579,7 +575,7 @@ class PixelCNN(distribution.Distribution):
579
575
  width, num_channels]`.
580
576
  """
581
577
  inputs = samples if conditional_input is None else [samples, h]
582
- params = self.network(inputs, training=training)
578
+ params = self._network(inputs, training=training)
583
579
  samples_new = self._sample_channels(*params, seed=seed)
584
580
 
585
581
  # Update the current pixel
@@ -673,7 +669,7 @@ class PixelCNN(distribution.Distribution):
673
669
  return tf.TensorShape(self.image_shape)
674
670
 
675
671
 
676
- class _PixelCNNNetwork(keras.layers.Layer):
672
+ class PixelCNNNetwork(keras.layers.Layer):
677
673
  """Keras `Layer` to parameterize a Pixel CNN++ distribution.
678
674
  This is a Keras implementation of the Pixel CNN++ network, as described in
679
675
  Salimans et al. (2017)[1] and van den Oord et al. (2016)[2].
@@ -699,12 +695,12 @@ class _PixelCNNNetwork(keras.layers.Layer):
699
695
  num_hierarchies: int = 3,
700
696
  num_filters: int = 160,
701
697
  num_logistic_mix: int = 10,
702
- receptive_field_dims: tuple = (3, 3),
698
+ receptive_field_dims: tuple[int, int] = (3, 3),
703
699
  resnet_activation: str = "concat_elu",
704
700
  l2_weight: float = 0.0,
705
701
  use_weight_norm: bool = True,
706
702
  use_data_init: bool = True,
707
- dtype=tf.float32,
703
+ dtype: tf.DType = tf.float32,
708
704
  ) -> None:
709
705
  """Initialize the :term:`neural network<Neural Network>` for the Pixel CNN++ distribution.
710
706
 
@@ -765,7 +761,7 @@ class _PixelCNNNetwork(keras.layers.Layer):
765
761
  else:
766
762
  self._layer_wrapper = lambda layer: layer
767
763
 
768
- def build(self, input_shape):
764
+ def build(self, input_shape: tuple[int, ...]) -> None:
769
765
  dtype = self.dtype
770
766
  if len(input_shape) == 2:
771
767
  batch_image_shape, batch_conditional_shape = input_shape
@@ -1040,7 +1036,7 @@ class _PixelCNNNetwork(keras.layers.Layer):
1040
1036
  self._network = keras.Model(inputs=inputs, outputs=outputs)
1041
1037
  super().build(input_shape)
1042
1038
 
1043
- def call(self, inputs, training=None):
1039
+ def call(self, inputs: tf.Tensor, training: bool | None = None, mask: tf.Tensor | None = None) -> tf.Tensor:
1044
1040
  """Call the Pixel CNN network model.
1045
1041
 
1046
1042
  Parameters
@@ -60,7 +60,9 @@ def trainer(
60
60
  loss_fn = loss_fn() if isinstance(loss_fn, type) else loss_fn
61
61
  optimizer = optimizer() if isinstance(optimizer, type) else optimizer
62
62
 
63
- train_data = x_train if y_train is None else (x_train, y_train)
63
+ train_data = (
64
+ x_train.astype(np.float32) if y_train is None else (x_train.astype(np.float32), y_train.astype(np.float32))
65
+ )
64
66
  dataset = tf.data.Dataset.from_tensor_slices(train_data)
65
67
  dataset = dataset.shuffle(buffer_size=buffer_size).batch(batch_size)
66
68
  n_minibatch = len(dataset)
@@ -9,7 +9,7 @@ Licensed under Apache Software License (Apache 2.0)
9
9
  from __future__ import annotations
10
10
 
11
11
  import math
12
- from typing import Callable, Union, cast
12
+ from typing import Any, Callable, Literal, Union, cast
13
13
 
14
14
  import numpy as np
15
15
  import tensorflow as tf
@@ -26,8 +26,8 @@ from tf_keras.layers import (
26
26
  Reshape,
27
27
  )
28
28
 
29
- from dataeval._internal.models.tensorflow.autoencoder import AE, AEGMM, VAE, VAEGMM
30
- from dataeval._internal.models.tensorflow.pixelcnn import PixelCNN
29
+ from dataeval.utils.tensorflow._internal.autoencoder import AE, AEGMM, VAE, VAEGMM
30
+ from dataeval.utils.tensorflow._internal.pixelcnn import PixelCNN
31
31
 
32
32
 
33
33
  def predict_batch(
@@ -95,7 +95,7 @@ def predict_batch(
95
95
  return out
96
96
 
97
97
 
98
- def _get_default_encoder_net(input_shape: tuple[int, int, int], encoding_dim: int):
98
+ def get_default_encoder_net(input_shape: tuple[int, int, int], encoding_dim: int):
99
99
  return Sequential(
100
100
  [
101
101
  InputLayer(input_shape=input_shape),
@@ -108,7 +108,7 @@ def _get_default_encoder_net(input_shape: tuple[int, int, int], encoding_dim: in
108
108
  )
109
109
 
110
110
 
111
- def _get_default_decoder_net(input_shape: tuple[int, int, int], encoding_dim: int):
111
+ def get_default_decoder_net(input_shape: tuple[int, int, int], encoding_dim: int):
112
112
  return Sequential(
113
113
  [
114
114
  InputLayer(input_shape=(encoding_dim,)),
@@ -124,18 +124,18 @@ def _get_default_decoder_net(input_shape: tuple[int, int, int], encoding_dim: in
124
124
 
125
125
 
126
126
  def create_model(
127
- model_type: AE | AEGMM | PixelCNN | VAE | VAEGMM,
127
+ model_type: Literal["AE", "AEGMM", "PixelCNN", "VAE", "VAEGMM"],
128
128
  input_shape: tuple[int, int, int],
129
129
  encoding_dim: int | None = None,
130
130
  n_gmm: int | None = None,
131
131
  gmm_latent_dim: int | None = None,
132
- ):
132
+ ) -> Any:
133
133
  """
134
134
  Create a default model for the specified model type.
135
135
 
136
136
  Parameters
137
137
  ----------
138
- model_type : Union[AE, AEGMM, PixelCNN, VAE, VAEGMM]
138
+ model_type : Literal["AE", "AEGMM", "PixelCNN", "VAE", "VAEGMM"]
139
139
  The model type to create.
140
140
  input_shape : Tuple[int, int, int]
141
141
  The input shape of the data used.
@@ -148,20 +148,20 @@ def create_model(
148
148
  """
149
149
  input_dim = math.prod(input_shape)
150
150
  encoding_dim = int(math.pow(2, int(input_dim.bit_length() * 0.8)) if encoding_dim is None else encoding_dim)
151
- if model_type == AE:
151
+ if model_type == "AE":
152
152
  return AE(
153
- _get_default_encoder_net(input_shape, encoding_dim),
154
- _get_default_decoder_net(input_shape, encoding_dim),
153
+ get_default_encoder_net(input_shape, encoding_dim),
154
+ get_default_decoder_net(input_shape, encoding_dim),
155
155
  )
156
156
 
157
- if model_type == VAE:
157
+ if model_type == "VAE":
158
158
  return VAE(
159
- _get_default_encoder_net(input_shape, encoding_dim),
160
- _get_default_decoder_net(input_shape, encoding_dim),
159
+ get_default_encoder_net(input_shape, encoding_dim),
160
+ get_default_decoder_net(input_shape, encoding_dim),
161
161
  encoding_dim,
162
162
  )
163
163
 
164
- if model_type == AEGMM:
164
+ if model_type == "AEGMM":
165
165
  n_gmm = 2 if n_gmm is None else n_gmm
166
166
  gmm_latent_dim = 1 if gmm_latent_dim is None else gmm_latent_dim
167
167
  # The outlier detector is an encoder/decoder architecture
@@ -201,7 +201,7 @@ def create_model(
201
201
  n_gmm=n_gmm,
202
202
  )
203
203
 
204
- if model_type == VAEGMM:
204
+ if model_type == "VAEGMM":
205
205
  n_gmm = 2 if n_gmm is None else n_gmm
206
206
  gmm_latent_dim = 2 if gmm_latent_dim is None else gmm_latent_dim
207
207
  # The outlier detector is an encoder/decoder architecture
@@ -242,7 +242,7 @@ def create_model(
242
242
  latent_dim=gmm_latent_dim,
243
243
  )
244
244
 
245
- if model_type == PixelCNN:
245
+ if model_type == "PixelCNN":
246
246
  return PixelCNN(
247
247
  image_shape=input_shape,
248
248
  num_resnet=5,
@@ -1,7 +1,11 @@
1
1
  from dataeval import _IS_TENSORFLOW_AVAILABLE
2
- from dataeval._internal.models.tensorflow.losses import Elbo, LossGMM
3
2
 
4
3
  __all__ = []
5
4
 
5
+
6
6
  if _IS_TENSORFLOW_AVAILABLE:
7
- __all__ += ["Elbo", "LossGMM"]
7
+ from dataeval.utils.tensorflow._internal.loss import Elbo, LossGMM
8
+
9
+ __all__ = ["Elbo", "LossGMM"]
10
+
11
+ del _IS_TENSORFLOW_AVAILABLE
@@ -6,16 +6,20 @@ to create a seamless integration between custom models and DataEval's metrics.
6
6
  """
7
7
 
8
8
  from dataeval import _IS_TORCH_AVAILABLE, _IS_TORCHVISION_AVAILABLE
9
- from dataeval._internal.utils import read_dataset
10
9
 
11
10
  __all__ = []
12
11
 
13
12
  if _IS_TORCH_AVAILABLE:
14
- from . import models, trainer
13
+ from dataeval.utils.torch import models, trainer
14
+ from dataeval.utils.torch.utils import read_dataset
15
15
 
16
16
  __all__ += ["read_dataset", "models", "trainer"]
17
17
 
18
18
  if _IS_TORCHVISION_AVAILABLE:
19
- from . import datasets
19
+ from dataeval.utils.torch import datasets
20
20
 
21
21
  __all__ += ["datasets"]
22
+
23
+
24
+ del _IS_TORCH_AVAILABLE
25
+ del _IS_TORCHVISION_AVAILABLE
@@ -1,3 +1,7 @@
1
+ from typing import Any
2
+
3
+ __all__ = []
4
+
1
5
  import torch.nn as nn
2
6
 
3
7
 
@@ -8,21 +12,22 @@ class Conv(nn.Module):
8
12
 
9
13
  def __init__(
10
14
  self,
11
- in_channels,
12
- out_channels,
13
- k=1,
14
- s=1,
15
- p=0,
16
- activation="relu",
17
- norm="instance",
18
- ):
15
+ in_channels: int,
16
+ out_channels: int,
17
+ k: int = 1,
18
+ s: int = 1,
19
+ p: int = 0,
20
+ activation: str = "relu",
21
+ norm: str = "instance",
22
+ ) -> None:
19
23
  super().__init__()
20
- conv = nn.Conv2d(in_channels, out_channels, kernel_size=k, stride=s, padding=p)
21
- norm = self.get_norm_func(norm=norm, out_channels=out_channels)
22
- act = self.get_activation_func(activation=activation)
23
- self.module = nn.Sequential(conv, norm, act)
24
+ self.module: nn.Sequential = nn.Sequential(
25
+ nn.Conv2d(in_channels, out_channels, kernel_size=k, stride=s, padding=p),
26
+ self.get_norm_func(norm=norm, out_channels=out_channels),
27
+ self.get_activation_func(activation=activation),
28
+ )
24
29
 
25
- def get_norm_func(self, norm: str, out_channels) -> nn.Module:
30
+ def get_norm_func(self, norm: str, out_channels: int) -> nn.Module:
26
31
  if norm == "batch":
27
32
  return nn.BatchNorm2d(out_channels)
28
33
  if norm == "instance":
@@ -42,5 +47,5 @@ class Conv(nn.Module):
42
47
  return nn.Tanh()
43
48
  return nn.Identity()
44
49
 
45
- def forward(self, x):
50
+ def forward(self, x: Any) -> Any:
46
51
  return self.module(x)
@@ -1,5 +1,7 @@
1
1
  from __future__ import annotations
2
2
 
3
+ __all__ = ["MNIST", "CIFAR10", "VOCDetection"]
4
+
3
5
  import hashlib
4
6
  import os
5
7
  import zipfile
@@ -11,7 +13,7 @@ import numpy as np
11
13
  import requests
12
14
  from numpy.typing import NDArray
13
15
  from torch.utils.data import Dataset
14
- from torchvision.datasets import CIFAR10, VOCDetection # noqa: F401
16
+ from torchvision.datasets import CIFAR10, VOCDetection
15
17
 
16
18
  ClassStringMap = Literal["zero", "one", "two", "three", "four", "five", "six", "seven", "eight", "nine"]
17
19
  TClassMap = TypeVar("TClassMap", ClassStringMap, int, list[ClassStringMap], list[int])
@@ -50,6 +52,7 @@ def _get_file(
50
52
  file_hash: str | None = None,
51
53
  verbose: bool = True,
52
54
  md5: bool = False,
55
+ timeout: int = 60,
53
56
  ):
54
57
  fpath = os.path.join(root, fname)
55
58
  download = True
@@ -64,16 +67,16 @@ def _get_file(
64
67
  try:
65
68
  error_msg = "URL fetch failure on {}: {} -- {}"
66
69
  try:
67
- with requests.get(origin, stream=True, timeout=60) as r:
70
+ with requests.get(origin, stream=True, timeout=timeout) as r:
68
71
  r.raise_for_status()
69
72
  with open(fpath, "wb") as f:
70
73
  for chunk in r.iter_content(chunk_size=8192):
71
74
  if chunk:
72
75
  f.write(chunk)
73
76
  except requests.exceptions.HTTPError as e:
74
- raise Exception(f"{error_msg.format(origin, e.response.status_code)} -- {e.response.reason}") from e
77
+ raise RuntimeError(f"{error_msg.format(origin, e.response.status_code, e.response.reason)}") from e
75
78
  except requests.exceptions.RequestException as e:
76
- raise Exception(f"{error_msg.format(origin, 'Unknown error')} -- {str(e)}") from e
79
+ raise ValueError(f"{error_msg.format(origin, 'Unknown error', str(e))}") from e
77
80
  except (Exception, KeyboardInterrupt):
78
81
  if os.path.exists(fpath):
79
82
  os.remove(fpath)
@@ -89,7 +92,7 @@ def _get_file(
89
92
  return fpath
90
93
 
91
94
 
92
- def check_exists(
95
+ def _check_exists(
93
96
  folder: str | Path,
94
97
  url: str,
95
98
  root: str | Path,
@@ -103,7 +106,7 @@ def check_exists(
103
106
  location = str(folder)
104
107
  if not os.path.exists(folder):
105
108
  if download:
106
- location = download_dataset(url, root, fname, file_hash, verbose, md5)
109
+ location = _download_dataset(url, root, fname, file_hash, verbose, md5)
107
110
  else:
108
111
  raise RuntimeError("Dataset not found. You can use download=True to download it")
109
112
  else:
@@ -112,7 +115,7 @@ def check_exists(
112
115
  return location
113
116
 
114
117
 
115
- def download_dataset(
118
+ def _download_dataset(
116
119
  url: str, root: str | Path, fname: str, file_hash: str, verbose: bool = True, md5: bool = False
117
120
  ) -> str:
118
121
  """Code to download mnist and corruptions, originates from tensorflow_datasets (tfds):
@@ -131,11 +134,11 @@ def download_dataset(
131
134
  md5=md5,
132
135
  )
133
136
  if md5:
134
- folder = extract_archive(fpath, root, remove_finished=True)
137
+ folder = _extract_archive(fpath, root, remove_finished=True)
135
138
  return folder
136
139
 
137
140
 
138
- def extract_archive(
141
+ def _extract_archive(
139
142
  from_path: str | Path,
140
143
  to_path: str | Path | None = None,
141
144
  remove_finished: bool = False,
@@ -163,13 +166,13 @@ def extract_archive(
163
166
  return str(to_path)
164
167
 
165
168
 
166
- def subselect(arr: NDArray, count: int, from_back: bool = False):
169
+ def _subselect(arr: NDArray, count: int, from_back: bool = False):
167
170
  if from_back:
168
171
  return arr[-count:]
169
172
  return arr[:count]
170
173
 
171
174
 
172
- class MNIST(Dataset):
175
+ class MNIST(Dataset[tuple[NDArray[np.float64], int]]):
173
176
  """MNIST Dataset and Corruptions.
174
177
 
175
178
  Args:
@@ -211,17 +214,17 @@ class MNIST(Dataset):
211
214
  If True, outputs print statements.
212
215
  """
213
216
 
214
- mirror = [
217
+ _mirrors: tuple[str, ...] = (
215
218
  "https://storage.googleapis.com/tensorflow/tf-keras-datasets/",
216
219
  "https://zenodo.org/record/3239543/files/",
217
- ]
220
+ )
218
221
 
219
- resources = [
222
+ _resources: tuple[tuple[str, str], ...] = (
220
223
  ("mnist.npz", "731c5ac602752760c8e48fbffcf8c3b850d9dc2a2aedcf2cc48468fc17b673d1"),
221
224
  ("mnist_c.zip", "4b34b33045869ee6d424616cd3a65da3"),
222
- ]
225
+ )
223
226
 
224
- class_dict = {
227
+ class_dict: dict[str, int] = {
225
228
  "zero": 0,
226
229
  "one": 1,
227
230
  "two": 2,
@@ -267,43 +270,46 @@ class MNIST(Dataset):
267
270
  self.randomize = randomize
268
271
  self.from_back = slice_back
269
272
  self.verbose = verbose
273
+ self.data: NDArray[np.float64]
274
+ self.targets: NDArray[np.int_]
275
+ self.size: int
270
276
 
271
- self.class_set = []
277
+ self._class_set = []
272
278
  if classes is not None:
273
279
  if not isinstance(classes, list):
274
280
  classes = [classes] # type: ignore
275
281
 
276
282
  for val in classes: # type: ignore
277
283
  if isinstance(val, int) and 0 <= val < 10:
278
- self.class_set.append(val)
284
+ self._class_set.append(val)
279
285
  elif isinstance(val, str):
280
- self.class_set.append(self.class_dict[val])
281
- self.class_set = set(self.class_set)
286
+ self._class_set.append(self.class_dict[val])
287
+ self._class_set = set(self._class_set)
282
288
 
283
- if not self.class_set:
284
- self.class_set = set(self.class_dict.values())
289
+ if not self._class_set:
290
+ self._class_set = set(self.class_dict.values())
285
291
 
286
- self.num_classes = len(self.class_set)
292
+ self._num_classes = len(self._class_set)
287
293
 
288
294
  if self.corruption is None:
289
- file_resource = self.resources[0]
290
- mirror = self.mirror[0]
295
+ file_resource = self._resources[0]
296
+ mirror = self._mirrors[0]
291
297
  md5 = False
292
298
  else:
293
299
  if self.corruption == "identity" and verbose:
294
300
  print("Identity is not a corrupted dataset but the original MNIST dataset.")
295
- file_resource = self.resources[1]
296
- mirror = self.mirror[1]
301
+ file_resource = self._resources[1]
302
+ mirror = self._mirrors[1]
297
303
  md5 = True
298
- check_exists(self.mnist_folder, mirror, self.root, file_resource[0], file_resource[1], download, verbose, md5)
304
+ _check_exists(self.mnist_folder, mirror, self.root, file_resource[0], file_resource[1], download, verbose, md5)
299
305
 
300
306
  self.data, self.targets = self._load_data()
301
307
 
302
308
  self._augmentations()
303
309
 
304
- def _load_data(self):
310
+ def _load_data(self) -> tuple[NDArray[np.float64], NDArray[np.int64]]:
305
311
  if self.corruption is None:
306
- image_file = self.resources[0][0]
312
+ image_file = self._resources[0][0]
307
313
  data, targets = self._read_normal_file(os.path.join(self.mnist_folder, image_file))
308
314
  else:
309
315
  image_file = f"{'train' if self.train else 'test'}_images.npy"
@@ -329,27 +335,27 @@ class MNIST(Dataset):
329
335
  self.data = self.data[shuffled_indices]
330
336
  self.targets = self.targets[shuffled_indices]
331
337
 
332
- if not self.balance and self.num_classes > self.size:
338
+ if not self.balance and self._num_classes > self.size:
333
339
  if self.size > 0:
334
- self.data = subselect(self.data, self.size, self.from_back)
335
- self.targets = subselect(self.targets, self.size, self.from_back)
340
+ self.data = _subselect(self.data, self.size, self.from_back)
341
+ self.targets = _subselect(self.targets, self.size, self.from_back)
336
342
  else:
337
- label_dict = {label: np.where(self.targets == label)[0] for label in self.class_set}
343
+ label_dict = {label: np.where(self.targets == label)[0] for label in self._class_set}
338
344
  min_label_count = min(len(indices) for indices in label_dict.values())
339
345
 
340
- self.per_class_count = int(np.ceil(self.size / self.num_classes)) if self.size > 0 else min_label_count
346
+ self._per_class_count = int(np.ceil(self.size / self._num_classes)) if self.size > 0 else min_label_count
341
347
 
342
- if self.per_class_count > min_label_count:
343
- self.per_class_count = min_label_count
348
+ if self._per_class_count > min_label_count:
349
+ self._per_class_count = min_label_count
344
350
  if not self.balance and self.verbose:
345
351
  warn(
346
- f"Because of dataset limitations, only {min_label_count*self.num_classes} samples "
352
+ f"Because of dataset limitations, only {min_label_count*self._num_classes} samples "
347
353
  f"will be returned, instead of the desired {self.size}."
348
354
  )
349
355
 
350
- all_indices = np.empty(shape=(self.num_classes, self.per_class_count), dtype=int)
351
- for i, label in enumerate(self.class_set):
352
- all_indices[i] = subselect(label_dict[label], self.per_class_count, self.from_back)
356
+ all_indices: NDArray[np.int_] = np.empty(shape=(self._num_classes, self._per_class_count), dtype=np.int_)
357
+ for i, label in enumerate(self._class_set):
358
+ all_indices[i] = _subselect(label_dict[label], self._per_class_count, self.from_back)
353
359
  self.data = np.vstack(self.data[all_indices.T]) # type: ignore
354
360
  self.targets = np.hstack(self.targets[all_indices.T]) # type: ignore
355
361
 
@@ -370,7 +376,7 @@ class MNIST(Dataset):
370
376
  if self.flatten and self.channels is None:
371
377
  self.data = self.data.reshape(self.data.shape[0], -1)
372
378
 
373
- def __getitem__(self, index: int) -> tuple[NDArray, int]:
379
+ def __getitem__(self, index: int) -> tuple[NDArray[np.float64], int]:
374
380
  """
375
381
  Args:
376
382
  index (int): Index