maite-datasets 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,224 @@
1
+ """
2
+ Common type protocols used for interoperability with MAITE.
3
+ """
4
+
5
+ import sys
6
+ from typing import (
7
+ Any,
8
+ Generic,
9
+ Iterator,
10
+ Mapping,
11
+ Protocol,
12
+ TypedDict,
13
+ TypeVar,
14
+ runtime_checkable,
15
+ )
16
+
17
+ import numpy.typing
18
+ from typing_extensions import NotRequired, ReadOnly, Required
19
+
20
+ if sys.version_info >= (3, 10):
21
+ from typing import TypeAlias
22
+ else:
23
+ from typing_extensions import TypeAlias
24
+
25
+
26
+ ArrayLike: TypeAlias = numpy.typing.ArrayLike
27
+ """
28
+ Type alias for a `Union` representing objects that can be coerced into an array.
29
+
30
+ See Also
31
+ --------
32
+ `NumPy ArrayLike <https://numpy.org/doc/stable/reference/typing.html#numpy.typing.ArrayLike>`_
33
+ """
34
+
35
+
36
+ @runtime_checkable
37
+ class Array(Protocol):
38
+ """
39
+ Protocol for array objects providing interoperability with DataEval.
40
+
41
+ Supports common array representations with popular libraries like
42
+ PyTorch, Tensorflow and JAX, as well as NumPy arrays.
43
+
44
+ Example
45
+ -------
46
+ >>> import numpy as np
47
+ >>> import torch
48
+ >>> from maite_datasets._typing import Array
49
+
50
+ Create array objects
51
+
52
+ >>> ndarray = np.random.random((10, 10))
53
+ >>> tensor = torch.tensor([1, 2, 3])
54
+
55
+ Check type at runtime
56
+
57
+ >>> isinstance(ndarray, Array)
58
+ True
59
+
60
+ >>> isinstance(tensor, Array)
61
+ True
62
+ """
63
+
64
+ @property
65
+ def shape(self) -> tuple[int, ...]: ...
66
+ def __array__(self) -> Any: ...
67
+ def __getitem__(self, key: Any, /) -> Any: ...
68
+ def __iter__(self) -> Iterator[Any]: ...
69
+ def __len__(self) -> int: ...
70
+
71
+
72
+ _T = TypeVar("_T")
73
+ _T_co = TypeVar("_T_co", covariant=True)
74
+
75
+
76
+ class DatasetMetadata(TypedDict, total=False):
77
+ """
78
+ Dataset level metadata required for all `AnnotatedDataset` classes.
79
+
80
+ Attributes
81
+ ----------
82
+ id : Required[str]
83
+ A unique identifier for the dataset
84
+ index2label : NotRequired[dict[int, str]]
85
+ A lookup table converting label value to class name
86
+ """
87
+
88
+ id: Required[ReadOnly[str]]
89
+ index2label: NotRequired[ReadOnly[dict[int, str]]]
90
+
91
+
92
+ @runtime_checkable
93
+ class Dataset(Generic[_T_co], Protocol):
94
+ """
95
+ Protocol for a generic `Dataset`.
96
+
97
+ Methods
98
+ -------
99
+ __getitem__(index: int)
100
+ Returns datum at specified index.
101
+ __len__()
102
+ Returns dataset length.
103
+ """
104
+
105
+ def __getitem__(self, index: int, /) -> _T_co: ...
106
+ def __len__(self) -> int: ...
107
+
108
+
109
+ @runtime_checkable
110
+ class AnnotatedDataset(Dataset[_T_co], Generic[_T_co], Protocol):
111
+ """
112
+ Protocol for a generic `AnnotatedDataset`.
113
+
114
+ Attributes
115
+ ----------
116
+ metadata : :class:`.DatasetMetadata` or derivatives.
117
+
118
+ Methods
119
+ -------
120
+ __getitem__(index: int)
121
+ Returns datum at specified index.
122
+ __len__()
123
+ Returns dataset length.
124
+
125
+ Notes
126
+ -----
127
+ Inherits from :class:`.Dataset`.
128
+ """
129
+
130
+ @property
131
+ def metadata(self) -> DatasetMetadata: ...
132
+
133
+
134
+ # ========== IMAGE CLASSIFICATION DATASETS ==========
135
+
136
+
137
+ ImageClassificationDatum: TypeAlias = tuple[ArrayLike, ArrayLike, Mapping[str, Any]]
138
+ """
139
+ Type alias for an image classification datum tuple.
140
+
141
+ - :class:`ArrayLike` of shape (C, H, W) - Image data in channel, height, width format.
142
+ - :class:`ArrayLike` of shape (N,) - Class label as one-hot encoded ground-truth or prediction confidences.
143
+ - dict[str, Any] - Datum level metadata.
144
+ """
145
+
146
+
147
+ ImageClassificationDataset: TypeAlias = AnnotatedDataset[ImageClassificationDatum]
148
+ """
149
+ Type alias for an :class:`AnnotatedDataset` of :class:`ImageClassificationDatum` elements.
150
+ """
151
+
152
+ # ========== OBJECT DETECTION DATASETS ==========
153
+
154
+
155
+ @runtime_checkable
156
+ class ObjectDetectionTarget(Protocol):
157
+ """
158
+ Protocol for targets in an Object Detection dataset.
159
+
160
+ Attributes
161
+ ----------
162
+ boxes : :class:`ArrayLike` of shape (N, 4)
163
+ labels : :class:`ArrayLike` of shape (N,)
164
+ scores : :class:`ArrayLike` of shape (N, M)
165
+ """
166
+
167
+ @property
168
+ def boxes(self) -> ArrayLike: ...
169
+
170
+ @property
171
+ def labels(self) -> ArrayLike: ...
172
+
173
+ @property
174
+ def scores(self) -> ArrayLike: ...
175
+
176
+
177
+ ObjectDetectionDatum: TypeAlias = tuple[
178
+ ArrayLike, ObjectDetectionTarget, Mapping[str, Any]
179
+ ]
180
+ """
181
+ Type alias for an object detection datum tuple.
182
+
183
+ - :class:`ArrayLike` of shape (C, H, W) - Image data in channel, height, width format.
184
+ - :class:`ObjectDetectionTarget` - Object detection target information for the image.
185
+ - dict[str, Any] - Datum level metadata.
186
+ """
187
+
188
+
189
+ ObjectDetectionDataset: TypeAlias = AnnotatedDataset[ObjectDetectionDatum]
190
+ """
191
+ Type alias for an :class:`AnnotatedDataset` of :class:`ObjectDetectionDatum` elements.
192
+ """
193
+
194
+
195
+ # ========== TRANSFORM ==========
196
+
197
+
198
+ @runtime_checkable
199
+ class Transform(Generic[_T], Protocol):
200
+ """
201
+ Protocol defining a transform function.
202
+
203
+ Requires a `__call__` method that returns transformed data.
204
+
205
+ Example
206
+ -------
207
+ >>> from typing import Any
208
+ >>> from numpy.typing import NDArray
209
+
210
+ >>> class MyTransform:
211
+ ... def __init__(self, divisor: float) -> None:
212
+ ... self.divisor = divisor
213
+ ...
214
+ ... def __call__(self, data: NDArray[Any], /) -> NDArray[Any]:
215
+ ... return data / self.divisor
216
+
217
+ >>> my_transform = MyTransform(divisor=255.0)
218
+ >>> isinstance(my_transform, Transform)
219
+ True
220
+ >>> my_transform(np.array([1, 2, 3]))
221
+ array([0.004, 0.008, 0.012])
222
+ """
223
+
224
+ def __call__(self, data: _T, /) -> _T: ...
@@ -0,0 +1,54 @@
1
+ from __future__ import annotations
2
+
3
+ __all__ = []
4
+
5
+ from dataclasses import dataclass
6
+ from typing import Generic, TypedDict, TypeVar
7
+
8
+ from typing_extensions import NotRequired, Required
9
+
10
+ _T_co = TypeVar("_T_co", covariant=True)
11
+
12
+
13
+ class Dataset(Generic[_T_co]):
14
+ """Abstract generic base class for PyTorch style Dataset"""
15
+
16
+ def __getitem__(self, index: int) -> _T_co: ...
17
+ def __add__(self, other: Dataset[_T_co]) -> Dataset[_T_co]: ...
18
+
19
+
20
+ class DatasetMetadata(TypedDict):
21
+ id: Required[str]
22
+ index2label: NotRequired[dict[int, str]]
23
+ split: NotRequired[str]
24
+
25
+
26
+ class DatumMetadata(TypedDict, total=False):
27
+ id: Required[str]
28
+
29
+
30
+ _TDatum = TypeVar("_TDatum")
31
+ _TArray = TypeVar("_TArray")
32
+
33
+
34
+ class AnnotatedDataset(Dataset[_TDatum]):
35
+ metadata: DatasetMetadata
36
+
37
+ def __len__(self) -> int: ...
38
+
39
+
40
+ class ImageClassificationDataset(
41
+ AnnotatedDataset[tuple[_TArray, _TArray, DatumMetadata]]
42
+ ): ...
43
+
44
+
45
+ @dataclass
46
+ class ObjectDetectionTarget(Generic[_TArray]):
47
+ boxes: _TArray
48
+ labels: _TArray
49
+ scores: _TArray
50
+
51
+
52
+ class ObjectDetectionDataset(
53
+ AnnotatedDataset[tuple[_TArray, ObjectDetectionTarget[_TArray], DatumMetadata]]
54
+ ): ...
@@ -0,0 +1,11 @@
1
+ """Module for MAITE compliant Image Classification datasets."""
2
+
3
+ from maite_datasets.image_classification._cifar10 import CIFAR10
4
+ from maite_datasets.image_classification._mnist import MNIST
5
+ from maite_datasets.image_classification._ships import Ships
6
+
7
+ __all__ = [
8
+ "CIFAR10",
9
+ "MNIST",
10
+ "Ships",
11
+ ]
@@ -0,0 +1,233 @@
1
+ from __future__ import annotations
2
+
3
+ __all__ = []
4
+
5
+ from pathlib import Path
6
+ from typing import Any, Literal, Sequence, TypeVar
7
+
8
+ import numpy as np
9
+ from numpy.typing import NDArray
10
+
11
+ from maite_datasets._base import BaseICDataset, DataLocation
12
+ from maite_datasets._mixin._numpy import BaseDatasetNumpyMixin
13
+ from maite_datasets._protocols import Transform
14
+
15
+ CIFARClassStringMap = Literal[
16
+ "airplane",
17
+ "automobile",
18
+ "bird",
19
+ "cat",
20
+ "deer",
21
+ "dog",
22
+ "frog",
23
+ "horse",
24
+ "ship",
25
+ "truck",
26
+ ]
27
+ TCIFARClassMap = TypeVar(
28
+ "TCIFARClassMap", CIFARClassStringMap, int, list[CIFARClassStringMap], list[int]
29
+ )
30
+
31
+
32
+ class CIFAR10(BaseICDataset[NDArray[np.number[Any]]], BaseDatasetNumpyMixin):
33
+ """
34
+ `CIFAR10 <https://www.cs.toronto.edu/~kriz/cifar.html>`_ Dataset as NumPy arrays.
35
+
36
+ Parameters
37
+ ----------
38
+ root : str or pathlib.Path
39
+ Root directory where the data should be downloaded to or the ``cifar10`` folder of the already downloaded data.
40
+ image_set : "train", "test" or "base", default "train"
41
+ If "base", returns all of the data to allow the user to create their own splits.
42
+ transforms : Transform, Sequence[Transform] or None, default None
43
+ Transform(s) to apply to the data.
44
+ download : bool, default False
45
+ If True, downloads the dataset from the internet and puts it in root directory.
46
+ Class checks to see if data is already downloaded to ensure it does not create a duplicate download.
47
+ verbose : bool, default False
48
+ If True, outputs print statements.
49
+
50
+ Attributes
51
+ ----------
52
+ path : pathlib.Path
53
+ Location of the folder containing the data.
54
+ image_set : "train", "test" or "base"
55
+ The selected image set from the dataset.
56
+ transforms : Sequence[Transform]
57
+ The transforms to be applied to the data.
58
+ size : int
59
+ The size of the dataset.
60
+ index2label : dict[int, str]
61
+ Dictionary which translates from class integers to the associated class strings.
62
+ label2index : dict[str, int]
63
+ Dictionary which translates from class strings to the associated class integers.
64
+ metadata : DatasetMetadata
65
+ Typed dictionary containing dataset metadata, such as `id` which returns the dataset class name.
66
+ """
67
+
68
+ _resources = [
69
+ DataLocation(
70
+ url="https://www.cs.toronto.edu/~kriz/cifar-10-binary.tar.gz",
71
+ filename="cifar-10-binary.tar.gz",
72
+ md5=True,
73
+ checksum="c32a1d4ab5d03f1284b67883e8d87530",
74
+ ),
75
+ ]
76
+
77
+ index2label: dict[int, str] = {
78
+ 0: "airplane",
79
+ 1: "automobile",
80
+ 2: "bird",
81
+ 3: "cat",
82
+ 4: "deer",
83
+ 5: "dog",
84
+ 6: "frog",
85
+ 7: "horse",
86
+ 8: "ship",
87
+ 9: "truck",
88
+ }
89
+
90
+ def __init__(
91
+ self,
92
+ root: str | Path,
93
+ image_set: Literal["train", "test", "base"] = "train",
94
+ transforms: Transform[NDArray[np.number[Any]]]
95
+ | Sequence[Transform[NDArray[np.number[Any]]]]
96
+ | None = None,
97
+ download: bool = False,
98
+ verbose: bool = False,
99
+ ) -> None:
100
+ super().__init__(
101
+ root,
102
+ image_set,
103
+ transforms,
104
+ download,
105
+ verbose,
106
+ )
107
+
108
+ def _load_bin_data(
109
+ self, data_folder: list[Path]
110
+ ) -> tuple[list[str], list[int], dict[str, Any]]:
111
+ batch_nums = np.zeros(60000, dtype=np.uint8)
112
+ all_labels = np.zeros(60000, dtype=np.uint8)
113
+ all_images = np.zeros((60000, 3, 32, 32), dtype=np.uint8)
114
+ # Process each batch file, skipping .meta and .html files
115
+ for batch_file in data_folder:
116
+ # Get batch parameters
117
+ batch_type = "test" if "test" in batch_file.stem else "train"
118
+ batch_num = (
119
+ 5 if batch_type == "test" else int(batch_file.stem.split("_")[-1]) - 1
120
+ )
121
+
122
+ # Load data
123
+ batch_images, batch_labels = self._unpack_batch_files(batch_file)
124
+
125
+ # Stack data
126
+ num_images = batch_images.shape[0]
127
+ batch_start = batch_num * num_images
128
+ all_images[batch_start : batch_start + num_images] = batch_images
129
+ all_labels[batch_start : batch_start + num_images] = batch_labels
130
+ batch_nums[batch_start : batch_start + num_images] = batch_num
131
+
132
+ # Save data
133
+ self._loaded_data = all_images
134
+ np.savez(
135
+ self.path / "cifar10",
136
+ images=self._loaded_data,
137
+ labels=all_labels,
138
+ batches=batch_nums,
139
+ )
140
+
141
+ # Select data
142
+ image_list = np.arange(all_labels.shape[0]).astype(str)
143
+ if self.image_set == "train":
144
+ return (
145
+ image_list[np.nonzero(batch_nums != 5)[0]].tolist(),
146
+ all_labels[batch_nums != 5].tolist(),
147
+ {"batch_num": batch_nums[batch_nums != 5].tolist()},
148
+ )
149
+ if self.image_set == "test":
150
+ return (
151
+ image_list[np.nonzero(batch_nums == 5)[0]].tolist(),
152
+ all_labels[batch_nums == 5].tolist(),
153
+ {"batch_num": batch_nums[batch_nums == 5].tolist()},
154
+ )
155
+ return (
156
+ image_list.tolist(),
157
+ all_labels.tolist(),
158
+ {"batch_num": batch_nums.tolist()},
159
+ )
160
+
161
+ def _load_data_inner(self) -> tuple[list[str], list[int], dict[str, Any]]:
162
+ """Function to load in the file paths for the data and labels and retrieve metadata"""
163
+ data_file = self.path / "cifar10.npz"
164
+ if not data_file.exists():
165
+ data_folder = sorted((self.path / "cifar-10-batches-bin").glob("*.bin"))
166
+ if not data_folder:
167
+ raise FileNotFoundError
168
+ return self._load_bin_data(data_folder)
169
+
170
+ # Load data
171
+ data = np.load(data_file)
172
+ self._loaded_data = data["images"]
173
+ all_labels = data["labels"]
174
+ batch_nums = data["batches"]
175
+
176
+ # Select data
177
+ image_list = np.arange(all_labels.shape[0]).astype(str)
178
+ if self.image_set == "train":
179
+ return (
180
+ image_list[np.nonzero(batch_nums != 5)[0]].tolist(),
181
+ all_labels[batch_nums != 5].tolist(),
182
+ {"batch_num": batch_nums[batch_nums != 5].tolist()},
183
+ )
184
+ if self.image_set == "test":
185
+ return (
186
+ image_list[np.nonzero(batch_nums == 5)[0]].tolist(),
187
+ all_labels[batch_nums == 5].tolist(),
188
+ {"batch_num": batch_nums[batch_nums == 5].tolist()},
189
+ )
190
+ return (
191
+ image_list.tolist(),
192
+ all_labels.tolist(),
193
+ {"batch_num": batch_nums.tolist()},
194
+ )
195
+
196
+ def _unpack_batch_files(
197
+ self, file_path: Path
198
+ ) -> tuple[NDArray[np.uint8], NDArray[np.uint8]]:
199
+ # Load pickle data with latin1 encoding
200
+ with file_path.open("rb") as f:
201
+ buffer = np.frombuffer(f.read(), dtype=np.uint8)
202
+ # Each entry is 1 byte for label + 3072 bytes for image (3*32*32)
203
+ entry_size = 1 + 3072
204
+ num_entries = buffer.size // entry_size
205
+ # Extract labels (first byte of each entry)
206
+ labels = buffer[::entry_size]
207
+
208
+ # Extract image data and reshape to (N, 3, 32, 32)
209
+ images = np.zeros((num_entries, 3, 32, 32), dtype=np.uint8)
210
+ for i in range(num_entries):
211
+ # Skip the label byte and get image data for this entry
212
+ start_idx = i * entry_size + 1 # +1 to skip label
213
+ img_flat = buffer[start_idx : start_idx + 3072]
214
+
215
+ # The CIFAR format stores channels in blocks (all R, then all G, then all B)
216
+ # Each channel block is 1024 bytes (32x32)
217
+ red_channel = img_flat[0:1024].reshape(32, 32)
218
+ green_channel = img_flat[1024:2048].reshape(32, 32)
219
+ blue_channel = img_flat[2048:3072].reshape(32, 32)
220
+
221
+ # Stack the channels in the proper C×H×W format
222
+ images[i, 0] = red_channel # Red channel
223
+ images[i, 1] = green_channel # Green channel
224
+ images[i, 2] = blue_channel # Blue channel
225
+ return images, labels
226
+
227
+ def _read_file(self, path: str) -> NDArray[np.number[Any]]:
228
+ """
229
+ Function to grab the correct image from the loaded data.
230
+ Overwrite of the base `_read_file` because data is an all or nothing load.
231
+ """
232
+ index = int(path)
233
+ return self._loaded_data[index]