dataeval 0.72.1__py3-none-any.whl → 0.72.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dataeval/__init__.py +4 -4
- dataeval/detectors/__init__.py +4 -3
- dataeval/detectors/drift/__init__.py +9 -10
- dataeval/{_internal/detectors → detectors}/drift/base.py +39 -91
- dataeval/{_internal/detectors → detectors}/drift/cvm.py +4 -3
- dataeval/{_internal/detectors → detectors}/drift/ks.py +4 -3
- dataeval/{_internal/detectors → detectors}/drift/mmd.py +23 -25
- dataeval/{_internal/detectors → detectors}/drift/torch.py +13 -11
- dataeval/{_internal/detectors → detectors}/drift/uncertainty.py +7 -5
- dataeval/detectors/drift/updates.py +61 -0
- dataeval/detectors/linters/__init__.py +3 -3
- dataeval/{_internal/detectors → detectors/linters}/clusterer.py +41 -39
- dataeval/{_internal/detectors → detectors/linters}/duplicates.py +19 -9
- dataeval/{_internal/detectors → detectors/linters}/merged_stats.py +3 -1
- dataeval/{_internal/detectors → detectors/linters}/outliers.py +14 -21
- dataeval/detectors/ood/__init__.py +6 -6
- dataeval/{_internal/detectors → detectors}/ood/ae.py +7 -7
- dataeval/{_internal/detectors → detectors}/ood/aegmm.py +9 -29
- dataeval/{_internal/detectors → detectors}/ood/base.py +24 -18
- dataeval/{_internal/detectors → detectors}/ood/llr.py +24 -20
- dataeval/detectors/ood/metadata_ks_compare.py +99 -0
- dataeval/detectors/ood/metadata_least_likely.py +119 -0
- dataeval/detectors/ood/metadata_ood_mi.py +92 -0
- dataeval/{_internal/detectors → detectors}/ood/vae.py +10 -12
- dataeval/{_internal/detectors → detectors}/ood/vaegmm.py +10 -32
- dataeval/{_internal/interop.py → interop.py} +12 -7
- dataeval/metrics/__init__.py +1 -1
- dataeval/metrics/bias/__init__.py +4 -4
- dataeval/{_internal/metrics → metrics/bias}/balance.py +75 -9
- dataeval/{_internal/metrics → metrics/bias}/coverage.py +6 -4
- dataeval/{_internal/metrics → metrics/bias}/diversity.py +48 -14
- dataeval/metrics/bias/metadata.py +275 -0
- dataeval/{_internal/metrics → metrics/bias}/parity.py +12 -10
- dataeval/metrics/estimators/__init__.py +3 -3
- dataeval/{_internal/metrics → metrics/estimators}/ber.py +25 -22
- dataeval/{_internal/metrics → metrics/estimators}/divergence.py +11 -12
- dataeval/{_internal/metrics → metrics/estimators}/uap.py +5 -3
- dataeval/metrics/stats/__init__.py +7 -7
- dataeval/{_internal/metrics → metrics}/stats/base.py +59 -35
- dataeval/{_internal/metrics → metrics}/stats/boxratiostats.py +18 -14
- dataeval/{_internal/metrics → metrics}/stats/datasetstats.py +18 -16
- dataeval/{_internal/metrics → metrics}/stats/dimensionstats.py +9 -7
- dataeval/metrics/stats/hashstats.py +156 -0
- dataeval/{_internal/metrics → metrics}/stats/labelstats.py +5 -3
- dataeval/{_internal/metrics → metrics}/stats/pixelstats.py +9 -8
- dataeval/{_internal/metrics → metrics}/stats/visualstats.py +10 -9
- dataeval/{_internal/output.py → output.py} +26 -6
- dataeval/utils/__init__.py +7 -3
- dataeval/utils/image.py +71 -0
- dataeval/utils/shared.py +151 -0
- dataeval/{_internal → utils}/split_dataset.py +98 -33
- dataeval/utils/tensorflow/__init__.py +7 -6
- dataeval/{_internal/models/tensorflow → utils/tensorflow/_internal}/autoencoder.py +60 -64
- dataeval/{_internal/models/tensorflow/losses.py → utils/tensorflow/_internal/loss.py} +9 -8
- dataeval/{_internal/models/tensorflow → utils/tensorflow/_internal}/pixelcnn.py +16 -20
- dataeval/{_internal/models/tensorflow → utils/tensorflow/_internal}/trainer.py +3 -1
- dataeval/{_internal/models/tensorflow → utils/tensorflow/_internal}/utils.py +17 -17
- dataeval/utils/tensorflow/loss/__init__.py +6 -2
- dataeval/utils/torch/__init__.py +7 -3
- dataeval/{_internal/models/pytorch → utils/torch}/blocks.py +19 -14
- dataeval/{_internal → utils/torch}/datasets.py +48 -42
- dataeval/utils/torch/models.py +138 -0
- dataeval/{_internal/models/pytorch/autoencoder.py → utils/torch/trainer.py} +7 -136
- dataeval/{_internal → utils/torch}/utils.py +3 -1
- dataeval/workflows/__init__.py +1 -1
- dataeval/{_internal/workflows → workflows}/sufficiency.py +39 -34
- {dataeval-0.72.1.dist-info → dataeval-0.72.2.dist-info}/METADATA +2 -1
- dataeval-0.72.2.dist-info/RECORD +72 -0
- dataeval/_internal/detectors/__init__.py +0 -0
- dataeval/_internal/detectors/drift/__init__.py +0 -0
- dataeval/_internal/detectors/ood/__init__.py +0 -0
- dataeval/_internal/metrics/__init__.py +0 -0
- dataeval/_internal/metrics/stats/hashstats.py +0 -75
- dataeval/_internal/metrics/utils.py +0 -447
- dataeval/_internal/models/__init__.py +0 -0
- dataeval/_internal/models/pytorch/__init__.py +0 -0
- dataeval/_internal/models/pytorch/utils.py +0 -67
- dataeval/_internal/models/tensorflow/__init__.py +0 -0
- dataeval/_internal/workflows/__init__.py +0 -0
- dataeval/detectors/drift/kernels/__init__.py +0 -10
- dataeval/detectors/drift/updates/__init__.py +0 -8
- dataeval/utils/tensorflow/models/__init__.py +0 -9
- dataeval/utils/tensorflow/recon/__init__.py +0 -3
- dataeval/utils/torch/datasets/__init__.py +0 -12
- dataeval/utils/torch/models/__init__.py +0 -11
- dataeval/utils/torch/trainer/__init__.py +0 -7
- dataeval-0.72.1.dist-info/RECORD +0 -81
- /dataeval/{_internal/models/tensorflow → utils/tensorflow/_internal}/gmm.py +0 -0
- {dataeval-0.72.1.dist-info → dataeval-0.72.2.dist-info}/LICENSE.txt +0 -0
- {dataeval-0.72.1.dist-info → dataeval-0.72.2.dist-info}/WHEEL +0 -0
@@ -34,13 +34,9 @@ from tensorflow_probability.python.internal import (
|
|
34
34
|
tensorshape_util,
|
35
35
|
)
|
36
36
|
|
37
|
-
__all__ = [
|
38
|
-
"Shift",
|
39
|
-
]
|
40
|
-
|
41
37
|
|
42
38
|
class WeightNorm(keras.layers.Wrapper):
|
43
|
-
def __init__(self, layer, data_init: bool = True, **kwargs):
|
39
|
+
def __init__(self, layer, data_init: bool = True, **kwargs) -> None:
|
44
40
|
"""Layer wrapper to decouple magnitude and direction of the layer's weights.
|
45
41
|
|
46
42
|
This wrapper reparameterizes a layer by decoupling the weight's
|
@@ -187,7 +183,7 @@ class WeightNorm(keras.layers.Wrapper):
|
|
187
183
|
|
188
184
|
|
189
185
|
class Shift(bijector.Bijector):
|
190
|
-
def __init__(self, shift, validate_args=False, name="shift"):
|
186
|
+
def __init__(self, shift, validate_args=False, name="shift") -> None:
|
191
187
|
"""Instantiates the `Shift` bijector which computes `Y = g(X; shift) = X + shift`
|
192
188
|
where `shift` is a numeric `Tensor`.
|
193
189
|
|
@@ -276,13 +272,13 @@ class PixelCNN(distribution.Distribution):
|
|
276
272
|
|
277
273
|
def __init__(
|
278
274
|
self,
|
279
|
-
image_shape: tuple,
|
280
|
-
conditional_shape: tuple | None = None,
|
275
|
+
image_shape: tuple[int, int, int],
|
276
|
+
conditional_shape: tuple[int, ...] | None = None,
|
281
277
|
num_resnet: int = 5,
|
282
278
|
num_hierarchies: int = 3,
|
283
279
|
num_filters: int = 160,
|
284
280
|
num_logistic_mix: int = 10,
|
285
|
-
receptive_field_dims: tuple = (3, 3),
|
281
|
+
receptive_field_dims: tuple[int, int] = (3, 3),
|
286
282
|
dropout_p: float = 0.5,
|
287
283
|
resnet_activation: str = "concat_elu",
|
288
284
|
l2_weight: float = 0.0,
|
@@ -290,7 +286,7 @@ class PixelCNN(distribution.Distribution):
|
|
290
286
|
use_data_init: bool = True,
|
291
287
|
high: int = 255,
|
292
288
|
low: int = 0,
|
293
|
-
dtype=tf.float32,
|
289
|
+
dtype: tf.DType = tf.float32,
|
294
290
|
) -> None:
|
295
291
|
parameters = dict(locals())
|
296
292
|
with tf.name_scope("PixelCNN") as name:
|
@@ -315,7 +311,7 @@ class PixelCNN(distribution.Distribution):
|
|
315
311
|
self._high = tf.cast(high, self.dtype)
|
316
312
|
self._low = tf.cast(low, self.dtype)
|
317
313
|
self._num_logistic_mix = num_logistic_mix
|
318
|
-
self.
|
314
|
+
self._network = PixelCNNNetwork(
|
319
315
|
dropout_p=dropout_p,
|
320
316
|
num_resnet=num_resnet,
|
321
317
|
num_hierarchies=num_hierarchies,
|
@@ -338,7 +334,7 @@ class PixelCNN(distribution.Distribution):
|
|
338
334
|
|
339
335
|
self.image_shape = image_shape
|
340
336
|
self.conditional_shape = conditional_shape
|
341
|
-
self.
|
337
|
+
self._network.build(input_shape)
|
342
338
|
|
343
339
|
def _make_mixture_dist(self, component_logits, locs, scales, return_per_feature: bool = False):
|
344
340
|
"""Builds a mixture of quantized logistic distributions.
|
@@ -455,7 +451,7 @@ class PixelCNN(distribution.Distribution):
|
|
455
451
|
transformed_value = (2.0 * (value - self._low) / (self._high - self._low)) - 1.0
|
456
452
|
inputs = transformed_value if conditional_input is None else [transformed_value, conditional_input]
|
457
453
|
|
458
|
-
params = self.
|
454
|
+
params = self._network(inputs, training=training)
|
459
455
|
|
460
456
|
num_channels = self.event_shape[-1]
|
461
457
|
if num_channels == 1:
|
@@ -554,7 +550,7 @@ class PixelCNN(distribution.Distribution):
|
|
554
550
|
seed=seed,
|
555
551
|
)
|
556
552
|
inputs = samples_0 if conditional_input is None else [samples_0, h]
|
557
|
-
params_0 = self.
|
553
|
+
params_0 = self._network(inputs, training=training)
|
558
554
|
samples_0 = self._sample_channels(*params_0, seed=seed)
|
559
555
|
|
560
556
|
image_height, image_width, _ = tensorshape_util.as_list(self.event_shape)
|
@@ -579,7 +575,7 @@ class PixelCNN(distribution.Distribution):
|
|
579
575
|
width, num_channels]`.
|
580
576
|
"""
|
581
577
|
inputs = samples if conditional_input is None else [samples, h]
|
582
|
-
params = self.
|
578
|
+
params = self._network(inputs, training=training)
|
583
579
|
samples_new = self._sample_channels(*params, seed=seed)
|
584
580
|
|
585
581
|
# Update the current pixel
|
@@ -673,7 +669,7 @@ class PixelCNN(distribution.Distribution):
|
|
673
669
|
return tf.TensorShape(self.image_shape)
|
674
670
|
|
675
671
|
|
676
|
-
class
|
672
|
+
class PixelCNNNetwork(keras.layers.Layer):
|
677
673
|
"""Keras `Layer` to parameterize a Pixel CNN++ distribution.
|
678
674
|
This is a Keras implementation of the Pixel CNN++ network, as described in
|
679
675
|
Salimans et al. (2017)[1] and van den Oord et al. (2016)[2].
|
@@ -699,12 +695,12 @@ class _PixelCNNNetwork(keras.layers.Layer):
|
|
699
695
|
num_hierarchies: int = 3,
|
700
696
|
num_filters: int = 160,
|
701
697
|
num_logistic_mix: int = 10,
|
702
|
-
receptive_field_dims: tuple = (3, 3),
|
698
|
+
receptive_field_dims: tuple[int, int] = (3, 3),
|
703
699
|
resnet_activation: str = "concat_elu",
|
704
700
|
l2_weight: float = 0.0,
|
705
701
|
use_weight_norm: bool = True,
|
706
702
|
use_data_init: bool = True,
|
707
|
-
dtype=tf.float32,
|
703
|
+
dtype: tf.DType = tf.float32,
|
708
704
|
) -> None:
|
709
705
|
"""Initialize the :term:`neural network<Neural Network>` for the Pixel CNN++ distribution.
|
710
706
|
|
@@ -765,7 +761,7 @@ class _PixelCNNNetwork(keras.layers.Layer):
|
|
765
761
|
else:
|
766
762
|
self._layer_wrapper = lambda layer: layer
|
767
763
|
|
768
|
-
def build(self, input_shape):
|
764
|
+
def build(self, input_shape: tuple[int, ...]) -> None:
|
769
765
|
dtype = self.dtype
|
770
766
|
if len(input_shape) == 2:
|
771
767
|
batch_image_shape, batch_conditional_shape = input_shape
|
@@ -1040,7 +1036,7 @@ class _PixelCNNNetwork(keras.layers.Layer):
|
|
1040
1036
|
self._network = keras.Model(inputs=inputs, outputs=outputs)
|
1041
1037
|
super().build(input_shape)
|
1042
1038
|
|
1043
|
-
def call(self, inputs, training=None):
|
1039
|
+
def call(self, inputs: tf.Tensor, training: bool | None = None, mask: tf.Tensor | None = None) -> tf.Tensor:
|
1044
1040
|
"""Call the Pixel CNN network model.
|
1045
1041
|
|
1046
1042
|
Parameters
|
@@ -60,7 +60,9 @@ def trainer(
|
|
60
60
|
loss_fn = loss_fn() if isinstance(loss_fn, type) else loss_fn
|
61
61
|
optimizer = optimizer() if isinstance(optimizer, type) else optimizer
|
62
62
|
|
63
|
-
train_data =
|
63
|
+
train_data = (
|
64
|
+
x_train.astype(np.float32) if y_train is None else (x_train.astype(np.float32), y_train.astype(np.float32))
|
65
|
+
)
|
64
66
|
dataset = tf.data.Dataset.from_tensor_slices(train_data)
|
65
67
|
dataset = dataset.shuffle(buffer_size=buffer_size).batch(batch_size)
|
66
68
|
n_minibatch = len(dataset)
|
@@ -9,7 +9,7 @@ Licensed under Apache Software License (Apache 2.0)
|
|
9
9
|
from __future__ import annotations
|
10
10
|
|
11
11
|
import math
|
12
|
-
from typing import Callable, Union, cast
|
12
|
+
from typing import Any, Callable, Literal, Union, cast
|
13
13
|
|
14
14
|
import numpy as np
|
15
15
|
import tensorflow as tf
|
@@ -26,8 +26,8 @@ from tf_keras.layers import (
|
|
26
26
|
Reshape,
|
27
27
|
)
|
28
28
|
|
29
|
-
from dataeval.
|
30
|
-
from dataeval.
|
29
|
+
from dataeval.utils.tensorflow._internal.autoencoder import AE, AEGMM, VAE, VAEGMM
|
30
|
+
from dataeval.utils.tensorflow._internal.pixelcnn import PixelCNN
|
31
31
|
|
32
32
|
|
33
33
|
def predict_batch(
|
@@ -95,7 +95,7 @@ def predict_batch(
|
|
95
95
|
return out
|
96
96
|
|
97
97
|
|
98
|
-
def
|
98
|
+
def get_default_encoder_net(input_shape: tuple[int, int, int], encoding_dim: int):
|
99
99
|
return Sequential(
|
100
100
|
[
|
101
101
|
InputLayer(input_shape=input_shape),
|
@@ -108,7 +108,7 @@ def _get_default_encoder_net(input_shape: tuple[int, int, int], encoding_dim: in
|
|
108
108
|
)
|
109
109
|
|
110
110
|
|
111
|
-
def
|
111
|
+
def get_default_decoder_net(input_shape: tuple[int, int, int], encoding_dim: int):
|
112
112
|
return Sequential(
|
113
113
|
[
|
114
114
|
InputLayer(input_shape=(encoding_dim,)),
|
@@ -124,18 +124,18 @@ def _get_default_decoder_net(input_shape: tuple[int, int, int], encoding_dim: in
|
|
124
124
|
|
125
125
|
|
126
126
|
def create_model(
|
127
|
-
model_type: AE
|
127
|
+
model_type: Literal["AE", "AEGMM", "PixelCNN", "VAE", "VAEGMM"],
|
128
128
|
input_shape: tuple[int, int, int],
|
129
129
|
encoding_dim: int | None = None,
|
130
130
|
n_gmm: int | None = None,
|
131
131
|
gmm_latent_dim: int | None = None,
|
132
|
-
):
|
132
|
+
) -> Any:
|
133
133
|
"""
|
134
134
|
Create a default model for the specified model type.
|
135
135
|
|
136
136
|
Parameters
|
137
137
|
----------
|
138
|
-
model_type :
|
138
|
+
model_type : Literal["AE", "AEGMM", "PixelCNN", "VAE", "VAEGMM"]
|
139
139
|
The model type to create.
|
140
140
|
input_shape : Tuple[int, int, int]
|
141
141
|
The input shape of the data used.
|
@@ -148,20 +148,20 @@ def create_model(
|
|
148
148
|
"""
|
149
149
|
input_dim = math.prod(input_shape)
|
150
150
|
encoding_dim = int(math.pow(2, int(input_dim.bit_length() * 0.8)) if encoding_dim is None else encoding_dim)
|
151
|
-
if model_type == AE:
|
151
|
+
if model_type == "AE":
|
152
152
|
return AE(
|
153
|
-
|
154
|
-
|
153
|
+
get_default_encoder_net(input_shape, encoding_dim),
|
154
|
+
get_default_decoder_net(input_shape, encoding_dim),
|
155
155
|
)
|
156
156
|
|
157
|
-
if model_type == VAE:
|
157
|
+
if model_type == "VAE":
|
158
158
|
return VAE(
|
159
|
-
|
160
|
-
|
159
|
+
get_default_encoder_net(input_shape, encoding_dim),
|
160
|
+
get_default_decoder_net(input_shape, encoding_dim),
|
161
161
|
encoding_dim,
|
162
162
|
)
|
163
163
|
|
164
|
-
if model_type == AEGMM:
|
164
|
+
if model_type == "AEGMM":
|
165
165
|
n_gmm = 2 if n_gmm is None else n_gmm
|
166
166
|
gmm_latent_dim = 1 if gmm_latent_dim is None else gmm_latent_dim
|
167
167
|
# The outlier detector is an encoder/decoder architecture
|
@@ -201,7 +201,7 @@ def create_model(
|
|
201
201
|
n_gmm=n_gmm,
|
202
202
|
)
|
203
203
|
|
204
|
-
if model_type == VAEGMM:
|
204
|
+
if model_type == "VAEGMM":
|
205
205
|
n_gmm = 2 if n_gmm is None else n_gmm
|
206
206
|
gmm_latent_dim = 2 if gmm_latent_dim is None else gmm_latent_dim
|
207
207
|
# The outlier detector is an encoder/decoder architecture
|
@@ -242,7 +242,7 @@ def create_model(
|
|
242
242
|
latent_dim=gmm_latent_dim,
|
243
243
|
)
|
244
244
|
|
245
|
-
if model_type == PixelCNN:
|
245
|
+
if model_type == "PixelCNN":
|
246
246
|
return PixelCNN(
|
247
247
|
image_shape=input_shape,
|
248
248
|
num_resnet=5,
|
@@ -1,7 +1,11 @@
|
|
1
1
|
from dataeval import _IS_TENSORFLOW_AVAILABLE
|
2
|
-
from dataeval._internal.models.tensorflow.losses import Elbo, LossGMM
|
3
2
|
|
4
3
|
__all__ = []
|
5
4
|
|
5
|
+
|
6
6
|
if _IS_TENSORFLOW_AVAILABLE:
|
7
|
-
|
7
|
+
from dataeval.utils.tensorflow._internal.loss import Elbo, LossGMM
|
8
|
+
|
9
|
+
__all__ = ["Elbo", "LossGMM"]
|
10
|
+
|
11
|
+
del _IS_TENSORFLOW_AVAILABLE
|
dataeval/utils/torch/__init__.py
CHANGED
@@ -6,16 +6,20 @@ to create a seamless integration between custom models and DataEval's metrics.
|
|
6
6
|
"""
|
7
7
|
|
8
8
|
from dataeval import _IS_TORCH_AVAILABLE, _IS_TORCHVISION_AVAILABLE
|
9
|
-
from dataeval._internal.utils import read_dataset
|
10
9
|
|
11
10
|
__all__ = []
|
12
11
|
|
13
12
|
if _IS_TORCH_AVAILABLE:
|
14
|
-
from . import models, trainer
|
13
|
+
from dataeval.utils.torch import models, trainer
|
14
|
+
from dataeval.utils.torch.utils import read_dataset
|
15
15
|
|
16
16
|
__all__ += ["read_dataset", "models", "trainer"]
|
17
17
|
|
18
18
|
if _IS_TORCHVISION_AVAILABLE:
|
19
|
-
from . import datasets
|
19
|
+
from dataeval.utils.torch import datasets
|
20
20
|
|
21
21
|
__all__ += ["datasets"]
|
22
|
+
|
23
|
+
|
24
|
+
del _IS_TORCH_AVAILABLE
|
25
|
+
del _IS_TORCHVISION_AVAILABLE
|
@@ -1,3 +1,7 @@
|
|
1
|
+
from typing import Any
|
2
|
+
|
3
|
+
__all__ = []
|
4
|
+
|
1
5
|
import torch.nn as nn
|
2
6
|
|
3
7
|
|
@@ -8,21 +12,22 @@ class Conv(nn.Module):
|
|
8
12
|
|
9
13
|
def __init__(
|
10
14
|
self,
|
11
|
-
in_channels,
|
12
|
-
out_channels,
|
13
|
-
k=1,
|
14
|
-
s=1,
|
15
|
-
p=0,
|
16
|
-
activation="relu",
|
17
|
-
norm="instance",
|
18
|
-
):
|
15
|
+
in_channels: int,
|
16
|
+
out_channels: int,
|
17
|
+
k: int = 1,
|
18
|
+
s: int = 1,
|
19
|
+
p: int = 0,
|
20
|
+
activation: str = "relu",
|
21
|
+
norm: str = "instance",
|
22
|
+
) -> None:
|
19
23
|
super().__init__()
|
20
|
-
|
21
|
-
|
22
|
-
|
23
|
-
|
24
|
+
self.module: nn.Sequential = nn.Sequential(
|
25
|
+
nn.Conv2d(in_channels, out_channels, kernel_size=k, stride=s, padding=p),
|
26
|
+
self.get_norm_func(norm=norm, out_channels=out_channels),
|
27
|
+
self.get_activation_func(activation=activation),
|
28
|
+
)
|
24
29
|
|
25
|
-
def get_norm_func(self, norm: str, out_channels) -> nn.Module:
|
30
|
+
def get_norm_func(self, norm: str, out_channels: int) -> nn.Module:
|
26
31
|
if norm == "batch":
|
27
32
|
return nn.BatchNorm2d(out_channels)
|
28
33
|
if norm == "instance":
|
@@ -42,5 +47,5 @@ class Conv(nn.Module):
|
|
42
47
|
return nn.Tanh()
|
43
48
|
return nn.Identity()
|
44
49
|
|
45
|
-
def forward(self, x):
|
50
|
+
def forward(self, x: Any) -> Any:
|
46
51
|
return self.module(x)
|
@@ -1,5 +1,7 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
|
+
__all__ = ["MNIST", "CIFAR10", "VOCDetection"]
|
4
|
+
|
3
5
|
import hashlib
|
4
6
|
import os
|
5
7
|
import zipfile
|
@@ -11,7 +13,7 @@ import numpy as np
|
|
11
13
|
import requests
|
12
14
|
from numpy.typing import NDArray
|
13
15
|
from torch.utils.data import Dataset
|
14
|
-
from torchvision.datasets import CIFAR10, VOCDetection
|
16
|
+
from torchvision.datasets import CIFAR10, VOCDetection
|
15
17
|
|
16
18
|
ClassStringMap = Literal["zero", "one", "two", "three", "four", "five", "six", "seven", "eight", "nine"]
|
17
19
|
TClassMap = TypeVar("TClassMap", ClassStringMap, int, list[ClassStringMap], list[int])
|
@@ -50,6 +52,7 @@ def _get_file(
|
|
50
52
|
file_hash: str | None = None,
|
51
53
|
verbose: bool = True,
|
52
54
|
md5: bool = False,
|
55
|
+
timeout: int = 60,
|
53
56
|
):
|
54
57
|
fpath = os.path.join(root, fname)
|
55
58
|
download = True
|
@@ -64,16 +67,16 @@ def _get_file(
|
|
64
67
|
try:
|
65
68
|
error_msg = "URL fetch failure on {}: {} -- {}"
|
66
69
|
try:
|
67
|
-
with requests.get(origin, stream=True, timeout=
|
70
|
+
with requests.get(origin, stream=True, timeout=timeout) as r:
|
68
71
|
r.raise_for_status()
|
69
72
|
with open(fpath, "wb") as f:
|
70
73
|
for chunk in r.iter_content(chunk_size=8192):
|
71
74
|
if chunk:
|
72
75
|
f.write(chunk)
|
73
76
|
except requests.exceptions.HTTPError as e:
|
74
|
-
raise
|
77
|
+
raise RuntimeError(f"{error_msg.format(origin, e.response.status_code, e.response.reason)}") from e
|
75
78
|
except requests.exceptions.RequestException as e:
|
76
|
-
raise
|
79
|
+
raise ValueError(f"{error_msg.format(origin, 'Unknown error', str(e))}") from e
|
77
80
|
except (Exception, KeyboardInterrupt):
|
78
81
|
if os.path.exists(fpath):
|
79
82
|
os.remove(fpath)
|
@@ -89,7 +92,7 @@ def _get_file(
|
|
89
92
|
return fpath
|
90
93
|
|
91
94
|
|
92
|
-
def
|
95
|
+
def _check_exists(
|
93
96
|
folder: str | Path,
|
94
97
|
url: str,
|
95
98
|
root: str | Path,
|
@@ -103,7 +106,7 @@ def check_exists(
|
|
103
106
|
location = str(folder)
|
104
107
|
if not os.path.exists(folder):
|
105
108
|
if download:
|
106
|
-
location =
|
109
|
+
location = _download_dataset(url, root, fname, file_hash, verbose, md5)
|
107
110
|
else:
|
108
111
|
raise RuntimeError("Dataset not found. You can use download=True to download it")
|
109
112
|
else:
|
@@ -112,7 +115,7 @@ def check_exists(
|
|
112
115
|
return location
|
113
116
|
|
114
117
|
|
115
|
-
def
|
118
|
+
def _download_dataset(
|
116
119
|
url: str, root: str | Path, fname: str, file_hash: str, verbose: bool = True, md5: bool = False
|
117
120
|
) -> str:
|
118
121
|
"""Code to download mnist and corruptions, originates from tensorflow_datasets (tfds):
|
@@ -131,11 +134,11 @@ def download_dataset(
|
|
131
134
|
md5=md5,
|
132
135
|
)
|
133
136
|
if md5:
|
134
|
-
folder =
|
137
|
+
folder = _extract_archive(fpath, root, remove_finished=True)
|
135
138
|
return folder
|
136
139
|
|
137
140
|
|
138
|
-
def
|
141
|
+
def _extract_archive(
|
139
142
|
from_path: str | Path,
|
140
143
|
to_path: str | Path | None = None,
|
141
144
|
remove_finished: bool = False,
|
@@ -163,13 +166,13 @@ def extract_archive(
|
|
163
166
|
return str(to_path)
|
164
167
|
|
165
168
|
|
166
|
-
def
|
169
|
+
def _subselect(arr: NDArray, count: int, from_back: bool = False):
|
167
170
|
if from_back:
|
168
171
|
return arr[-count:]
|
169
172
|
return arr[:count]
|
170
173
|
|
171
174
|
|
172
|
-
class MNIST(Dataset):
|
175
|
+
class MNIST(Dataset[tuple[NDArray[np.float64], int]]):
|
173
176
|
"""MNIST Dataset and Corruptions.
|
174
177
|
|
175
178
|
Args:
|
@@ -211,17 +214,17 @@ class MNIST(Dataset):
|
|
211
214
|
If True, outputs print statements.
|
212
215
|
"""
|
213
216
|
|
214
|
-
|
217
|
+
_mirrors: tuple[str, ...] = (
|
215
218
|
"https://storage.googleapis.com/tensorflow/tf-keras-datasets/",
|
216
219
|
"https://zenodo.org/record/3239543/files/",
|
217
|
-
|
220
|
+
)
|
218
221
|
|
219
|
-
|
222
|
+
_resources: tuple[tuple[str, str], ...] = (
|
220
223
|
("mnist.npz", "731c5ac602752760c8e48fbffcf8c3b850d9dc2a2aedcf2cc48468fc17b673d1"),
|
221
224
|
("mnist_c.zip", "4b34b33045869ee6d424616cd3a65da3"),
|
222
|
-
|
225
|
+
)
|
223
226
|
|
224
|
-
class_dict = {
|
227
|
+
class_dict: dict[str, int] = {
|
225
228
|
"zero": 0,
|
226
229
|
"one": 1,
|
227
230
|
"two": 2,
|
@@ -267,43 +270,46 @@ class MNIST(Dataset):
|
|
267
270
|
self.randomize = randomize
|
268
271
|
self.from_back = slice_back
|
269
272
|
self.verbose = verbose
|
273
|
+
self.data: NDArray[np.float64]
|
274
|
+
self.targets: NDArray[np.int_]
|
275
|
+
self.size: int
|
270
276
|
|
271
|
-
self.
|
277
|
+
self._class_set = []
|
272
278
|
if classes is not None:
|
273
279
|
if not isinstance(classes, list):
|
274
280
|
classes = [classes] # type: ignore
|
275
281
|
|
276
282
|
for val in classes: # type: ignore
|
277
283
|
if isinstance(val, int) and 0 <= val < 10:
|
278
|
-
self.
|
284
|
+
self._class_set.append(val)
|
279
285
|
elif isinstance(val, str):
|
280
|
-
self.
|
281
|
-
self.
|
286
|
+
self._class_set.append(self.class_dict[val])
|
287
|
+
self._class_set = set(self._class_set)
|
282
288
|
|
283
|
-
if not self.
|
284
|
-
self.
|
289
|
+
if not self._class_set:
|
290
|
+
self._class_set = set(self.class_dict.values())
|
285
291
|
|
286
|
-
self.
|
292
|
+
self._num_classes = len(self._class_set)
|
287
293
|
|
288
294
|
if self.corruption is None:
|
289
|
-
file_resource = self.
|
290
|
-
mirror = self.
|
295
|
+
file_resource = self._resources[0]
|
296
|
+
mirror = self._mirrors[0]
|
291
297
|
md5 = False
|
292
298
|
else:
|
293
299
|
if self.corruption == "identity" and verbose:
|
294
300
|
print("Identity is not a corrupted dataset but the original MNIST dataset.")
|
295
|
-
file_resource = self.
|
296
|
-
mirror = self.
|
301
|
+
file_resource = self._resources[1]
|
302
|
+
mirror = self._mirrors[1]
|
297
303
|
md5 = True
|
298
|
-
|
304
|
+
_check_exists(self.mnist_folder, mirror, self.root, file_resource[0], file_resource[1], download, verbose, md5)
|
299
305
|
|
300
306
|
self.data, self.targets = self._load_data()
|
301
307
|
|
302
308
|
self._augmentations()
|
303
309
|
|
304
|
-
def _load_data(self):
|
310
|
+
def _load_data(self) -> tuple[NDArray[np.float64], NDArray[np.int64]]:
|
305
311
|
if self.corruption is None:
|
306
|
-
image_file = self.
|
312
|
+
image_file = self._resources[0][0]
|
307
313
|
data, targets = self._read_normal_file(os.path.join(self.mnist_folder, image_file))
|
308
314
|
else:
|
309
315
|
image_file = f"{'train' if self.train else 'test'}_images.npy"
|
@@ -329,27 +335,27 @@ class MNIST(Dataset):
|
|
329
335
|
self.data = self.data[shuffled_indices]
|
330
336
|
self.targets = self.targets[shuffled_indices]
|
331
337
|
|
332
|
-
if not self.balance and self.
|
338
|
+
if not self.balance and self._num_classes > self.size:
|
333
339
|
if self.size > 0:
|
334
|
-
self.data =
|
335
|
-
self.targets =
|
340
|
+
self.data = _subselect(self.data, self.size, self.from_back)
|
341
|
+
self.targets = _subselect(self.targets, self.size, self.from_back)
|
336
342
|
else:
|
337
|
-
label_dict = {label: np.where(self.targets == label)[0] for label in self.
|
343
|
+
label_dict = {label: np.where(self.targets == label)[0] for label in self._class_set}
|
338
344
|
min_label_count = min(len(indices) for indices in label_dict.values())
|
339
345
|
|
340
|
-
self.
|
346
|
+
self._per_class_count = int(np.ceil(self.size / self._num_classes)) if self.size > 0 else min_label_count
|
341
347
|
|
342
|
-
if self.
|
343
|
-
self.
|
348
|
+
if self._per_class_count > min_label_count:
|
349
|
+
self._per_class_count = min_label_count
|
344
350
|
if not self.balance and self.verbose:
|
345
351
|
warn(
|
346
|
-
f"Because of dataset limitations, only {min_label_count*self.
|
352
|
+
f"Because of dataset limitations, only {min_label_count*self._num_classes} samples "
|
347
353
|
f"will be returned, instead of the desired {self.size}."
|
348
354
|
)
|
349
355
|
|
350
|
-
all_indices = np.empty(shape=(self.
|
351
|
-
for i, label in enumerate(self.
|
352
|
-
all_indices[i] =
|
356
|
+
all_indices: NDArray[np.int_] = np.empty(shape=(self._num_classes, self._per_class_count), dtype=np.int_)
|
357
|
+
for i, label in enumerate(self._class_set):
|
358
|
+
all_indices[i] = _subselect(label_dict[label], self._per_class_count, self.from_back)
|
353
359
|
self.data = np.vstack(self.data[all_indices.T]) # type: ignore
|
354
360
|
self.targets = np.hstack(self.targets[all_indices.T]) # type: ignore
|
355
361
|
|
@@ -370,7 +376,7 @@ class MNIST(Dataset):
|
|
370
376
|
if self.flatten and self.channels is None:
|
371
377
|
self.data = self.data.reshape(self.data.shape[0], -1)
|
372
378
|
|
373
|
-
def __getitem__(self, index: int) -> tuple[NDArray, int]:
|
379
|
+
def __getitem__(self, index: int) -> tuple[NDArray[np.float64], int]:
|
374
380
|
"""
|
375
381
|
Args:
|
376
382
|
index (int): Index
|