dataeval 0.86.9__py3-none-any.whl → 0.88.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (78) hide show
  1. dataeval/__init__.py +1 -1
  2. dataeval/_log.py +1 -1
  3. dataeval/_version.py +2 -2
  4. dataeval/config.py +4 -19
  5. dataeval/data/_embeddings.py +78 -35
  6. dataeval/data/_images.py +41 -8
  7. dataeval/data/_metadata.py +348 -66
  8. dataeval/data/_selection.py +22 -7
  9. dataeval/data/_split.py +3 -2
  10. dataeval/data/selections/_classbalance.py +4 -3
  11. dataeval/data/selections/_classfilter.py +9 -8
  12. dataeval/data/selections/_indices.py +4 -3
  13. dataeval/data/selections/_prioritize.py +249 -29
  14. dataeval/data/selections/_reverse.py +1 -1
  15. dataeval/data/selections/_shuffle.py +5 -4
  16. dataeval/detectors/drift/_base.py +2 -1
  17. dataeval/detectors/drift/_mmd.py +2 -1
  18. dataeval/detectors/drift/_nml/_base.py +1 -1
  19. dataeval/detectors/drift/_nml/_chunk.py +2 -1
  20. dataeval/detectors/drift/_nml/_result.py +3 -2
  21. dataeval/detectors/drift/_nml/_thresholds.py +6 -5
  22. dataeval/detectors/drift/_uncertainty.py +2 -1
  23. dataeval/detectors/linters/duplicates.py +2 -1
  24. dataeval/detectors/linters/outliers.py +4 -3
  25. dataeval/detectors/ood/__init__.py +2 -1
  26. dataeval/detectors/ood/ae.py +1 -1
  27. dataeval/detectors/ood/base.py +39 -1
  28. dataeval/detectors/ood/knn.py +95 -0
  29. dataeval/detectors/ood/mixin.py +2 -1
  30. dataeval/metadata/_utils.py +1 -1
  31. dataeval/metrics/bias/_balance.py +29 -22
  32. dataeval/metrics/bias/_diversity.py +4 -4
  33. dataeval/metrics/bias/_parity.py +2 -2
  34. dataeval/metrics/stats/_base.py +3 -29
  35. dataeval/metrics/stats/_boxratiostats.py +2 -1
  36. dataeval/metrics/stats/_dimensionstats.py +2 -1
  37. dataeval/metrics/stats/_hashstats.py +21 -3
  38. dataeval/metrics/stats/_pixelstats.py +2 -1
  39. dataeval/metrics/stats/_visualstats.py +2 -1
  40. dataeval/outputs/_base.py +2 -3
  41. dataeval/outputs/_bias.py +2 -1
  42. dataeval/outputs/_estimators.py +1 -1
  43. dataeval/outputs/_linters.py +3 -3
  44. dataeval/outputs/_stats.py +3 -3
  45. dataeval/outputs/_utils.py +1 -1
  46. dataeval/outputs/_workflows.py +49 -31
  47. dataeval/typing.py +23 -9
  48. dataeval/utils/__init__.py +2 -2
  49. dataeval/utils/_array.py +3 -2
  50. dataeval/utils/_bin.py +9 -7
  51. dataeval/utils/_method.py +2 -3
  52. dataeval/utils/_multiprocessing.py +34 -0
  53. dataeval/utils/_plot.py +2 -1
  54. dataeval/utils/data/__init__.py +6 -5
  55. dataeval/utils/data/{metadata.py → _merge.py} +3 -2
  56. dataeval/utils/data/_validate.py +170 -0
  57. dataeval/utils/data/collate.py +2 -1
  58. dataeval/utils/torch/_internal.py +2 -1
  59. dataeval/utils/torch/trainer.py +1 -1
  60. dataeval/workflows/sufficiency.py +13 -9
  61. {dataeval-0.86.9.dist-info → dataeval-0.88.0.dist-info}/METADATA +8 -21
  62. dataeval-0.88.0.dist-info/RECORD +105 -0
  63. dataeval/utils/data/_dataset.py +0 -246
  64. dataeval/utils/datasets/__init__.py +0 -21
  65. dataeval/utils/datasets/_antiuav.py +0 -189
  66. dataeval/utils/datasets/_base.py +0 -266
  67. dataeval/utils/datasets/_cifar10.py +0 -201
  68. dataeval/utils/datasets/_fileio.py +0 -142
  69. dataeval/utils/datasets/_milco.py +0 -197
  70. dataeval/utils/datasets/_mixin.py +0 -54
  71. dataeval/utils/datasets/_mnist.py +0 -202
  72. dataeval/utils/datasets/_seadrone.py +0 -512
  73. dataeval/utils/datasets/_ships.py +0 -144
  74. dataeval/utils/datasets/_types.py +0 -48
  75. dataeval/utils/datasets/_voc.py +0 -583
  76. dataeval-0.86.9.dist-info/RECORD +0 -115
  77. {dataeval-0.86.9.dist-info → dataeval-0.88.0.dist-info}/WHEEL +0 -0
  78. /dataeval-0.86.9.dist-info/licenses/LICENSE.txt → /dataeval-0.88.0.dist-info/licenses/LICENSE +0 -0
@@ -1,142 +0,0 @@
1
- from __future__ import annotations
2
-
3
- __all__ = []
4
-
5
- import hashlib
6
- import tarfile
7
- import zipfile
8
- from pathlib import Path
9
-
10
- import requests
11
- from tqdm.auto import tqdm
12
-
13
- ARCHIVE_ENDINGS = [".zip", ".tar", ".tgz"]
14
- COMPRESS_ENDINGS = [".gz", ".bz2"]
15
-
16
-
17
- def _print(text: str, verbose: bool) -> None:
18
- if verbose:
19
- print(text)
20
-
21
-
22
- def _validate_file(fpath: Path | str, file_md5: str, md5: bool = False, chunk_size: int = 65535) -> bool:
23
- hasher = hashlib.md5(usedforsecurity=False) if md5 else hashlib.sha256()
24
- with open(fpath, "rb") as fpath_file:
25
- while chunk := fpath_file.read(chunk_size):
26
- hasher.update(chunk)
27
- return hasher.hexdigest() == file_md5
28
-
29
-
30
- def _download_dataset(url: str, file_path: Path, timeout: int = 60, verbose: bool = False) -> None:
31
- """Download a single resource from its URL to the `data_folder`."""
32
- error_msg = "URL fetch failure on {}: {} -- {}"
33
- try:
34
- response = requests.get(url, stream=True, timeout=timeout)
35
- response.raise_for_status()
36
- except requests.exceptions.HTTPError as e:
37
- raise RuntimeError(f"{error_msg.format(url, e.response.status_code, e.response.reason)}") from e
38
- except requests.exceptions.RequestException as e:
39
- raise ValueError(f"{error_msg.format(url, 'Unknown error', str(e))}") from e
40
-
41
- total_size = int(response.headers.get("content-length", 0))
42
- block_size = 8192 # 8 KB
43
- progress_bar = tqdm(total=total_size, unit="iB", unit_scale=True, disable=not verbose)
44
-
45
- with open(file_path, "wb") as f:
46
- for chunk in response.iter_content(block_size):
47
- f.write(chunk)
48
- progress_bar.update(len(chunk))
49
- progress_bar.close()
50
-
51
-
52
- def _extract_zip_archive(file_path: Path, extract_to: Path) -> None:
53
- """Extracts the zip file to the given directory."""
54
- try:
55
- with zipfile.ZipFile(file_path, "r") as zip_ref:
56
- zip_ref.extractall(extract_to) # noqa: S202
57
- file_path.unlink()
58
- except zipfile.BadZipFile:
59
- raise FileNotFoundError(f"{file_path.name} is not a valid zip file, skipping extraction.")
60
-
61
-
62
- def _extract_tar_archive(file_path: Path, extract_to: Path) -> None:
63
- """Extracts a tar file (or compressed tar) to the specified directory."""
64
- try:
65
- with tarfile.open(file_path, "r:*") as tar_ref:
66
- tar_ref.extractall(extract_to) # noqa: S202
67
- file_path.unlink()
68
- except tarfile.TarError:
69
- raise FileNotFoundError(f"{file_path.name} is not a valid tar file, skipping extraction.")
70
-
71
-
72
- def _extract_archive(
73
- file_ext: str, file_path: Path, directory: Path, compression: bool = False, verbose: bool = False
74
- ) -> None:
75
- """
76
- Single function to extract and then flatten if necessary.
77
- Recursively extracts nested zip files as well.
78
- Extracts and flattens all folders to the base directory.
79
- """
80
- if file_ext != ".zip" or compression:
81
- _extract_tar_archive(file_path, directory)
82
- else:
83
- _extract_zip_archive(file_path, directory)
84
- # Look for nested zip files in the extraction directory and extract them recursively.
85
- # Does NOT extract in place - extracts everything to directory
86
- for child in directory.iterdir():
87
- if child.suffix == ".zip":
88
- _print(f"Extracting nested zip: {child} to {directory}", verbose)
89
- _extract_zip_archive(child, directory)
90
-
91
-
92
- def _ensure_exists(
93
- url: str,
94
- filename: str,
95
- md5: bool,
96
- checksum: str,
97
- directory: Path,
98
- root: Path,
99
- download: bool = True,
100
- verbose: bool = False,
101
- ) -> None:
102
- """
103
- For each resource, download it if it doesn't exist in the dataset_dir.
104
- If the resource is a zip file, extract it (including recursively extracting nested zips).
105
- """
106
- file_path = directory / str(filename)
107
- alternate_path = root / str(filename)
108
- _, file_ext = file_path.stem, file_path.suffix
109
- compression = False
110
- if file_ext in COMPRESS_ENDINGS:
111
- file_ext = file_path.suffixes[0]
112
- compression = True
113
-
114
- check_path = alternate_path if alternate_path.exists() and not file_path.exists() else file_path
115
-
116
- # Download file if it doesn't exist.
117
- if not check_path.exists() and download:
118
- _print(f"Downloading {filename} from {url}", verbose)
119
- _download_dataset(url, check_path, verbose=verbose)
120
-
121
- if not _validate_file(check_path, checksum, md5):
122
- raise Exception("File checksum mismatch. Remove current file and retry download.")
123
-
124
- # If the file is a zip, tar or tgz extract it into the designated folder.
125
- if file_ext in ARCHIVE_ENDINGS:
126
- _print(f"Extracting {filename}...", verbose)
127
- _extract_archive(file_ext, check_path, directory, compression, verbose)
128
-
129
- elif not check_path.exists() and not download:
130
- raise FileNotFoundError(
131
- "Data could not be loaded with the provided root directory, "
132
- f"the file path to the file {filename} does not exist, "
133
- "and the download parameter is set to False."
134
- )
135
- else:
136
- if not _validate_file(check_path, checksum, md5):
137
- raise Exception("File checksum mismatch. Remove current file and retry download.")
138
- _print(f"{filename} already exists, skipping download.", verbose)
139
-
140
- if file_ext in ARCHIVE_ENDINGS:
141
- _print(f"Extracting {filename}...", verbose)
142
- _extract_archive(file_ext, check_path, directory, compression, verbose)
@@ -1,197 +0,0 @@
1
- from __future__ import annotations
2
-
3
- __all__ = []
4
-
5
- from pathlib import Path
6
- from typing import TYPE_CHECKING, Any, Literal, Sequence
7
-
8
- from numpy.typing import NDArray
9
-
10
- from dataeval.utils.datasets._base import BaseODDataset, DataLocation
11
- from dataeval.utils.datasets._mixin import BaseDatasetNumpyMixin
12
-
13
- if TYPE_CHECKING:
14
- from dataeval.typing import Transform
15
-
16
-
17
- class MILCO(BaseODDataset[NDArray[Any], list[str], str], BaseDatasetNumpyMixin):
18
- """
19
- A side-scan sonar dataset focused on mine-like object detection.
20
-
21
- The dataset comes from the paper
22
- `Side-scan sonar imaging data of underwater vehicles for mine detection <https://doi.org/10.1016/j.dib.2024.110132>`_
23
- by N.P. Santos et. al. (2024).
24
-
25
- The full dataset contains 1170 side-scan sonar images collected using a 900-1800 kHz Marine Sonic
26
- dual frequency side-scan sonar of a Teledyne Marine Gavia Autonomous Underwater Vehicle.
27
- All the images were carefully analyzed and annotated, including the image coordinates of the
28
- Bounding Box (BB) of the detected objects divided into NOn-Mine-like BOttom Objects (NOMBO)
29
- and MIne-Like COntacts (MILCO) classes.
30
-
31
- This dataset is consists of 345 images from 2010, 120 images from 2015, 93 images from 2017, 564 images from 2018,
32
- and 48 images from 2021). In these 1170 images, there are 432 MILCO objects, and 235 NOMBO objects.
33
- The class “0” corresponds to a MILCO object and the class “1” corresponds to a NOMBO object.
34
- The raw BB coordinates provided in the downloaded text files are (x, y, w, h),
35
- given as percentages of the image (x_BB = x/img_width, y_BB = y/img_height, etc.).
36
- The images come in 2 sizes, 416 x 416 or 1024 x 1024.
37
-
38
- Parameters
39
- ----------
40
- root : str or pathlib.Path
41
- Root directory where the data should be downloaded to or the ``milco`` folder of the already downloaded data.
42
- image_set: "train", "operational", or "base", default "train"
43
- If "train", then the images from 2015, 2017 and 2021 are selected,
44
- resulting in 315 MILCO objects and 177 NOMBO objects.
45
- If "operational", then the images from 2010 and 2018 are selected,
46
- resulting in 117 MILCO objects and 58 NOMBO objects.
47
- If "base", then the full dataset is selected.
48
- transforms : Transform, Sequence[Transform] or None, default None
49
- Transform(s) to apply to the data.
50
- download : bool, default False
51
- If True, downloads the dataset from the internet and puts it in root directory.
52
- Class checks to see if data is already downloaded to ensure it does not create a duplicate download.
53
- verbose : bool, default False
54
- If True, outputs print statements.
55
-
56
- Attributes
57
- ----------
58
- path : pathlib.Path
59
- Location of the folder containing the data.
60
- image_set : "train", "operational" or "base"
61
- The selected image set from the dataset.
62
- index2label : dict[int, str]
63
- Dictionary which translates from class integers to the associated class strings.
64
- label2index : dict[str, int]
65
- Dictionary which translates from class strings to the associated class integers.
66
- metadata : DatasetMetadata
67
- Typed dictionary containing dataset metadata, such as `id` which returns the dataset class name.
68
- transforms : Sequence[Transform]
69
- The transforms to be applied to the data.
70
- size : int
71
- The size of the dataset.
72
-
73
- Note
74
- ----
75
- Data License: `CC BY 4.0 <https://creativecommons.org/licenses/by/4.0/>`_
76
- """
77
-
78
- _resources = [
79
- DataLocation(
80
- url="https://figshare.com/ndownloader/files/43169002",
81
- filename="2015.zip",
82
- md5=True,
83
- checksum="93dfbb4fb7987734152c372496b4884c",
84
- ),
85
- DataLocation(
86
- url="https://figshare.com/ndownloader/files/43169005",
87
- filename="2017.zip",
88
- md5=True,
89
- checksum="9c2de230a2bbf654921416bea6fc0f42",
90
- ),
91
- DataLocation(
92
- url="https://figshare.com/ndownloader/files/43168999",
93
- filename="2021.zip",
94
- md5=True,
95
- checksum="b84749b21fa95a4a4c7de3741db78bc7",
96
- ),
97
- DataLocation(
98
- url="https://figshare.com/ndownloader/files/43169008",
99
- filename="2010.zip",
100
- md5=True,
101
- checksum="43347a0cc383c0d3dbe0d24ae56f328d",
102
- ),
103
- DataLocation(
104
- url="https://figshare.com/ndownloader/files/43169011",
105
- filename="2018.zip",
106
- md5=True,
107
- checksum="25d091044a10c78674fedad655023e3b",
108
- ),
109
- ]
110
-
111
- index2label: dict[int, str] = {
112
- 0: "MILCO",
113
- 1: "NOMBO",
114
- }
115
-
116
- def __init__(
117
- self,
118
- root: str | Path,
119
- image_set: Literal["train", "operational", "base"] = "train",
120
- transforms: Transform[NDArray[Any]] | Sequence[Transform[NDArray[Any]]] | None = None,
121
- download: bool = False,
122
- verbose: bool = False,
123
- ) -> None:
124
- super().__init__(
125
- root,
126
- image_set,
127
- transforms,
128
- download,
129
- verbose,
130
- )
131
- self._bboxes_per_size = True
132
-
133
- def _load_data(self) -> tuple[list[str], list[str], dict[str, list[Any]]]:
134
- filepaths: list[str] = []
135
- targets: list[str] = []
136
- datum_metadata: dict[str, list[Any]] = {}
137
- metadata_list: list[dict[str, Any]] = []
138
- image_sets: dict[str, list[int]] = {
139
- "base": list(range(len(self._resources))),
140
- "train": list(range(3)),
141
- "operational": list(range(3, len(self._resources))),
142
- }
143
-
144
- # Load the data
145
- resource_indices = image_sets[self.image_set]
146
- for idx in resource_indices:
147
- self._resource = self._resources[idx]
148
- filepath, target, metadata = super()._load_data()
149
- filepaths.extend(filepath)
150
- targets.extend(target)
151
- metadata_list.append(metadata)
152
-
153
- # Adjust datum metadata to correct format
154
- for data_dict in metadata_list:
155
- for key, val in data_dict.items():
156
- if key not in datum_metadata:
157
- datum_metadata[str(key)] = []
158
- datum_metadata[str(key)].extend(val)
159
-
160
- return filepaths, targets, datum_metadata
161
-
162
- def _load_data_inner(self) -> tuple[list[str], list[str], dict[str, Any]]:
163
- file_data = {"year": [], "image_id": [], "data_path": [], "label_path": []}
164
- data_folder = sorted((self.path / self._resource.filename[:-4]).glob("*.jpg"))
165
- if not data_folder:
166
- raise FileNotFoundError
167
-
168
- for entry in data_folder:
169
- # Remove file extension and split by "_"
170
- parts = entry.stem.split("_")
171
- file_data["image_id"].append(parts[0])
172
- file_data["year"].append(parts[1])
173
- file_data["data_path"].append(str(entry))
174
- file_data["label_path"].append(str(entry.parent / entry.stem) + ".txt")
175
- data = file_data.pop("data_path")
176
- annotations = file_data.pop("label_path")
177
-
178
- return data, annotations, file_data
179
-
180
- def _read_annotations(self, annotation: str) -> tuple[list[list[float]], list[int], dict[str, Any]]:
181
- """Function for extracting the info out of the text files"""
182
- labels: list[int] = []
183
- boxes: list[list[float]] = []
184
- with open(annotation) as f:
185
- for line in f.readlines():
186
- out = line.strip().split()
187
- labels.append(int(out[0]))
188
-
189
- xcenter, ycenter, width, height = [float(out[1]), float(out[2]), float(out[3]), float(out[4])]
190
-
191
- x0 = xcenter - width / 2
192
- x1 = x0 + width
193
- y0 = ycenter - height / 2
194
- y1 = y0 + height
195
- boxes.append([x0, y0, x1, y1])
196
-
197
- return boxes, labels, {}
@@ -1,54 +0,0 @@
1
- from __future__ import annotations
2
-
3
- __all__ = []
4
-
5
- from typing import Any, Generic, TypeVar
6
-
7
- import numpy as np
8
- import torch
9
- from numpy.typing import NDArray
10
- from PIL import Image
11
-
12
- _TArray = TypeVar("_TArray")
13
-
14
-
15
- class BaseDatasetMixin(Generic[_TArray]):
16
- index2label: dict[int, str]
17
-
18
- def _as_array(self, raw: list[Any]) -> _TArray: ...
19
- def _one_hot_encode(self, value: int | list[int]) -> _TArray: ...
20
- def _read_file(self, path: str) -> _TArray: ...
21
-
22
-
23
- class BaseDatasetNumpyMixin(BaseDatasetMixin[NDArray[Any]]):
24
- def _as_array(self, raw: list[Any]) -> NDArray[Any]:
25
- return np.asarray(raw)
26
-
27
- def _one_hot_encode(self, value: int | list[int]) -> NDArray[Any]:
28
- if isinstance(value, int):
29
- encoded = np.zeros(len(self.index2label))
30
- encoded[value] = 1
31
- else:
32
- encoded = np.zeros((len(value), len(self.index2label)))
33
- encoded[np.arange(len(value)), value] = 1
34
- return encoded
35
-
36
- def _read_file(self, path: str) -> NDArray[Any]:
37
- return np.array(Image.open(path)).transpose(2, 0, 1)
38
-
39
-
40
- class BaseDatasetTorchMixin(BaseDatasetMixin[torch.Tensor]):
41
- def _as_array(self, raw: list[Any]) -> torch.Tensor:
42
- return torch.as_tensor(raw)
43
-
44
- def _one_hot_encode(self, value: int | list[int]) -> torch.Tensor:
45
- if isinstance(value, int):
46
- encoded = torch.zeros(len(self.index2label))
47
- encoded[value] = 1
48
- else:
49
- encoded = torch.zeros((len(value), len(self.index2label)))
50
- encoded[torch.arange(len(value)), value] = 1
51
- return encoded
52
-
53
- def _read_file(self, path: str) -> torch.Tensor:
54
- return torch.as_tensor(np.array(Image.open(path)).transpose(2, 0, 1))
@@ -1,202 +0,0 @@
1
- from __future__ import annotations
2
-
3
- __all__ = []
4
-
5
- from pathlib import Path
6
- from typing import TYPE_CHECKING, Any, Literal, Sequence, TypeVar
7
-
8
- import numpy as np
9
- from numpy.typing import NDArray
10
-
11
- from dataeval.utils.datasets._base import BaseICDataset, DataLocation
12
- from dataeval.utils.datasets._mixin import BaseDatasetNumpyMixin
13
-
14
- if TYPE_CHECKING:
15
- from dataeval.typing import Transform
16
-
17
- MNISTClassStringMap = Literal["zero", "one", "two", "three", "four", "five", "six", "seven", "eight", "nine"]
18
- TMNISTClassMap = TypeVar("TMNISTClassMap", MNISTClassStringMap, int, list[MNISTClassStringMap], list[int])
19
- CorruptionStringMap = Literal[
20
- "identity",
21
- "shot_noise",
22
- "impulse_noise",
23
- "glass_blur",
24
- "motion_blur",
25
- "shear",
26
- "scale",
27
- "rotate",
28
- "brightness",
29
- "translate",
30
- "stripe",
31
- "fog",
32
- "spatter",
33
- "dotted_line",
34
- "zigzag",
35
- "canny_edges",
36
- ]
37
-
38
-
39
- class MNIST(BaseICDataset[NDArray[Any]], BaseDatasetNumpyMixin):
40
- """`MNIST <https://en.wikipedia.org/wiki/MNIST_database>`_ Dataset and `Corruptions <https://arxiv.org/abs/1906.02337>`_.
41
-
42
- There are 15 different styles of corruptions. This class downloads differently depending on if you
43
- need just the original dataset or if you need corruptions. If you need both a corrupt version and the
44
- original version then choose `corruption="identity"` as this downloads all of the corrupt datasets and
45
- provides the original as `identity`. If you just need the original, then using `corruption=None` will
46
- download only the original dataset to save time and space.
47
-
48
- Parameters
49
- ----------
50
- root : str or pathlib.Path
51
- Root directory where the data should be downloaded to or the ``minst`` folder of the already downloaded data.
52
- image_set : "train", "test" or "base", default "train"
53
- If "base", returns all of the data to allow the user to create their own splits.
54
- corruption : "identity", "shot_noise", "impulse_noise", "glass_blur", "motion_blur", \
55
- "shear", "scale", "rotate", "brightness", "translate", "stripe", "fog", "spatter", \
56
- "dotted_line", "zigzag", "canny_edges" or None, default None
57
- Corruption to apply to the data.
58
- transforms : Transform, Sequence[Transform] or None, default None
59
- Transform(s) to apply to the data.
60
- download : bool, default False
61
- If True, downloads the dataset from the internet and puts it in root directory.
62
- Class checks to see if data is already downloaded to ensure it does not create a duplicate download.
63
- verbose : bool, default False
64
- If True, outputs print statements.
65
-
66
- Attributes
67
- ----------
68
- path : pathlib.Path
69
- Location of the folder containing the data.
70
- image_set : "train", "test" or "base"
71
- The selected image set from the dataset.
72
- index2label : dict[int, str]
73
- Dictionary which translates from class integers to the associated class strings.
74
- label2index : dict[str, int]
75
- Dictionary which translates from class strings to the associated class integers.
76
- metadata : DatasetMetadata
77
- Typed dictionary containing dataset metadata, such as `id` which returns the dataset class name.
78
- corruption : str or None
79
- Corruption applied to the data.
80
- transforms : Sequence[Transform]
81
- The transforms to be applied to the data.
82
- size : int
83
- The size of the dataset.
84
-
85
- Note
86
- ----
87
- Data License: `CC BY 4.0 <https://creativecommons.org/licenses/by/4.0/>`_ for corruption dataset
88
- """
89
-
90
- _resources = [
91
- DataLocation(
92
- url="https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz",
93
- filename="mnist.npz",
94
- md5=False,
95
- checksum="731c5ac602752760c8e48fbffcf8c3b850d9dc2a2aedcf2cc48468fc17b673d1",
96
- ),
97
- DataLocation(
98
- url="https://zenodo.org/record/3239543/files/mnist_c.zip",
99
- filename="mnist_c.zip",
100
- md5=True,
101
- checksum="4b34b33045869ee6d424616cd3a65da3",
102
- ),
103
- ]
104
-
105
- index2label: dict[int, str] = {
106
- 0: "zero",
107
- 1: "one",
108
- 2: "two",
109
- 3: "three",
110
- 4: "four",
111
- 5: "five",
112
- 6: "six",
113
- 7: "seven",
114
- 8: "eight",
115
- 9: "nine",
116
- }
117
-
118
- def __init__(
119
- self,
120
- root: str | Path,
121
- image_set: Literal["train", "test", "base"] = "train",
122
- corruption: CorruptionStringMap | None = None,
123
- transforms: Transform[NDArray[Any]] | Sequence[Transform[NDArray[Any]]] | None = None,
124
- download: bool = False,
125
- verbose: bool = False,
126
- ) -> None:
127
- self.corruption = corruption
128
- if self.corruption == "identity" and verbose:
129
- print("Identity is not a corrupted dataset but the original MNIST dataset.")
130
- self._resource_index = 0 if self.corruption is None else 1
131
-
132
- super().__init__(
133
- root,
134
- image_set,
135
- transforms,
136
- download,
137
- verbose,
138
- )
139
-
140
- def _load_data_inner(self) -> tuple[list[str], list[int], dict[str, Any]]:
141
- """Function to load in the file paths for the data and labels from the correct data format"""
142
- if self.corruption is None:
143
- try:
144
- file_path = self.path / self._resource.filename
145
- self._loaded_data, labels = self._grab_data(file_path)
146
- except FileNotFoundError:
147
- self._loaded_data, labels = self._load_corruption()
148
- else:
149
- self._loaded_data, labels = self._load_corruption()
150
-
151
- index_strings = np.arange(self._loaded_data.shape[0]).astype(str).tolist()
152
- return index_strings, labels.tolist(), {}
153
-
154
- def _load_corruption(self) -> tuple[NDArray[Any], NDArray[np.uintp]]:
155
- """Function to load in the file paths for the data and labels for the different corrupt data formats"""
156
- corruption = self.corruption if self.corruption is not None else "identity"
157
- base_path = self.path / "mnist_c" / corruption
158
- if self.image_set == "base":
159
- raw_data = []
160
- raw_labels = []
161
- for group in ["train", "test"]:
162
- file_path = base_path / f"{group}_images.npy"
163
- raw_data.append(self._grab_corruption_data(file_path))
164
-
165
- label_path = base_path / f"{group}_labels.npy"
166
- raw_labels.append(self._grab_corruption_data(label_path))
167
-
168
- data = np.concatenate(raw_data, axis=0).transpose(0, 3, 1, 2)
169
- labels = np.concatenate(raw_labels).astype(np.uintp)
170
- else:
171
- file_path = base_path / f"{self.image_set}_images.npy"
172
- data = self._grab_corruption_data(file_path)
173
- data = data.astype(np.float64).transpose(0, 3, 1, 2)
174
-
175
- label_path = base_path / f"{self.image_set}_labels.npy"
176
- labels = self._grab_corruption_data(label_path)
177
- labels = labels.astype(np.uintp)
178
-
179
- return data, labels
180
-
181
- def _grab_data(self, path: Path) -> tuple[NDArray[Any], NDArray[np.uintp]]:
182
- """Function to load in the data numpy array"""
183
- with np.load(path, allow_pickle=True) as data_array:
184
- if self.image_set == "base":
185
- data = np.concatenate([data_array["x_train"], data_array["x_test"]], axis=0)
186
- labels = np.concatenate([data_array["y_train"], data_array["y_test"]], axis=0).astype(np.uintp)
187
- else:
188
- data, labels = data_array[f"x_{self.image_set}"], data_array[f"y_{self.image_set}"].astype(np.uintp)
189
- data = np.expand_dims(data, axis=1)
190
- return data, labels
191
-
192
- def _grab_corruption_data(self, path: Path) -> NDArray[Any]:
193
- """Function to load in the data numpy array for the previously chosen corrupt format"""
194
- return np.load(path, allow_pickle=False)
195
-
196
- def _read_file(self, path: str) -> NDArray[Any]:
197
- """
198
- Function to grab the correct image from the loaded data.
199
- Overwrite of the base `_read_file` because data is an all or nothing load.
200
- """
201
- index = int(path)
202
- return self._loaded_data[index]