maite-datasets 0.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- maite_datasets/__init__.py +1 -0
- maite_datasets/_base.py +254 -0
- maite_datasets/_fileio.py +174 -0
- maite_datasets/_mixin/__init__.py +0 -0
- maite_datasets/_mixin/_numpy.py +28 -0
- maite_datasets/_mixin/_torch.py +28 -0
- maite_datasets/_protocols.py +224 -0
- maite_datasets/_types.py +54 -0
- maite_datasets/image_classification/__init__.py +11 -0
- maite_datasets/image_classification/_cifar10.py +233 -0
- maite_datasets/image_classification/_mnist.py +215 -0
- maite_datasets/image_classification/_ships.py +150 -0
- maite_datasets/object_detection/__init__.py +20 -0
- maite_datasets/object_detection/_antiuav.py +200 -0
- maite_datasets/object_detection/_milco.py +207 -0
- maite_datasets/object_detection/_seadrone.py +551 -0
- maite_datasets/object_detection/_voc.py +510 -0
- maite_datasets/object_detection/_voc_torch.py +65 -0
- maite_datasets/py.typed +0 -0
- maite_datasets-0.0.1.dist-info/METADATA +91 -0
- maite_datasets-0.0.1.dist-info/RECORD +23 -0
- maite_datasets-0.0.1.dist-info/WHEEL +4 -0
- maite_datasets-0.0.1.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,224 @@
|
|
1
|
+
"""
|
2
|
+
Common type protocols used for interoperability with MAITE.
|
3
|
+
"""
|
4
|
+
|
5
|
+
import sys
|
6
|
+
from typing import (
|
7
|
+
Any,
|
8
|
+
Generic,
|
9
|
+
Iterator,
|
10
|
+
Mapping,
|
11
|
+
Protocol,
|
12
|
+
TypedDict,
|
13
|
+
TypeVar,
|
14
|
+
runtime_checkable,
|
15
|
+
)
|
16
|
+
|
17
|
+
import numpy.typing
|
18
|
+
from typing_extensions import NotRequired, ReadOnly, Required
|
19
|
+
|
20
|
+
if sys.version_info >= (3, 10):
|
21
|
+
from typing import TypeAlias
|
22
|
+
else:
|
23
|
+
from typing_extensions import TypeAlias
|
24
|
+
|
25
|
+
|
26
|
+
ArrayLike: TypeAlias = numpy.typing.ArrayLike
|
27
|
+
"""
|
28
|
+
Type alias for a `Union` representing objects that can be coerced into an array.
|
29
|
+
|
30
|
+
See Also
|
31
|
+
--------
|
32
|
+
`NumPy ArrayLike <https://numpy.org/doc/stable/reference/typing.html#numpy.typing.ArrayLike>`_
|
33
|
+
"""
|
34
|
+
|
35
|
+
|
36
|
+
@runtime_checkable
|
37
|
+
class Array(Protocol):
|
38
|
+
"""
|
39
|
+
Protocol for array objects providing interoperability with DataEval.
|
40
|
+
|
41
|
+
Supports common array representations with popular libraries like
|
42
|
+
PyTorch, Tensorflow and JAX, as well as NumPy arrays.
|
43
|
+
|
44
|
+
Example
|
45
|
+
-------
|
46
|
+
>>> import numpy as np
|
47
|
+
>>> import torch
|
48
|
+
>>> from maite_datasets._typing import Array
|
49
|
+
|
50
|
+
Create array objects
|
51
|
+
|
52
|
+
>>> ndarray = np.random.random((10, 10))
|
53
|
+
>>> tensor = torch.tensor([1, 2, 3])
|
54
|
+
|
55
|
+
Check type at runtime
|
56
|
+
|
57
|
+
>>> isinstance(ndarray, Array)
|
58
|
+
True
|
59
|
+
|
60
|
+
>>> isinstance(tensor, Array)
|
61
|
+
True
|
62
|
+
"""
|
63
|
+
|
64
|
+
@property
|
65
|
+
def shape(self) -> tuple[int, ...]: ...
|
66
|
+
def __array__(self) -> Any: ...
|
67
|
+
def __getitem__(self, key: Any, /) -> Any: ...
|
68
|
+
def __iter__(self) -> Iterator[Any]: ...
|
69
|
+
def __len__(self) -> int: ...
|
70
|
+
|
71
|
+
|
72
|
+
_T = TypeVar("_T")
|
73
|
+
_T_co = TypeVar("_T_co", covariant=True)
|
74
|
+
|
75
|
+
|
76
|
+
class DatasetMetadata(TypedDict, total=False):
|
77
|
+
"""
|
78
|
+
Dataset level metadata required for all `AnnotatedDataset` classes.
|
79
|
+
|
80
|
+
Attributes
|
81
|
+
----------
|
82
|
+
id : Required[str]
|
83
|
+
A unique identifier for the dataset
|
84
|
+
index2label : NotRequired[dict[int, str]]
|
85
|
+
A lookup table converting label value to class name
|
86
|
+
"""
|
87
|
+
|
88
|
+
id: Required[ReadOnly[str]]
|
89
|
+
index2label: NotRequired[ReadOnly[dict[int, str]]]
|
90
|
+
|
91
|
+
|
92
|
+
@runtime_checkable
|
93
|
+
class Dataset(Generic[_T_co], Protocol):
|
94
|
+
"""
|
95
|
+
Protocol for a generic `Dataset`.
|
96
|
+
|
97
|
+
Methods
|
98
|
+
-------
|
99
|
+
__getitem__(index: int)
|
100
|
+
Returns datum at specified index.
|
101
|
+
__len__()
|
102
|
+
Returns dataset length.
|
103
|
+
"""
|
104
|
+
|
105
|
+
def __getitem__(self, index: int, /) -> _T_co: ...
|
106
|
+
def __len__(self) -> int: ...
|
107
|
+
|
108
|
+
|
109
|
+
@runtime_checkable
|
110
|
+
class AnnotatedDataset(Dataset[_T_co], Generic[_T_co], Protocol):
|
111
|
+
"""
|
112
|
+
Protocol for a generic `AnnotatedDataset`.
|
113
|
+
|
114
|
+
Attributes
|
115
|
+
----------
|
116
|
+
metadata : :class:`.DatasetMetadata` or derivatives.
|
117
|
+
|
118
|
+
Methods
|
119
|
+
-------
|
120
|
+
__getitem__(index: int)
|
121
|
+
Returns datum at specified index.
|
122
|
+
__len__()
|
123
|
+
Returns dataset length.
|
124
|
+
|
125
|
+
Notes
|
126
|
+
-----
|
127
|
+
Inherits from :class:`.Dataset`.
|
128
|
+
"""
|
129
|
+
|
130
|
+
@property
|
131
|
+
def metadata(self) -> DatasetMetadata: ...
|
132
|
+
|
133
|
+
|
134
|
+
# ========== IMAGE CLASSIFICATION DATASETS ==========
|
135
|
+
|
136
|
+
|
137
|
+
ImageClassificationDatum: TypeAlias = tuple[ArrayLike, ArrayLike, Mapping[str, Any]]
|
138
|
+
"""
|
139
|
+
Type alias for an image classification datum tuple.
|
140
|
+
|
141
|
+
- :class:`ArrayLike` of shape (C, H, W) - Image data in channel, height, width format.
|
142
|
+
- :class:`ArrayLike` of shape (N,) - Class label as one-hot encoded ground-truth or prediction confidences.
|
143
|
+
- dict[str, Any] - Datum level metadata.
|
144
|
+
"""
|
145
|
+
|
146
|
+
|
147
|
+
ImageClassificationDataset: TypeAlias = AnnotatedDataset[ImageClassificationDatum]
|
148
|
+
"""
|
149
|
+
Type alias for an :class:`AnnotatedDataset` of :class:`ImageClassificationDatum` elements.
|
150
|
+
"""
|
151
|
+
|
152
|
+
# ========== OBJECT DETECTION DATASETS ==========
|
153
|
+
|
154
|
+
|
155
|
+
@runtime_checkable
|
156
|
+
class ObjectDetectionTarget(Protocol):
|
157
|
+
"""
|
158
|
+
Protocol for targets in an Object Detection dataset.
|
159
|
+
|
160
|
+
Attributes
|
161
|
+
----------
|
162
|
+
boxes : :class:`ArrayLike` of shape (N, 4)
|
163
|
+
labels : :class:`ArrayLike` of shape (N,)
|
164
|
+
scores : :class:`ArrayLike` of shape (N, M)
|
165
|
+
"""
|
166
|
+
|
167
|
+
@property
|
168
|
+
def boxes(self) -> ArrayLike: ...
|
169
|
+
|
170
|
+
@property
|
171
|
+
def labels(self) -> ArrayLike: ...
|
172
|
+
|
173
|
+
@property
|
174
|
+
def scores(self) -> ArrayLike: ...
|
175
|
+
|
176
|
+
|
177
|
+
ObjectDetectionDatum: TypeAlias = tuple[
|
178
|
+
ArrayLike, ObjectDetectionTarget, Mapping[str, Any]
|
179
|
+
]
|
180
|
+
"""
|
181
|
+
Type alias for an object detection datum tuple.
|
182
|
+
|
183
|
+
- :class:`ArrayLike` of shape (C, H, W) - Image data in channel, height, width format.
|
184
|
+
- :class:`ObjectDetectionTarget` - Object detection target information for the image.
|
185
|
+
- dict[str, Any] - Datum level metadata.
|
186
|
+
"""
|
187
|
+
|
188
|
+
|
189
|
+
ObjectDetectionDataset: TypeAlias = AnnotatedDataset[ObjectDetectionDatum]
|
190
|
+
"""
|
191
|
+
Type alias for an :class:`AnnotatedDataset` of :class:`ObjectDetectionDatum` elements.
|
192
|
+
"""
|
193
|
+
|
194
|
+
|
195
|
+
# ========== TRANSFORM ==========
|
196
|
+
|
197
|
+
|
198
|
+
@runtime_checkable
|
199
|
+
class Transform(Generic[_T], Protocol):
|
200
|
+
"""
|
201
|
+
Protocol defining a transform function.
|
202
|
+
|
203
|
+
Requires a `__call__` method that returns transformed data.
|
204
|
+
|
205
|
+
Example
|
206
|
+
-------
|
207
|
+
>>> from typing import Any
|
208
|
+
>>> from numpy.typing import NDArray
|
209
|
+
|
210
|
+
>>> class MyTransform:
|
211
|
+
... def __init__(self, divisor: float) -> None:
|
212
|
+
... self.divisor = divisor
|
213
|
+
...
|
214
|
+
... def __call__(self, data: NDArray[Any], /) -> NDArray[Any]:
|
215
|
+
... return data / self.divisor
|
216
|
+
|
217
|
+
>>> my_transform = MyTransform(divisor=255.0)
|
218
|
+
>>> isinstance(my_transform, Transform)
|
219
|
+
True
|
220
|
+
>>> my_transform(np.array([1, 2, 3]))
|
221
|
+
array([0.004, 0.008, 0.012])
|
222
|
+
"""
|
223
|
+
|
224
|
+
def __call__(self, data: _T, /) -> _T: ...
|
maite_datasets/_types.py
ADDED
@@ -0,0 +1,54 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
__all__ = []
|
4
|
+
|
5
|
+
from dataclasses import dataclass
|
6
|
+
from typing import Generic, TypedDict, TypeVar
|
7
|
+
|
8
|
+
from typing_extensions import NotRequired, Required
|
9
|
+
|
10
|
+
_T_co = TypeVar("_T_co", covariant=True)
|
11
|
+
|
12
|
+
|
13
|
+
class Dataset(Generic[_T_co]):
|
14
|
+
"""Abstract generic base class for PyTorch style Dataset"""
|
15
|
+
|
16
|
+
def __getitem__(self, index: int) -> _T_co: ...
|
17
|
+
def __add__(self, other: Dataset[_T_co]) -> Dataset[_T_co]: ...
|
18
|
+
|
19
|
+
|
20
|
+
class DatasetMetadata(TypedDict):
|
21
|
+
id: Required[str]
|
22
|
+
index2label: NotRequired[dict[int, str]]
|
23
|
+
split: NotRequired[str]
|
24
|
+
|
25
|
+
|
26
|
+
class DatumMetadata(TypedDict, total=False):
|
27
|
+
id: Required[str]
|
28
|
+
|
29
|
+
|
30
|
+
_TDatum = TypeVar("_TDatum")
|
31
|
+
_TArray = TypeVar("_TArray")
|
32
|
+
|
33
|
+
|
34
|
+
class AnnotatedDataset(Dataset[_TDatum]):
|
35
|
+
metadata: DatasetMetadata
|
36
|
+
|
37
|
+
def __len__(self) -> int: ...
|
38
|
+
|
39
|
+
|
40
|
+
class ImageClassificationDataset(
|
41
|
+
AnnotatedDataset[tuple[_TArray, _TArray, DatumMetadata]]
|
42
|
+
): ...
|
43
|
+
|
44
|
+
|
45
|
+
@dataclass
|
46
|
+
class ObjectDetectionTarget(Generic[_TArray]):
|
47
|
+
boxes: _TArray
|
48
|
+
labels: _TArray
|
49
|
+
scores: _TArray
|
50
|
+
|
51
|
+
|
52
|
+
class ObjectDetectionDataset(
|
53
|
+
AnnotatedDataset[tuple[_TArray, ObjectDetectionTarget[_TArray], DatumMetadata]]
|
54
|
+
): ...
|
@@ -0,0 +1,11 @@
|
|
1
|
+
"""Module for MAITE compliant Image Classification datasets."""
|
2
|
+
|
3
|
+
from maite_datasets.image_classification._cifar10 import CIFAR10
|
4
|
+
from maite_datasets.image_classification._mnist import MNIST
|
5
|
+
from maite_datasets.image_classification._ships import Ships
|
6
|
+
|
7
|
+
__all__ = [
|
8
|
+
"CIFAR10",
|
9
|
+
"MNIST",
|
10
|
+
"Ships",
|
11
|
+
]
|
@@ -0,0 +1,233 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
__all__ = []
|
4
|
+
|
5
|
+
from pathlib import Path
|
6
|
+
from typing import Any, Literal, Sequence, TypeVar
|
7
|
+
|
8
|
+
import numpy as np
|
9
|
+
from numpy.typing import NDArray
|
10
|
+
|
11
|
+
from maite_datasets._base import BaseICDataset, DataLocation
|
12
|
+
from maite_datasets._mixin._numpy import BaseDatasetNumpyMixin
|
13
|
+
from maite_datasets._protocols import Transform
|
14
|
+
|
15
|
+
CIFARClassStringMap = Literal[
|
16
|
+
"airplane",
|
17
|
+
"automobile",
|
18
|
+
"bird",
|
19
|
+
"cat",
|
20
|
+
"deer",
|
21
|
+
"dog",
|
22
|
+
"frog",
|
23
|
+
"horse",
|
24
|
+
"ship",
|
25
|
+
"truck",
|
26
|
+
]
|
27
|
+
TCIFARClassMap = TypeVar(
|
28
|
+
"TCIFARClassMap", CIFARClassStringMap, int, list[CIFARClassStringMap], list[int]
|
29
|
+
)
|
30
|
+
|
31
|
+
|
32
|
+
class CIFAR10(BaseICDataset[NDArray[np.number[Any]]], BaseDatasetNumpyMixin):
|
33
|
+
"""
|
34
|
+
`CIFAR10 <https://www.cs.toronto.edu/~kriz/cifar.html>`_ Dataset as NumPy arrays.
|
35
|
+
|
36
|
+
Parameters
|
37
|
+
----------
|
38
|
+
root : str or pathlib.Path
|
39
|
+
Root directory where the data should be downloaded to or the ``cifar10`` folder of the already downloaded data.
|
40
|
+
image_set : "train", "test" or "base", default "train"
|
41
|
+
If "base", returns all of the data to allow the user to create their own splits.
|
42
|
+
transforms : Transform, Sequence[Transform] or None, default None
|
43
|
+
Transform(s) to apply to the data.
|
44
|
+
download : bool, default False
|
45
|
+
If True, downloads the dataset from the internet and puts it in root directory.
|
46
|
+
Class checks to see if data is already downloaded to ensure it does not create a duplicate download.
|
47
|
+
verbose : bool, default False
|
48
|
+
If True, outputs print statements.
|
49
|
+
|
50
|
+
Attributes
|
51
|
+
----------
|
52
|
+
path : pathlib.Path
|
53
|
+
Location of the folder containing the data.
|
54
|
+
image_set : "train", "test" or "base"
|
55
|
+
The selected image set from the dataset.
|
56
|
+
transforms : Sequence[Transform]
|
57
|
+
The transforms to be applied to the data.
|
58
|
+
size : int
|
59
|
+
The size of the dataset.
|
60
|
+
index2label : dict[int, str]
|
61
|
+
Dictionary which translates from class integers to the associated class strings.
|
62
|
+
label2index : dict[str, int]
|
63
|
+
Dictionary which translates from class strings to the associated class integers.
|
64
|
+
metadata : DatasetMetadata
|
65
|
+
Typed dictionary containing dataset metadata, such as `id` which returns the dataset class name.
|
66
|
+
"""
|
67
|
+
|
68
|
+
_resources = [
|
69
|
+
DataLocation(
|
70
|
+
url="https://www.cs.toronto.edu/~kriz/cifar-10-binary.tar.gz",
|
71
|
+
filename="cifar-10-binary.tar.gz",
|
72
|
+
md5=True,
|
73
|
+
checksum="c32a1d4ab5d03f1284b67883e8d87530",
|
74
|
+
),
|
75
|
+
]
|
76
|
+
|
77
|
+
index2label: dict[int, str] = {
|
78
|
+
0: "airplane",
|
79
|
+
1: "automobile",
|
80
|
+
2: "bird",
|
81
|
+
3: "cat",
|
82
|
+
4: "deer",
|
83
|
+
5: "dog",
|
84
|
+
6: "frog",
|
85
|
+
7: "horse",
|
86
|
+
8: "ship",
|
87
|
+
9: "truck",
|
88
|
+
}
|
89
|
+
|
90
|
+
def __init__(
|
91
|
+
self,
|
92
|
+
root: str | Path,
|
93
|
+
image_set: Literal["train", "test", "base"] = "train",
|
94
|
+
transforms: Transform[NDArray[np.number[Any]]]
|
95
|
+
| Sequence[Transform[NDArray[np.number[Any]]]]
|
96
|
+
| None = None,
|
97
|
+
download: bool = False,
|
98
|
+
verbose: bool = False,
|
99
|
+
) -> None:
|
100
|
+
super().__init__(
|
101
|
+
root,
|
102
|
+
image_set,
|
103
|
+
transforms,
|
104
|
+
download,
|
105
|
+
verbose,
|
106
|
+
)
|
107
|
+
|
108
|
+
def _load_bin_data(
|
109
|
+
self, data_folder: list[Path]
|
110
|
+
) -> tuple[list[str], list[int], dict[str, Any]]:
|
111
|
+
batch_nums = np.zeros(60000, dtype=np.uint8)
|
112
|
+
all_labels = np.zeros(60000, dtype=np.uint8)
|
113
|
+
all_images = np.zeros((60000, 3, 32, 32), dtype=np.uint8)
|
114
|
+
# Process each batch file, skipping .meta and .html files
|
115
|
+
for batch_file in data_folder:
|
116
|
+
# Get batch parameters
|
117
|
+
batch_type = "test" if "test" in batch_file.stem else "train"
|
118
|
+
batch_num = (
|
119
|
+
5 if batch_type == "test" else int(batch_file.stem.split("_")[-1]) - 1
|
120
|
+
)
|
121
|
+
|
122
|
+
# Load data
|
123
|
+
batch_images, batch_labels = self._unpack_batch_files(batch_file)
|
124
|
+
|
125
|
+
# Stack data
|
126
|
+
num_images = batch_images.shape[0]
|
127
|
+
batch_start = batch_num * num_images
|
128
|
+
all_images[batch_start : batch_start + num_images] = batch_images
|
129
|
+
all_labels[batch_start : batch_start + num_images] = batch_labels
|
130
|
+
batch_nums[batch_start : batch_start + num_images] = batch_num
|
131
|
+
|
132
|
+
# Save data
|
133
|
+
self._loaded_data = all_images
|
134
|
+
np.savez(
|
135
|
+
self.path / "cifar10",
|
136
|
+
images=self._loaded_data,
|
137
|
+
labels=all_labels,
|
138
|
+
batches=batch_nums,
|
139
|
+
)
|
140
|
+
|
141
|
+
# Select data
|
142
|
+
image_list = np.arange(all_labels.shape[0]).astype(str)
|
143
|
+
if self.image_set == "train":
|
144
|
+
return (
|
145
|
+
image_list[np.nonzero(batch_nums != 5)[0]].tolist(),
|
146
|
+
all_labels[batch_nums != 5].tolist(),
|
147
|
+
{"batch_num": batch_nums[batch_nums != 5].tolist()},
|
148
|
+
)
|
149
|
+
if self.image_set == "test":
|
150
|
+
return (
|
151
|
+
image_list[np.nonzero(batch_nums == 5)[0]].tolist(),
|
152
|
+
all_labels[batch_nums == 5].tolist(),
|
153
|
+
{"batch_num": batch_nums[batch_nums == 5].tolist()},
|
154
|
+
)
|
155
|
+
return (
|
156
|
+
image_list.tolist(),
|
157
|
+
all_labels.tolist(),
|
158
|
+
{"batch_num": batch_nums.tolist()},
|
159
|
+
)
|
160
|
+
|
161
|
+
def _load_data_inner(self) -> tuple[list[str], list[int], dict[str, Any]]:
|
162
|
+
"""Function to load in the file paths for the data and labels and retrieve metadata"""
|
163
|
+
data_file = self.path / "cifar10.npz"
|
164
|
+
if not data_file.exists():
|
165
|
+
data_folder = sorted((self.path / "cifar-10-batches-bin").glob("*.bin"))
|
166
|
+
if not data_folder:
|
167
|
+
raise FileNotFoundError
|
168
|
+
return self._load_bin_data(data_folder)
|
169
|
+
|
170
|
+
# Load data
|
171
|
+
data = np.load(data_file)
|
172
|
+
self._loaded_data = data["images"]
|
173
|
+
all_labels = data["labels"]
|
174
|
+
batch_nums = data["batches"]
|
175
|
+
|
176
|
+
# Select data
|
177
|
+
image_list = np.arange(all_labels.shape[0]).astype(str)
|
178
|
+
if self.image_set == "train":
|
179
|
+
return (
|
180
|
+
image_list[np.nonzero(batch_nums != 5)[0]].tolist(),
|
181
|
+
all_labels[batch_nums != 5].tolist(),
|
182
|
+
{"batch_num": batch_nums[batch_nums != 5].tolist()},
|
183
|
+
)
|
184
|
+
if self.image_set == "test":
|
185
|
+
return (
|
186
|
+
image_list[np.nonzero(batch_nums == 5)[0]].tolist(),
|
187
|
+
all_labels[batch_nums == 5].tolist(),
|
188
|
+
{"batch_num": batch_nums[batch_nums == 5].tolist()},
|
189
|
+
)
|
190
|
+
return (
|
191
|
+
image_list.tolist(),
|
192
|
+
all_labels.tolist(),
|
193
|
+
{"batch_num": batch_nums.tolist()},
|
194
|
+
)
|
195
|
+
|
196
|
+
def _unpack_batch_files(
|
197
|
+
self, file_path: Path
|
198
|
+
) -> tuple[NDArray[np.uint8], NDArray[np.uint8]]:
|
199
|
+
# Load pickle data with latin1 encoding
|
200
|
+
with file_path.open("rb") as f:
|
201
|
+
buffer = np.frombuffer(f.read(), dtype=np.uint8)
|
202
|
+
# Each entry is 1 byte for label + 3072 bytes for image (3*32*32)
|
203
|
+
entry_size = 1 + 3072
|
204
|
+
num_entries = buffer.size // entry_size
|
205
|
+
# Extract labels (first byte of each entry)
|
206
|
+
labels = buffer[::entry_size]
|
207
|
+
|
208
|
+
# Extract image data and reshape to (N, 3, 32, 32)
|
209
|
+
images = np.zeros((num_entries, 3, 32, 32), dtype=np.uint8)
|
210
|
+
for i in range(num_entries):
|
211
|
+
# Skip the label byte and get image data for this entry
|
212
|
+
start_idx = i * entry_size + 1 # +1 to skip label
|
213
|
+
img_flat = buffer[start_idx : start_idx + 3072]
|
214
|
+
|
215
|
+
# The CIFAR format stores channels in blocks (all R, then all G, then all B)
|
216
|
+
# Each channel block is 1024 bytes (32x32)
|
217
|
+
red_channel = img_flat[0:1024].reshape(32, 32)
|
218
|
+
green_channel = img_flat[1024:2048].reshape(32, 32)
|
219
|
+
blue_channel = img_flat[2048:3072].reshape(32, 32)
|
220
|
+
|
221
|
+
# Stack the channels in the proper C×H×W format
|
222
|
+
images[i, 0] = red_channel # Red channel
|
223
|
+
images[i, 1] = green_channel # Green channel
|
224
|
+
images[i, 2] = blue_channel # Blue channel
|
225
|
+
return images, labels
|
226
|
+
|
227
|
+
def _read_file(self, path: str) -> NDArray[np.number[Any]]:
|
228
|
+
"""
|
229
|
+
Function to grab the correct image from the loaded data.
|
230
|
+
Overwrite of the base `_read_file` because data is an all or nothing load.
|
231
|
+
"""
|
232
|
+
index = int(path)
|
233
|
+
return self._loaded_data[index]
|