dataeval 0.76.0__py3-none-any.whl → 0.81.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (96) hide show
  1. dataeval/__init__.py +3 -3
  2. dataeval/{output.py → _output.py} +14 -0
  3. dataeval/config.py +77 -0
  4. dataeval/detectors/__init__.py +1 -1
  5. dataeval/detectors/drift/__init__.py +6 -6
  6. dataeval/detectors/drift/{base.py → _base.py} +41 -30
  7. dataeval/detectors/drift/{cvm.py → _cvm.py} +21 -28
  8. dataeval/detectors/drift/{ks.py → _ks.py} +20 -26
  9. dataeval/detectors/drift/{mmd.py → _mmd.py} +33 -19
  10. dataeval/detectors/drift/{torch.py → _torch.py} +2 -1
  11. dataeval/detectors/drift/{uncertainty.py → _uncertainty.py} +23 -7
  12. dataeval/detectors/drift/updates.py +1 -1
  13. dataeval/detectors/linters/__init__.py +0 -3
  14. dataeval/detectors/linters/duplicates.py +17 -8
  15. dataeval/detectors/linters/outliers.py +52 -43
  16. dataeval/detectors/ood/ae.py +29 -8
  17. dataeval/detectors/ood/base.py +5 -4
  18. dataeval/detectors/ood/metadata_ks_compare.py +1 -1
  19. dataeval/detectors/ood/mixin.py +20 -5
  20. dataeval/detectors/ood/output.py +1 -1
  21. dataeval/detectors/ood/vae.py +73 -0
  22. dataeval/metadata/__init__.py +5 -0
  23. dataeval/metadata/_ood.py +238 -0
  24. dataeval/metrics/__init__.py +1 -1
  25. dataeval/metrics/bias/__init__.py +5 -4
  26. dataeval/metrics/bias/{balance.py → _balance.py} +67 -17
  27. dataeval/metrics/bias/{coverage.py → _coverage.py} +41 -35
  28. dataeval/metrics/bias/{diversity.py → _diversity.py} +17 -12
  29. dataeval/metrics/bias/{parity.py → _parity.py} +89 -63
  30. dataeval/metrics/estimators/__init__.py +14 -4
  31. dataeval/metrics/estimators/{ber.py → _ber.py} +42 -11
  32. dataeval/metrics/estimators/_clusterer.py +104 -0
  33. dataeval/metrics/estimators/{divergence.py → _divergence.py} +18 -13
  34. dataeval/metrics/estimators/{uap.py → _uap.py} +4 -4
  35. dataeval/metrics/stats/__init__.py +7 -7
  36. dataeval/metrics/stats/{base.py → _base.py} +52 -16
  37. dataeval/metrics/stats/{boxratiostats.py → _boxratiostats.py} +6 -9
  38. dataeval/metrics/stats/{datasetstats.py → _datasetstats.py} +10 -14
  39. dataeval/metrics/stats/{dimensionstats.py → _dimensionstats.py} +6 -5
  40. dataeval/metrics/stats/{hashstats.py → _hashstats.py} +6 -6
  41. dataeval/metrics/stats/{labelstats.py → _labelstats.py} +25 -25
  42. dataeval/metrics/stats/{pixelstats.py → _pixelstats.py} +5 -4
  43. dataeval/metrics/stats/{visualstats.py → _visualstats.py} +9 -8
  44. dataeval/typing.py +54 -0
  45. dataeval/utils/__init__.py +2 -2
  46. dataeval/utils/_array.py +169 -0
  47. dataeval/utils/_bin.py +199 -0
  48. dataeval/utils/_clusterer.py +144 -0
  49. dataeval/utils/_fast_mst.py +189 -0
  50. dataeval/utils/{image.py → _image.py} +6 -4
  51. dataeval/utils/_method.py +18 -0
  52. dataeval/utils/{shared.py → _mst.py} +3 -65
  53. dataeval/utils/{plot.py → _plot.py} +4 -4
  54. dataeval/utils/data/__init__.py +22 -0
  55. dataeval/utils/data/_embeddings.py +105 -0
  56. dataeval/utils/data/_images.py +65 -0
  57. dataeval/utils/data/_metadata.py +352 -0
  58. dataeval/utils/data/_selection.py +119 -0
  59. dataeval/utils/{dataset/split.py → data/_split.py} +13 -14
  60. dataeval/utils/data/_targets.py +73 -0
  61. dataeval/utils/data/_types.py +58 -0
  62. dataeval/utils/data/collate.py +103 -0
  63. dataeval/utils/data/datasets/__init__.py +17 -0
  64. dataeval/utils/data/datasets/_base.py +254 -0
  65. dataeval/utils/data/datasets/_cifar10.py +134 -0
  66. dataeval/utils/data/datasets/_fileio.py +168 -0
  67. dataeval/utils/data/datasets/_milco.py +153 -0
  68. dataeval/utils/data/datasets/_mixin.py +56 -0
  69. dataeval/utils/data/datasets/_mnist.py +183 -0
  70. dataeval/utils/data/datasets/_ships.py +123 -0
  71. dataeval/utils/data/datasets/_voc.py +352 -0
  72. dataeval/utils/data/selections/__init__.py +15 -0
  73. dataeval/utils/data/selections/_classfilter.py +60 -0
  74. dataeval/utils/data/selections/_indices.py +26 -0
  75. dataeval/utils/data/selections/_limit.py +26 -0
  76. dataeval/utils/data/selections/_reverse.py +18 -0
  77. dataeval/utils/data/selections/_shuffle.py +29 -0
  78. dataeval/utils/metadata.py +198 -376
  79. dataeval/utils/torch/{gmm.py → _gmm.py} +4 -2
  80. dataeval/utils/torch/{internal.py → _internal.py} +21 -51
  81. dataeval/utils/torch/models.py +43 -2
  82. dataeval/workflows/sufficiency.py +10 -9
  83. {dataeval-0.76.0.dist-info → dataeval-0.81.0.dist-info}/METADATA +44 -15
  84. dataeval-0.81.0.dist-info/RECORD +94 -0
  85. dataeval/detectors/linters/clusterer.py +0 -512
  86. dataeval/detectors/linters/merged_stats.py +0 -49
  87. dataeval/detectors/ood/metadata_least_likely.py +0 -119
  88. dataeval/interop.py +0 -69
  89. dataeval/utils/dataset/__init__.py +0 -7
  90. dataeval/utils/dataset/datasets.py +0 -412
  91. dataeval/utils/dataset/read.py +0 -63
  92. dataeval-0.76.0.dist-info/RECORD +0 -67
  93. /dataeval/{log.py → _log.py} +0 -0
  94. /dataeval/utils/torch/{blocks.py → _blocks.py} +0 -0
  95. {dataeval-0.76.0.dist-info → dataeval-0.81.0.dist-info}/LICENSE.txt +0 -0
  96. {dataeval-0.76.0.dist-info → dataeval-0.81.0.dist-info}/WHEEL +0 -0
@@ -0,0 +1,134 @@
1
+ from __future__ import annotations
2
+
3
+ __all__ = []
4
+
5
+ from pathlib import Path
6
+ from typing import Any, Literal, Sequence, TypeVar
7
+
8
+ import numpy as np
9
+ from numpy.typing import NDArray
10
+ from PIL import Image
11
+
12
+ from dataeval.utils.data._types import Transform
13
+ from dataeval.utils.data.datasets._base import BaseICDataset, DataLocation
14
+ from dataeval.utils.data.datasets._mixin import BaseDatasetNumpyMixin
15
+
16
+ CIFARClassStringMap = Literal["airplane", "automobile", "bird", "cat", "deer", "dog", "frog", "horse", "ship", "truck"]
17
+ TCIFARClassMap = TypeVar("TCIFARClassMap", CIFARClassStringMap, int, list[CIFARClassStringMap], list[int])
18
+
19
+
20
+ class CIFAR10(BaseICDataset[NDArray[Any]], BaseDatasetNumpyMixin):
21
+ """
22
+ `CIFAR10 <https://www.cs.toronto.edu/~kriz/cifar.html>`_ Dataset as NumPy arrays.
23
+
24
+ Parameters
25
+ ----------
26
+ root : str or pathlib.Path
27
+ Root directory of dataset where the ``mnist`` folder exists.
28
+ download : bool, default False
29
+ If True, downloads the dataset from the internet and puts it in root directory.
30
+ Class checks to see if data is already downloaded to ensure it does not create a duplicate download.
31
+ image_set : "train", "test" or "base", default "train"
32
+ If "base", returns all of the data to allow the user to create their own splits.
33
+ transforms : Transform | Sequence[Transform] | None, default None
34
+ Transform(s) to apply to the data.
35
+ verbose : bool, default False
36
+ If True, outputs print statements.
37
+
38
+ Attributes
39
+ ----------
40
+ index2label : dict
41
+ Dictionary which translates from class integers to the associated class strings.
42
+ label2index : dict
43
+ Dictionary which translates from class strings to the associated class integers.
44
+ path : Path
45
+ Location of the folder containing the data.
46
+ metadata : dict
47
+ Dictionary containing Dataset metadata, such as `id` which returns the dataset class name.
48
+ """
49
+
50
+ _resources = [
51
+ DataLocation(
52
+ url="https://www.cs.toronto.edu/~kriz/cifar-10-binary.tar.gz",
53
+ filename="cifar-10-binary.tar.gz",
54
+ md5=True,
55
+ checksum="c32a1d4ab5d03f1284b67883e8d87530",
56
+ ),
57
+ ]
58
+
59
+ index2label: dict[int, str] = {
60
+ 0: "airplane",
61
+ 1: "automobile",
62
+ 2: "bird",
63
+ 3: "cat",
64
+ 4: "deer",
65
+ 5: "dog",
66
+ 6: "frog",
67
+ 7: "horse",
68
+ 8: "ship",
69
+ 9: "truck",
70
+ }
71
+
72
+ def __init__(
73
+ self,
74
+ root: str | Path,
75
+ download: bool = False,
76
+ image_set: Literal["train", "test", "base"] = "train",
77
+ transforms: Transform[NDArray[Any]] | Sequence[Transform[NDArray[Any]]] | None = None,
78
+ verbose: bool = False,
79
+ ) -> None:
80
+ super().__init__(
81
+ root,
82
+ download,
83
+ image_set,
84
+ transforms,
85
+ verbose,
86
+ )
87
+
88
+ def _load_data_inner(self) -> tuple[list[str], list[int], dict[str, Any]]:
89
+ """Function to load in the file paths for the data and labels and retrieve metadata"""
90
+ file_meta = {"batch_num": []}
91
+ raw_data = []
92
+ labels = []
93
+ data_folder = self.path / "cifar-10-batches-bin"
94
+ save_folder = self.path / "images"
95
+ image_sets: dict[str, list[str]] = {"base": [], "train": [], "test": []}
96
+
97
+ # Process each batch file, skipping .meta and .html files
98
+ for entry in data_folder.iterdir():
99
+ if entry.suffix == ".bin":
100
+ batch_data, batch_labels = self._unpack_batch_files(entry)
101
+ raw_data.append(batch_data)
102
+ group = "train" if "test" not in entry.stem else "test"
103
+ name_split = entry.stem.split("_")
104
+ batch_num = int(name_split[-1]) - 1 if group == "train" else 5
105
+ file_names = [
106
+ str(save_folder / f"{i + 10000 * batch_num:05d}_{self.index2label[label]}.png")
107
+ for i, label in enumerate(batch_labels)
108
+ ]
109
+ image_sets["base"].extend(file_names)
110
+ image_sets[group].extend(file_names)
111
+
112
+ if self.image_set in (group, "base"):
113
+ labels.extend(batch_labels)
114
+ file_meta["batch_num"].extend([batch_num] * len(labels))
115
+
116
+ # Stack and reshape images
117
+ images = np.vstack(raw_data).reshape(-1, 3, 32, 32)
118
+
119
+ # Save the raw data into images if not already there
120
+ if not save_folder.exists():
121
+ save_folder.mkdir(exist_ok=True)
122
+ for i, file in enumerate(image_sets["base"]):
123
+ Image.fromarray(images[i].transpose(1, 2, 0).astype(np.uint8)).save(file)
124
+
125
+ return image_sets[self.image_set], labels, file_meta
126
+
127
+ def _unpack_batch_files(self, file_path: Path) -> tuple[NDArray[Any], list[int]]:
128
+ # Load pickle data with latin1 encoding
129
+ with file_path.open("rb") as f:
130
+ buffer = np.frombuffer(f.read(), "B")
131
+ labels = buffer[::3073]
132
+ pixels = np.delete(buffer, np.arange(0, buffer.size, 3073))
133
+ images = pixels.reshape(-1, 3072)
134
+ return images, labels.tolist()
@@ -0,0 +1,168 @@
1
+ from __future__ import annotations
2
+
3
+ __all__ = []
4
+
5
+ import hashlib
6
+ import shutil
7
+ import tarfile
8
+ import zipfile
9
+ from pathlib import Path
10
+
11
+ import requests
12
+ from tqdm import tqdm
13
+
14
+ ARCHIVE_ENDINGS = [".zip", ".tar", ".tgz"]
15
+ COMPRESS_ENDINGS = [".gz", ".bz2"]
16
+
17
+
18
+ def _validate_file(fpath, file_md5, md5: bool = False, chunk_size=65535) -> bool:
19
+ hasher = hashlib.md5(usedforsecurity=False) if md5 else hashlib.sha256()
20
+ with open(fpath, "rb") as fpath_file:
21
+ while chunk := fpath_file.read(chunk_size):
22
+ hasher.update(chunk)
23
+ return hasher.hexdigest() == file_md5
24
+
25
+
26
+ def _download_dataset(url: str, file_path: Path, timeout: int = 60) -> None:
27
+ """Download a single resource from its URL to the `data_folder`."""
28
+ error_msg = "URL fetch failure on {}: {} -- {}"
29
+ try:
30
+ response = requests.get(url, stream=True, timeout=timeout)
31
+ response.raise_for_status()
32
+ except requests.exceptions.HTTPError as e:
33
+ raise RuntimeError(f"{error_msg.format(url, e.response.status_code, e.response.reason)}") from e
34
+ except requests.exceptions.RequestException as e:
35
+ raise ValueError(f"{error_msg.format(url, 'Unknown error', str(e))}") from e
36
+
37
+ total_size = int(response.headers.get("content-length", 0))
38
+ block_size = 8192 # 8 KB
39
+ progress_bar = tqdm(total=total_size, unit="iB", unit_scale=True)
40
+
41
+ with open(file_path, "wb") as f:
42
+ for chunk in response.iter_content(block_size):
43
+ f.write(chunk)
44
+ progress_bar.update(len(chunk))
45
+ progress_bar.close()
46
+
47
+
48
+ def _extract_zip_archive(file_path: Path, extract_to: Path) -> None:
49
+ """Extracts the zip file to the given directory."""
50
+ try:
51
+ with zipfile.ZipFile(file_path, "r") as zip_ref:
52
+ zip_ref.extractall(extract_to)
53
+ file_path.unlink()
54
+ except zipfile.BadZipFile:
55
+ raise FileNotFoundError(f"{file_path.name} is not a valid zip file, skipping extraction.")
56
+
57
+
58
+ def _extract_tar_archive(file_path: Path, extract_to: Path) -> None:
59
+ """Extracts a tar file (or compressed tar) to the specified directory."""
60
+ try:
61
+ with tarfile.open(file_path, "r:*") as tar_ref:
62
+ tar_ref.extractall(extract_to)
63
+ file_path.unlink()
64
+ except tarfile.TarError:
65
+ raise FileNotFoundError(f"{file_path.name} is not a valid tar file, skipping extraction.")
66
+
67
+
68
+ def _flatten_extraction(base_directory: Path, verbose: bool = False) -> None:
69
+ """
70
+ If the extracted folder contains only directories (and no files),
71
+ move all its subfolders to the dataset_dir and remove the now-empty folder.
72
+ """
73
+ for child in base_directory.iterdir():
74
+ if child.is_dir():
75
+ inner_list = list(child.iterdir())
76
+ if all(subchild.is_dir() for subchild in inner_list):
77
+ for subchild in child.iterdir():
78
+ if verbose:
79
+ print(f"Moving {subchild.stem} to {base_directory}")
80
+ shutil.move(subchild, base_directory)
81
+
82
+ if verbose:
83
+ print(f"Removing empty folder {child.stem}")
84
+ child.rmdir()
85
+
86
+ # Checking for additional placeholder folders
87
+ if len(inner_list) == 1:
88
+ _flatten_extraction(base_directory, verbose)
89
+
90
+
91
+ def _archive_extraction(file_ext, file_path, directory, compression: bool = False, verbose: bool = False):
92
+ """
93
+ Single function to extract and then flatten if necessary.
94
+ Recursively extracts nested zip files as well.
95
+ Extracts and flattens all folders to the base directory.
96
+ """
97
+ if file_ext != ".zip" or compression:
98
+ _extract_tar_archive(file_path, directory)
99
+ else:
100
+ _extract_zip_archive(file_path, directory)
101
+ # Look for nested zip files in the extraction directory and extract them recursively.
102
+ # Does NOT extract in place - extracts everything to directory
103
+ for child in directory.iterdir():
104
+ if child.suffix == ".zip":
105
+ if verbose:
106
+ print(f"Extracting nested zip: {child} to {directory}")
107
+ _extract_zip_archive(child, directory)
108
+
109
+ # Determine if there are nested folders and remove them
110
+ # Helps ensure there that data is at most one folder below main directory
111
+ _flatten_extraction(directory, verbose)
112
+
113
+
114
+ def _ensure_exists(
115
+ url: str,
116
+ filename: str,
117
+ md5: bool,
118
+ checksum: str,
119
+ directory: Path,
120
+ root: Path,
121
+ download: bool = True,
122
+ verbose: bool = False,
123
+ ) -> None:
124
+ """
125
+ For each resource, download it if it doesn't exist in the dataset_dir.
126
+ If the resource is a zip file, extract it (including recursively extracting nested zips).
127
+ """
128
+ file_path = directory / str(filename)
129
+ alternate_path = root / str(filename)
130
+ _, file_ext = file_path.stem, file_path.suffix
131
+ compression = False
132
+ if file_ext in COMPRESS_ENDINGS:
133
+ file_ext = file_path.suffixes[0]
134
+ compression = True
135
+
136
+ check_path = alternate_path if alternate_path.exists() and not file_path.exists() else file_path
137
+
138
+ # Download file if it doesn't exist.
139
+ if not check_path.exists() and download:
140
+ if verbose:
141
+ print(f"Downloading {filename} from {url}")
142
+ _download_dataset(url, check_path)
143
+
144
+ if not _validate_file(check_path, checksum, md5):
145
+ raise Exception("File checksum mismatch. Remove current file and retry download.")
146
+
147
+ # If the file is a zip, tar or tgz extract it into the designated folder.
148
+ if file_ext in ARCHIVE_ENDINGS:
149
+ if verbose:
150
+ print(f"Extracting {filename}...")
151
+ _archive_extraction(file_ext, check_path, directory, compression, verbose)
152
+
153
+ elif not check_path.exists() and not download:
154
+ raise FileNotFoundError(
155
+ "Data could not be loaded with the provided root directory, ",
156
+ f"the file path to the file {filename} does not exist, ",
157
+ "and the download parameter is set to False.",
158
+ )
159
+ else:
160
+ if not _validate_file(check_path, checksum, md5):
161
+ raise Exception("File checksum mismatch. Remove current file and retry download.")
162
+ if verbose:
163
+ print(f"{filename} already exists, skipping download.")
164
+
165
+ if file_ext in ARCHIVE_ENDINGS:
166
+ if verbose:
167
+ print(f"Extracting {filename}...")
168
+ _archive_extraction(file_ext, check_path, directory, compression, verbose)
@@ -0,0 +1,153 @@
1
+ from __future__ import annotations
2
+
3
+ from dataeval.utils.data.datasets._mixin import BaseDatasetNumpyMixin
4
+
5
+ __all__ = []
6
+
7
+ from pathlib import Path
8
+ from typing import Any, Sequence
9
+
10
+ from numpy.typing import NDArray
11
+
12
+ from dataeval.utils.data._types import Transform
13
+ from dataeval.utils.data.datasets._base import BaseODDataset, DataLocation
14
+
15
+
16
+ class MILCO(BaseODDataset[NDArray[Any]], BaseDatasetNumpyMixin):
17
+ """
18
+ A side-scan sonar dataset focused on mine (object) detection.
19
+
20
+
21
+ The dataset comes from the paper
22
+ `Side-scan sonar imaging data of underwater vehicles for mine detection <https://doi.org/10.1016/j.dib.2024.110132>`_
23
+ by N.P. Santos et. al. (2024).
24
+
25
+ This class only accesses a portion of the above dataset due to size constraints.
26
+ The full dataset contains 1170 side-scan sonar images collected using a 900-1800 kHz Marine Sonic
27
+ dual frequency side-scan sonar of a Teledyne Marine Gavia Autonomous Underwater Vehicle.
28
+ All the images were carefully analyzed and annotated, including the image coordinates of the
29
+ Bounding Box (BB) of the detected objects divided into NOn-Mine-like BOttom Objects (NOMBO)
30
+ and MIne-Like COntacts (MILCO) classes.
31
+
32
+ This dataset is consists of 261 images (120 images from 2015, 93 images from 2017, and 48 images from 2021).
33
+ In these 261 images, there are 315 MILCO objects, and 175 NOMBO objects.
34
+ The class “0” corresponds to a MILCO object and the class “1” corresponds to a NOMBO object.
35
+ The raw BB coordinates provided in the downloaded text files are (x, y, w, h),
36
+ given as percentages of the image (x_BB = x/img_width, y_BB = y/img_height, etc.).
37
+ The images come in 2 sizes, 416 x 416 or 1024 x 1024.
38
+
39
+ Parameters
40
+ ----------
41
+ root : str or pathlib.Path
42
+ Root directory of dataset where the ``milco`` folder exists.
43
+ download : bool, default False
44
+ If True, downloads the dataset from the internet and puts it in root directory.
45
+ Class checks to see if data is already downloaded to ensure it does not create a duplicate download.
46
+ transforms : Transform | Sequence[Transform] | None, default None
47
+ Transform(s) to apply to the data.
48
+ verbose : bool, default False
49
+ If True, outputs print statements.
50
+
51
+ Attributes
52
+ ----------
53
+ index2label : dict
54
+ Dictionary which translates from class integers to the associated class strings.
55
+ label2index : dict
56
+ Dictionary which translates from class strings to the associated class integers.
57
+ path : Path
58
+ Location of the folder containing the data.
59
+ metadata : dict
60
+ Dictionary containing Dataset metadata, such as `id` which returns the dataset class name.
61
+ """
62
+
63
+ _resources = [
64
+ DataLocation(
65
+ url="https://figshare.com/ndownloader/files/43169002",
66
+ filename="2015.zip",
67
+ md5=True,
68
+ checksum="93dfbb4fb7987734152c372496b4884c",
69
+ ),
70
+ DataLocation(
71
+ url="https://figshare.com/ndownloader/files/43169005",
72
+ filename="2017.zip",
73
+ md5=True,
74
+ checksum="9c2de230a2bbf654921416bea6fc0f42",
75
+ ),
76
+ DataLocation(
77
+ url="https://figshare.com/ndownloader/files/43168999",
78
+ filename="2021.zip",
79
+ md5=True,
80
+ checksum="b84749b21fa95a4a4c7de3741db78bc7",
81
+ ),
82
+ ]
83
+
84
+ index2label: dict[int, str] = {
85
+ 0: "MILCO",
86
+ 1: "NOMBO",
87
+ }
88
+
89
+ def __init__(
90
+ self,
91
+ root: str | Path,
92
+ download: bool = False,
93
+ transforms: Transform[NDArray[Any]] | Sequence[Transform[NDArray[Any]]] | None = None,
94
+ verbose: bool = False,
95
+ ) -> None:
96
+ super().__init__(
97
+ root,
98
+ download,
99
+ "base",
100
+ transforms,
101
+ verbose,
102
+ )
103
+
104
+ def _load_data(self) -> tuple[list[str], list[str], dict[str, list[Any]]]:
105
+ filepaths: list[str] = []
106
+ targets: list[str] = []
107
+ datum_metadata: dict[str, list[Any]] = {}
108
+ metadata_list: list[dict[str, Any]] = []
109
+
110
+ # Load the data
111
+ for resource in self._resources:
112
+ self._resource = resource
113
+ filepath, target, metadata = super()._load_data()
114
+ filepaths.extend(filepath)
115
+ targets.extend(target)
116
+ metadata_list.append(metadata)
117
+
118
+ # Adjust datum metadata to correct format
119
+ for data_dict in metadata_list:
120
+ for key, val in data_dict.items():
121
+ if key not in datum_metadata:
122
+ datum_metadata[str(key)] = []
123
+ datum_metadata[str(key)].extend(val)
124
+
125
+ return filepaths, targets, datum_metadata
126
+
127
+ def _load_data_inner(self) -> tuple[list[str], list[str], dict[str, Any]]:
128
+ file_data = {"year": [], "image_id": [], "data_path": [], "label_path": []}
129
+ data_folder = self.path / self._resource.filename[:-4]
130
+ for entry in data_folder.iterdir():
131
+ if entry.is_file() and entry.suffix == ".jpg":
132
+ # Remove file extension and split by "_"
133
+ parts = entry.stem.split("_")
134
+ file_data["image_id"].append(parts[0])
135
+ file_data["year"].append(parts[1])
136
+ file_data["data_path"].append(str(entry))
137
+ file_data["label_path"].append(str(entry.parent / entry.stem) + ".txt")
138
+ data = file_data.pop("data_path")
139
+ annotations = file_data.pop("label_path")
140
+
141
+ return data, annotations, file_data
142
+
143
+ def _read_annotations(self, annotation: str) -> tuple[list[list[float]], list[int], dict[str, Any]]:
144
+ """Function for extracting the info out of the text files"""
145
+ labels: list[int] = []
146
+ boxes: list[list[float]] = []
147
+ with open(annotation) as f:
148
+ for line in f.readlines():
149
+ out = line.strip().split(" ")
150
+ labels.append(int(out[0]))
151
+ boxes.append([float(out[1]), float(out[2]), float(out[3]), float(out[4])])
152
+
153
+ return boxes, labels, {}
@@ -0,0 +1,56 @@
1
+ from __future__ import annotations
2
+
3
+ __all__ = []
4
+
5
+ from typing import Any, Generic, TypeVar
6
+
7
+ import numpy as np
8
+ import torch
9
+ from numpy.typing import NDArray
10
+ from PIL import Image
11
+
12
+ _TArray = TypeVar("_TArray")
13
+
14
+
15
+ class BaseDatasetMixin(Generic[_TArray]):
16
+ index2label: dict[int, str]
17
+
18
+ def _as_array(self, raw: list[Any]) -> _TArray: ...
19
+ def _one_hot_encode(self, value: int | list[int]) -> _TArray: ...
20
+ def _read_file(self, path: str) -> _TArray: ...
21
+
22
+
23
+ class BaseDatasetNumpyMixin(BaseDatasetMixin[NDArray[Any]]):
24
+ def _as_array(self, raw: list[Any]) -> NDArray[Any]:
25
+ return np.asarray(raw)
26
+
27
+ def _one_hot_encode(self, value: int | list[int]) -> NDArray[Any]:
28
+ if isinstance(value, int):
29
+ encoded = np.zeros(len(self.index2label))
30
+ encoded[value] = 1
31
+ else:
32
+ encoded = np.zeros((len(value), len(self.index2label)))
33
+ encoded[np.arange(len(value)), value] = 1
34
+ return encoded
35
+
36
+ def _read_file(self, path: str) -> NDArray[Any]:
37
+ x = np.array(Image.open(path)).transpose(2, 0, 1)
38
+ return x
39
+
40
+
41
+ class BaseDatasetTorchMixin(BaseDatasetMixin[torch.Tensor]):
42
+ def _as_array(self, raw: list[Any]) -> torch.Tensor:
43
+ return torch.as_tensor(raw)
44
+
45
+ def _one_hot_encode(self, value: int | list[int]) -> torch.Tensor:
46
+ if isinstance(value, int):
47
+ encoded = torch.zeros(len(self.index2label))
48
+ encoded[value] = 1
49
+ else:
50
+ encoded = torch.zeros((len(value), len(self.index2label)))
51
+ encoded[torch.arange(len(value)), value] = 1
52
+ return encoded
53
+
54
+ def _read_file(self, path: str) -> torch.Tensor:
55
+ x = torch.as_tensor(np.array(Image.open(path)).transpose(2, 0, 1))
56
+ return x
@@ -0,0 +1,183 @@
1
+ from __future__ import annotations
2
+
3
+ __all__ = []
4
+
5
+ from pathlib import Path
6
+ from typing import Any, Literal, Sequence, TypeVar
7
+
8
+ import numpy as np
9
+ from numpy.typing import NDArray
10
+
11
+ from dataeval.utils.data._types import Transform
12
+ from dataeval.utils.data.datasets._base import BaseICDataset, DataLocation
13
+ from dataeval.utils.data.datasets._mixin import BaseDatasetNumpyMixin
14
+
15
+ MNISTClassStringMap = Literal["zero", "one", "two", "three", "four", "five", "six", "seven", "eight", "nine"]
16
+ TMNISTClassMap = TypeVar("TMNISTClassMap", MNISTClassStringMap, int, list[MNISTClassStringMap], list[int])
17
+ CorruptionStringMap = Literal[
18
+ "identity",
19
+ "shot_noise",
20
+ "impulse_noise",
21
+ "glass_blur",
22
+ "motion_blur",
23
+ "shear",
24
+ "scale",
25
+ "rotate",
26
+ "brightness",
27
+ "translate",
28
+ "stripe",
29
+ "fog",
30
+ "spatter",
31
+ "dotted_line",
32
+ "zigzag",
33
+ "canny_edges",
34
+ ]
35
+
36
+
37
+ class MNIST(BaseICDataset[NDArray[Any]], BaseDatasetNumpyMixin):
38
+ """`MNIST <https://en.wikipedia.org/wiki/MNIST_database>`_ Dataset and `Corruptions <https://arxiv.org/abs/1906.02337>`_.
39
+
40
+ There are 15 different styles of corruptions. This class downloads differently depending on if you
41
+ need just the original dataset or if you need corruptions. If you need both a corrupt version and the
42
+ original version then choose `corruption="identity"` as this downloads all of the corrupt datasets and
43
+ provides the original as `identity`. If you just need the original, then using `corruption=None` will
44
+ download only the original dataset to save time and space.
45
+
46
+ Parameters
47
+ ----------
48
+ root : str or pathlib.Path
49
+ Root directory of dataset where the ``mnist`` folder exists.
50
+ download : bool, default False
51
+ If True, downloads the dataset from the internet and puts it in root directory.
52
+ Class checks to see if data is already downloaded to ensure it does not create a duplicate download.
53
+ image_set : "train", "test" or "base", default "train"
54
+ If "base", returns all of the data to allow the user to create their own splits.
55
+ verbose : bool, default False
56
+ If True, outputs print statements.
57
+
58
+ Attributes
59
+ ----------
60
+ index2label : dict
61
+ Dictionary which translates from class integers to the associated class strings.
62
+ label2index : dict
63
+ Dictionary which translates from class strings to the associated class integers.
64
+ path : Path
65
+ Location of the folder containing the data.
66
+ metadata : dict
67
+ Dictionary containing Dataset metadata, such as `id` which returns the dataset class name.
68
+ """
69
+
70
+ _resources = [
71
+ DataLocation(
72
+ url="https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz",
73
+ filename="mnist.npz",
74
+ md5=False,
75
+ checksum="731c5ac602752760c8e48fbffcf8c3b850d9dc2a2aedcf2cc48468fc17b673d1",
76
+ ),
77
+ DataLocation(
78
+ url="https://zenodo.org/record/3239543/files/mnist_c.zip",
79
+ filename="mnist_c.zip",
80
+ md5=True,
81
+ checksum="4b34b33045869ee6d424616cd3a65da3",
82
+ ),
83
+ ]
84
+
85
+ index2label: dict[int, str] = {
86
+ 0: "zero",
87
+ 1: "one",
88
+ 2: "two",
89
+ 3: "three",
90
+ 4: "four",
91
+ 5: "five",
92
+ 6: "six",
93
+ 7: "seven",
94
+ 8: "eight",
95
+ 9: "nine",
96
+ }
97
+
98
+ def __init__(
99
+ self,
100
+ root: str | Path,
101
+ download: bool = False,
102
+ image_set: Literal["train", "test", "base"] = "train",
103
+ corruption: CorruptionStringMap | None = None,
104
+ transforms: Transform[NDArray[Any]] | Sequence[Transform[NDArray[Any]]] | None = None,
105
+ verbose: bool = False,
106
+ ) -> None:
107
+ self.corruption = corruption
108
+ if self.corruption == "identity" and verbose:
109
+ print("Identity is not a corrupted dataset but the original MNIST dataset.")
110
+ self._resource_index = 0 if self.corruption is None else 1
111
+
112
+ super().__init__(
113
+ root,
114
+ download,
115
+ image_set,
116
+ transforms,
117
+ verbose,
118
+ )
119
+
120
+ def _load_data_inner(self) -> tuple[list[str], list[int], dict[str, Any]]:
121
+ """Function to load in the file paths for the data and labels from the correct data format"""
122
+ if self.corruption is None:
123
+ try:
124
+ file_path = self.path / self._resource.filename
125
+ self._loaded_data, labels = self._grab_data(file_path)
126
+ except FileNotFoundError:
127
+ self._loaded_data, labels = self._load_corruption()
128
+ else:
129
+ self._loaded_data, labels = self._load_corruption()
130
+
131
+ index_strings = np.arange(self._loaded_data.shape[0]).astype(str).tolist()
132
+ return index_strings, labels.tolist(), {}
133
+
134
+ def _load_corruption(self) -> tuple[NDArray[Any], NDArray[np.uintp]]:
135
+ """Function to load in the file paths for the data and labels for the different corrupt data formats"""
136
+ corruption = self.corruption if self.corruption is not None else "identity"
137
+ base_path = self.path / corruption
138
+ if self.image_set == "base":
139
+ raw_data = []
140
+ raw_labels = []
141
+ for group in ["train", "test"]:
142
+ file_path = base_path / f"{group}_images.npy"
143
+ raw_data.append(self._grab_corruption_data(file_path))
144
+
145
+ label_path = base_path / f"{group}_labels.npy"
146
+ raw_labels.append(self._grab_corruption_data(label_path))
147
+
148
+ data = np.concatenate(raw_data, axis=0).transpose(0, 3, 1, 2)
149
+ labels = np.concatenate(raw_labels).astype(np.uintp)
150
+ else:
151
+ file_path = base_path / f"{self.image_set}_images.npy"
152
+ data = self._grab_corruption_data(file_path)
153
+ data = data.astype(np.float64).transpose(0, 3, 1, 2)
154
+
155
+ label_path = base_path / f"{self.image_set}_labels.npy"
156
+ labels = self._grab_corruption_data(label_path)
157
+ labels = labels.astype(np.uintp)
158
+
159
+ return data, labels
160
+
161
+ def _grab_data(self, path: Path) -> tuple[NDArray[Any], NDArray[np.uintp]]:
162
+ """Function to load in the data numpy array"""
163
+ with np.load(path, allow_pickle=True) as data_array:
164
+ if self.image_set == "base":
165
+ data = np.concatenate([data_array["x_train"], data_array["x_test"]], axis=0)
166
+ labels = np.concatenate([data_array["y_train"], data_array["y_test"]], axis=0).astype(np.uintp)
167
+ else:
168
+ data, labels = data_array[f"x_{self.image_set}"], data_array[f"y_{self.image_set}"].astype(np.uintp)
169
+ data = np.expand_dims(data, axis=1)
170
+ return data, labels
171
+
172
+ def _grab_corruption_data(self, path: Path) -> NDArray[Any]:
173
+ """Function to load in the data numpy array for the previously chosen corrupt format"""
174
+ x = np.load(path, allow_pickle=False)
175
+ return x
176
+
177
+ def _read_file(self, path: str) -> NDArray[Any]:
178
+ """
179
+ Function to grab the correct image from the loaded data.
180
+ Overwrite of the base `_read_file` because data is an all or nothing load.
181
+ """
182
+ index = int(path)
183
+ return self._loaded_data[index]