dataeval 0.69.3__py3-none-any.whl → 0.69.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dataeval/__init__.py +1 -1
- dataeval/_internal/datasets.py +300 -0
- {dataeval-0.69.3.dist-info → dataeval-0.69.4.dist-info}/METADATA +2 -1
- {dataeval-0.69.3.dist-info → dataeval-0.69.4.dist-info}/RECORD +6 -5
- {dataeval-0.69.3.dist-info → dataeval-0.69.4.dist-info}/LICENSE.txt +0 -0
- {dataeval-0.69.3.dist-info → dataeval-0.69.4.dist-info}/WHEEL +0 -0
dataeval/__init__.py
CHANGED
@@ -0,0 +1,300 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
import hashlib
|
4
|
+
import os
|
5
|
+
import zipfile
|
6
|
+
from pathlib import Path
|
7
|
+
from typing import Literal
|
8
|
+
from urllib.error import HTTPError, URLError
|
9
|
+
from urllib.request import urlretrieve
|
10
|
+
|
11
|
+
import numpy as np
|
12
|
+
from numpy.typing import NDArray
|
13
|
+
from torch.utils.data import Dataset
|
14
|
+
from torchvision.datasets import CIFAR10, VOCDetection # noqa: F401
|
15
|
+
|
16
|
+
|
17
|
+
def _validate_file(fpath, file_md5, chunk_size=65535):
|
18
|
+
hasher = hashlib.md5()
|
19
|
+
with open(fpath, "rb") as fpath_file:
|
20
|
+
while chunk := fpath_file.read(chunk_size):
|
21
|
+
hasher.update(chunk)
|
22
|
+
return hasher.hexdigest() == file_md5
|
23
|
+
|
24
|
+
|
25
|
+
def _get_file(
|
26
|
+
root: str | Path,
|
27
|
+
fname: str,
|
28
|
+
origin: str,
|
29
|
+
file_md5: str | None = None,
|
30
|
+
):
|
31
|
+
fname = os.fspath(fname) if isinstance(fname, os.PathLike) else fname
|
32
|
+
fpath = os.path.join(root, fname)
|
33
|
+
|
34
|
+
download = False
|
35
|
+
if os.path.exists(fpath):
|
36
|
+
if file_md5 is not None and not _validate_file(fpath, file_md5):
|
37
|
+
download = True
|
38
|
+
else:
|
39
|
+
print("Files already downloaded and verified")
|
40
|
+
else:
|
41
|
+
download = True
|
42
|
+
|
43
|
+
if download:
|
44
|
+
try:
|
45
|
+
error_msg = "URL fetch failure on {}: {} -- {}"
|
46
|
+
try:
|
47
|
+
urlretrieve(origin, fpath)
|
48
|
+
except HTTPError as e:
|
49
|
+
raise Exception(error_msg.format(origin, e.code, e.msg)) from e
|
50
|
+
except URLError as e:
|
51
|
+
raise Exception(error_msg.format(origin, e.errno, e.reason)) from e
|
52
|
+
except (Exception, KeyboardInterrupt):
|
53
|
+
if os.path.exists(fpath):
|
54
|
+
os.remove(fpath)
|
55
|
+
raise
|
56
|
+
|
57
|
+
if os.path.exists(fpath) and file_md5 is not None and not _validate_file(fpath, file_md5):
|
58
|
+
raise ValueError(
|
59
|
+
"Incomplete or corrupted file detected. "
|
60
|
+
f"The md5 file hash does not match the provided value "
|
61
|
+
f"of {file_md5}.",
|
62
|
+
)
|
63
|
+
return fpath
|
64
|
+
|
65
|
+
|
66
|
+
def download_dataset(url: str, root: str | Path, fname: str, md5: str) -> str:
|
67
|
+
"""Code to download mnist and corruptions, originates from tensorflow_datasets (tfds):
|
68
|
+
https://github.com/tensorflow/datasets/blob/master/tensorflow_datasets/image_classification/mnist_corrupted.py
|
69
|
+
"""
|
70
|
+
name, _ = os.path.splitext(fname)
|
71
|
+
folder = os.path.join(root, name)
|
72
|
+
os.makedirs(folder, exist_ok=True)
|
73
|
+
|
74
|
+
path = _get_file(
|
75
|
+
root,
|
76
|
+
fname,
|
77
|
+
origin=url + fname,
|
78
|
+
file_md5=md5,
|
79
|
+
)
|
80
|
+
extract_archive(path, remove_finished=True)
|
81
|
+
return path
|
82
|
+
|
83
|
+
|
84
|
+
def extract_archive(
|
85
|
+
from_path: str | Path,
|
86
|
+
to_path: str | Path | None = None,
|
87
|
+
remove_finished: bool = False,
|
88
|
+
):
|
89
|
+
"""Extract an archive.
|
90
|
+
|
91
|
+
The archive type and a possible compression is automatically detected from the file name.
|
92
|
+
"""
|
93
|
+
from_path = Path(from_path)
|
94
|
+
if not from_path.is_absolute():
|
95
|
+
from_path = from_path.resolve()
|
96
|
+
|
97
|
+
if to_path is None:
|
98
|
+
to_path = os.path.dirname(from_path)
|
99
|
+
|
100
|
+
# Extracting zip
|
101
|
+
with zipfile.ZipFile(from_path, "r", compression=zipfile.ZIP_STORED) as zzip:
|
102
|
+
zzip.extractall(to_path)
|
103
|
+
|
104
|
+
if remove_finished:
|
105
|
+
os.remove(from_path)
|
106
|
+
|
107
|
+
|
108
|
+
class MNIST(Dataset):
|
109
|
+
"""MNIST Dataset and Corruptions.
|
110
|
+
|
111
|
+
Args:
|
112
|
+
root : str | ``pathlib.Path``
|
113
|
+
Root directory of dataset where the ``mnist_c/`` folder exists.
|
114
|
+
train : bool, default True
|
115
|
+
If True, creates dataset from ``train_images.npy`` and ``train_labels.npy``.
|
116
|
+
download : bool, default False
|
117
|
+
If True, downloads the dataset from the internet and puts it in root
|
118
|
+
directory. If dataset is already downloaded, it is not downloaded again.
|
119
|
+
size : int, default -1
|
120
|
+
Limit the dataset size, must be a value greater than 0.
|
121
|
+
unit_interval : bool, default False
|
122
|
+
Shift the data values to the unit interval [0-1].
|
123
|
+
dtype : type | None, default None
|
124
|
+
Change the numpy dtype - data is loaded as np.uint8
|
125
|
+
channels : Literal['channels_first' | 'channels_last'] | None, default None
|
126
|
+
Location of channel axis if desired, default has no channels (N, 28, 28)
|
127
|
+
flatten : bool, default False
|
128
|
+
Flatten data into single dimension (N, 784) - cannot use both channels and flatten,
|
129
|
+
channels takes priority over flatten.
|
130
|
+
normalize : tuple[mean, std] | None, default None
|
131
|
+
Normalize images acorrding to provided mean and standard deviation
|
132
|
+
corruption : Literal['identity' | 'shot_noise' | 'impulse_noise' | 'glass_blur' |
|
133
|
+
'motion_blur' | 'shear' | 'scale' | 'rotate' | 'brightness' | 'translate' | 'stripe' |
|
134
|
+
'fog' | 'spatter' | 'dotted_line' | 'zigzag' | 'canny_edges'] | None, default None
|
135
|
+
The desired corruption style or None.
|
136
|
+
"""
|
137
|
+
|
138
|
+
mirror = "https://zenodo.org/record/3239543/files/"
|
139
|
+
|
140
|
+
resources = ("mnist_c.zip", "4b34b33045869ee6d424616cd3a65da3")
|
141
|
+
|
142
|
+
classes = [
|
143
|
+
"0 - zero",
|
144
|
+
"1 - one",
|
145
|
+
"2 - two",
|
146
|
+
"3 - three",
|
147
|
+
"4 - four",
|
148
|
+
"5 - five",
|
149
|
+
"6 - six",
|
150
|
+
"7 - seven",
|
151
|
+
"8 - eight",
|
152
|
+
"9 - nine",
|
153
|
+
]
|
154
|
+
|
155
|
+
@property
|
156
|
+
def train_labels(self):
|
157
|
+
return self.targets
|
158
|
+
|
159
|
+
@property
|
160
|
+
def test_labels(self):
|
161
|
+
return self.targets
|
162
|
+
|
163
|
+
@property
|
164
|
+
def train_data(self):
|
165
|
+
return self.data
|
166
|
+
|
167
|
+
@property
|
168
|
+
def test_data(self):
|
169
|
+
return self.data
|
170
|
+
|
171
|
+
def __init__(
|
172
|
+
self,
|
173
|
+
root: str | Path,
|
174
|
+
train: bool = True,
|
175
|
+
download: bool = False,
|
176
|
+
size: int = -1,
|
177
|
+
unit_interval: bool = False,
|
178
|
+
dtype: type | None = None,
|
179
|
+
channels: Literal["channels_first", "channels_last"] | None = None,
|
180
|
+
flatten: bool = False,
|
181
|
+
normalize: tuple[float, float] | None = None,
|
182
|
+
corruption: Literal[
|
183
|
+
"identity",
|
184
|
+
"shot_noise",
|
185
|
+
"impulse_noise",
|
186
|
+
"glass_blur",
|
187
|
+
"motion_blur",
|
188
|
+
"shear",
|
189
|
+
"scale",
|
190
|
+
"rotate",
|
191
|
+
"brightness",
|
192
|
+
"translate",
|
193
|
+
"stripe",
|
194
|
+
"fog",
|
195
|
+
"spatter",
|
196
|
+
"dotted_line",
|
197
|
+
"zigzag",
|
198
|
+
"canny_edges",
|
199
|
+
]
|
200
|
+
| None = None,
|
201
|
+
) -> None:
|
202
|
+
if isinstance(root, str):
|
203
|
+
root = os.path.expanduser(root)
|
204
|
+
self.root = root # location of stored dataset
|
205
|
+
self.train = train # training set or test set
|
206
|
+
self.size = size
|
207
|
+
self.unit_interval = unit_interval
|
208
|
+
self.dtype = dtype
|
209
|
+
self.channels = channels
|
210
|
+
self.flatten = flatten
|
211
|
+
self.normalize = normalize
|
212
|
+
|
213
|
+
if corruption is None:
|
214
|
+
corruption = "identity"
|
215
|
+
elif corruption == "identity":
|
216
|
+
print("Identity is not a corrupted dataset but the original MNIST dataset")
|
217
|
+
self.corruption = corruption
|
218
|
+
|
219
|
+
if os.path.exists(self.mnist_folder):
|
220
|
+
print("Files already downloaded and verified")
|
221
|
+
elif download:
|
222
|
+
download_dataset(self.mirror, self.root, self.resources[0], self.resources[1])
|
223
|
+
else:
|
224
|
+
raise RuntimeError("Dataset not found. You can use download=True to download it")
|
225
|
+
|
226
|
+
self.data, self.targets = self._load_data()
|
227
|
+
|
228
|
+
def _load_data(self):
|
229
|
+
image_file = f"{'train' if self.train else 'test'}_images.npy"
|
230
|
+
data = self._read_image_file(os.path.join(self.mnist_folder, image_file))
|
231
|
+
|
232
|
+
label_file = f"{'train' if self.train else 'test'}_labels.npy"
|
233
|
+
targets = self._read_label_file(os.path.join(self.mnist_folder, label_file))
|
234
|
+
|
235
|
+
if self.size >= 1 and self.size >= len(self.classes):
|
236
|
+
final_data = []
|
237
|
+
final_targets = []
|
238
|
+
for label in range(len(self.classes)):
|
239
|
+
indices = np.where(targets == label)[0]
|
240
|
+
selected_indices = indices[: int(self.size / len(self.classes))]
|
241
|
+
final_data.append(data[selected_indices])
|
242
|
+
final_targets.append(targets[selected_indices])
|
243
|
+
data = np.concatenate(final_data)
|
244
|
+
targets = np.concatenate(final_targets)
|
245
|
+
shuffled_indices = np.random.permutation(data.shape[0])
|
246
|
+
data = data[shuffled_indices]
|
247
|
+
targets = targets[shuffled_indices]
|
248
|
+
elif self.size >= 1:
|
249
|
+
data = data[: self.size]
|
250
|
+
targets = targets[: self.size]
|
251
|
+
|
252
|
+
if self.unit_interval:
|
253
|
+
data = data / 255
|
254
|
+
|
255
|
+
if self.normalize:
|
256
|
+
data = (data - self.normalize[0]) / self.normalize[1]
|
257
|
+
|
258
|
+
if self.dtype:
|
259
|
+
data = data.astype(self.dtype)
|
260
|
+
|
261
|
+
if self.channels == "channels_first":
|
262
|
+
data = np.moveaxis(data, -1, 1)
|
263
|
+
elif self.channels is None:
|
264
|
+
data = data[:, :, :, 0]
|
265
|
+
|
266
|
+
if self.flatten and self.channels is None:
|
267
|
+
data = data.reshape(data.shape[0], -1)
|
268
|
+
|
269
|
+
return data, targets
|
270
|
+
|
271
|
+
def __getitem__(self, index: int) -> tuple[NDArray, int]:
|
272
|
+
"""
|
273
|
+
Args:
|
274
|
+
index (int): Index
|
275
|
+
|
276
|
+
Returns:
|
277
|
+
tuple: (image, target) where target is index of the target class.
|
278
|
+
"""
|
279
|
+
img, target = self.data[index], int(self.targets[index])
|
280
|
+
|
281
|
+
return img, target
|
282
|
+
|
283
|
+
def __len__(self) -> int:
|
284
|
+
return len(self.data)
|
285
|
+
|
286
|
+
@property
|
287
|
+
def mnist_folder(self) -> str:
|
288
|
+
return os.path.join(self.root, "mnist_c", self.corruption)
|
289
|
+
|
290
|
+
@property
|
291
|
+
def class_to_idx(self) -> dict[str, int]:
|
292
|
+
return {_class: i for i, _class in enumerate(self.classes)}
|
293
|
+
|
294
|
+
def _read_label_file(self, path: str) -> NDArray:
|
295
|
+
x = np.load(path, allow_pickle=False)
|
296
|
+
return x
|
297
|
+
|
298
|
+
def _read_image_file(self, path: str) -> NDArray:
|
299
|
+
x = np.load(path, allow_pickle=False)
|
300
|
+
return x
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: dataeval
|
3
|
-
Version: 0.69.
|
3
|
+
Version: 0.69.4
|
4
4
|
Summary: DataEval provides a simple interface to characterize image data and its impact on model performance across classification and object-detection tasks
|
5
5
|
Home-page: https://dataeval.ai/
|
6
6
|
License: MIT
|
@@ -33,6 +33,7 @@ Requires-Dist: tensorflow (>=2.14.1,<2.16) ; extra == "tensorflow" or extra == "
|
|
33
33
|
Requires-Dist: tensorflow-io-gcs-filesystem (>=0.35.0,<0.37) ; extra == "tensorflow" or extra == "all"
|
34
34
|
Requires-Dist: tensorflow_probability (>=0.22.1,<0.24) ; extra == "tensorflow" or extra == "all"
|
35
35
|
Requires-Dist: torch (>=2.0.1,!=2.2.0) ; extra == "torch" or extra == "all"
|
36
|
+
Requires-Dist: torchvision (>=0.16.0) ; extra == "torch" or extra == "all"
|
36
37
|
Requires-Dist: xxhash (>=3.3)
|
37
38
|
Project-URL: Documentation, https://dataeval.readthedocs.io/
|
38
39
|
Project-URL: Repository, https://github.com/aria-ml/dataeval/
|
@@ -1,4 +1,5 @@
|
|
1
|
-
dataeval/__init__.py,sha256=
|
1
|
+
dataeval/__init__.py,sha256=KOZnb9SovSSuD2UrqV-NS_b5vpfWdQlsweB55fned58,590
|
2
|
+
dataeval/_internal/datasets.py,sha256=MwN6xgZW1cA5yIxXZ05qBBz4aO3bjKzIEbZZfa1HkQo,9790
|
2
3
|
dataeval/_internal/detectors/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
3
4
|
dataeval/_internal/detectors/clusterer.py,sha256=hJwELUeAdZZ3OVLIfwalw2P7Zz13q2ZqrV6gx90s44E,20695
|
4
5
|
dataeval/_internal/detectors/drift/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
@@ -67,7 +68,7 @@ dataeval/torch/models/__init__.py,sha256=YnDnePYpRIKHyYn3F5qR1OObMSb-g0FGvI8X-uT
|
|
67
68
|
dataeval/torch/trainer/__init__.py,sha256=Te-qElt8h-Zv8NN0r-VJOEdCPHTQ2yO3rd2MhRiZGZs,93
|
68
69
|
dataeval/utils/__init__.py,sha256=ExQ1xj62MjcM9uIu1-g1P2fW0EPJpcIofnvxjQ908c4,172
|
69
70
|
dataeval/workflows/__init__.py,sha256=gkU2B6yUiefexcYrBwqfZKNl8BvX8abUjfeNvVBXF4E,186
|
70
|
-
dataeval-0.69.
|
71
|
-
dataeval-0.69.
|
72
|
-
dataeval-0.69.
|
73
|
-
dataeval-0.69.
|
71
|
+
dataeval-0.69.4.dist-info/LICENSE.txt,sha256=Kpzcfobf1HlqafF-EX6dQLw9TlJiaJzfgvLQFukyXYw,1060
|
72
|
+
dataeval-0.69.4.dist-info/METADATA,sha256=R_YlthIsAkOizGWkgXiOCEsD_6F5wJm8qjU4hjhL_c8,4292
|
73
|
+
dataeval-0.69.4.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
|
74
|
+
dataeval-0.69.4.dist-info/RECORD,,
|
File without changes
|
File without changes
|