dataeval 0.86.9__py3-none-any.whl → 0.88.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (78) hide show
  1. dataeval/__init__.py +1 -1
  2. dataeval/_log.py +1 -1
  3. dataeval/_version.py +2 -2
  4. dataeval/config.py +4 -19
  5. dataeval/data/_embeddings.py +78 -35
  6. dataeval/data/_images.py +41 -8
  7. dataeval/data/_metadata.py +348 -66
  8. dataeval/data/_selection.py +22 -7
  9. dataeval/data/_split.py +3 -2
  10. dataeval/data/selections/_classbalance.py +4 -3
  11. dataeval/data/selections/_classfilter.py +9 -8
  12. dataeval/data/selections/_indices.py +4 -3
  13. dataeval/data/selections/_prioritize.py +249 -29
  14. dataeval/data/selections/_reverse.py +1 -1
  15. dataeval/data/selections/_shuffle.py +5 -4
  16. dataeval/detectors/drift/_base.py +2 -1
  17. dataeval/detectors/drift/_mmd.py +2 -1
  18. dataeval/detectors/drift/_nml/_base.py +1 -1
  19. dataeval/detectors/drift/_nml/_chunk.py +2 -1
  20. dataeval/detectors/drift/_nml/_result.py +3 -2
  21. dataeval/detectors/drift/_nml/_thresholds.py +6 -5
  22. dataeval/detectors/drift/_uncertainty.py +2 -1
  23. dataeval/detectors/linters/duplicates.py +2 -1
  24. dataeval/detectors/linters/outliers.py +4 -3
  25. dataeval/detectors/ood/__init__.py +2 -1
  26. dataeval/detectors/ood/ae.py +1 -1
  27. dataeval/detectors/ood/base.py +39 -1
  28. dataeval/detectors/ood/knn.py +95 -0
  29. dataeval/detectors/ood/mixin.py +2 -1
  30. dataeval/metadata/_utils.py +1 -1
  31. dataeval/metrics/bias/_balance.py +29 -22
  32. dataeval/metrics/bias/_diversity.py +4 -4
  33. dataeval/metrics/bias/_parity.py +2 -2
  34. dataeval/metrics/stats/_base.py +3 -29
  35. dataeval/metrics/stats/_boxratiostats.py +2 -1
  36. dataeval/metrics/stats/_dimensionstats.py +2 -1
  37. dataeval/metrics/stats/_hashstats.py +21 -3
  38. dataeval/metrics/stats/_pixelstats.py +2 -1
  39. dataeval/metrics/stats/_visualstats.py +2 -1
  40. dataeval/outputs/_base.py +2 -3
  41. dataeval/outputs/_bias.py +2 -1
  42. dataeval/outputs/_estimators.py +1 -1
  43. dataeval/outputs/_linters.py +3 -3
  44. dataeval/outputs/_stats.py +3 -3
  45. dataeval/outputs/_utils.py +1 -1
  46. dataeval/outputs/_workflows.py +49 -31
  47. dataeval/typing.py +23 -9
  48. dataeval/utils/__init__.py +2 -2
  49. dataeval/utils/_array.py +3 -2
  50. dataeval/utils/_bin.py +9 -7
  51. dataeval/utils/_method.py +2 -3
  52. dataeval/utils/_multiprocessing.py +34 -0
  53. dataeval/utils/_plot.py +2 -1
  54. dataeval/utils/data/__init__.py +6 -5
  55. dataeval/utils/data/{metadata.py → _merge.py} +3 -2
  56. dataeval/utils/data/_validate.py +170 -0
  57. dataeval/utils/data/collate.py +2 -1
  58. dataeval/utils/torch/_internal.py +2 -1
  59. dataeval/utils/torch/trainer.py +1 -1
  60. dataeval/workflows/sufficiency.py +13 -9
  61. {dataeval-0.86.9.dist-info → dataeval-0.88.0.dist-info}/METADATA +8 -21
  62. dataeval-0.88.0.dist-info/RECORD +105 -0
  63. dataeval/utils/data/_dataset.py +0 -246
  64. dataeval/utils/datasets/__init__.py +0 -21
  65. dataeval/utils/datasets/_antiuav.py +0 -189
  66. dataeval/utils/datasets/_base.py +0 -266
  67. dataeval/utils/datasets/_cifar10.py +0 -201
  68. dataeval/utils/datasets/_fileio.py +0 -142
  69. dataeval/utils/datasets/_milco.py +0 -197
  70. dataeval/utils/datasets/_mixin.py +0 -54
  71. dataeval/utils/datasets/_mnist.py +0 -202
  72. dataeval/utils/datasets/_seadrone.py +0 -512
  73. dataeval/utils/datasets/_ships.py +0 -144
  74. dataeval/utils/datasets/_types.py +0 -48
  75. dataeval/utils/datasets/_voc.py +0 -583
  76. dataeval-0.86.9.dist-info/RECORD +0 -115
  77. {dataeval-0.86.9.dist-info → dataeval-0.88.0.dist-info}/WHEEL +0 -0
  78. /dataeval-0.86.9.dist-info/licenses/LICENSE.txt → /dataeval-0.88.0.dist-info/licenses/LICENSE +0 -0
@@ -1,189 +0,0 @@
1
- from __future__ import annotations
2
-
3
- __all__ = []
4
-
5
- from pathlib import Path
6
- from typing import TYPE_CHECKING, Any, Literal, Sequence
7
-
8
- from defusedxml.ElementTree import parse
9
- from numpy.typing import NDArray
10
-
11
- from dataeval.utils.datasets._base import BaseODDataset, DataLocation
12
- from dataeval.utils.datasets._mixin import BaseDatasetNumpyMixin
13
-
14
- if TYPE_CHECKING:
15
- from dataeval.typing import Transform
16
-
17
-
18
- class AntiUAVDetection(BaseODDataset[NDArray[Any], list[str], str], BaseDatasetNumpyMixin):
19
- """
20
- A UAV detection dataset focused on detecting UAVs in natural images against large variation in backgrounds.
21
-
22
- The dataset comes from the paper
23
- `Vision-based Anti-UAV Detection and Tracking <https://ieeexplore.ieee.org/document/9785379>`_
24
- by Jie Zhao et. al. (2022).
25
-
26
- The dataset is approximately 1.3 GB and can be found `here <https://github.com/wangdongdut/DUT-Anti-UAV>`_.
27
- Images are collected against a variety of different backgrounds with a variety in the number and type of UAV.
28
- Ground truth labels are provided for the train, validation and test set.
29
- There are 35 different types of drones along with a variety in lighting conditions and weather conditions.
30
-
31
- There are 10,000 images: 5200 images in the training set, 2200 images in the validation set,
32
- and 2600 images in the test set.
33
- The dataset only has a single UAV class with the focus being on identifying object location in the image.
34
- Ground-truth bounding boxes are provided in (x0, y0, x1, y1) format.
35
- The images come in a variety of sizes from 3744 x 5616 to 160 x 240.
36
-
37
- Parameters
38
- ----------
39
- root : str or pathlib.Path
40
- Root directory where the data should be downloaded to or
41
- the ``antiuavdetection`` folder of the already downloaded data.
42
- image_set: "train", "val", "test", or "base", default "train"
43
- If "base", then the full dataset is selected (train, val and test).
44
- transforms : Transform, Sequence[Transform] or None, default None
45
- Transform(s) to apply to the data.
46
- download : bool, default False
47
- If True, downloads the dataset from the internet and puts it in root directory.
48
- Class checks to see if data is already downloaded to ensure it does not create a duplicate download.
49
- verbose : bool, default False
50
- If True, outputs print statements.
51
-
52
- Attributes
53
- ----------
54
- path : pathlib.Path
55
- Location of the folder containing the data.
56
- image_set : "train", "val", "test", or "base"
57
- The selected image set from the dataset.
58
- index2label : dict[int, str]
59
- Dictionary which translates from class integers to the associated class strings.
60
- label2index : dict[str, int]
61
- Dictionary which translates from class strings to the associated class integers.
62
- metadata : DatasetMetadata
63
- Typed dictionary containing dataset metadata, such as `id` which returns the dataset class name.
64
- transforms : Sequence[Transform]
65
- The transforms to be applied to the data.
66
- size : int
67
- The size of the dataset.
68
-
69
- Note
70
- ----
71
- Data License: `Apache 2.0 <https://www.apache.org/licenses/LICENSE-2.0.txt>`_
72
- """
73
-
74
- # Need to run the sha256 on the files and then store that
75
- _resources = [
76
- DataLocation(
77
- url="https://drive.usercontent.google.com/download?id=1RVsSGPUKTdmoyoPTBTWwroyulLek1eTj&export=download&authuser=0&confirm=t&uuid=6bca4f94-a242-4bc2-9663-fb03cd94ef2c&at=APcmpox0--NroQ_3bqeTFaJxP7Pw%3A1746552902927",
78
- filename="train.zip",
79
- md5=False,
80
- checksum="14f927290556df60e23cedfa80dffc10dc21e4a3b6843e150cfc49644376eece",
81
- ),
82
- DataLocation(
83
- url="https://drive.usercontent.google.com/download?id=1333uEQfGuqTKslRkkeLSCxylh6AQ0X6n&export=download&authuser=0&confirm=t&uuid=c2ad2f01-aca8-4a85-96bb-b8ef6e40feea&at=APcmpozY-8bhk3nZSFaYbE8rq1Fi%3A1746551543297",
84
- filename="val.zip",
85
- md5=False,
86
- checksum="238be0ceb3e7c5be6711ee3247e49df2750d52f91f54f5366c68bebac112ebf8",
87
- ),
88
- DataLocation(
89
- url="https://drive.usercontent.google.com/download?id=1L1zeW1EMDLlXHClSDcCjl3rs_A6sVai0&export=download&authuser=0&confirm=t&uuid=5a1d7650-d8cd-4461-8354-7daf7292f06c&at=APcmpozLQC1CuP-n5_UX2JnP53Zo%3A1746551676177",
90
- filename="test.zip",
91
- md5=False,
92
- checksum="a671989a01cff98c684aeb084e59b86f4152c50499d86152eb970a9fc7fb1cbe",
93
- ),
94
- ]
95
-
96
- index2label: dict[int, str] = {
97
- 0: "unknown",
98
- 1: "UAV",
99
- }
100
-
101
- def __init__(
102
- self,
103
- root: str | Path,
104
- image_set: Literal["train", "val", "test", "base"] = "train",
105
- transforms: Transform[NDArray[Any]] | Sequence[Transform[NDArray[Any]]] | None = None,
106
- download: bool = False,
107
- verbose: bool = False,
108
- ) -> None:
109
- super().__init__(
110
- root,
111
- image_set,
112
- transforms,
113
- download,
114
- verbose,
115
- )
116
-
117
- def _load_data(self) -> tuple[list[str], list[str], dict[str, list[Any]]]:
118
- filepaths: list[str] = []
119
- targets: list[str] = []
120
- datum_metadata: dict[str, list[Any]] = {}
121
-
122
- # If base, load all resources
123
- if self.image_set == "base":
124
- metadata_list: list[dict[str, Any]] = []
125
-
126
- for resource in self._resources:
127
- self._resource = resource
128
- resource_filepaths, resource_targets, resource_metadata = super()._load_data()
129
- filepaths.extend(resource_filepaths)
130
- targets.extend(resource_targets)
131
- metadata_list.append(resource_metadata)
132
-
133
- # Combine metadata
134
- for data_dict in metadata_list:
135
- for key, val in data_dict.items():
136
- str_key = str(key) # Ensure key is string
137
- if str_key not in datum_metadata:
138
- datum_metadata[str_key] = []
139
- datum_metadata[str_key].extend(val)
140
-
141
- else:
142
- # Grab only the desired data
143
- for resource in self._resources:
144
- if self.image_set in resource.filename:
145
- self._resource = resource
146
- resource_filepaths, resource_targets, resource_metadata = super()._load_data()
147
- filepaths.extend(resource_filepaths)
148
- targets.extend(resource_targets)
149
- datum_metadata.update(resource_metadata)
150
-
151
- return filepaths, targets, datum_metadata
152
-
153
- def _load_data_inner(self) -> tuple[list[str], list[str], dict[str, Any]]:
154
- resource_name = self._resource.filename[:-4]
155
- base_dir = self.path / resource_name
156
- data_folder = sorted((base_dir / "img").glob("*.jpg"))
157
- if not data_folder:
158
- raise FileNotFoundError
159
-
160
- file_data = {"image_id": [f"{resource_name}_{entry.name}" for entry in data_folder]}
161
- data = [str(entry) for entry in data_folder]
162
- annotations = sorted(str(entry) for entry in (base_dir / "xml").glob("*.xml"))
163
-
164
- return data, annotations, file_data
165
-
166
- def _read_annotations(self, annotation: str) -> tuple[list[list[float]], list[int], dict[str, Any]]:
167
- """Function for extracting the info for the label and boxes"""
168
- boxes: list[list[float]] = []
169
- labels = []
170
- root = parse(annotation).getroot()
171
- if root is None:
172
- raise ValueError(f"Unable to parse {annotation}")
173
- additional_meta: dict[str, Any] = {
174
- "image_width": int(root.findtext("size/width", default="-1")),
175
- "image_height": int(root.findtext("size/height", default="-1")),
176
- "image_depth": int(root.findtext("size/depth", default="-1")),
177
- }
178
- for obj in root.findall("object"):
179
- labels.append(1 if obj.findtext("name", default="") == "UAV" else 0)
180
- boxes.append(
181
- [
182
- float(obj.findtext("bndbox/xmin", default="0")),
183
- float(obj.findtext("bndbox/ymin", default="0")),
184
- float(obj.findtext("bndbox/xmax", default="0")),
185
- float(obj.findtext("bndbox/ymax", default="0")),
186
- ]
187
- )
188
-
189
- return boxes, labels, additional_meta
@@ -1,266 +0,0 @@
1
- from __future__ import annotations
2
-
3
- __all__ = []
4
-
5
- from abc import abstractmethod
6
- from pathlib import Path
7
- from typing import TYPE_CHECKING, Any, Generic, Iterator, Literal, NamedTuple, Sequence, TypeVar, cast
8
-
9
- import numpy as np
10
-
11
- from dataeval.utils.datasets._fileio import _ensure_exists
12
- from dataeval.utils.datasets._mixin import BaseDatasetMixin
13
- from dataeval.utils.datasets._types import (
14
- AnnotatedDataset,
15
- DatasetMetadata,
16
- ImageClassificationDataset,
17
- ObjectDetectionDataset,
18
- ObjectDetectionTarget,
19
- SegmentationDataset,
20
- SegmentationTarget,
21
- )
22
-
23
- if TYPE_CHECKING:
24
- from dataeval.typing import Array, Transform
25
-
26
- _TArray = TypeVar("_TArray", bound=Array)
27
- else:
28
- _TArray = TypeVar("_TArray")
29
-
30
- _TTarget = TypeVar("_TTarget")
31
- _TRawTarget = TypeVar("_TRawTarget", Sequence[int], Sequence[str], Sequence[tuple[list[int], list[list[float]]]])
32
- _TAnnotation = TypeVar("_TAnnotation", int, str, tuple[list[int], list[list[float]]])
33
-
34
-
35
- class DataLocation(NamedTuple):
36
- url: str
37
- filename: str
38
- md5: bool
39
- checksum: str
40
-
41
-
42
- class BaseDataset(
43
- AnnotatedDataset[tuple[_TArray, _TTarget, dict[str, Any]]], Generic[_TArray, _TTarget, _TRawTarget, _TAnnotation]
44
- ):
45
- """
46
- Base class for internet downloaded datasets.
47
- """
48
-
49
- # Each subclass should override the attributes below.
50
- # Each resource tuple must contain:
51
- # 'url': str, the URL to download from
52
- # 'filename': str, the name of the file once downloaded
53
- # 'md5': boolean, True if it's the checksum value is md5
54
- # 'checksum': str, the associated checksum for the downloaded file
55
- _resources: list[DataLocation]
56
- _resource_index: int = 0
57
- index2label: dict[int, str]
58
-
59
- def __init__(
60
- self,
61
- root: str | Path,
62
- image_set: Literal["train", "val", "test", "operational", "base"] = "train",
63
- transforms: Transform[_TArray] | Sequence[Transform[_TArray]] | None = None,
64
- download: bool = False,
65
- verbose: bool = False,
66
- ) -> None:
67
- self._root: Path = root.absolute() if isinstance(root, Path) else Path(root).absolute()
68
- transforms = transforms if transforms is not None else []
69
- self.transforms: Sequence[Transform[_TArray]] = transforms if isinstance(transforms, Sequence) else [transforms]
70
- self.image_set = image_set
71
- self._verbose = verbose
72
-
73
- # Internal Attributes
74
- self._download = download
75
- self._filepaths: list[str]
76
- self._targets: _TRawTarget
77
- self._datum_metadata: dict[str, list[Any]]
78
- self._resource: DataLocation = self._resources[self._resource_index]
79
- self._label2index = {v: k for k, v in self.index2label.items()}
80
-
81
- self.metadata: DatasetMetadata = DatasetMetadata(
82
- id=self._unique_id(),
83
- index2label=self.index2label,
84
- split=self.image_set,
85
- )
86
-
87
- # Load the data
88
- self.path: Path = self._get_dataset_dir()
89
- self._filepaths, self._targets, self._datum_metadata = self._load_data()
90
- self.size: int = len(self._filepaths)
91
-
92
- def __str__(self) -> str:
93
- nt = "\n "
94
- title = f"{self.__class__.__name__} Dataset"
95
- sep = "-" * len(title)
96
- attrs = [f"{k.capitalize()}: {v}" for k, v in self.__dict__.items() if not k.startswith("_")]
97
- return f"{title}\n{sep}{nt}{nt.join(attrs)}"
98
-
99
- @property
100
- def label2index(self) -> dict[str, int]:
101
- return self._label2index
102
-
103
- def __iter__(self) -> Iterator[tuple[_TArray, _TTarget, dict[str, Any]]]:
104
- for i in range(len(self)):
105
- yield self[i]
106
-
107
- def _get_dataset_dir(self) -> Path:
108
- # Create a designated folder for this dataset (named after the class)
109
- if self._root.stem.lower() == self.__class__.__name__.lower():
110
- dataset_dir: Path = self._root
111
- else:
112
- dataset_dir: Path = self._root / self.__class__.__name__.lower()
113
- if not dataset_dir.exists():
114
- dataset_dir.mkdir(parents=True, exist_ok=True)
115
- return dataset_dir
116
-
117
- def _unique_id(self) -> str:
118
- return f"{self.__class__.__name__}_{self.image_set}"
119
-
120
- def _load_data(self) -> tuple[list[str], _TRawTarget, dict[str, Any]]:
121
- """
122
- Function to determine if data can be accessed or if it needs to be downloaded and/or extracted.
123
- """
124
- if self._verbose:
125
- print(f"Determining if {self._resource.filename} needs to be downloaded.")
126
-
127
- try:
128
- result = self._load_data_inner()
129
- if self._verbose:
130
- print("No download needed, loaded data successfully.")
131
- except FileNotFoundError:
132
- _ensure_exists(*self._resource, self.path, self._root, self._download, self._verbose)
133
- result = self._load_data_inner()
134
- return result
135
-
136
- @abstractmethod
137
- def _load_data_inner(self) -> tuple[list[str], _TRawTarget, dict[str, Any]]: ...
138
-
139
- def _transform(self, image: _TArray) -> _TArray:
140
- """Function to transform the image prior to returning based on parameters passed in."""
141
- for transform in self.transforms:
142
- image = transform(image)
143
- return image
144
-
145
- def __len__(self) -> int:
146
- return self.size
147
-
148
-
149
- class BaseICDataset(
150
- BaseDataset[_TArray, _TArray, list[int], int],
151
- BaseDatasetMixin[_TArray],
152
- ImageClassificationDataset[_TArray],
153
- ):
154
- """
155
- Base class for image classification datasets.
156
- """
157
-
158
- def __getitem__(self, index: int) -> tuple[_TArray, _TArray, dict[str, Any]]:
159
- """
160
- Args
161
- ----
162
- index : int
163
- Value of the desired data point
164
-
165
- Returns
166
- -------
167
- tuple[TArray, TArray, dict[str, Any]]
168
- Image, target, datum_metadata - where target is one-hot encoding of class.
169
- """
170
- # Get the associated label and score
171
- label = self._targets[index]
172
- score = self._one_hot_encode(label)
173
- # Get the image
174
- img = self._read_file(self._filepaths[index])
175
- img = self._transform(img)
176
-
177
- img_metadata = {key: val[index] for key, val in self._datum_metadata.items()}
178
-
179
- return img, score, img_metadata
180
-
181
-
182
- class BaseODDataset(
183
- BaseDataset[_TArray, ObjectDetectionTarget[_TArray], _TRawTarget, _TAnnotation],
184
- BaseDatasetMixin[_TArray],
185
- ObjectDetectionDataset[_TArray],
186
- ):
187
- """
188
- Base class for object detection datasets.
189
- """
190
-
191
- _bboxes_per_size: bool = False
192
-
193
- def __getitem__(self, index: int) -> tuple[_TArray, ObjectDetectionTarget[_TArray], dict[str, Any]]:
194
- """
195
- Args
196
- ----
197
- index : int
198
- Value of the desired data point
199
-
200
- Returns
201
- -------
202
- tuple[TArray, ObjectDetectionTarget[TArray], dict[str, Any]]
203
- Image, target, datum_metadata - target.boxes returns boxes in x0, y0, x1, y1 format
204
- """
205
- # Grab the bounding boxes and labels from the annotations
206
- annotation = cast(_TAnnotation, self._targets[index])
207
- boxes, labels, additional_metadata = self._read_annotations(annotation)
208
- # Get the image
209
- img = self._read_file(self._filepaths[index])
210
- img_size = img.shape
211
- img = self._transform(img)
212
- # Adjust labels if necessary
213
- if self._bboxes_per_size and boxes:
214
- boxes = boxes * np.array([[img_size[1], img_size[2], img_size[1], img_size[2]]])
215
- # Create the Object Detection Target
216
- target = ObjectDetectionTarget(self._as_array(boxes), self._as_array(labels), self._one_hot_encode(labels))
217
-
218
- img_metadata = {key: val[index] for key, val in self._datum_metadata.items()}
219
- img_metadata = img_metadata | additional_metadata
220
-
221
- return img, target, img_metadata
222
-
223
- @abstractmethod
224
- def _read_annotations(self, annotation: _TAnnotation) -> tuple[list[list[float]], list[int], dict[str, Any]]: ...
225
-
226
-
227
- class BaseSegDataset(
228
- BaseDataset[_TArray, SegmentationTarget[_TArray], list[str], str],
229
- BaseDatasetMixin[_TArray],
230
- SegmentationDataset[_TArray],
231
- ):
232
- """
233
- Base class for segmentation datasets.
234
- """
235
-
236
- _masks: Sequence[str]
237
-
238
- def __getitem__(self, index: int) -> tuple[_TArray, SegmentationTarget[_TArray], dict[str, Any]]:
239
- """
240
- Args
241
- ----
242
- index : int
243
- Value of the desired data point
244
-
245
- Returns
246
- -------
247
- tuple[TArray, SegmentationTarget[TArray], dict[str, Any]]
248
- Image, target, datum_metadata - target.mask returns the ground truth mask
249
- """
250
- # Grab the labels from the annotations
251
- _, labels, additional_metadata = self._read_annotations(self._targets[index])
252
- # Grab the ground truth masks
253
- mask = self._read_file(self._masks[index])
254
- # Get the image
255
- img = self._read_file(self._filepaths[index])
256
- img = self._transform(img)
257
-
258
- target = SegmentationTarget(mask, self._as_array(labels), self._one_hot_encode(labels))
259
-
260
- img_metadata = {key: val[index] for key, val in self._datum_metadata.items()}
261
- img_metadata = img_metadata | additional_metadata
262
-
263
- return img, target, img_metadata
264
-
265
- @abstractmethod
266
- def _read_annotations(self, annotation: str) -> tuple[list[list[float]], list[int], dict[str, Any]]: ...
@@ -1,201 +0,0 @@
1
- from __future__ import annotations
2
-
3
- __all__ = []
4
-
5
- from pathlib import Path
6
- from typing import TYPE_CHECKING, Any, Literal, Sequence, TypeVar
7
-
8
- import numpy as np
9
- from numpy.typing import NDArray
10
-
11
- from dataeval.utils.datasets._base import BaseICDataset, DataLocation
12
- from dataeval.utils.datasets._mixin import BaseDatasetNumpyMixin
13
-
14
- if TYPE_CHECKING:
15
- from dataeval.typing import Transform
16
-
17
- CIFARClassStringMap = Literal["airplane", "automobile", "bird", "cat", "deer", "dog", "frog", "horse", "ship", "truck"]
18
- TCIFARClassMap = TypeVar("TCIFARClassMap", CIFARClassStringMap, int, list[CIFARClassStringMap], list[int])
19
-
20
-
21
- class CIFAR10(BaseICDataset[NDArray[Any]], BaseDatasetNumpyMixin):
22
- """
23
- `CIFAR10 <https://www.cs.toronto.edu/~kriz/cifar.html>`_ Dataset as NumPy arrays.
24
-
25
- Parameters
26
- ----------
27
- root : str or pathlib.Path
28
- Root directory where the data should be downloaded to or the ``cifar10`` folder of the already downloaded data.
29
- image_set : "train", "test" or "base", default "train"
30
- If "base", returns all of the data to allow the user to create their own splits.
31
- transforms : Transform, Sequence[Transform] or None, default None
32
- Transform(s) to apply to the data.
33
- download : bool, default False
34
- If True, downloads the dataset from the internet and puts it in root directory.
35
- Class checks to see if data is already downloaded to ensure it does not create a duplicate download.
36
- verbose : bool, default False
37
- If True, outputs print statements.
38
-
39
- Attributes
40
- ----------
41
- path : pathlib.Path
42
- Location of the folder containing the data.
43
- image_set : "train", "test" or "base"
44
- The selected image set from the dataset.
45
- transforms : Sequence[Transform]
46
- The transforms to be applied to the data.
47
- size : int
48
- The size of the dataset.
49
- index2label : dict[int, str]
50
- Dictionary which translates from class integers to the associated class strings.
51
- label2index : dict[str, int]
52
- Dictionary which translates from class strings to the associated class integers.
53
- metadata : DatasetMetadata
54
- Typed dictionary containing dataset metadata, such as `id` which returns the dataset class name.
55
- """
56
-
57
- _resources = [
58
- DataLocation(
59
- url="https://www.cs.toronto.edu/~kriz/cifar-10-binary.tar.gz",
60
- filename="cifar-10-binary.tar.gz",
61
- md5=True,
62
- checksum="c32a1d4ab5d03f1284b67883e8d87530",
63
- ),
64
- ]
65
-
66
- index2label: dict[int, str] = {
67
- 0: "airplane",
68
- 1: "automobile",
69
- 2: "bird",
70
- 3: "cat",
71
- 4: "deer",
72
- 5: "dog",
73
- 6: "frog",
74
- 7: "horse",
75
- 8: "ship",
76
- 9: "truck",
77
- }
78
-
79
- def __init__(
80
- self,
81
- root: str | Path,
82
- image_set: Literal["train", "test", "base"] = "train",
83
- transforms: Transform[NDArray[Any]] | Sequence[Transform[NDArray[Any]]] | None = None,
84
- download: bool = False,
85
- verbose: bool = False,
86
- ) -> None:
87
- super().__init__(
88
- root,
89
- image_set,
90
- transforms,
91
- download,
92
- verbose,
93
- )
94
-
95
- def _load_bin_data(self, data_folder: list[Path]) -> tuple[list[str], list[int], dict[str, Any]]:
96
- batch_nums = np.zeros(60000, dtype=np.uint8)
97
- all_labels = np.zeros(60000, dtype=np.uint8)
98
- all_images = np.zeros((60000, 3, 32, 32), dtype=np.uint8)
99
- # Process each batch file, skipping .meta and .html files
100
- for batch_file in data_folder:
101
- # Get batch parameters
102
- batch_type = "test" if "test" in batch_file.stem else "train"
103
- batch_num = 5 if batch_type == "test" else int(batch_file.stem.split("_")[-1]) - 1
104
-
105
- # Load data
106
- batch_images, batch_labels = self._unpack_batch_files(batch_file)
107
-
108
- # Stack data
109
- num_images = batch_images.shape[0]
110
- batch_start = batch_num * num_images
111
- all_images[batch_start : batch_start + num_images] = batch_images
112
- all_labels[batch_start : batch_start + num_images] = batch_labels
113
- batch_nums[batch_start : batch_start + num_images] = batch_num
114
-
115
- # Save data
116
- self._loaded_data = all_images
117
- np.savez(self.path / "cifar10", images=self._loaded_data, labels=all_labels, batches=batch_nums)
118
-
119
- # Select data
120
- image_list = np.arange(all_labels.shape[0]).astype(str)
121
- if self.image_set == "train":
122
- return (
123
- image_list[np.nonzero(batch_nums != 5)[0]].tolist(),
124
- all_labels[batch_nums != 5].tolist(),
125
- {"batch_num": batch_nums[batch_nums != 5].tolist()},
126
- )
127
- if self.image_set == "test":
128
- return (
129
- image_list[np.nonzero(batch_nums == 5)[0]].tolist(),
130
- all_labels[batch_nums == 5].tolist(),
131
- {"batch_num": batch_nums[batch_nums == 5].tolist()},
132
- )
133
- return image_list.tolist(), all_labels.tolist(), {"batch_num": batch_nums.tolist()}
134
-
135
- def _load_data_inner(self) -> tuple[list[str], list[int], dict[str, Any]]:
136
- """Function to load in the file paths for the data and labels and retrieve metadata"""
137
- data_file = self.path / "cifar10.npz"
138
- if not data_file.exists():
139
- data_folder = sorted((self.path / "cifar-10-batches-bin").glob("*.bin"))
140
- if not data_folder:
141
- raise FileNotFoundError
142
- return self._load_bin_data(data_folder)
143
-
144
- # Load data
145
- data = np.load(data_file)
146
- self._loaded_data = data["images"]
147
- all_labels = data["labels"]
148
- batch_nums = data["batches"]
149
-
150
- # Select data
151
- image_list = np.arange(all_labels.shape[0]).astype(str)
152
- if self.image_set == "train":
153
- return (
154
- image_list[np.nonzero(batch_nums != 5)[0]].tolist(),
155
- all_labels[batch_nums != 5].tolist(),
156
- {"batch_num": batch_nums[batch_nums != 5].tolist()},
157
- )
158
- if self.image_set == "test":
159
- return (
160
- image_list[np.nonzero(batch_nums == 5)[0]].tolist(),
161
- all_labels[batch_nums == 5].tolist(),
162
- {"batch_num": batch_nums[batch_nums == 5].tolist()},
163
- )
164
- return image_list.tolist(), all_labels.tolist(), {"batch_num": batch_nums.tolist()}
165
-
166
- def _unpack_batch_files(self, file_path: Path) -> tuple[NDArray[np.uint8], NDArray[np.uint8]]:
167
- # Load pickle data with latin1 encoding
168
- with file_path.open("rb") as f:
169
- buffer = np.frombuffer(f.read(), dtype=np.uint8)
170
- # Each entry is 1 byte for label + 3072 bytes for image (3*32*32)
171
- entry_size = 1 + 3072
172
- num_entries = buffer.size // entry_size
173
- # Extract labels (first byte of each entry)
174
- labels = buffer[::entry_size]
175
-
176
- # Extract image data and reshape to (N, 3, 32, 32)
177
- images = np.zeros((num_entries, 3, 32, 32), dtype=np.uint8)
178
- for i in range(num_entries):
179
- # Skip the label byte and get image data for this entry
180
- start_idx = i * entry_size + 1 # +1 to skip label
181
- img_flat = buffer[start_idx : start_idx + 3072]
182
-
183
- # The CIFAR format stores channels in blocks (all R, then all G, then all B)
184
- # Each channel block is 1024 bytes (32x32)
185
- red_channel = img_flat[0:1024].reshape(32, 32)
186
- green_channel = img_flat[1024:2048].reshape(32, 32)
187
- blue_channel = img_flat[2048:3072].reshape(32, 32)
188
-
189
- # Stack the channels in the proper C×H×W format
190
- images[i, 0] = red_channel # Red channel
191
- images[i, 1] = green_channel # Green channel
192
- images[i, 2] = blue_channel # Blue channel
193
- return images, labels
194
-
195
- def _read_file(self, path: str) -> NDArray[Any]:
196
- """
197
- Function to grab the correct image from the loaded data.
198
- Overwrite of the base `_read_file` because data is an all or nothing load.
199
- """
200
- index = int(path)
201
- return self._loaded_data[index]