dataeval 0.86.9__py3-none-any.whl → 0.87.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (41) hide show
  1. dataeval/__init__.py +1 -1
  2. dataeval/_version.py +2 -2
  3. dataeval/config.py +4 -19
  4. dataeval/data/_metadata.py +56 -27
  5. dataeval/data/_split.py +1 -1
  6. dataeval/data/selections/_classbalance.py +4 -3
  7. dataeval/data/selections/_classfilter.py +5 -5
  8. dataeval/data/selections/_indices.py +2 -2
  9. dataeval/data/selections/_prioritize.py +249 -29
  10. dataeval/data/selections/_reverse.py +1 -1
  11. dataeval/data/selections/_shuffle.py +2 -2
  12. dataeval/detectors/ood/__init__.py +2 -1
  13. dataeval/detectors/ood/base.py +38 -1
  14. dataeval/detectors/ood/knn.py +95 -0
  15. dataeval/metrics/bias/_balance.py +28 -21
  16. dataeval/metrics/bias/_diversity.py +4 -4
  17. dataeval/metrics/bias/_parity.py +2 -2
  18. dataeval/metrics/stats/_hashstats.py +19 -2
  19. dataeval/outputs/_workflows.py +20 -7
  20. dataeval/typing.py +14 -2
  21. dataeval/utils/__init__.py +2 -2
  22. dataeval/utils/_bin.py +7 -6
  23. dataeval/utils/data/__init__.py +2 -0
  24. dataeval/utils/data/_dataset.py +13 -6
  25. dataeval/utils/data/_validate.py +169 -0
  26. {dataeval-0.86.9.dist-info → dataeval-0.87.0.dist-info}/METADATA +5 -17
  27. {dataeval-0.86.9.dist-info → dataeval-0.87.0.dist-info}/RECORD +29 -39
  28. dataeval/utils/datasets/__init__.py +0 -21
  29. dataeval/utils/datasets/_antiuav.py +0 -189
  30. dataeval/utils/datasets/_base.py +0 -266
  31. dataeval/utils/datasets/_cifar10.py +0 -201
  32. dataeval/utils/datasets/_fileio.py +0 -142
  33. dataeval/utils/datasets/_milco.py +0 -197
  34. dataeval/utils/datasets/_mixin.py +0 -54
  35. dataeval/utils/datasets/_mnist.py +0 -202
  36. dataeval/utils/datasets/_seadrone.py +0 -512
  37. dataeval/utils/datasets/_ships.py +0 -144
  38. dataeval/utils/datasets/_types.py +0 -48
  39. dataeval/utils/datasets/_voc.py +0 -583
  40. {dataeval-0.86.9.dist-info → dataeval-0.87.0.dist-info}/WHEEL +0 -0
  41. /dataeval-0.86.9.dist-info/licenses/LICENSE.txt → /dataeval-0.87.0.dist-info/licenses/LICENSE +0 -0
@@ -1,202 +0,0 @@
1
- from __future__ import annotations
2
-
3
- __all__ = []
4
-
5
- from pathlib import Path
6
- from typing import TYPE_CHECKING, Any, Literal, Sequence, TypeVar
7
-
8
- import numpy as np
9
- from numpy.typing import NDArray
10
-
11
- from dataeval.utils.datasets._base import BaseICDataset, DataLocation
12
- from dataeval.utils.datasets._mixin import BaseDatasetNumpyMixin
13
-
14
- if TYPE_CHECKING:
15
- from dataeval.typing import Transform
16
-
17
- MNISTClassStringMap = Literal["zero", "one", "two", "three", "four", "five", "six", "seven", "eight", "nine"]
18
- TMNISTClassMap = TypeVar("TMNISTClassMap", MNISTClassStringMap, int, list[MNISTClassStringMap], list[int])
19
- CorruptionStringMap = Literal[
20
- "identity",
21
- "shot_noise",
22
- "impulse_noise",
23
- "glass_blur",
24
- "motion_blur",
25
- "shear",
26
- "scale",
27
- "rotate",
28
- "brightness",
29
- "translate",
30
- "stripe",
31
- "fog",
32
- "spatter",
33
- "dotted_line",
34
- "zigzag",
35
- "canny_edges",
36
- ]
37
-
38
-
39
- class MNIST(BaseICDataset[NDArray[Any]], BaseDatasetNumpyMixin):
40
- """`MNIST <https://en.wikipedia.org/wiki/MNIST_database>`_ Dataset and `Corruptions <https://arxiv.org/abs/1906.02337>`_.
41
-
42
- There are 15 different styles of corruptions. This class downloads differently depending on if you
43
- need just the original dataset or if you need corruptions. If you need both a corrupt version and the
44
- original version then choose `corruption="identity"` as this downloads all of the corrupt datasets and
45
- provides the original as `identity`. If you just need the original, then using `corruption=None` will
46
- download only the original dataset to save time and space.
47
-
48
- Parameters
49
- ----------
50
- root : str or pathlib.Path
51
- Root directory where the data should be downloaded to or the ``minst`` folder of the already downloaded data.
52
- image_set : "train", "test" or "base", default "train"
53
- If "base", returns all of the data to allow the user to create their own splits.
54
- corruption : "identity", "shot_noise", "impulse_noise", "glass_blur", "motion_blur", \
55
- "shear", "scale", "rotate", "brightness", "translate", "stripe", "fog", "spatter", \
56
- "dotted_line", "zigzag", "canny_edges" or None, default None
57
- Corruption to apply to the data.
58
- transforms : Transform, Sequence[Transform] or None, default None
59
- Transform(s) to apply to the data.
60
- download : bool, default False
61
- If True, downloads the dataset from the internet and puts it in root directory.
62
- Class checks to see if data is already downloaded to ensure it does not create a duplicate download.
63
- verbose : bool, default False
64
- If True, outputs print statements.
65
-
66
- Attributes
67
- ----------
68
- path : pathlib.Path
69
- Location of the folder containing the data.
70
- image_set : "train", "test" or "base"
71
- The selected image set from the dataset.
72
- index2label : dict[int, str]
73
- Dictionary which translates from class integers to the associated class strings.
74
- label2index : dict[str, int]
75
- Dictionary which translates from class strings to the associated class integers.
76
- metadata : DatasetMetadata
77
- Typed dictionary containing dataset metadata, such as `id` which returns the dataset class name.
78
- corruption : str or None
79
- Corruption applied to the data.
80
- transforms : Sequence[Transform]
81
- The transforms to be applied to the data.
82
- size : int
83
- The size of the dataset.
84
-
85
- Note
86
- ----
87
- Data License: `CC BY 4.0 <https://creativecommons.org/licenses/by/4.0/>`_ for corruption dataset
88
- """
89
-
90
- _resources = [
91
- DataLocation(
92
- url="https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz",
93
- filename="mnist.npz",
94
- md5=False,
95
- checksum="731c5ac602752760c8e48fbffcf8c3b850d9dc2a2aedcf2cc48468fc17b673d1",
96
- ),
97
- DataLocation(
98
- url="https://zenodo.org/record/3239543/files/mnist_c.zip",
99
- filename="mnist_c.zip",
100
- md5=True,
101
- checksum="4b34b33045869ee6d424616cd3a65da3",
102
- ),
103
- ]
104
-
105
- index2label: dict[int, str] = {
106
- 0: "zero",
107
- 1: "one",
108
- 2: "two",
109
- 3: "three",
110
- 4: "four",
111
- 5: "five",
112
- 6: "six",
113
- 7: "seven",
114
- 8: "eight",
115
- 9: "nine",
116
- }
117
-
118
- def __init__(
119
- self,
120
- root: str | Path,
121
- image_set: Literal["train", "test", "base"] = "train",
122
- corruption: CorruptionStringMap | None = None,
123
- transforms: Transform[NDArray[Any]] | Sequence[Transform[NDArray[Any]]] | None = None,
124
- download: bool = False,
125
- verbose: bool = False,
126
- ) -> None:
127
- self.corruption = corruption
128
- if self.corruption == "identity" and verbose:
129
- print("Identity is not a corrupted dataset but the original MNIST dataset.")
130
- self._resource_index = 0 if self.corruption is None else 1
131
-
132
- super().__init__(
133
- root,
134
- image_set,
135
- transforms,
136
- download,
137
- verbose,
138
- )
139
-
140
- def _load_data_inner(self) -> tuple[list[str], list[int], dict[str, Any]]:
141
- """Function to load in the file paths for the data and labels from the correct data format"""
142
- if self.corruption is None:
143
- try:
144
- file_path = self.path / self._resource.filename
145
- self._loaded_data, labels = self._grab_data(file_path)
146
- except FileNotFoundError:
147
- self._loaded_data, labels = self._load_corruption()
148
- else:
149
- self._loaded_data, labels = self._load_corruption()
150
-
151
- index_strings = np.arange(self._loaded_data.shape[0]).astype(str).tolist()
152
- return index_strings, labels.tolist(), {}
153
-
154
- def _load_corruption(self) -> tuple[NDArray[Any], NDArray[np.uintp]]:
155
- """Function to load in the file paths for the data and labels for the different corrupt data formats"""
156
- corruption = self.corruption if self.corruption is not None else "identity"
157
- base_path = self.path / "mnist_c" / corruption
158
- if self.image_set == "base":
159
- raw_data = []
160
- raw_labels = []
161
- for group in ["train", "test"]:
162
- file_path = base_path / f"{group}_images.npy"
163
- raw_data.append(self._grab_corruption_data(file_path))
164
-
165
- label_path = base_path / f"{group}_labels.npy"
166
- raw_labels.append(self._grab_corruption_data(label_path))
167
-
168
- data = np.concatenate(raw_data, axis=0).transpose(0, 3, 1, 2)
169
- labels = np.concatenate(raw_labels).astype(np.uintp)
170
- else:
171
- file_path = base_path / f"{self.image_set}_images.npy"
172
- data = self._grab_corruption_data(file_path)
173
- data = data.astype(np.float64).transpose(0, 3, 1, 2)
174
-
175
- label_path = base_path / f"{self.image_set}_labels.npy"
176
- labels = self._grab_corruption_data(label_path)
177
- labels = labels.astype(np.uintp)
178
-
179
- return data, labels
180
-
181
- def _grab_data(self, path: Path) -> tuple[NDArray[Any], NDArray[np.uintp]]:
182
- """Function to load in the data numpy array"""
183
- with np.load(path, allow_pickle=True) as data_array:
184
- if self.image_set == "base":
185
- data = np.concatenate([data_array["x_train"], data_array["x_test"]], axis=0)
186
- labels = np.concatenate([data_array["y_train"], data_array["y_test"]], axis=0).astype(np.uintp)
187
- else:
188
- data, labels = data_array[f"x_{self.image_set}"], data_array[f"y_{self.image_set}"].astype(np.uintp)
189
- data = np.expand_dims(data, axis=1)
190
- return data, labels
191
-
192
- def _grab_corruption_data(self, path: Path) -> NDArray[Any]:
193
- """Function to load in the data numpy array for the previously chosen corrupt format"""
194
- return np.load(path, allow_pickle=False)
195
-
196
- def _read_file(self, path: str) -> NDArray[Any]:
197
- """
198
- Function to grab the correct image from the loaded data.
199
- Overwrite of the base `_read_file` because data is an all or nothing load.
200
- """
201
- index = int(path)
202
- return self._loaded_data[index]