maite-datasets 0.0.5__py3-none-any.whl → 0.0.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (35) hide show
  1. maite_datasets/__init__.py +2 -6
  2. maite_datasets/_base.py +169 -51
  3. maite_datasets/_builder.py +46 -55
  4. maite_datasets/_collate.py +2 -3
  5. maite_datasets/{_reader/_base.py → _reader.py} +62 -36
  6. maite_datasets/_validate.py +4 -2
  7. maite_datasets/adapters/__init__.py +3 -0
  8. maite_datasets/adapters/_huggingface.py +391 -0
  9. maite_datasets/image_classification/_cifar10.py +12 -7
  10. maite_datasets/image_classification/_mnist.py +15 -10
  11. maite_datasets/image_classification/_ships.py +12 -8
  12. maite_datasets/object_detection/__init__.py +4 -7
  13. maite_datasets/object_detection/_antiuav.py +11 -8
  14. maite_datasets/{_reader → object_detection}/_coco.py +29 -27
  15. maite_datasets/object_detection/_milco.py +11 -9
  16. maite_datasets/object_detection/_seadrone.py +11 -9
  17. maite_datasets/object_detection/_voc.py +11 -13
  18. maite_datasets/{_reader → object_detection}/_yolo.py +26 -21
  19. maite_datasets/protocols.py +94 -0
  20. maite_datasets/wrappers/__init__.py +8 -0
  21. maite_datasets/wrappers/_torch.py +109 -0
  22. maite_datasets-0.0.7.dist-info/METADATA +181 -0
  23. maite_datasets-0.0.7.dist-info/RECORD +28 -0
  24. maite_datasets/_mixin/__init__.py +0 -0
  25. maite_datasets/_mixin/_numpy.py +0 -28
  26. maite_datasets/_mixin/_torch.py +0 -28
  27. maite_datasets/_protocols.py +0 -217
  28. maite_datasets/_reader/__init__.py +0 -6
  29. maite_datasets/_reader/_factory.py +0 -64
  30. maite_datasets/_types.py +0 -50
  31. maite_datasets/object_detection/_voc_torch.py +0 -65
  32. maite_datasets-0.0.5.dist-info/METADATA +0 -91
  33. maite_datasets-0.0.5.dist-info/RECORD +0 -31
  34. {maite_datasets-0.0.5.dist-info → maite_datasets-0.0.7.dist-info}/WHEEL +0 -0
  35. {maite_datasets-0.0.5.dist-info → maite_datasets-0.0.7.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,3 @@
1
+ from ._huggingface import HFImageClassificationDataset, HFObjectDetectionDataset, from_huggingface
2
+
3
+ __all__ = ["HFImageClassificationDataset", "HFObjectDetectionDataset", "from_huggingface"]
@@ -0,0 +1,391 @@
1
+ from __future__ import annotations
2
+
3
+ from collections.abc import Mapping
4
+ from dataclasses import dataclass
5
+ from functools import lru_cache
6
+ from typing import Any, Literal, TypeAlias, overload
7
+
8
+ import maite.protocols.image_classification as ic
9
+ import maite.protocols.object_detection as od
10
+ import numpy as np
11
+ from maite.protocols import DatasetMetadata, DatumMetadata
12
+
13
+ from maite_datasets._base import BaseDataset, NumpyArray, ObjectDetectionTarget
14
+ from maite_datasets.protocols import HFArray, HFClassLabel, HFDataset, HFImage, HFList, HFValue
15
+ from maite_datasets.wrappers._torch import TTarget
16
+
17
+ # Constants for image processing
18
+ MAX_VALID_CHANNELS = 10
19
+
20
+ FeatureDict: TypeAlias = Mapping[str, Any]
21
+
22
+
23
+ @dataclass
24
+ class HFDatasetInfo:
25
+ image_key: str
26
+
27
+
28
+ @dataclass
29
+ class HFImageClassificationDatasetInfo(HFDatasetInfo):
30
+ label_key: str
31
+
32
+
33
+ @dataclass
34
+ class HFObjectDetectionDatasetInfo(HFDatasetInfo):
35
+ objects_key: str
36
+ bbox_key: str
37
+ label_key: str
38
+
39
+
40
+ class HFBaseDataset(BaseDataset[NumpyArray, TTarget]):
41
+ """Base wrapper for Hugging Face datasets, handling common logic."""
42
+
43
+ def __init__(self, hf_dataset: HFDataset, image_key: str, known_keys: set[str]) -> None:
44
+ self.source = hf_dataset
45
+ self._image_key = image_key
46
+
47
+ # Add dataset metadata
48
+ dataset_info_dict = hf_dataset.info.__dict__
49
+ if "id" in dataset_info_dict:
50
+ dataset_info_dict["datasetinfo_id"] = dataset_info_dict.pop("id")
51
+ self._metadata_id = dataset_info_dict["dataset_name"]
52
+ self._metadata_dict = dataset_info_dict
53
+
54
+ # Pre-validate features and cache metadata keys
55
+ self._validate_features(hf_dataset.features)
56
+ self._scalar_meta_keys = self._extract_scalar_meta_keys(hf_dataset.features, known_keys)
57
+
58
+ # Cache for image conversions
59
+ self._image_cache: dict[int, np.ndarray] = {}
60
+
61
+ def _validate_features(self, features: FeatureDict) -> None:
62
+ """Pre-validate all features during initialization."""
63
+ if self._image_key not in features:
64
+ raise ValueError(f"Image key '{self._image_key}' not found in dataset features.")
65
+
66
+ if not isinstance(features[self._image_key], (HFImage, HFArray)):
67
+ raise TypeError(f"Image feature '{self._image_key}' must be HFImage or HFArray.")
68
+
69
+ def _extract_scalar_meta_keys(self, features: FeatureDict, known_keys: set[str]) -> list[str]:
70
+ """Extract scalar metadata keys during initialization."""
71
+ return [key for key, feature in features.items() if key not in known_keys and isinstance(feature, HFValue)]
72
+
73
+ def __len__(self) -> int:
74
+ return len(self.source)
75
+
76
+ def _get_base_metadata(self, index: int) -> DatumMetadata:
77
+ """Extract base metadata for a datum."""
78
+ item = self.source[index]
79
+ datum_metadata: DatumMetadata = {"id": index}
80
+ for key in self._scalar_meta_keys:
81
+ datum_metadata[key] = item[key]
82
+ return datum_metadata
83
+
84
+ @lru_cache(maxsize=64) # Cache image conversions
85
+ def _get_image(self, index: int) -> np.ndarray:
86
+ """Get and process image with caching and optimized conversions."""
87
+ # Convert to numpy array only once
88
+ raw_image = self.source[index][self._image_key]
89
+ image = np.asarray(raw_image)
90
+
91
+ # Handle different image formats efficiently
92
+ if image.ndim == 2:
93
+ # Grayscale: HW -> CHW
94
+ image = image[np.newaxis, ...] # More efficient than expand_dims
95
+ elif image.ndim == 3:
96
+ # Check if we need to transpose from HWC to CHW
97
+ if image.shape[-1] < image.shape[-3] and image.shape[-1] <= MAX_VALID_CHANNELS:
98
+ # HWC -> CHW using optimized transpose
99
+ image = np.transpose(image, (2, 0, 1))
100
+ elif image.shape[0] > MAX_VALID_CHANNELS:
101
+ raise ValueError(
102
+ f"Image at index {index} has invalid channel configuration. "
103
+ f"Expected channels to be less than {MAX_VALID_CHANNELS}, got shape {image.shape}"
104
+ )
105
+ else:
106
+ raise ValueError(
107
+ f"Image at index {index} has unsupported dimensionality. "
108
+ f"Expected 2D or 3D, got {image.ndim}D with shape {image.shape}"
109
+ )
110
+
111
+ if image.ndim != 3:
112
+ raise ValueError(f"Image processing failed for index {index}. Final shape: {image.shape}")
113
+
114
+ return image
115
+
116
+
117
+ class HFImageClassificationDataset(HFBaseDataset[NumpyArray], ic.Dataset):
118
+ """Wraps a Hugging Face dataset to comply with the ImageClassificationDataset protocol."""
119
+
120
+ def __init__(self, hf_dataset: HFDataset, image_key: str, label_key: str) -> None:
121
+ super().__init__(hf_dataset, image_key, known_keys={image_key, label_key})
122
+ self._label_key = label_key
123
+
124
+ # Pre-validate label feature
125
+ label_feature = hf_dataset.features[self._label_key]
126
+ if not isinstance(label_feature, HFClassLabel):
127
+ raise TypeError(
128
+ f"Label feature '{self._label_key}' must be a datasets.ClassLabel, got {type(label_feature).__name__}."
129
+ )
130
+
131
+ self._num_classes: int = label_feature.num_classes
132
+
133
+ # Pre-compute one-hot identity matrix for efficient encoding
134
+ self._one_hot_matrix = np.eye(self._num_classes, dtype=np.float32)
135
+
136
+ # Enhanced metadata with validation
137
+ self.metadata: DatasetMetadata = DatasetMetadata(
138
+ id=self._metadata_id, index2label=dict(enumerate(label_feature.names), **self._metadata_dict)
139
+ )
140
+
141
+ def __getitem__(self, index: int) -> tuple[NumpyArray, NumpyArray, DatumMetadata]:
142
+ if not 0 <= index < len(self.source):
143
+ raise IndexError(f"Index {index} out of range for dataset of size {len(self.source)}")
144
+
145
+ # Process image
146
+ image = self._get_image(index)
147
+ label_int = self.source[index][self._label_key]
148
+
149
+ # Process target
150
+ if not 0 <= label_int < self._num_classes:
151
+ raise ValueError(f"Label {label_int} at index {index} is out of range [0, {self._num_classes})")
152
+ one_hot_label = self._one_hot_matrix[label_int]
153
+
154
+ # Process metadata
155
+ datum_metadata = self._get_base_metadata(index)
156
+
157
+ return image, one_hot_label, datum_metadata
158
+
159
+
160
+ class HFObjectDetectionDataset(HFBaseDataset[ObjectDetectionTarget], od.Dataset):
161
+ """Wraps a Hugging Face dataset to comply with the ObjectDetectionDataset protocol."""
162
+
163
+ def __init__(self, hf_dataset: HFDataset, image_key: str, objects_key: str, bbox_key: str, label_key: str) -> None:
164
+ super().__init__(hf_dataset, image_key, known_keys={image_key, objects_key})
165
+ self._objects_key = objects_key
166
+ self._bbox_key = bbox_key
167
+ self._label_key = label_key
168
+
169
+ # Pre-validate and extract object features
170
+ self._object_meta_keys = self._validate_and_extract_object_features(hf_dataset.features)
171
+
172
+ # Validate and extract label information
173
+ label_feature = self._extract_label_feature(hf_dataset.features)
174
+ self.metadata: DatasetMetadata = DatasetMetadata(
175
+ id=self._metadata_id, index2label=dict(enumerate(label_feature.names)), **self._metadata_dict
176
+ )
177
+
178
+ def _validate_and_extract_object_features(self, features: FeatureDict) -> list[str]:
179
+ """Validate objects feature and extract metadata keys."""
180
+ objects_feature = features[self._objects_key]
181
+
182
+ # Determine the structure and get inner features
183
+ if isinstance(objects_feature, HFList): # list(dict) case
184
+ if not isinstance(objects_feature.feature, dict):
185
+ raise TypeError(f"Objects feature '{self._objects_key}' with list type must contain dict features.")
186
+ inner_feature_dict = objects_feature.feature
187
+ elif isinstance(objects_feature, dict): # dict(list) case
188
+ inner_feature_dict = objects_feature
189
+ else:
190
+ raise TypeError(
191
+ f"Objects feature '{self._objects_key}' must be a list or dict, got {type(objects_feature).__name__}."
192
+ )
193
+
194
+ # Validate required keys exist
195
+ required_keys = {self._bbox_key, self._label_key}
196
+ missing_keys = required_keys - set(inner_feature_dict.keys())
197
+ if missing_keys:
198
+ raise ValueError(f"Objects feature '{self._objects_key}' missing required keys: {missing_keys}")
199
+
200
+ # Extract object metadata keys
201
+ known_inner_keys = {self._bbox_key, self._label_key}
202
+ return [
203
+ key
204
+ for key, feature in inner_feature_dict.items()
205
+ if key not in known_inner_keys and isinstance(feature, (HFValue, HFList))
206
+ ]
207
+
208
+ def _extract_label_feature(self, features: FeatureDict) -> HFClassLabel:
209
+ """Extract and validate the label feature."""
210
+ objects_feature = features[self._objects_key]
211
+
212
+ inner_features = objects_feature.feature if isinstance(objects_feature, HFList) else objects_feature
213
+ label_feature_container = inner_features[self._label_key]
214
+ label_feature = (
215
+ label_feature_container.feature
216
+ if isinstance(label_feature_container.feature, HFClassLabel)
217
+ else label_feature_container
218
+ )
219
+
220
+ if not isinstance(label_feature, HFClassLabel):
221
+ raise TypeError(
222
+ f"Label '{self._label_key}' in '{self._objects_key}' must be a ClassLabel, "
223
+ f"got {type(label_feature).__name__}."
224
+ )
225
+
226
+ return label_feature
227
+
228
+ def __getitem__(self, index: int) -> tuple[NumpyArray, ObjectDetectionTarget, DatumMetadata]:
229
+ if not 0 <= index < len(self.source):
230
+ raise IndexError(f"Index {index} out of range for dataset of size {len(self.source)}")
231
+
232
+ # Process image
233
+ image = self._get_image(index)
234
+ objects = self.source[index][self._objects_key]
235
+
236
+ # Process target
237
+ boxes = objects[self._bbox_key]
238
+ labels = objects[self._label_key]
239
+ scores = np.zeros_like(labels, dtype=np.float32)
240
+ target = ObjectDetectionTarget(boxes, labels, scores)
241
+
242
+ # Process metadata
243
+ datum_metadata = self._get_base_metadata(index)
244
+ self._add_object_metadata(objects, datum_metadata)
245
+
246
+ return image, target, datum_metadata
247
+
248
+ def _add_object_metadata(self, objects: dict[str, Any], datum_metadata: DatumMetadata) -> None:
249
+ """Efficiently add object metadata to datum metadata."""
250
+ if not objects[self._bbox_key]: # No objects
251
+ return
252
+
253
+ num_objects = len(objects[self._bbox_key])
254
+
255
+ for key in self._object_meta_keys:
256
+ value = objects[key]
257
+ if isinstance(value, list):
258
+ if len(value) == num_objects:
259
+ datum_metadata[key] = value
260
+ else:
261
+ raise ValueError(
262
+ f"Object metadata '{key}' length {len(value)} doesn't match number of objects {num_objects}"
263
+ )
264
+ else:
265
+ datum_metadata[key] = [value] * num_objects
266
+
267
+
268
+ def is_bbox(feature: Any) -> bool:
269
+ """Check if feature represents bounding box data with proper type validation."""
270
+ if not isinstance(feature, HFList):
271
+ return False
272
+
273
+ # Handle nested list structure
274
+ bbox_candidate = feature.feature if isinstance(feature.feature, HFList) else feature
275
+
276
+ return (
277
+ isinstance(bbox_candidate, HFList)
278
+ and bbox_candidate.length == 4
279
+ and isinstance(bbox_candidate.feature, HFValue)
280
+ and any(dtype in bbox_candidate.feature.dtype for dtype in ["float", "int"])
281
+ )
282
+
283
+
284
+ def is_label(feature: Any) -> bool:
285
+ """Check if feature represents label data with proper type validation."""
286
+ target_feature = feature.feature if isinstance(feature, HFList) else feature
287
+ return isinstance(target_feature, HFClassLabel)
288
+
289
+
290
+ def find_od_keys(feature: Any) -> tuple[str | None, str | None]:
291
+ """Helper to find bbox and label keys for object detection with improved logic."""
292
+ if not ((isinstance(feature, HFList) and isinstance(feature.feature, dict)) or isinstance(feature, dict)):
293
+ return None, None
294
+
295
+ inner_features: FeatureDict = feature.feature if isinstance(feature, HFList) else feature
296
+
297
+ bbox_key = label_key = None
298
+
299
+ for inner_name, inner_feature in inner_features.items():
300
+ if bbox_key is None and is_bbox(inner_feature):
301
+ bbox_key = inner_name
302
+ if label_key is None and is_label(inner_feature):
303
+ label_key = inner_name
304
+
305
+ # Early exit if both found
306
+ if bbox_key and label_key:
307
+ break
308
+
309
+ return bbox_key, label_key
310
+
311
+
312
+ def get_dataset_info(dataset: HFDataset) -> HFDatasetInfo:
313
+ """Extract dataset information with improved validation and error messages."""
314
+ features = dataset.features
315
+ image_key = label_key = objects_key = bbox_key = None
316
+
317
+ # More efficient feature detection
318
+ for name, feature in features.items():
319
+ if image_key is None and isinstance(feature, (HFImage, HFArray)):
320
+ image_key = name
321
+ elif label_key is None and isinstance(feature, HFClassLabel):
322
+ label_key = name
323
+ elif objects_key is None:
324
+ temp_bbox, temp_label = find_od_keys(feature)
325
+ if temp_bbox and temp_label:
326
+ objects_key, bbox_key, label_key = name, temp_bbox, temp_label
327
+
328
+ if not image_key:
329
+ available_features = list(features.keys())
330
+ raise ValueError(
331
+ f"No image key found in dataset. Available features: {available_features}. "
332
+ f"Expected HFImage or HFArray type."
333
+ )
334
+
335
+ # Return appropriate dataset info based on detected features
336
+ if objects_key and bbox_key and label_key:
337
+ return HFObjectDetectionDatasetInfo(image_key, objects_key, bbox_key, label_key)
338
+ if label_key:
339
+ return HFImageClassificationDatasetInfo(image_key, label_key)
340
+ return HFDatasetInfo(image_key)
341
+
342
+
343
+ @overload
344
+ def from_huggingface(dataset: HFDataset, task: Literal["image_classification"]) -> HFImageClassificationDataset: ...
345
+
346
+
347
+ @overload
348
+ def from_huggingface(dataset: HFDataset, task: Literal["object_detection"]) -> HFObjectDetectionDataset: ...
349
+
350
+
351
+ @overload
352
+ def from_huggingface(
353
+ dataset: HFDataset, task: Literal["auto"] = "auto"
354
+ ) -> HFObjectDetectionDataset | HFImageClassificationDataset: ...
355
+
356
+
357
+ def from_huggingface(
358
+ dataset: HFDataset, task: Literal["image_classification", "object_detection", "auto"] = "auto"
359
+ ) -> HFObjectDetectionDataset | HFImageClassificationDataset:
360
+ """Create appropriate dataset wrapper with enhanced error handling."""
361
+ info = get_dataset_info(dataset)
362
+
363
+ if isinstance(info, HFImageClassificationDatasetInfo):
364
+ if task in ("image_classification", "auto"):
365
+ return HFImageClassificationDataset(dataset, info.image_key, info.label_key)
366
+ if task == "object_detection":
367
+ raise ValueError(
368
+ f"Task mismatch: requested 'object_detection' but dataset appears to be "
369
+ f"image classification. Detected features: image='{info.image_key}', "
370
+ f"label='{info.label_key}'"
371
+ )
372
+
373
+ elif isinstance(info, HFObjectDetectionDatasetInfo):
374
+ if task in ("object_detection", "auto"):
375
+ return HFObjectDetectionDataset(dataset, info.image_key, info.objects_key, info.bbox_key, info.label_key)
376
+ if task == "image_classification":
377
+ raise ValueError(
378
+ f"Task mismatch: requested 'image_classification' but dataset appears to be "
379
+ f"object detection. Detected features: image='{info.image_key}', "
380
+ f"objects='{info.objects_key}'"
381
+ )
382
+
383
+ # Enhanced error message for auto-detection failure
384
+ available_features = list(dataset.features.keys())
385
+ feature_types = {k: type(v).__name__ for k, v in dataset.features.items()}
386
+
387
+ raise ValueError(
388
+ f"Could not automatically determine task for requested type '{task}'. "
389
+ f"Detected info: {info}. Available features: {available_features}. "
390
+ f"Feature types: {feature_types}. Ensure dataset has proper image and label/objects features."
391
+ )
@@ -2,15 +2,20 @@ from __future__ import annotations
2
2
 
3
3
  __all__ = []
4
4
 
5
+ from collections.abc import Sequence
5
6
  from pathlib import Path
6
- from typing import Any, Literal, Sequence, TypeVar
7
+ from typing import Any, Literal, TypeVar
7
8
 
8
9
  import numpy as np
9
10
  from numpy.typing import NDArray
10
11
 
11
- from maite_datasets._base import BaseICDataset, DataLocation
12
- from maite_datasets._mixin._numpy import BaseDatasetNumpyMixin
13
- from maite_datasets._protocols import Transform
12
+ from maite_datasets._base import (
13
+ BaseDatasetNumpyMixin,
14
+ BaseICDataset,
15
+ DataLocation,
16
+ NumpyArray,
17
+ NumpyImageClassificationTransform,
18
+ )
14
19
 
15
20
  CIFARClassStringMap = Literal[
16
21
  "airplane",
@@ -27,7 +32,7 @@ CIFARClassStringMap = Literal[
27
32
  TCIFARClassMap = TypeVar("TCIFARClassMap", CIFARClassStringMap, int, list[CIFARClassStringMap], list[int])
28
33
 
29
34
 
30
- class CIFAR10(BaseICDataset[NDArray[np.number[Any]]], BaseDatasetNumpyMixin):
35
+ class CIFAR10(BaseICDataset[NumpyArray], BaseDatasetNumpyMixin):
31
36
  """
32
37
  `CIFAR10 <https://www.cs.toronto.edu/~kriz/cifar.html>`_ Dataset as NumPy arrays.
33
38
 
@@ -89,7 +94,7 @@ class CIFAR10(BaseICDataset[NDArray[np.number[Any]]], BaseDatasetNumpyMixin):
89
94
  self,
90
95
  root: str | Path,
91
96
  image_set: Literal["train", "test", "base"] = "train",
92
- transforms: Transform[NDArray[np.number[Any]]] | Sequence[Transform[NDArray[np.number[Any]]]] | None = None,
97
+ transforms: NumpyImageClassificationTransform | Sequence[NumpyImageClassificationTransform] | None = None,
93
98
  download: bool = False,
94
99
  verbose: bool = False,
95
100
  ) -> None:
@@ -214,7 +219,7 @@ class CIFAR10(BaseICDataset[NDArray[np.number[Any]]], BaseDatasetNumpyMixin):
214
219
  images[i, 2] = blue_channel # Blue channel
215
220
  return images, labels
216
221
 
217
- def _read_file(self, path: str) -> NDArray[np.number[Any]]:
222
+ def _read_file(self, path: str) -> NumpyArray:
218
223
  """
219
224
  Function to grab the correct image from the loaded data.
220
225
  Overwrite of the base `_read_file` because data is an all or nothing load.
@@ -2,15 +2,20 @@ from __future__ import annotations
2
2
 
3
3
  __all__ = []
4
4
 
5
+ from collections.abc import Sequence
5
6
  from pathlib import Path
6
- from typing import Any, Literal, Sequence, TypeVar
7
+ from typing import Any, Literal, TypeVar
7
8
 
8
9
  import numpy as np
9
10
  from numpy.typing import NDArray
10
11
 
11
- from maite_datasets._base import BaseICDataset, DataLocation
12
- from maite_datasets._mixin._numpy import BaseDatasetNumpyMixin
13
- from maite_datasets._protocols import Transform
12
+ from maite_datasets._base import (
13
+ BaseDatasetNumpyMixin,
14
+ BaseICDataset,
15
+ DataLocation,
16
+ NumpyArray,
17
+ NumpyImageClassificationTransform,
18
+ )
14
19
 
15
20
  MNISTClassStringMap = Literal["zero", "one", "two", "three", "four", "five", "six", "seven", "eight", "nine"]
16
21
  TMNISTClassMap = TypeVar("TMNISTClassMap", MNISTClassStringMap, int, list[MNISTClassStringMap], list[int])
@@ -34,7 +39,7 @@ CorruptionStringMap = Literal[
34
39
  ]
35
40
 
36
41
 
37
- class MNIST(BaseICDataset[NDArray[np.number[Any]]], BaseDatasetNumpyMixin):
42
+ class MNIST(BaseICDataset[NumpyArray], BaseDatasetNumpyMixin):
38
43
  """`MNIST <https://en.wikipedia.org/wiki/MNIST_database>`_ Dataset and `Corruptions <https://arxiv.org/abs/1906.02337>`_.
39
44
 
40
45
  There are 15 different styles of corruptions. This class downloads differently depending on if you
@@ -118,7 +123,7 @@ class MNIST(BaseICDataset[NDArray[np.number[Any]]], BaseDatasetNumpyMixin):
118
123
  root: str | Path,
119
124
  image_set: Literal["train", "test", "base"] = "train",
120
125
  corruption: CorruptionStringMap | None = None,
121
- transforms: Transform[NDArray[np.number[Any]]] | Sequence[Transform[NDArray[np.number[Any]]]] | None = None,
126
+ transforms: NumpyImageClassificationTransform | Sequence[NumpyImageClassificationTransform] | None = None,
122
127
  download: bool = False,
123
128
  verbose: bool = False,
124
129
  ) -> None:
@@ -149,7 +154,7 @@ class MNIST(BaseICDataset[NDArray[np.number[Any]]], BaseDatasetNumpyMixin):
149
154
  index_strings = np.arange(self._loaded_data.shape[0]).astype(str).tolist()
150
155
  return index_strings, labels.tolist(), {}
151
156
 
152
- def _load_corruption(self) -> tuple[NDArray[np.number[Any]], NDArray[np.uintp]]:
157
+ def _load_corruption(self) -> tuple[NumpyArray, NDArray[np.uintp]]:
153
158
  """Function to load in the file paths for the data and labels for the different corrupt data formats"""
154
159
  corruption = self.corruption if self.corruption is not None else "identity"
155
160
  base_path = self.path / "mnist_c" / corruption
@@ -176,7 +181,7 @@ class MNIST(BaseICDataset[NDArray[np.number[Any]]], BaseDatasetNumpyMixin):
176
181
 
177
182
  return data, labels
178
183
 
179
- def _grab_data(self, path: Path) -> tuple[NDArray[np.number[Any]], NDArray[np.uintp]]:
184
+ def _grab_data(self, path: Path) -> tuple[NumpyArray, NDArray[np.uintp]]:
180
185
  """Function to load in the data numpy array"""
181
186
  with np.load(path, allow_pickle=True) as data_array:
182
187
  if self.image_set == "base":
@@ -190,11 +195,11 @@ class MNIST(BaseICDataset[NDArray[np.number[Any]]], BaseDatasetNumpyMixin):
190
195
  data = np.expand_dims(data, axis=1)
191
196
  return data, labels
192
197
 
193
- def _grab_corruption_data(self, path: Path) -> NDArray[np.number[Any]]:
198
+ def _grab_corruption_data(self, path: Path) -> NumpyArray:
194
199
  """Function to load in the data numpy array for the previously chosen corrupt format"""
195
200
  return np.load(path, allow_pickle=False)
196
201
 
197
- def _read_file(self, path: str) -> NDArray[np.number[Any]]:
202
+ def _read_file(self, path: str) -> NumpyArray:
198
203
  """
199
204
  Function to grab the correct image from the loaded data.
200
205
  Overwrite of the base `_read_file` because data is an all or nothing load.
@@ -2,18 +2,22 @@ from __future__ import annotations
2
2
 
3
3
  __all__ = []
4
4
 
5
+ from collections.abc import Sequence
5
6
  from pathlib import Path
6
- from typing import Any, Sequence
7
+ from typing import Any
7
8
 
8
9
  import numpy as np
9
- from numpy.typing import NDArray
10
10
 
11
- from maite_datasets._base import BaseICDataset, DataLocation
12
- from maite_datasets._mixin._numpy import BaseDatasetNumpyMixin
13
- from maite_datasets._protocols import Transform
11
+ from maite_datasets._base import (
12
+ BaseDatasetNumpyMixin,
13
+ BaseICDataset,
14
+ DataLocation,
15
+ NumpyArray,
16
+ NumpyImageClassificationTransform,
17
+ )
14
18
 
15
19
 
16
- class Ships(BaseICDataset[NDArray[np.number[Any]]], BaseDatasetNumpyMixin):
20
+ class Ships(BaseICDataset[NumpyArray], BaseDatasetNumpyMixin):
17
21
  """
18
22
  A dataset that focuses on identifying ships from satellite images.
19
23
 
@@ -76,7 +80,7 @@ class Ships(BaseICDataset[NDArray[np.number[Any]]], BaseDatasetNumpyMixin):
76
80
  def __init__(
77
81
  self,
78
82
  root: str | Path,
79
- transforms: Transform[NDArray[np.number[Any]]] | Sequence[Transform[NDArray[np.number[Any]]]] | None = None,
83
+ transforms: NumpyImageClassificationTransform | Sequence[NumpyImageClassificationTransform] | None = None,
80
84
  download: bool = False,
81
85
  verbose: bool = False,
82
86
  ) -> None:
@@ -125,7 +129,7 @@ class Ships(BaseICDataset[NDArray[np.number[Any]]], BaseDatasetNumpyMixin):
125
129
  """Function to load in the file paths for the scene images"""
126
130
  return sorted(str(entry) for entry in (self.path / "scenes").glob("*.png"))
127
131
 
128
- def get_scene(self, index: int) -> NDArray[np.number[Any]]:
132
+ def get_scene(self, index: int) -> NumpyArray:
129
133
  """
130
134
  Get the desired satellite image (scene) by passing in the index of the desired file.
131
135
 
@@ -1,20 +1,17 @@
1
1
  """Module for MAITE compliant Object Detection datasets."""
2
2
 
3
3
  from maite_datasets.object_detection._antiuav import AntiUAVDetection
4
+ from maite_datasets.object_detection._coco import COCODatasetReader
4
5
  from maite_datasets.object_detection._milco import MILCO
5
6
  from maite_datasets.object_detection._seadrone import SeaDrone
6
7
  from maite_datasets.object_detection._voc import VOCDetection
8
+ from maite_datasets.object_detection._yolo import YOLODatasetReader
7
9
 
8
10
  __all__ = [
9
11
  "AntiUAVDetection",
10
12
  "MILCO",
11
13
  "SeaDrone",
12
14
  "VOCDetection",
15
+ "COCODatasetReader",
16
+ "YOLODatasetReader",
13
17
  ]
14
-
15
- import importlib.util
16
-
17
- if importlib.util.find_spec("torch") is not None:
18
- from maite_datasets.object_detection._voc_torch import VOCDetectionTorch
19
-
20
- __all__ += ["VOCDetectionTorch"]
@@ -2,19 +2,22 @@ from __future__ import annotations
2
2
 
3
3
  __all__ = []
4
4
 
5
+ from collections.abc import Sequence
5
6
  from pathlib import Path
6
- from typing import Any, Literal, Sequence
7
+ from typing import Any, Literal
7
8
 
8
- import numpy as np
9
9
  from defusedxml.ElementTree import parse
10
- from numpy.typing import NDArray
11
10
 
12
- from maite_datasets._base import BaseODDataset, DataLocation
13
- from maite_datasets._mixin._numpy import BaseDatasetNumpyMixin
14
- from maite_datasets._protocols import Transform
11
+ from maite_datasets._base import (
12
+ BaseDatasetNumpyMixin,
13
+ BaseODDataset,
14
+ DataLocation,
15
+ NumpyArray,
16
+ NumpyObjectDetectionTransform,
17
+ )
15
18
 
16
19
 
17
- class AntiUAVDetection(BaseODDataset[NDArray[np.number[Any]], list[str], str], BaseDatasetNumpyMixin):
20
+ class AntiUAVDetection(BaseODDataset[NumpyArray, list[str], str], BaseDatasetNumpyMixin):
18
21
  """
19
22
  A UAV detection dataset focused on detecting UAVs in natural images against large variation in backgrounds.
20
23
 
@@ -101,7 +104,7 @@ class AntiUAVDetection(BaseODDataset[NDArray[np.number[Any]], list[str], str], B
101
104
  self,
102
105
  root: str | Path,
103
106
  image_set: Literal["train", "val", "test", "base"] = "train",
104
- transforms: Transform[NDArray[np.number[Any]]] | Sequence[Transform[NDArray[np.number[Any]]]] | None = None,
107
+ transforms: NumpyObjectDetectionTransform | Sequence[NumpyObjectDetectionTransform] | None = None,
105
108
  download: bool = False,
106
109
  verbose: bool = False,
107
110
  ) -> None: