dataeval 0.69.3__py3-none-any.whl → 0.69.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
dataeval/__init__.py CHANGED
@@ -1,4 +1,4 @@
1
- __version__ = "0.69.3"
1
+ __version__ = "0.69.4"
2
2
 
3
3
  from importlib.util import find_spec
4
4
 
@@ -0,0 +1,300 @@
1
+ from __future__ import annotations
2
+
3
+ import hashlib
4
+ import os
5
+ import zipfile
6
+ from pathlib import Path
7
+ from typing import Literal
8
+ from urllib.error import HTTPError, URLError
9
+ from urllib.request import urlretrieve
10
+
11
+ import numpy as np
12
+ from numpy.typing import NDArray
13
+ from torch.utils.data import Dataset
14
+ from torchvision.datasets import CIFAR10, VOCDetection # noqa: F401
15
+
16
+
17
+ def _validate_file(fpath, file_md5, chunk_size=65535):
18
+ hasher = hashlib.md5()
19
+ with open(fpath, "rb") as fpath_file:
20
+ while chunk := fpath_file.read(chunk_size):
21
+ hasher.update(chunk)
22
+ return hasher.hexdigest() == file_md5
23
+
24
+
25
+ def _get_file(
26
+ root: str | Path,
27
+ fname: str,
28
+ origin: str,
29
+ file_md5: str | None = None,
30
+ ):
31
+ fname = os.fspath(fname) if isinstance(fname, os.PathLike) else fname
32
+ fpath = os.path.join(root, fname)
33
+
34
+ download = False
35
+ if os.path.exists(fpath):
36
+ if file_md5 is not None and not _validate_file(fpath, file_md5):
37
+ download = True
38
+ else:
39
+ print("Files already downloaded and verified")
40
+ else:
41
+ download = True
42
+
43
+ if download:
44
+ try:
45
+ error_msg = "URL fetch failure on {}: {} -- {}"
46
+ try:
47
+ urlretrieve(origin, fpath)
48
+ except HTTPError as e:
49
+ raise Exception(error_msg.format(origin, e.code, e.msg)) from e
50
+ except URLError as e:
51
+ raise Exception(error_msg.format(origin, e.errno, e.reason)) from e
52
+ except (Exception, KeyboardInterrupt):
53
+ if os.path.exists(fpath):
54
+ os.remove(fpath)
55
+ raise
56
+
57
+ if os.path.exists(fpath) and file_md5 is not None and not _validate_file(fpath, file_md5):
58
+ raise ValueError(
59
+ "Incomplete or corrupted file detected. "
60
+ f"The md5 file hash does not match the provided value "
61
+ f"of {file_md5}.",
62
+ )
63
+ return fpath
64
+
65
+
66
+ def download_dataset(url: str, root: str | Path, fname: str, md5: str) -> str:
67
+ """Code to download mnist and corruptions, originates from tensorflow_datasets (tfds):
68
+ https://github.com/tensorflow/datasets/blob/master/tensorflow_datasets/image_classification/mnist_corrupted.py
69
+ """
70
+ name, _ = os.path.splitext(fname)
71
+ folder = os.path.join(root, name)
72
+ os.makedirs(folder, exist_ok=True)
73
+
74
+ path = _get_file(
75
+ root,
76
+ fname,
77
+ origin=url + fname,
78
+ file_md5=md5,
79
+ )
80
+ extract_archive(path, remove_finished=True)
81
+ return path
82
+
83
+
84
+ def extract_archive(
85
+ from_path: str | Path,
86
+ to_path: str | Path | None = None,
87
+ remove_finished: bool = False,
88
+ ):
89
+ """Extract an archive.
90
+
91
+ The archive type and a possible compression is automatically detected from the file name.
92
+ """
93
+ from_path = Path(from_path)
94
+ if not from_path.is_absolute():
95
+ from_path = from_path.resolve()
96
+
97
+ if to_path is None:
98
+ to_path = os.path.dirname(from_path)
99
+
100
+ # Extracting zip
101
+ with zipfile.ZipFile(from_path, "r", compression=zipfile.ZIP_STORED) as zzip:
102
+ zzip.extractall(to_path)
103
+
104
+ if remove_finished:
105
+ os.remove(from_path)
106
+
107
+
108
+ class MNIST(Dataset):
109
+ """MNIST Dataset and Corruptions.
110
+
111
+ Args:
112
+ root : str | ``pathlib.Path``
113
+ Root directory of dataset where the ``mnist_c/`` folder exists.
114
+ train : bool, default True
115
+ If True, creates dataset from ``train_images.npy`` and ``train_labels.npy``.
116
+ download : bool, default False
117
+ If True, downloads the dataset from the internet and puts it in root
118
+ directory. If dataset is already downloaded, it is not downloaded again.
119
+ size : int, default -1
120
+ Limit the dataset size, must be a value greater than 0.
121
+ unit_interval : bool, default False
122
+ Shift the data values to the unit interval [0-1].
123
+ dtype : type | None, default None
124
+ Change the numpy dtype - data is loaded as np.uint8
125
+ channels : Literal['channels_first' | 'channels_last'] | None, default None
126
+ Location of channel axis if desired, default has no channels (N, 28, 28)
127
+ flatten : bool, default False
128
+ Flatten data into single dimension (N, 784) - cannot use both channels and flatten,
129
+ channels takes priority over flatten.
130
+ normalize : tuple[mean, std] | None, default None
131
+ Normalize images acorrding to provided mean and standard deviation
132
+ corruption : Literal['identity' | 'shot_noise' | 'impulse_noise' | 'glass_blur' |
133
+ 'motion_blur' | 'shear' | 'scale' | 'rotate' | 'brightness' | 'translate' | 'stripe' |
134
+ 'fog' | 'spatter' | 'dotted_line' | 'zigzag' | 'canny_edges'] | None, default None
135
+ The desired corruption style or None.
136
+ """
137
+
138
+ mirror = "https://zenodo.org/record/3239543/files/"
139
+
140
+ resources = ("mnist_c.zip", "4b34b33045869ee6d424616cd3a65da3")
141
+
142
+ classes = [
143
+ "0 - zero",
144
+ "1 - one",
145
+ "2 - two",
146
+ "3 - three",
147
+ "4 - four",
148
+ "5 - five",
149
+ "6 - six",
150
+ "7 - seven",
151
+ "8 - eight",
152
+ "9 - nine",
153
+ ]
154
+
155
+ @property
156
+ def train_labels(self):
157
+ return self.targets
158
+
159
+ @property
160
+ def test_labels(self):
161
+ return self.targets
162
+
163
+ @property
164
+ def train_data(self):
165
+ return self.data
166
+
167
+ @property
168
+ def test_data(self):
169
+ return self.data
170
+
171
+ def __init__(
172
+ self,
173
+ root: str | Path,
174
+ train: bool = True,
175
+ download: bool = False,
176
+ size: int = -1,
177
+ unit_interval: bool = False,
178
+ dtype: type | None = None,
179
+ channels: Literal["channels_first", "channels_last"] | None = None,
180
+ flatten: bool = False,
181
+ normalize: tuple[float, float] | None = None,
182
+ corruption: Literal[
183
+ "identity",
184
+ "shot_noise",
185
+ "impulse_noise",
186
+ "glass_blur",
187
+ "motion_blur",
188
+ "shear",
189
+ "scale",
190
+ "rotate",
191
+ "brightness",
192
+ "translate",
193
+ "stripe",
194
+ "fog",
195
+ "spatter",
196
+ "dotted_line",
197
+ "zigzag",
198
+ "canny_edges",
199
+ ]
200
+ | None = None,
201
+ ) -> None:
202
+ if isinstance(root, str):
203
+ root = os.path.expanduser(root)
204
+ self.root = root # location of stored dataset
205
+ self.train = train # training set or test set
206
+ self.size = size
207
+ self.unit_interval = unit_interval
208
+ self.dtype = dtype
209
+ self.channels = channels
210
+ self.flatten = flatten
211
+ self.normalize = normalize
212
+
213
+ if corruption is None:
214
+ corruption = "identity"
215
+ elif corruption == "identity":
216
+ print("Identity is not a corrupted dataset but the original MNIST dataset")
217
+ self.corruption = corruption
218
+
219
+ if os.path.exists(self.mnist_folder):
220
+ print("Files already downloaded and verified")
221
+ elif download:
222
+ download_dataset(self.mirror, self.root, self.resources[0], self.resources[1])
223
+ else:
224
+ raise RuntimeError("Dataset not found. You can use download=True to download it")
225
+
226
+ self.data, self.targets = self._load_data()
227
+
228
+ def _load_data(self):
229
+ image_file = f"{'train' if self.train else 'test'}_images.npy"
230
+ data = self._read_image_file(os.path.join(self.mnist_folder, image_file))
231
+
232
+ label_file = f"{'train' if self.train else 'test'}_labels.npy"
233
+ targets = self._read_label_file(os.path.join(self.mnist_folder, label_file))
234
+
235
+ if self.size >= 1 and self.size >= len(self.classes):
236
+ final_data = []
237
+ final_targets = []
238
+ for label in range(len(self.classes)):
239
+ indices = np.where(targets == label)[0]
240
+ selected_indices = indices[: int(self.size / len(self.classes))]
241
+ final_data.append(data[selected_indices])
242
+ final_targets.append(targets[selected_indices])
243
+ data = np.concatenate(final_data)
244
+ targets = np.concatenate(final_targets)
245
+ shuffled_indices = np.random.permutation(data.shape[0])
246
+ data = data[shuffled_indices]
247
+ targets = targets[shuffled_indices]
248
+ elif self.size >= 1:
249
+ data = data[: self.size]
250
+ targets = targets[: self.size]
251
+
252
+ if self.unit_interval:
253
+ data = data / 255
254
+
255
+ if self.normalize:
256
+ data = (data - self.normalize[0]) / self.normalize[1]
257
+
258
+ if self.dtype:
259
+ data = data.astype(self.dtype)
260
+
261
+ if self.channels == "channels_first":
262
+ data = np.moveaxis(data, -1, 1)
263
+ elif self.channels is None:
264
+ data = data[:, :, :, 0]
265
+
266
+ if self.flatten and self.channels is None:
267
+ data = data.reshape(data.shape[0], -1)
268
+
269
+ return data, targets
270
+
271
+ def __getitem__(self, index: int) -> tuple[NDArray, int]:
272
+ """
273
+ Args:
274
+ index (int): Index
275
+
276
+ Returns:
277
+ tuple: (image, target) where target is index of the target class.
278
+ """
279
+ img, target = self.data[index], int(self.targets[index])
280
+
281
+ return img, target
282
+
283
+ def __len__(self) -> int:
284
+ return len(self.data)
285
+
286
+ @property
287
+ def mnist_folder(self) -> str:
288
+ return os.path.join(self.root, "mnist_c", self.corruption)
289
+
290
+ @property
291
+ def class_to_idx(self) -> dict[str, int]:
292
+ return {_class: i for i, _class in enumerate(self.classes)}
293
+
294
+ def _read_label_file(self, path: str) -> NDArray:
295
+ x = np.load(path, allow_pickle=False)
296
+ return x
297
+
298
+ def _read_image_file(self, path: str) -> NDArray:
299
+ x = np.load(path, allow_pickle=False)
300
+ return x
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: dataeval
3
- Version: 0.69.3
3
+ Version: 0.69.4
4
4
  Summary: DataEval provides a simple interface to characterize image data and its impact on model performance across classification and object-detection tasks
5
5
  Home-page: https://dataeval.ai/
6
6
  License: MIT
@@ -33,6 +33,7 @@ Requires-Dist: tensorflow (>=2.14.1,<2.16) ; extra == "tensorflow" or extra == "
33
33
  Requires-Dist: tensorflow-io-gcs-filesystem (>=0.35.0,<0.37) ; extra == "tensorflow" or extra == "all"
34
34
  Requires-Dist: tensorflow_probability (>=0.22.1,<0.24) ; extra == "tensorflow" or extra == "all"
35
35
  Requires-Dist: torch (>=2.0.1,!=2.2.0) ; extra == "torch" or extra == "all"
36
+ Requires-Dist: torchvision (>=0.16.0) ; extra == "torch" or extra == "all"
36
37
  Requires-Dist: xxhash (>=3.3)
37
38
  Project-URL: Documentation, https://dataeval.readthedocs.io/
38
39
  Project-URL: Repository, https://github.com/aria-ml/dataeval/
@@ -1,4 +1,5 @@
1
- dataeval/__init__.py,sha256=4JtJRUfhO_kYbjWDhzY5niIvmLb8K_3sCL-wbcZ_mUU,590
1
+ dataeval/__init__.py,sha256=KOZnb9SovSSuD2UrqV-NS_b5vpfWdQlsweB55fned58,590
2
+ dataeval/_internal/datasets.py,sha256=MwN6xgZW1cA5yIxXZ05qBBz4aO3bjKzIEbZZfa1HkQo,9790
2
3
  dataeval/_internal/detectors/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
3
4
  dataeval/_internal/detectors/clusterer.py,sha256=hJwELUeAdZZ3OVLIfwalw2P7Zz13q2ZqrV6gx90s44E,20695
4
5
  dataeval/_internal/detectors/drift/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
@@ -67,7 +68,7 @@ dataeval/torch/models/__init__.py,sha256=YnDnePYpRIKHyYn3F5qR1OObMSb-g0FGvI8X-uT
67
68
  dataeval/torch/trainer/__init__.py,sha256=Te-qElt8h-Zv8NN0r-VJOEdCPHTQ2yO3rd2MhRiZGZs,93
68
69
  dataeval/utils/__init__.py,sha256=ExQ1xj62MjcM9uIu1-g1P2fW0EPJpcIofnvxjQ908c4,172
69
70
  dataeval/workflows/__init__.py,sha256=gkU2B6yUiefexcYrBwqfZKNl8BvX8abUjfeNvVBXF4E,186
70
- dataeval-0.69.3.dist-info/LICENSE.txt,sha256=Kpzcfobf1HlqafF-EX6dQLw9TlJiaJzfgvLQFukyXYw,1060
71
- dataeval-0.69.3.dist-info/METADATA,sha256=dyyl60cjz6n7gRgYMZs9gCOdqpc9UbSV4LFCD8rJNCM,4217
72
- dataeval-0.69.3.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
73
- dataeval-0.69.3.dist-info/RECORD,,
71
+ dataeval-0.69.4.dist-info/LICENSE.txt,sha256=Kpzcfobf1HlqafF-EX6dQLw9TlJiaJzfgvLQFukyXYw,1060
72
+ dataeval-0.69.4.dist-info/METADATA,sha256=R_YlthIsAkOizGWkgXiOCEsD_6F5wJm8qjU4hjhL_c8,4292
73
+ dataeval-0.69.4.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
74
+ dataeval-0.69.4.dist-info/RECORD,,