dataeval 0.69.2__py3-none-any.whl → 0.69.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
dataeval/__init__.py CHANGED
@@ -1,4 +1,4 @@
1
- __version__ = "0.69.2"
1
+ __version__ = "0.69.4"
2
2
 
3
3
  from importlib.util import find_spec
4
4
 
@@ -0,0 +1,300 @@
1
+ from __future__ import annotations
2
+
3
+ import hashlib
4
+ import os
5
+ import zipfile
6
+ from pathlib import Path
7
+ from typing import Literal
8
+ from urllib.error import HTTPError, URLError
9
+ from urllib.request import urlretrieve
10
+
11
+ import numpy as np
12
+ from numpy.typing import NDArray
13
+ from torch.utils.data import Dataset
14
+ from torchvision.datasets import CIFAR10, VOCDetection # noqa: F401
15
+
16
+
17
+ def _validate_file(fpath, file_md5, chunk_size=65535):
18
+ hasher = hashlib.md5()
19
+ with open(fpath, "rb") as fpath_file:
20
+ while chunk := fpath_file.read(chunk_size):
21
+ hasher.update(chunk)
22
+ return hasher.hexdigest() == file_md5
23
+
24
+
25
+ def _get_file(
26
+ root: str | Path,
27
+ fname: str,
28
+ origin: str,
29
+ file_md5: str | None = None,
30
+ ):
31
+ fname = os.fspath(fname) if isinstance(fname, os.PathLike) else fname
32
+ fpath = os.path.join(root, fname)
33
+
34
+ download = False
35
+ if os.path.exists(fpath):
36
+ if file_md5 is not None and not _validate_file(fpath, file_md5):
37
+ download = True
38
+ else:
39
+ print("Files already downloaded and verified")
40
+ else:
41
+ download = True
42
+
43
+ if download:
44
+ try:
45
+ error_msg = "URL fetch failure on {}: {} -- {}"
46
+ try:
47
+ urlretrieve(origin, fpath)
48
+ except HTTPError as e:
49
+ raise Exception(error_msg.format(origin, e.code, e.msg)) from e
50
+ except URLError as e:
51
+ raise Exception(error_msg.format(origin, e.errno, e.reason)) from e
52
+ except (Exception, KeyboardInterrupt):
53
+ if os.path.exists(fpath):
54
+ os.remove(fpath)
55
+ raise
56
+
57
+ if os.path.exists(fpath) and file_md5 is not None and not _validate_file(fpath, file_md5):
58
+ raise ValueError(
59
+ "Incomplete or corrupted file detected. "
60
+ f"The md5 file hash does not match the provided value "
61
+ f"of {file_md5}.",
62
+ )
63
+ return fpath
64
+
65
+
66
+ def download_dataset(url: str, root: str | Path, fname: str, md5: str) -> str:
67
+ """Code to download mnist and corruptions, originates from tensorflow_datasets (tfds):
68
+ https://github.com/tensorflow/datasets/blob/master/tensorflow_datasets/image_classification/mnist_corrupted.py
69
+ """
70
+ name, _ = os.path.splitext(fname)
71
+ folder = os.path.join(root, name)
72
+ os.makedirs(folder, exist_ok=True)
73
+
74
+ path = _get_file(
75
+ root,
76
+ fname,
77
+ origin=url + fname,
78
+ file_md5=md5,
79
+ )
80
+ extract_archive(path, remove_finished=True)
81
+ return path
82
+
83
+
84
+ def extract_archive(
85
+ from_path: str | Path,
86
+ to_path: str | Path | None = None,
87
+ remove_finished: bool = False,
88
+ ):
89
+ """Extract an archive.
90
+
91
+ The archive type and a possible compression is automatically detected from the file name.
92
+ """
93
+ from_path = Path(from_path)
94
+ if not from_path.is_absolute():
95
+ from_path = from_path.resolve()
96
+
97
+ if to_path is None:
98
+ to_path = os.path.dirname(from_path)
99
+
100
+ # Extracting zip
101
+ with zipfile.ZipFile(from_path, "r", compression=zipfile.ZIP_STORED) as zzip:
102
+ zzip.extractall(to_path)
103
+
104
+ if remove_finished:
105
+ os.remove(from_path)
106
+
107
+
108
+ class MNIST(Dataset):
109
+ """MNIST Dataset and Corruptions.
110
+
111
+ Args:
112
+ root : str | ``pathlib.Path``
113
+ Root directory of dataset where the ``mnist_c/`` folder exists.
114
+ train : bool, default True
115
+ If True, creates dataset from ``train_images.npy`` and ``train_labels.npy``.
116
+ download : bool, default False
117
+ If True, downloads the dataset from the internet and puts it in root
118
+ directory. If dataset is already downloaded, it is not downloaded again.
119
+ size : int, default -1
120
+ Limit the dataset size, must be a value greater than 0.
121
+ unit_interval : bool, default False
122
+ Shift the data values to the unit interval [0-1].
123
+ dtype : type | None, default None
124
+ Change the numpy dtype - data is loaded as np.uint8
125
+ channels : Literal['channels_first' | 'channels_last'] | None, default None
126
+ Location of channel axis if desired, default has no channels (N, 28, 28)
127
+ flatten : bool, default False
128
+ Flatten data into single dimension (N, 784) - cannot use both channels and flatten,
129
+ channels takes priority over flatten.
130
+ normalize : tuple[mean, std] | None, default None
131
+ Normalize images acorrding to provided mean and standard deviation
132
+ corruption : Literal['identity' | 'shot_noise' | 'impulse_noise' | 'glass_blur' |
133
+ 'motion_blur' | 'shear' | 'scale' | 'rotate' | 'brightness' | 'translate' | 'stripe' |
134
+ 'fog' | 'spatter' | 'dotted_line' | 'zigzag' | 'canny_edges'] | None, default None
135
+ The desired corruption style or None.
136
+ """
137
+
138
+ mirror = "https://zenodo.org/record/3239543/files/"
139
+
140
+ resources = ("mnist_c.zip", "4b34b33045869ee6d424616cd3a65da3")
141
+
142
+ classes = [
143
+ "0 - zero",
144
+ "1 - one",
145
+ "2 - two",
146
+ "3 - three",
147
+ "4 - four",
148
+ "5 - five",
149
+ "6 - six",
150
+ "7 - seven",
151
+ "8 - eight",
152
+ "9 - nine",
153
+ ]
154
+
155
+ @property
156
+ def train_labels(self):
157
+ return self.targets
158
+
159
+ @property
160
+ def test_labels(self):
161
+ return self.targets
162
+
163
+ @property
164
+ def train_data(self):
165
+ return self.data
166
+
167
+ @property
168
+ def test_data(self):
169
+ return self.data
170
+
171
+ def __init__(
172
+ self,
173
+ root: str | Path,
174
+ train: bool = True,
175
+ download: bool = False,
176
+ size: int = -1,
177
+ unit_interval: bool = False,
178
+ dtype: type | None = None,
179
+ channels: Literal["channels_first", "channels_last"] | None = None,
180
+ flatten: bool = False,
181
+ normalize: tuple[float, float] | None = None,
182
+ corruption: Literal[
183
+ "identity",
184
+ "shot_noise",
185
+ "impulse_noise",
186
+ "glass_blur",
187
+ "motion_blur",
188
+ "shear",
189
+ "scale",
190
+ "rotate",
191
+ "brightness",
192
+ "translate",
193
+ "stripe",
194
+ "fog",
195
+ "spatter",
196
+ "dotted_line",
197
+ "zigzag",
198
+ "canny_edges",
199
+ ]
200
+ | None = None,
201
+ ) -> None:
202
+ if isinstance(root, str):
203
+ root = os.path.expanduser(root)
204
+ self.root = root # location of stored dataset
205
+ self.train = train # training set or test set
206
+ self.size = size
207
+ self.unit_interval = unit_interval
208
+ self.dtype = dtype
209
+ self.channels = channels
210
+ self.flatten = flatten
211
+ self.normalize = normalize
212
+
213
+ if corruption is None:
214
+ corruption = "identity"
215
+ elif corruption == "identity":
216
+ print("Identity is not a corrupted dataset but the original MNIST dataset")
217
+ self.corruption = corruption
218
+
219
+ if os.path.exists(self.mnist_folder):
220
+ print("Files already downloaded and verified")
221
+ elif download:
222
+ download_dataset(self.mirror, self.root, self.resources[0], self.resources[1])
223
+ else:
224
+ raise RuntimeError("Dataset not found. You can use download=True to download it")
225
+
226
+ self.data, self.targets = self._load_data()
227
+
228
+ def _load_data(self):
229
+ image_file = f"{'train' if self.train else 'test'}_images.npy"
230
+ data = self._read_image_file(os.path.join(self.mnist_folder, image_file))
231
+
232
+ label_file = f"{'train' if self.train else 'test'}_labels.npy"
233
+ targets = self._read_label_file(os.path.join(self.mnist_folder, label_file))
234
+
235
+ if self.size >= 1 and self.size >= len(self.classes):
236
+ final_data = []
237
+ final_targets = []
238
+ for label in range(len(self.classes)):
239
+ indices = np.where(targets == label)[0]
240
+ selected_indices = indices[: int(self.size / len(self.classes))]
241
+ final_data.append(data[selected_indices])
242
+ final_targets.append(targets[selected_indices])
243
+ data = np.concatenate(final_data)
244
+ targets = np.concatenate(final_targets)
245
+ shuffled_indices = np.random.permutation(data.shape[0])
246
+ data = data[shuffled_indices]
247
+ targets = targets[shuffled_indices]
248
+ elif self.size >= 1:
249
+ data = data[: self.size]
250
+ targets = targets[: self.size]
251
+
252
+ if self.unit_interval:
253
+ data = data / 255
254
+
255
+ if self.normalize:
256
+ data = (data - self.normalize[0]) / self.normalize[1]
257
+
258
+ if self.dtype:
259
+ data = data.astype(self.dtype)
260
+
261
+ if self.channels == "channels_first":
262
+ data = np.moveaxis(data, -1, 1)
263
+ elif self.channels is None:
264
+ data = data[:, :, :, 0]
265
+
266
+ if self.flatten and self.channels is None:
267
+ data = data.reshape(data.shape[0], -1)
268
+
269
+ return data, targets
270
+
271
+ def __getitem__(self, index: int) -> tuple[NDArray, int]:
272
+ """
273
+ Args:
274
+ index (int): Index
275
+
276
+ Returns:
277
+ tuple: (image, target) where target is index of the target class.
278
+ """
279
+ img, target = self.data[index], int(self.targets[index])
280
+
281
+ return img, target
282
+
283
+ def __len__(self) -> int:
284
+ return len(self.data)
285
+
286
+ @property
287
+ def mnist_folder(self) -> str:
288
+ return os.path.join(self.root, "mnist_c", self.corruption)
289
+
290
+ @property
291
+ def class_to_idx(self) -> dict[str, int]:
292
+ return {_class: i for i, _class in enumerate(self.classes)}
293
+
294
+ def _read_label_file(self, path: str) -> NDArray:
295
+ x = np.load(path, allow_pickle=False)
296
+ return x
297
+
298
+ def _read_image_file(self, path: str) -> NDArray:
299
+ x = np.load(path, allow_pickle=False)
300
+ return x
@@ -130,7 +130,10 @@ def diversity_simpson(
130
130
  p_i = cnts / cnts.sum()
131
131
  # inverse Simpson index normalized by (number of bins)
132
132
  s_0 = 1 / np.sum(p_i**2) / num_bins[col]
133
- ev_index[col] = (s_0 * num_bins[col] - 1) / (num_bins[col] - 1)
133
+ if num_bins[col] == 1:
134
+ ev_index[col] = 0
135
+ else:
136
+ ev_index[col] = (s_0 * num_bins[col] - 1) / (num_bins[col] - 1)
134
137
  return ev_index
135
138
 
136
139
 
@@ -348,6 +348,7 @@ def parity(
348
348
  chi_scores = np.zeros(len(factors))
349
349
  p_values = np.zeros(len(factors))
350
350
  n_cls = len(np.unique(labels))
351
+ not_enough_data = {}
351
352
  for i, (current_factor_name, factor_values) in enumerate(factors.items()):
352
353
  unique_factor_values = np.unique(factor_values)
353
354
  contingency_matrix = np.zeros((len(unique_factor_values), n_cls))
@@ -361,13 +362,12 @@ def parity(
361
362
  with_both = np.bitwise_and((labels == label), factor_values == factor_value)
362
363
  contingency_matrix[fi, label] = np.sum(with_both)
363
364
  if 0 < contingency_matrix[fi, label] < 5:
364
- warnings.warn(
365
- f"Factor {current_factor_name} value {factor_value} co-occurs "
366
- f"only {contingency_matrix[fi, label]} times with label {label}. "
367
- "This can cause inaccurate chi_square calculation. Recommend"
368
- "ensuring each label occurs either 0 times or at least 5 times. "
369
- "Alternatively, digitize any continuous-valued factors "
370
- "into fewer bins."
365
+ if current_factor_name not in not_enough_data:
366
+ not_enough_data[current_factor_name] = {}
367
+ if factor_value not in not_enough_data[current_factor_name]:
368
+ not_enough_data[current_factor_name][factor_value] = []
369
+ not_enough_data[current_factor_name][factor_value].append(
370
+ (label, int(contingency_matrix[fi, label]))
371
371
  )
372
372
 
373
373
  # This deletes rows containing only zeros,
@@ -381,4 +381,23 @@ def parity(
381
381
  chi_scores[i] = chi2
382
382
  p_values[i] = p
383
383
 
384
+ if not_enough_data:
385
+ factor_msg = []
386
+ for factor, fact_dict in not_enough_data.items():
387
+ stacked_msg = []
388
+ for key, value in fact_dict.items():
389
+ msg = []
390
+ for item in value:
391
+ msg.append(f"label {item[0]}: {item[1]} occurrences")
392
+ flat_msg = "\n\t\t".join(msg)
393
+ stacked_msg.append(f"value {key} - {flat_msg}\n\t")
394
+ factor_msg.append(factor + " - " + "".join(stacked_msg))
395
+
396
+ message = "\n".join(factor_msg)
397
+
398
+ warnings.warn(
399
+ f"The following factors did not meet the recommended 5 occurrences for each value-label combination. \nRecommend rerunning parity after adjusting the following factor-value-label combinations: \n{message}", # noqa: E501
400
+ UserWarning,
401
+ )
402
+
384
403
  return ParityOutput(chi_scores, p_values)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: dataeval
3
- Version: 0.69.2
3
+ Version: 0.69.4
4
4
  Summary: DataEval provides a simple interface to characterize image data and its impact on model performance across classification and object-detection tasks
5
5
  Home-page: https://dataeval.ai/
6
6
  License: MIT
@@ -33,6 +33,7 @@ Requires-Dist: tensorflow (>=2.14.1,<2.16) ; extra == "tensorflow" or extra == "
33
33
  Requires-Dist: tensorflow-io-gcs-filesystem (>=0.35.0,<0.37) ; extra == "tensorflow" or extra == "all"
34
34
  Requires-Dist: tensorflow_probability (>=0.22.1,<0.24) ; extra == "tensorflow" or extra == "all"
35
35
  Requires-Dist: torch (>=2.0.1,!=2.2.0) ; extra == "torch" or extra == "all"
36
+ Requires-Dist: torchvision (>=0.16.0) ; extra == "torch" or extra == "all"
36
37
  Requires-Dist: xxhash (>=3.3)
37
38
  Project-URL: Documentation, https://dataeval.readthedocs.io/
38
39
  Project-URL: Repository, https://github.com/aria-ml/dataeval/
@@ -1,4 +1,5 @@
1
- dataeval/__init__.py,sha256=NUQixSNyEc0GiI7YgbfY9IL0OEkIN9kdbrOGAB041Ig,590
1
+ dataeval/__init__.py,sha256=KOZnb9SovSSuD2UrqV-NS_b5vpfWdQlsweB55fned58,590
2
+ dataeval/_internal/datasets.py,sha256=MwN6xgZW1cA5yIxXZ05qBBz4aO3bjKzIEbZZfa1HkQo,9790
2
3
  dataeval/_internal/detectors/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
3
4
  dataeval/_internal/detectors/clusterer.py,sha256=hJwELUeAdZZ3OVLIfwalw2P7Zz13q2ZqrV6gx90s44E,20695
4
5
  dataeval/_internal/detectors/drift/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
@@ -25,8 +26,8 @@ dataeval/_internal/metrics/balance.py,sha256=eAHvgjiGCH893XSQLqh9j9wgvAECoNPVT8k
25
26
  dataeval/_internal/metrics/ber.py,sha256=Onsi47AbT9rMvng-Pbu8LIrYRfLpI13En1FxkFoMKQs,4668
26
27
  dataeval/_internal/metrics/coverage.py,sha256=EZVES1rbZW2j_CtQv1VFfSO-UmWcrt5nmqxDErtrG14,3473
27
28
  dataeval/_internal/metrics/divergence.py,sha256=nmMUfr9FGnH798eb6xzEiMj4C42rQVthh5HeexiY6EE,4119
28
- dataeval/_internal/metrics/diversity.py,sha256=nGjYQ-NLjb8mPt1PAYnvkWH4D58kjM39IPs2FULfis4,7503
29
- dataeval/_internal/metrics/parity.py,sha256=suv1Pf7gPj0_NxsS0_M6ewfUndsFJyEhbt5NPp6ktMI,15457
29
+ dataeval/_internal/metrics/diversity.py,sha256=_oT0FHsgfLOoe_TLD2Aax4r4jmH6WnOPVIkcl_YjaoY,7582
30
+ dataeval/_internal/metrics/parity.py,sha256=VszQNbHWjct2bCqrIXUZC_qFi4ZIq2Lm-vs-DiarBFo,16244
30
31
  dataeval/_internal/metrics/stats.py,sha256=ILKteVMGjrp1s2CECPL_hbLsijIKR2d6II2-8w9oxW8,18105
31
32
  dataeval/_internal/metrics/uap.py,sha256=w-wvXXnX16kUq-weaZD2SrJi22LJ8EjOFbOhPxeGejI,2043
32
33
  dataeval/_internal/metrics/utils.py,sha256=mSYa-3cHGcsQwPr7zbdpzrnK_8jIXCiAcu2HCcvrtaY,13007
@@ -67,7 +68,7 @@ dataeval/torch/models/__init__.py,sha256=YnDnePYpRIKHyYn3F5qR1OObMSb-g0FGvI8X-uT
67
68
  dataeval/torch/trainer/__init__.py,sha256=Te-qElt8h-Zv8NN0r-VJOEdCPHTQ2yO3rd2MhRiZGZs,93
68
69
  dataeval/utils/__init__.py,sha256=ExQ1xj62MjcM9uIu1-g1P2fW0EPJpcIofnvxjQ908c4,172
69
70
  dataeval/workflows/__init__.py,sha256=gkU2B6yUiefexcYrBwqfZKNl8BvX8abUjfeNvVBXF4E,186
70
- dataeval-0.69.2.dist-info/LICENSE.txt,sha256=Kpzcfobf1HlqafF-EX6dQLw9TlJiaJzfgvLQFukyXYw,1060
71
- dataeval-0.69.2.dist-info/METADATA,sha256=_9rVrbIh4EPYStZtOUYnB-Xo3cZ5JMUAf06TqDKvrZs,4217
72
- dataeval-0.69.2.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
73
- dataeval-0.69.2.dist-info/RECORD,,
71
+ dataeval-0.69.4.dist-info/LICENSE.txt,sha256=Kpzcfobf1HlqafF-EX6dQLw9TlJiaJzfgvLQFukyXYw,1060
72
+ dataeval-0.69.4.dist-info/METADATA,sha256=R_YlthIsAkOizGWkgXiOCEsD_6F5wJm8qjU4hjhL_c8,4292
73
+ dataeval-0.69.4.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
74
+ dataeval-0.69.4.dist-info/RECORD,,