maite-datasets 0.0.5__py3-none-any.whl → 0.0.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- maite_datasets/__init__.py +2 -6
- maite_datasets/_base.py +169 -51
- maite_datasets/_builder.py +46 -55
- maite_datasets/_collate.py +2 -3
- maite_datasets/{_reader/_base.py → _reader.py} +62 -36
- maite_datasets/_validate.py +4 -2
- maite_datasets/adapters/__init__.py +3 -0
- maite_datasets/adapters/_huggingface.py +391 -0
- maite_datasets/image_classification/_cifar10.py +12 -7
- maite_datasets/image_classification/_mnist.py +15 -10
- maite_datasets/image_classification/_ships.py +12 -8
- maite_datasets/object_detection/__init__.py +4 -7
- maite_datasets/object_detection/_antiuav.py +11 -8
- maite_datasets/{_reader → object_detection}/_coco.py +29 -27
- maite_datasets/object_detection/_milco.py +11 -9
- maite_datasets/object_detection/_seadrone.py +11 -9
- maite_datasets/object_detection/_voc.py +11 -13
- maite_datasets/{_reader → object_detection}/_yolo.py +26 -21
- maite_datasets/protocols.py +94 -0
- maite_datasets/wrappers/__init__.py +8 -0
- maite_datasets/wrappers/_torch.py +109 -0
- maite_datasets-0.0.7.dist-info/METADATA +181 -0
- maite_datasets-0.0.7.dist-info/RECORD +28 -0
- maite_datasets/_mixin/__init__.py +0 -0
- maite_datasets/_mixin/_numpy.py +0 -28
- maite_datasets/_mixin/_torch.py +0 -28
- maite_datasets/_protocols.py +0 -217
- maite_datasets/_reader/__init__.py +0 -6
- maite_datasets/_reader/_factory.py +0 -64
- maite_datasets/_types.py +0 -50
- maite_datasets/object_detection/_voc_torch.py +0 -65
- maite_datasets-0.0.5.dist-info/METADATA +0 -91
- maite_datasets-0.0.5.dist-info/RECORD +0 -31
- {maite_datasets-0.0.5.dist-info → maite_datasets-0.0.7.dist-info}/WHEEL +0 -0
- {maite_datasets-0.0.5.dist-info → maite_datasets-0.0.7.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,391 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
from collections.abc import Mapping
|
4
|
+
from dataclasses import dataclass
|
5
|
+
from functools import lru_cache
|
6
|
+
from typing import Any, Literal, TypeAlias, overload
|
7
|
+
|
8
|
+
import maite.protocols.image_classification as ic
|
9
|
+
import maite.protocols.object_detection as od
|
10
|
+
import numpy as np
|
11
|
+
from maite.protocols import DatasetMetadata, DatumMetadata
|
12
|
+
|
13
|
+
from maite_datasets._base import BaseDataset, NumpyArray, ObjectDetectionTarget
|
14
|
+
from maite_datasets.protocols import HFArray, HFClassLabel, HFDataset, HFImage, HFList, HFValue
|
15
|
+
from maite_datasets.wrappers._torch import TTarget
|
16
|
+
|
17
|
+
# Constants for image processing
|
18
|
+
MAX_VALID_CHANNELS = 10
|
19
|
+
|
20
|
+
FeatureDict: TypeAlias = Mapping[str, Any]
|
21
|
+
|
22
|
+
|
23
|
+
@dataclass
|
24
|
+
class HFDatasetInfo:
|
25
|
+
image_key: str
|
26
|
+
|
27
|
+
|
28
|
+
@dataclass
|
29
|
+
class HFImageClassificationDatasetInfo(HFDatasetInfo):
|
30
|
+
label_key: str
|
31
|
+
|
32
|
+
|
33
|
+
@dataclass
|
34
|
+
class HFObjectDetectionDatasetInfo(HFDatasetInfo):
|
35
|
+
objects_key: str
|
36
|
+
bbox_key: str
|
37
|
+
label_key: str
|
38
|
+
|
39
|
+
|
40
|
+
class HFBaseDataset(BaseDataset[NumpyArray, TTarget]):
|
41
|
+
"""Base wrapper for Hugging Face datasets, handling common logic."""
|
42
|
+
|
43
|
+
def __init__(self, hf_dataset: HFDataset, image_key: str, known_keys: set[str]) -> None:
|
44
|
+
self.source = hf_dataset
|
45
|
+
self._image_key = image_key
|
46
|
+
|
47
|
+
# Add dataset metadata
|
48
|
+
dataset_info_dict = hf_dataset.info.__dict__
|
49
|
+
if "id" in dataset_info_dict:
|
50
|
+
dataset_info_dict["datasetinfo_id"] = dataset_info_dict.pop("id")
|
51
|
+
self._metadata_id = dataset_info_dict["dataset_name"]
|
52
|
+
self._metadata_dict = dataset_info_dict
|
53
|
+
|
54
|
+
# Pre-validate features and cache metadata keys
|
55
|
+
self._validate_features(hf_dataset.features)
|
56
|
+
self._scalar_meta_keys = self._extract_scalar_meta_keys(hf_dataset.features, known_keys)
|
57
|
+
|
58
|
+
# Cache for image conversions
|
59
|
+
self._image_cache: dict[int, np.ndarray] = {}
|
60
|
+
|
61
|
+
def _validate_features(self, features: FeatureDict) -> None:
|
62
|
+
"""Pre-validate all features during initialization."""
|
63
|
+
if self._image_key not in features:
|
64
|
+
raise ValueError(f"Image key '{self._image_key}' not found in dataset features.")
|
65
|
+
|
66
|
+
if not isinstance(features[self._image_key], (HFImage, HFArray)):
|
67
|
+
raise TypeError(f"Image feature '{self._image_key}' must be HFImage or HFArray.")
|
68
|
+
|
69
|
+
def _extract_scalar_meta_keys(self, features: FeatureDict, known_keys: set[str]) -> list[str]:
|
70
|
+
"""Extract scalar metadata keys during initialization."""
|
71
|
+
return [key for key, feature in features.items() if key not in known_keys and isinstance(feature, HFValue)]
|
72
|
+
|
73
|
+
def __len__(self) -> int:
|
74
|
+
return len(self.source)
|
75
|
+
|
76
|
+
def _get_base_metadata(self, index: int) -> DatumMetadata:
|
77
|
+
"""Extract base metadata for a datum."""
|
78
|
+
item = self.source[index]
|
79
|
+
datum_metadata: DatumMetadata = {"id": index}
|
80
|
+
for key in self._scalar_meta_keys:
|
81
|
+
datum_metadata[key] = item[key]
|
82
|
+
return datum_metadata
|
83
|
+
|
84
|
+
@lru_cache(maxsize=64) # Cache image conversions
|
85
|
+
def _get_image(self, index: int) -> np.ndarray:
|
86
|
+
"""Get and process image with caching and optimized conversions."""
|
87
|
+
# Convert to numpy array only once
|
88
|
+
raw_image = self.source[index][self._image_key]
|
89
|
+
image = np.asarray(raw_image)
|
90
|
+
|
91
|
+
# Handle different image formats efficiently
|
92
|
+
if image.ndim == 2:
|
93
|
+
# Grayscale: HW -> CHW
|
94
|
+
image = image[np.newaxis, ...] # More efficient than expand_dims
|
95
|
+
elif image.ndim == 3:
|
96
|
+
# Check if we need to transpose from HWC to CHW
|
97
|
+
if image.shape[-1] < image.shape[-3] and image.shape[-1] <= MAX_VALID_CHANNELS:
|
98
|
+
# HWC -> CHW using optimized transpose
|
99
|
+
image = np.transpose(image, (2, 0, 1))
|
100
|
+
elif image.shape[0] > MAX_VALID_CHANNELS:
|
101
|
+
raise ValueError(
|
102
|
+
f"Image at index {index} has invalid channel configuration. "
|
103
|
+
f"Expected channels to be less than {MAX_VALID_CHANNELS}, got shape {image.shape}"
|
104
|
+
)
|
105
|
+
else:
|
106
|
+
raise ValueError(
|
107
|
+
f"Image at index {index} has unsupported dimensionality. "
|
108
|
+
f"Expected 2D or 3D, got {image.ndim}D with shape {image.shape}"
|
109
|
+
)
|
110
|
+
|
111
|
+
if image.ndim != 3:
|
112
|
+
raise ValueError(f"Image processing failed for index {index}. Final shape: {image.shape}")
|
113
|
+
|
114
|
+
return image
|
115
|
+
|
116
|
+
|
117
|
+
class HFImageClassificationDataset(HFBaseDataset[NumpyArray], ic.Dataset):
|
118
|
+
"""Wraps a Hugging Face dataset to comply with the ImageClassificationDataset protocol."""
|
119
|
+
|
120
|
+
def __init__(self, hf_dataset: HFDataset, image_key: str, label_key: str) -> None:
|
121
|
+
super().__init__(hf_dataset, image_key, known_keys={image_key, label_key})
|
122
|
+
self._label_key = label_key
|
123
|
+
|
124
|
+
# Pre-validate label feature
|
125
|
+
label_feature = hf_dataset.features[self._label_key]
|
126
|
+
if not isinstance(label_feature, HFClassLabel):
|
127
|
+
raise TypeError(
|
128
|
+
f"Label feature '{self._label_key}' must be a datasets.ClassLabel, got {type(label_feature).__name__}."
|
129
|
+
)
|
130
|
+
|
131
|
+
self._num_classes: int = label_feature.num_classes
|
132
|
+
|
133
|
+
# Pre-compute one-hot identity matrix for efficient encoding
|
134
|
+
self._one_hot_matrix = np.eye(self._num_classes, dtype=np.float32)
|
135
|
+
|
136
|
+
# Enhanced metadata with validation
|
137
|
+
self.metadata: DatasetMetadata = DatasetMetadata(
|
138
|
+
id=self._metadata_id, index2label=dict(enumerate(label_feature.names), **self._metadata_dict)
|
139
|
+
)
|
140
|
+
|
141
|
+
def __getitem__(self, index: int) -> tuple[NumpyArray, NumpyArray, DatumMetadata]:
|
142
|
+
if not 0 <= index < len(self.source):
|
143
|
+
raise IndexError(f"Index {index} out of range for dataset of size {len(self.source)}")
|
144
|
+
|
145
|
+
# Process image
|
146
|
+
image = self._get_image(index)
|
147
|
+
label_int = self.source[index][self._label_key]
|
148
|
+
|
149
|
+
# Process target
|
150
|
+
if not 0 <= label_int < self._num_classes:
|
151
|
+
raise ValueError(f"Label {label_int} at index {index} is out of range [0, {self._num_classes})")
|
152
|
+
one_hot_label = self._one_hot_matrix[label_int]
|
153
|
+
|
154
|
+
# Process metadata
|
155
|
+
datum_metadata = self._get_base_metadata(index)
|
156
|
+
|
157
|
+
return image, one_hot_label, datum_metadata
|
158
|
+
|
159
|
+
|
160
|
+
class HFObjectDetectionDataset(HFBaseDataset[ObjectDetectionTarget], od.Dataset):
|
161
|
+
"""Wraps a Hugging Face dataset to comply with the ObjectDetectionDataset protocol."""
|
162
|
+
|
163
|
+
def __init__(self, hf_dataset: HFDataset, image_key: str, objects_key: str, bbox_key: str, label_key: str) -> None:
|
164
|
+
super().__init__(hf_dataset, image_key, known_keys={image_key, objects_key})
|
165
|
+
self._objects_key = objects_key
|
166
|
+
self._bbox_key = bbox_key
|
167
|
+
self._label_key = label_key
|
168
|
+
|
169
|
+
# Pre-validate and extract object features
|
170
|
+
self._object_meta_keys = self._validate_and_extract_object_features(hf_dataset.features)
|
171
|
+
|
172
|
+
# Validate and extract label information
|
173
|
+
label_feature = self._extract_label_feature(hf_dataset.features)
|
174
|
+
self.metadata: DatasetMetadata = DatasetMetadata(
|
175
|
+
id=self._metadata_id, index2label=dict(enumerate(label_feature.names)), **self._metadata_dict
|
176
|
+
)
|
177
|
+
|
178
|
+
def _validate_and_extract_object_features(self, features: FeatureDict) -> list[str]:
|
179
|
+
"""Validate objects feature and extract metadata keys."""
|
180
|
+
objects_feature = features[self._objects_key]
|
181
|
+
|
182
|
+
# Determine the structure and get inner features
|
183
|
+
if isinstance(objects_feature, HFList): # list(dict) case
|
184
|
+
if not isinstance(objects_feature.feature, dict):
|
185
|
+
raise TypeError(f"Objects feature '{self._objects_key}' with list type must contain dict features.")
|
186
|
+
inner_feature_dict = objects_feature.feature
|
187
|
+
elif isinstance(objects_feature, dict): # dict(list) case
|
188
|
+
inner_feature_dict = objects_feature
|
189
|
+
else:
|
190
|
+
raise TypeError(
|
191
|
+
f"Objects feature '{self._objects_key}' must be a list or dict, got {type(objects_feature).__name__}."
|
192
|
+
)
|
193
|
+
|
194
|
+
# Validate required keys exist
|
195
|
+
required_keys = {self._bbox_key, self._label_key}
|
196
|
+
missing_keys = required_keys - set(inner_feature_dict.keys())
|
197
|
+
if missing_keys:
|
198
|
+
raise ValueError(f"Objects feature '{self._objects_key}' missing required keys: {missing_keys}")
|
199
|
+
|
200
|
+
# Extract object metadata keys
|
201
|
+
known_inner_keys = {self._bbox_key, self._label_key}
|
202
|
+
return [
|
203
|
+
key
|
204
|
+
for key, feature in inner_feature_dict.items()
|
205
|
+
if key not in known_inner_keys and isinstance(feature, (HFValue, HFList))
|
206
|
+
]
|
207
|
+
|
208
|
+
def _extract_label_feature(self, features: FeatureDict) -> HFClassLabel:
|
209
|
+
"""Extract and validate the label feature."""
|
210
|
+
objects_feature = features[self._objects_key]
|
211
|
+
|
212
|
+
inner_features = objects_feature.feature if isinstance(objects_feature, HFList) else objects_feature
|
213
|
+
label_feature_container = inner_features[self._label_key]
|
214
|
+
label_feature = (
|
215
|
+
label_feature_container.feature
|
216
|
+
if isinstance(label_feature_container.feature, HFClassLabel)
|
217
|
+
else label_feature_container
|
218
|
+
)
|
219
|
+
|
220
|
+
if not isinstance(label_feature, HFClassLabel):
|
221
|
+
raise TypeError(
|
222
|
+
f"Label '{self._label_key}' in '{self._objects_key}' must be a ClassLabel, "
|
223
|
+
f"got {type(label_feature).__name__}."
|
224
|
+
)
|
225
|
+
|
226
|
+
return label_feature
|
227
|
+
|
228
|
+
def __getitem__(self, index: int) -> tuple[NumpyArray, ObjectDetectionTarget, DatumMetadata]:
|
229
|
+
if not 0 <= index < len(self.source):
|
230
|
+
raise IndexError(f"Index {index} out of range for dataset of size {len(self.source)}")
|
231
|
+
|
232
|
+
# Process image
|
233
|
+
image = self._get_image(index)
|
234
|
+
objects = self.source[index][self._objects_key]
|
235
|
+
|
236
|
+
# Process target
|
237
|
+
boxes = objects[self._bbox_key]
|
238
|
+
labels = objects[self._label_key]
|
239
|
+
scores = np.zeros_like(labels, dtype=np.float32)
|
240
|
+
target = ObjectDetectionTarget(boxes, labels, scores)
|
241
|
+
|
242
|
+
# Process metadata
|
243
|
+
datum_metadata = self._get_base_metadata(index)
|
244
|
+
self._add_object_metadata(objects, datum_metadata)
|
245
|
+
|
246
|
+
return image, target, datum_metadata
|
247
|
+
|
248
|
+
def _add_object_metadata(self, objects: dict[str, Any], datum_metadata: DatumMetadata) -> None:
|
249
|
+
"""Efficiently add object metadata to datum metadata."""
|
250
|
+
if not objects[self._bbox_key]: # No objects
|
251
|
+
return
|
252
|
+
|
253
|
+
num_objects = len(objects[self._bbox_key])
|
254
|
+
|
255
|
+
for key in self._object_meta_keys:
|
256
|
+
value = objects[key]
|
257
|
+
if isinstance(value, list):
|
258
|
+
if len(value) == num_objects:
|
259
|
+
datum_metadata[key] = value
|
260
|
+
else:
|
261
|
+
raise ValueError(
|
262
|
+
f"Object metadata '{key}' length {len(value)} doesn't match number of objects {num_objects}"
|
263
|
+
)
|
264
|
+
else:
|
265
|
+
datum_metadata[key] = [value] * num_objects
|
266
|
+
|
267
|
+
|
268
|
+
def is_bbox(feature: Any) -> bool:
|
269
|
+
"""Check if feature represents bounding box data with proper type validation."""
|
270
|
+
if not isinstance(feature, HFList):
|
271
|
+
return False
|
272
|
+
|
273
|
+
# Handle nested list structure
|
274
|
+
bbox_candidate = feature.feature if isinstance(feature.feature, HFList) else feature
|
275
|
+
|
276
|
+
return (
|
277
|
+
isinstance(bbox_candidate, HFList)
|
278
|
+
and bbox_candidate.length == 4
|
279
|
+
and isinstance(bbox_candidate.feature, HFValue)
|
280
|
+
and any(dtype in bbox_candidate.feature.dtype for dtype in ["float", "int"])
|
281
|
+
)
|
282
|
+
|
283
|
+
|
284
|
+
def is_label(feature: Any) -> bool:
|
285
|
+
"""Check if feature represents label data with proper type validation."""
|
286
|
+
target_feature = feature.feature if isinstance(feature, HFList) else feature
|
287
|
+
return isinstance(target_feature, HFClassLabel)
|
288
|
+
|
289
|
+
|
290
|
+
def find_od_keys(feature: Any) -> tuple[str | None, str | None]:
|
291
|
+
"""Helper to find bbox and label keys for object detection with improved logic."""
|
292
|
+
if not ((isinstance(feature, HFList) and isinstance(feature.feature, dict)) or isinstance(feature, dict)):
|
293
|
+
return None, None
|
294
|
+
|
295
|
+
inner_features: FeatureDict = feature.feature if isinstance(feature, HFList) else feature
|
296
|
+
|
297
|
+
bbox_key = label_key = None
|
298
|
+
|
299
|
+
for inner_name, inner_feature in inner_features.items():
|
300
|
+
if bbox_key is None and is_bbox(inner_feature):
|
301
|
+
bbox_key = inner_name
|
302
|
+
if label_key is None and is_label(inner_feature):
|
303
|
+
label_key = inner_name
|
304
|
+
|
305
|
+
# Early exit if both found
|
306
|
+
if bbox_key and label_key:
|
307
|
+
break
|
308
|
+
|
309
|
+
return bbox_key, label_key
|
310
|
+
|
311
|
+
|
312
|
+
def get_dataset_info(dataset: HFDataset) -> HFDatasetInfo:
|
313
|
+
"""Extract dataset information with improved validation and error messages."""
|
314
|
+
features = dataset.features
|
315
|
+
image_key = label_key = objects_key = bbox_key = None
|
316
|
+
|
317
|
+
# More efficient feature detection
|
318
|
+
for name, feature in features.items():
|
319
|
+
if image_key is None and isinstance(feature, (HFImage, HFArray)):
|
320
|
+
image_key = name
|
321
|
+
elif label_key is None and isinstance(feature, HFClassLabel):
|
322
|
+
label_key = name
|
323
|
+
elif objects_key is None:
|
324
|
+
temp_bbox, temp_label = find_od_keys(feature)
|
325
|
+
if temp_bbox and temp_label:
|
326
|
+
objects_key, bbox_key, label_key = name, temp_bbox, temp_label
|
327
|
+
|
328
|
+
if not image_key:
|
329
|
+
available_features = list(features.keys())
|
330
|
+
raise ValueError(
|
331
|
+
f"No image key found in dataset. Available features: {available_features}. "
|
332
|
+
f"Expected HFImage or HFArray type."
|
333
|
+
)
|
334
|
+
|
335
|
+
# Return appropriate dataset info based on detected features
|
336
|
+
if objects_key and bbox_key and label_key:
|
337
|
+
return HFObjectDetectionDatasetInfo(image_key, objects_key, bbox_key, label_key)
|
338
|
+
if label_key:
|
339
|
+
return HFImageClassificationDatasetInfo(image_key, label_key)
|
340
|
+
return HFDatasetInfo(image_key)
|
341
|
+
|
342
|
+
|
343
|
+
@overload
|
344
|
+
def from_huggingface(dataset: HFDataset, task: Literal["image_classification"]) -> HFImageClassificationDataset: ...
|
345
|
+
|
346
|
+
|
347
|
+
@overload
|
348
|
+
def from_huggingface(dataset: HFDataset, task: Literal["object_detection"]) -> HFObjectDetectionDataset: ...
|
349
|
+
|
350
|
+
|
351
|
+
@overload
|
352
|
+
def from_huggingface(
|
353
|
+
dataset: HFDataset, task: Literal["auto"] = "auto"
|
354
|
+
) -> HFObjectDetectionDataset | HFImageClassificationDataset: ...
|
355
|
+
|
356
|
+
|
357
|
+
def from_huggingface(
|
358
|
+
dataset: HFDataset, task: Literal["image_classification", "object_detection", "auto"] = "auto"
|
359
|
+
) -> HFObjectDetectionDataset | HFImageClassificationDataset:
|
360
|
+
"""Create appropriate dataset wrapper with enhanced error handling."""
|
361
|
+
info = get_dataset_info(dataset)
|
362
|
+
|
363
|
+
if isinstance(info, HFImageClassificationDatasetInfo):
|
364
|
+
if task in ("image_classification", "auto"):
|
365
|
+
return HFImageClassificationDataset(dataset, info.image_key, info.label_key)
|
366
|
+
if task == "object_detection":
|
367
|
+
raise ValueError(
|
368
|
+
f"Task mismatch: requested 'object_detection' but dataset appears to be "
|
369
|
+
f"image classification. Detected features: image='{info.image_key}', "
|
370
|
+
f"label='{info.label_key}'"
|
371
|
+
)
|
372
|
+
|
373
|
+
elif isinstance(info, HFObjectDetectionDatasetInfo):
|
374
|
+
if task in ("object_detection", "auto"):
|
375
|
+
return HFObjectDetectionDataset(dataset, info.image_key, info.objects_key, info.bbox_key, info.label_key)
|
376
|
+
if task == "image_classification":
|
377
|
+
raise ValueError(
|
378
|
+
f"Task mismatch: requested 'image_classification' but dataset appears to be "
|
379
|
+
f"object detection. Detected features: image='{info.image_key}', "
|
380
|
+
f"objects='{info.objects_key}'"
|
381
|
+
)
|
382
|
+
|
383
|
+
# Enhanced error message for auto-detection failure
|
384
|
+
available_features = list(dataset.features.keys())
|
385
|
+
feature_types = {k: type(v).__name__ for k, v in dataset.features.items()}
|
386
|
+
|
387
|
+
raise ValueError(
|
388
|
+
f"Could not automatically determine task for requested type '{task}'. "
|
389
|
+
f"Detected info: {info}. Available features: {available_features}. "
|
390
|
+
f"Feature types: {feature_types}. Ensure dataset has proper image and label/objects features."
|
391
|
+
)
|
@@ -2,15 +2,20 @@ from __future__ import annotations
|
|
2
2
|
|
3
3
|
__all__ = []
|
4
4
|
|
5
|
+
from collections.abc import Sequence
|
5
6
|
from pathlib import Path
|
6
|
-
from typing import Any, Literal,
|
7
|
+
from typing import Any, Literal, TypeVar
|
7
8
|
|
8
9
|
import numpy as np
|
9
10
|
from numpy.typing import NDArray
|
10
11
|
|
11
|
-
from maite_datasets._base import
|
12
|
-
|
13
|
-
|
12
|
+
from maite_datasets._base import (
|
13
|
+
BaseDatasetNumpyMixin,
|
14
|
+
BaseICDataset,
|
15
|
+
DataLocation,
|
16
|
+
NumpyArray,
|
17
|
+
NumpyImageClassificationTransform,
|
18
|
+
)
|
14
19
|
|
15
20
|
CIFARClassStringMap = Literal[
|
16
21
|
"airplane",
|
@@ -27,7 +32,7 @@ CIFARClassStringMap = Literal[
|
|
27
32
|
TCIFARClassMap = TypeVar("TCIFARClassMap", CIFARClassStringMap, int, list[CIFARClassStringMap], list[int])
|
28
33
|
|
29
34
|
|
30
|
-
class CIFAR10(BaseICDataset[
|
35
|
+
class CIFAR10(BaseICDataset[NumpyArray], BaseDatasetNumpyMixin):
|
31
36
|
"""
|
32
37
|
`CIFAR10 <https://www.cs.toronto.edu/~kriz/cifar.html>`_ Dataset as NumPy arrays.
|
33
38
|
|
@@ -89,7 +94,7 @@ class CIFAR10(BaseICDataset[NDArray[np.number[Any]]], BaseDatasetNumpyMixin):
|
|
89
94
|
self,
|
90
95
|
root: str | Path,
|
91
96
|
image_set: Literal["train", "test", "base"] = "train",
|
92
|
-
transforms:
|
97
|
+
transforms: NumpyImageClassificationTransform | Sequence[NumpyImageClassificationTransform] | None = None,
|
93
98
|
download: bool = False,
|
94
99
|
verbose: bool = False,
|
95
100
|
) -> None:
|
@@ -214,7 +219,7 @@ class CIFAR10(BaseICDataset[NDArray[np.number[Any]]], BaseDatasetNumpyMixin):
|
|
214
219
|
images[i, 2] = blue_channel # Blue channel
|
215
220
|
return images, labels
|
216
221
|
|
217
|
-
def _read_file(self, path: str) ->
|
222
|
+
def _read_file(self, path: str) -> NumpyArray:
|
218
223
|
"""
|
219
224
|
Function to grab the correct image from the loaded data.
|
220
225
|
Overwrite of the base `_read_file` because data is an all or nothing load.
|
@@ -2,15 +2,20 @@ from __future__ import annotations
|
|
2
2
|
|
3
3
|
__all__ = []
|
4
4
|
|
5
|
+
from collections.abc import Sequence
|
5
6
|
from pathlib import Path
|
6
|
-
from typing import Any, Literal,
|
7
|
+
from typing import Any, Literal, TypeVar
|
7
8
|
|
8
9
|
import numpy as np
|
9
10
|
from numpy.typing import NDArray
|
10
11
|
|
11
|
-
from maite_datasets._base import
|
12
|
-
|
13
|
-
|
12
|
+
from maite_datasets._base import (
|
13
|
+
BaseDatasetNumpyMixin,
|
14
|
+
BaseICDataset,
|
15
|
+
DataLocation,
|
16
|
+
NumpyArray,
|
17
|
+
NumpyImageClassificationTransform,
|
18
|
+
)
|
14
19
|
|
15
20
|
MNISTClassStringMap = Literal["zero", "one", "two", "three", "four", "five", "six", "seven", "eight", "nine"]
|
16
21
|
TMNISTClassMap = TypeVar("TMNISTClassMap", MNISTClassStringMap, int, list[MNISTClassStringMap], list[int])
|
@@ -34,7 +39,7 @@ CorruptionStringMap = Literal[
|
|
34
39
|
]
|
35
40
|
|
36
41
|
|
37
|
-
class MNIST(BaseICDataset[
|
42
|
+
class MNIST(BaseICDataset[NumpyArray], BaseDatasetNumpyMixin):
|
38
43
|
"""`MNIST <https://en.wikipedia.org/wiki/MNIST_database>`_ Dataset and `Corruptions <https://arxiv.org/abs/1906.02337>`_.
|
39
44
|
|
40
45
|
There are 15 different styles of corruptions. This class downloads differently depending on if you
|
@@ -118,7 +123,7 @@ class MNIST(BaseICDataset[NDArray[np.number[Any]]], BaseDatasetNumpyMixin):
|
|
118
123
|
root: str | Path,
|
119
124
|
image_set: Literal["train", "test", "base"] = "train",
|
120
125
|
corruption: CorruptionStringMap | None = None,
|
121
|
-
transforms:
|
126
|
+
transforms: NumpyImageClassificationTransform | Sequence[NumpyImageClassificationTransform] | None = None,
|
122
127
|
download: bool = False,
|
123
128
|
verbose: bool = False,
|
124
129
|
) -> None:
|
@@ -149,7 +154,7 @@ class MNIST(BaseICDataset[NDArray[np.number[Any]]], BaseDatasetNumpyMixin):
|
|
149
154
|
index_strings = np.arange(self._loaded_data.shape[0]).astype(str).tolist()
|
150
155
|
return index_strings, labels.tolist(), {}
|
151
156
|
|
152
|
-
def _load_corruption(self) -> tuple[
|
157
|
+
def _load_corruption(self) -> tuple[NumpyArray, NDArray[np.uintp]]:
|
153
158
|
"""Function to load in the file paths for the data and labels for the different corrupt data formats"""
|
154
159
|
corruption = self.corruption if self.corruption is not None else "identity"
|
155
160
|
base_path = self.path / "mnist_c" / corruption
|
@@ -176,7 +181,7 @@ class MNIST(BaseICDataset[NDArray[np.number[Any]]], BaseDatasetNumpyMixin):
|
|
176
181
|
|
177
182
|
return data, labels
|
178
183
|
|
179
|
-
def _grab_data(self, path: Path) -> tuple[
|
184
|
+
def _grab_data(self, path: Path) -> tuple[NumpyArray, NDArray[np.uintp]]:
|
180
185
|
"""Function to load in the data numpy array"""
|
181
186
|
with np.load(path, allow_pickle=True) as data_array:
|
182
187
|
if self.image_set == "base":
|
@@ -190,11 +195,11 @@ class MNIST(BaseICDataset[NDArray[np.number[Any]]], BaseDatasetNumpyMixin):
|
|
190
195
|
data = np.expand_dims(data, axis=1)
|
191
196
|
return data, labels
|
192
197
|
|
193
|
-
def _grab_corruption_data(self, path: Path) ->
|
198
|
+
def _grab_corruption_data(self, path: Path) -> NumpyArray:
|
194
199
|
"""Function to load in the data numpy array for the previously chosen corrupt format"""
|
195
200
|
return np.load(path, allow_pickle=False)
|
196
201
|
|
197
|
-
def _read_file(self, path: str) ->
|
202
|
+
def _read_file(self, path: str) -> NumpyArray:
|
198
203
|
"""
|
199
204
|
Function to grab the correct image from the loaded data.
|
200
205
|
Overwrite of the base `_read_file` because data is an all or nothing load.
|
@@ -2,18 +2,22 @@ from __future__ import annotations
|
|
2
2
|
|
3
3
|
__all__ = []
|
4
4
|
|
5
|
+
from collections.abc import Sequence
|
5
6
|
from pathlib import Path
|
6
|
-
from typing import Any
|
7
|
+
from typing import Any
|
7
8
|
|
8
9
|
import numpy as np
|
9
|
-
from numpy.typing import NDArray
|
10
10
|
|
11
|
-
from maite_datasets._base import
|
12
|
-
|
13
|
-
|
11
|
+
from maite_datasets._base import (
|
12
|
+
BaseDatasetNumpyMixin,
|
13
|
+
BaseICDataset,
|
14
|
+
DataLocation,
|
15
|
+
NumpyArray,
|
16
|
+
NumpyImageClassificationTransform,
|
17
|
+
)
|
14
18
|
|
15
19
|
|
16
|
-
class Ships(BaseICDataset[
|
20
|
+
class Ships(BaseICDataset[NumpyArray], BaseDatasetNumpyMixin):
|
17
21
|
"""
|
18
22
|
A dataset that focuses on identifying ships from satellite images.
|
19
23
|
|
@@ -76,7 +80,7 @@ class Ships(BaseICDataset[NDArray[np.number[Any]]], BaseDatasetNumpyMixin):
|
|
76
80
|
def __init__(
|
77
81
|
self,
|
78
82
|
root: str | Path,
|
79
|
-
transforms:
|
83
|
+
transforms: NumpyImageClassificationTransform | Sequence[NumpyImageClassificationTransform] | None = None,
|
80
84
|
download: bool = False,
|
81
85
|
verbose: bool = False,
|
82
86
|
) -> None:
|
@@ -125,7 +129,7 @@ class Ships(BaseICDataset[NDArray[np.number[Any]]], BaseDatasetNumpyMixin):
|
|
125
129
|
"""Function to load in the file paths for the scene images"""
|
126
130
|
return sorted(str(entry) for entry in (self.path / "scenes").glob("*.png"))
|
127
131
|
|
128
|
-
def get_scene(self, index: int) ->
|
132
|
+
def get_scene(self, index: int) -> NumpyArray:
|
129
133
|
"""
|
130
134
|
Get the desired satellite image (scene) by passing in the index of the desired file.
|
131
135
|
|
@@ -1,20 +1,17 @@
|
|
1
1
|
"""Module for MAITE compliant Object Detection datasets."""
|
2
2
|
|
3
3
|
from maite_datasets.object_detection._antiuav import AntiUAVDetection
|
4
|
+
from maite_datasets.object_detection._coco import COCODatasetReader
|
4
5
|
from maite_datasets.object_detection._milco import MILCO
|
5
6
|
from maite_datasets.object_detection._seadrone import SeaDrone
|
6
7
|
from maite_datasets.object_detection._voc import VOCDetection
|
8
|
+
from maite_datasets.object_detection._yolo import YOLODatasetReader
|
7
9
|
|
8
10
|
__all__ = [
|
9
11
|
"AntiUAVDetection",
|
10
12
|
"MILCO",
|
11
13
|
"SeaDrone",
|
12
14
|
"VOCDetection",
|
15
|
+
"COCODatasetReader",
|
16
|
+
"YOLODatasetReader",
|
13
17
|
]
|
14
|
-
|
15
|
-
import importlib.util
|
16
|
-
|
17
|
-
if importlib.util.find_spec("torch") is not None:
|
18
|
-
from maite_datasets.object_detection._voc_torch import VOCDetectionTorch
|
19
|
-
|
20
|
-
__all__ += ["VOCDetectionTorch"]
|
@@ -2,19 +2,22 @@ from __future__ import annotations
|
|
2
2
|
|
3
3
|
__all__ = []
|
4
4
|
|
5
|
+
from collections.abc import Sequence
|
5
6
|
from pathlib import Path
|
6
|
-
from typing import Any, Literal
|
7
|
+
from typing import Any, Literal
|
7
8
|
|
8
|
-
import numpy as np
|
9
9
|
from defusedxml.ElementTree import parse
|
10
|
-
from numpy.typing import NDArray
|
11
10
|
|
12
|
-
from maite_datasets._base import
|
13
|
-
|
14
|
-
|
11
|
+
from maite_datasets._base import (
|
12
|
+
BaseDatasetNumpyMixin,
|
13
|
+
BaseODDataset,
|
14
|
+
DataLocation,
|
15
|
+
NumpyArray,
|
16
|
+
NumpyObjectDetectionTransform,
|
17
|
+
)
|
15
18
|
|
16
19
|
|
17
|
-
class AntiUAVDetection(BaseODDataset[
|
20
|
+
class AntiUAVDetection(BaseODDataset[NumpyArray, list[str], str], BaseDatasetNumpyMixin):
|
18
21
|
"""
|
19
22
|
A UAV detection dataset focused on detecting UAVs in natural images against large variation in backgrounds.
|
20
23
|
|
@@ -101,7 +104,7 @@ class AntiUAVDetection(BaseODDataset[NDArray[np.number[Any]], list[str], str], B
|
|
101
104
|
self,
|
102
105
|
root: str | Path,
|
103
106
|
image_set: Literal["train", "val", "test", "base"] = "train",
|
104
|
-
transforms:
|
107
|
+
transforms: NumpyObjectDetectionTransform | Sequence[NumpyObjectDetectionTransform] | None = None,
|
105
108
|
download: bool = False,
|
106
109
|
verbose: bool = False,
|
107
110
|
) -> None:
|