maite-datasets 0.0.5__py3-none-any.whl → 0.0.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- maite_datasets/__init__.py +2 -6
- maite_datasets/_base.py +169 -51
- maite_datasets/_builder.py +46 -55
- maite_datasets/_collate.py +2 -3
- maite_datasets/{_reader/_base.py → _reader.py} +62 -36
- maite_datasets/_validate.py +4 -2
- maite_datasets/adapters/__init__.py +3 -0
- maite_datasets/adapters/_huggingface.py +391 -0
- maite_datasets/image_classification/_cifar10.py +12 -7
- maite_datasets/image_classification/_mnist.py +15 -10
- maite_datasets/image_classification/_ships.py +12 -8
- maite_datasets/object_detection/__init__.py +4 -7
- maite_datasets/object_detection/_antiuav.py +11 -8
- maite_datasets/{_reader → object_detection}/_coco.py +29 -27
- maite_datasets/object_detection/_milco.py +11 -9
- maite_datasets/object_detection/_seadrone.py +11 -9
- maite_datasets/object_detection/_voc.py +11 -13
- maite_datasets/{_reader → object_detection}/_yolo.py +26 -21
- maite_datasets/protocols.py +94 -0
- maite_datasets/wrappers/__init__.py +8 -0
- maite_datasets/wrappers/_torch.py +109 -0
- maite_datasets-0.0.7.dist-info/METADATA +181 -0
- maite_datasets-0.0.7.dist-info/RECORD +28 -0
- maite_datasets/_mixin/__init__.py +0 -0
- maite_datasets/_mixin/_numpy.py +0 -28
- maite_datasets/_mixin/_torch.py +0 -28
- maite_datasets/_protocols.py +0 -217
- maite_datasets/_reader/__init__.py +0 -6
- maite_datasets/_reader/_factory.py +0 -64
- maite_datasets/_types.py +0 -50
- maite_datasets/object_detection/_voc_torch.py +0 -65
- maite_datasets-0.0.5.dist-info/METADATA +0 -91
- maite_datasets-0.0.5.dist-info/RECORD +0 -31
- {maite_datasets-0.0.5.dist-info → maite_datasets-0.0.7.dist-info}/WHEEL +0 -0
- {maite_datasets-0.0.5.dist-info → maite_datasets-0.0.7.dist-info}/licenses/LICENSE +0 -0
maite_datasets/__init__.py
CHANGED
@@ -1,11 +1,9 @@
|
|
1
1
|
"""Module for MAITE compliant Computer Vision datasets."""
|
2
2
|
|
3
3
|
from maite_datasets._builder import to_image_classification_dataset, to_object_detection_dataset
|
4
|
-
from maite_datasets._collate import
|
4
|
+
from maite_datasets._collate import collate_as_list, collate_as_numpy, collate_as_torch
|
5
|
+
from maite_datasets._reader import create_dataset_reader
|
5
6
|
from maite_datasets._validate import validate_dataset
|
6
|
-
from maite_datasets._reader._factory import create_dataset_reader
|
7
|
-
from maite_datasets._reader._coco import COCODatasetReader
|
8
|
-
from maite_datasets._reader._yolo import YOLODatasetReader
|
9
7
|
|
10
8
|
__all__ = [
|
11
9
|
"collate_as_list",
|
@@ -15,6 +13,4 @@ __all__ = [
|
|
15
13
|
"to_image_classification_dataset",
|
16
14
|
"to_object_detection_dataset",
|
17
15
|
"validate_dataset",
|
18
|
-
"COCODatasetReader",
|
19
|
-
"YOLODatasetReader",
|
20
16
|
]
|
maite_datasets/_base.py
CHANGED
@@ -2,23 +2,24 @@ from __future__ import annotations
|
|
2
2
|
|
3
3
|
__all__ = []
|
4
4
|
|
5
|
+
import inspect
|
6
|
+
import warnings
|
5
7
|
from abc import abstractmethod
|
8
|
+
from collections import namedtuple
|
9
|
+
from collections.abc import Iterator, Sequence
|
6
10
|
from pathlib import Path
|
7
|
-
from typing import Any,
|
11
|
+
from typing import Any, Callable, Generic, Literal, NamedTuple, TypeVar, cast
|
8
12
|
|
9
13
|
import numpy as np
|
14
|
+
from maite.protocols import DatasetMetadata, DatumMetadata
|
15
|
+
from numpy.typing import NDArray
|
16
|
+
from PIL import Image
|
10
17
|
|
11
18
|
from maite_datasets._fileio import _ensure_exists
|
12
|
-
from maite_datasets.
|
13
|
-
from maite_datasets._types import (
|
14
|
-
AnnotatedDataset,
|
15
|
-
DatasetMetadata,
|
16
|
-
DatumMetadata,
|
17
|
-
ImageClassificationDataset,
|
18
|
-
ObjectDetectionDataset,
|
19
|
-
ObjectDetectionTarget,
|
20
|
-
)
|
19
|
+
from maite_datasets.protocols import Array
|
21
20
|
|
21
|
+
_T = TypeVar("_T")
|
22
|
+
_T_co = TypeVar("_T_co", covariant=True)
|
22
23
|
_TArray = TypeVar("_TArray", bound=Array)
|
23
24
|
_TTarget = TypeVar("_TTarget")
|
24
25
|
_TRawTarget = TypeVar(
|
@@ -30,16 +31,7 @@ _TRawTarget = TypeVar(
|
|
30
31
|
_TAnnotation = TypeVar("_TAnnotation", int, str, tuple[list[int], list[list[float]]])
|
31
32
|
|
32
33
|
|
33
|
-
|
34
|
-
_id = metadata.pop("id", index)
|
35
|
-
return DatumMetadata(id=_id, **metadata)
|
36
|
-
|
37
|
-
|
38
|
-
class DataLocation(NamedTuple):
|
39
|
-
url: str
|
40
|
-
filename: str
|
41
|
-
md5: bool
|
42
|
-
checksum: str
|
34
|
+
ObjectDetectionTarget = namedtuple("ObjectDetectionTarget", ["boxes", "labels", "scores"])
|
43
35
|
|
44
36
|
|
45
37
|
class BaseDatasetMixin(Generic[_TArray]):
|
@@ -50,8 +42,99 @@ class BaseDatasetMixin(Generic[_TArray]):
|
|
50
42
|
def _read_file(self, path: str) -> _TArray: ...
|
51
43
|
|
52
44
|
|
53
|
-
class
|
54
|
-
|
45
|
+
class Dataset(Generic[_T_co]):
|
46
|
+
"""Abstract generic base class for PyTorch style Dataset"""
|
47
|
+
|
48
|
+
def __getitem__(self, index: int) -> _T_co: ...
|
49
|
+
def __add__(self, other: Dataset[_T_co]) -> Dataset[_T_co]: ...
|
50
|
+
|
51
|
+
|
52
|
+
class BaseDataset(Dataset[tuple[_TArray, _TTarget, DatumMetadata]]):
|
53
|
+
metadata: DatasetMetadata
|
54
|
+
|
55
|
+
def __init__(
|
56
|
+
self,
|
57
|
+
transforms: Callable[[_TArray], _TArray]
|
58
|
+
| Callable[
|
59
|
+
[tuple[_TArray, _TTarget, DatumMetadata]],
|
60
|
+
tuple[_TArray, _TTarget, DatumMetadata],
|
61
|
+
]
|
62
|
+
| Sequence[
|
63
|
+
Callable[[_TArray], _TArray]
|
64
|
+
| Callable[
|
65
|
+
[tuple[_TArray, _TTarget, DatumMetadata]],
|
66
|
+
tuple[_TArray, _TTarget, DatumMetadata],
|
67
|
+
]
|
68
|
+
]
|
69
|
+
| None,
|
70
|
+
) -> None:
|
71
|
+
self.transforms: Sequence[
|
72
|
+
Callable[
|
73
|
+
[tuple[_TArray, _TTarget, DatumMetadata]],
|
74
|
+
tuple[_TArray, _TTarget, DatumMetadata],
|
75
|
+
]
|
76
|
+
] = []
|
77
|
+
transforms = transforms if isinstance(transforms, Sequence) else [transforms] if transforms else []
|
78
|
+
for transform in transforms:
|
79
|
+
sig = inspect.signature(transform)
|
80
|
+
if len(sig.parameters) != 1:
|
81
|
+
warnings.warn(f"Dropping unrecognized transform: {str(transform)}")
|
82
|
+
elif "tuple" in str(sig.parameters.values()):
|
83
|
+
transform = cast(
|
84
|
+
Callable[
|
85
|
+
[tuple[_TArray, _TTarget, DatumMetadata]],
|
86
|
+
tuple[_TArray, _TTarget, DatumMetadata],
|
87
|
+
],
|
88
|
+
transform,
|
89
|
+
)
|
90
|
+
self.transforms.append(transform)
|
91
|
+
else:
|
92
|
+
transform = cast(Callable[[_TArray], _TArray], transform)
|
93
|
+
self.transforms.append(self._wrap_transform(transform))
|
94
|
+
|
95
|
+
def _wrap_transform(
|
96
|
+
self, transform: Callable[[_TArray], _TArray]
|
97
|
+
) -> Callable[
|
98
|
+
[tuple[_TArray, _TTarget, DatumMetadata]],
|
99
|
+
tuple[_TArray, _TTarget, DatumMetadata],
|
100
|
+
]:
|
101
|
+
def wrapper(
|
102
|
+
datum: tuple[_TArray, _TTarget, DatumMetadata],
|
103
|
+
) -> tuple[_TArray, _TTarget, DatumMetadata]:
|
104
|
+
image, target, metadata = datum
|
105
|
+
return (transform(image), target, metadata)
|
106
|
+
|
107
|
+
return wrapper
|
108
|
+
|
109
|
+
def _transform(self, datum: tuple[_TArray, _TTarget, DatumMetadata]) -> tuple[_TArray, _TTarget, DatumMetadata]:
|
110
|
+
"""Function to transform the image prior to returning based on parameters passed in."""
|
111
|
+
for transform in self.transforms:
|
112
|
+
datum = transform(datum)
|
113
|
+
return datum
|
114
|
+
|
115
|
+
def __len__(self) -> int: ...
|
116
|
+
|
117
|
+
def __str__(self) -> str:
|
118
|
+
nt = "\n "
|
119
|
+
title = f"{self.__class__.__name__.replace('Dataset', '')} Dataset"
|
120
|
+
sep = "-" * len(title)
|
121
|
+
attrs = [
|
122
|
+
f"{' '.join(w.capitalize() for w in k.split('_'))}: {v}"
|
123
|
+
for k, v in self.__dict__.items()
|
124
|
+
if not k.startswith("_")
|
125
|
+
]
|
126
|
+
return f"{title}\n{sep}{nt}{nt.join(attrs)}"
|
127
|
+
|
128
|
+
|
129
|
+
class DataLocation(NamedTuple):
|
130
|
+
url: str
|
131
|
+
filename: str
|
132
|
+
md5: bool
|
133
|
+
checksum: str
|
134
|
+
|
135
|
+
|
136
|
+
class BaseDownloadedDataset(
|
137
|
+
BaseDataset[_TArray, _TTarget],
|
55
138
|
Generic[_TArray, _TTarget, _TRawTarget, _TAnnotation],
|
56
139
|
):
|
57
140
|
"""
|
@@ -72,13 +155,24 @@ class BaseDataset(
|
|
72
155
|
self,
|
73
156
|
root: str | Path,
|
74
157
|
image_set: Literal["train", "val", "test", "operational", "base"] = "train",
|
75
|
-
transforms:
|
158
|
+
transforms: Callable[[_TArray], _TArray]
|
159
|
+
| Callable[
|
160
|
+
[tuple[_TArray, _TTarget, DatumMetadata]],
|
161
|
+
tuple[_TArray, _TTarget, DatumMetadata],
|
162
|
+
]
|
163
|
+
| Sequence[
|
164
|
+
Callable[[_TArray], _TArray]
|
165
|
+
| Callable[
|
166
|
+
[tuple[_TArray, _TTarget, DatumMetadata]],
|
167
|
+
tuple[_TArray, _TTarget, DatumMetadata],
|
168
|
+
]
|
169
|
+
]
|
170
|
+
| None = None,
|
76
171
|
download: bool = False,
|
77
172
|
verbose: bool = False,
|
78
173
|
) -> None:
|
174
|
+
super().__init__(transforms)
|
79
175
|
self._root: Path = root.absolute() if isinstance(root, Path) else Path(root).absolute()
|
80
|
-
transforms = transforms if transforms is not None else []
|
81
|
-
self.transforms: Sequence[Transform[_TArray]] = transforms if isinstance(transforms, Sequence) else [transforms]
|
82
176
|
self.image_set = image_set
|
83
177
|
self._verbose = verbose
|
84
178
|
|
@@ -91,9 +185,11 @@ class BaseDataset(
|
|
91
185
|
self._label2index = {v: k for k, v in self.index2label.items()}
|
92
186
|
|
93
187
|
self.metadata: DatasetMetadata = DatasetMetadata(
|
94
|
-
|
95
|
-
|
96
|
-
|
188
|
+
**{
|
189
|
+
"id": self._unique_id(),
|
190
|
+
"index2label": self.index2label,
|
191
|
+
"split": self.image_set,
|
192
|
+
}
|
97
193
|
)
|
98
194
|
|
99
195
|
# Load the data
|
@@ -101,13 +197,6 @@ class BaseDataset(
|
|
101
197
|
self._filepaths, self._targets, self._datum_metadata = self._load_data()
|
102
198
|
self.size: int = len(self._filepaths)
|
103
199
|
|
104
|
-
def __str__(self) -> str:
|
105
|
-
nt = "\n "
|
106
|
-
title = f"{self.__class__.__name__} Dataset"
|
107
|
-
sep = "-" * len(title)
|
108
|
-
attrs = [f"{k.capitalize()}: {v}" for k, v in self.__dict__.items() if not k.startswith("_")]
|
109
|
-
return f"{title}\n{sep}{nt}{nt.join(attrs)}"
|
110
|
-
|
111
200
|
@property
|
112
201
|
def label2index(self) -> dict[str, int]:
|
113
202
|
return self._label2index
|
@@ -148,20 +237,18 @@ class BaseDataset(
|
|
148
237
|
@abstractmethod
|
149
238
|
def _load_data_inner(self) -> tuple[list[str], _TRawTarget, dict[str, Any]]: ...
|
150
239
|
|
151
|
-
def
|
152
|
-
|
153
|
-
|
154
|
-
image = transform(image)
|
155
|
-
return image
|
240
|
+
def _to_datum_metadata(self, index: int, metadata: dict[str, Any]) -> DatumMetadata:
|
241
|
+
_id = metadata.pop("id", index)
|
242
|
+
return DatumMetadata(id=_id, **metadata)
|
156
243
|
|
157
244
|
def __len__(self) -> int:
|
158
245
|
return self.size
|
159
246
|
|
160
247
|
|
161
248
|
class BaseICDataset(
|
162
|
-
|
249
|
+
BaseDownloadedDataset[_TArray, _TArray, list[int], int],
|
163
250
|
BaseDatasetMixin[_TArray],
|
164
|
-
|
251
|
+
BaseDataset[_TArray, _TArray],
|
165
252
|
):
|
166
253
|
"""
|
167
254
|
Base class for image classification datasets.
|
@@ -184,17 +271,16 @@ class BaseICDataset(
|
|
184
271
|
score = self._one_hot_encode(label)
|
185
272
|
# Get the image
|
186
273
|
img = self._read_file(self._filepaths[index])
|
187
|
-
img = self._transform(img)
|
188
274
|
|
189
275
|
img_metadata = {key: val[index] for key, val in self._datum_metadata.items()}
|
190
276
|
|
191
|
-
return img, score, _to_datum_metadata(index, img_metadata)
|
277
|
+
return self._transform((img, score, self._to_datum_metadata(index, img_metadata)))
|
192
278
|
|
193
279
|
|
194
280
|
class BaseODDataset(
|
195
|
-
|
281
|
+
BaseDownloadedDataset[_TArray, ObjectDetectionTarget, _TRawTarget, _TAnnotation],
|
196
282
|
BaseDatasetMixin[_TArray],
|
197
|
-
|
283
|
+
BaseDataset[_TArray, ObjectDetectionTarget],
|
198
284
|
):
|
199
285
|
"""
|
200
286
|
Base class for object detection datasets.
|
@@ -202,7 +288,7 @@ class BaseODDataset(
|
|
202
288
|
|
203
289
|
_bboxes_per_size: bool = False
|
204
290
|
|
205
|
-
def __getitem__(self, index: int) -> tuple[_TArray, ObjectDetectionTarget
|
291
|
+
def __getitem__(self, index: int) -> tuple[_TArray, ObjectDetectionTarget, DatumMetadata]:
|
206
292
|
"""
|
207
293
|
Args
|
208
294
|
----
|
@@ -211,7 +297,7 @@ class BaseODDataset(
|
|
211
297
|
|
212
298
|
Returns
|
213
299
|
-------
|
214
|
-
tuple[TArray, ObjectDetectionTarget
|
300
|
+
tuple[TArray, ObjectDetectionTarget, DatumMetadata]
|
215
301
|
Image, target, datum_metadata - target.boxes returns boxes in x0, y0, x1, y1 format
|
216
302
|
"""
|
217
303
|
# Grab the bounding boxes and labels from the annotations
|
@@ -220,17 +306,49 @@ class BaseODDataset(
|
|
220
306
|
# Get the image
|
221
307
|
img = self._read_file(self._filepaths[index])
|
222
308
|
img_size = img.shape
|
223
|
-
img = self._transform(img)
|
224
309
|
# Adjust labels if necessary
|
225
310
|
if self._bboxes_per_size and boxes:
|
226
|
-
boxes = boxes * np.
|
311
|
+
boxes = boxes * np.asarray([[img_size[1], img_size[2], img_size[1], img_size[2]]])
|
227
312
|
# Create the Object Detection Target
|
228
313
|
target = ObjectDetectionTarget(self._as_array(boxes), self._as_array(labels), self._one_hot_encode(labels))
|
229
314
|
|
230
315
|
img_metadata = {key: val[index] for key, val in self._datum_metadata.items()}
|
231
316
|
img_metadata = img_metadata | additional_metadata
|
232
317
|
|
233
|
-
return img, target, _to_datum_metadata(index, img_metadata)
|
318
|
+
return self._transform((img, target, self._to_datum_metadata(index, img_metadata)))
|
234
319
|
|
235
320
|
@abstractmethod
|
236
321
|
def _read_annotations(self, annotation: _TAnnotation) -> tuple[list[list[float]], list[int], dict[str, Any]]: ...
|
322
|
+
|
323
|
+
|
324
|
+
NumpyArray = NDArray[np.floating[Any]] | NDArray[np.integer[Any]]
|
325
|
+
|
326
|
+
|
327
|
+
class BaseDatasetNumpyMixin(BaseDatasetMixin[NumpyArray]):
|
328
|
+
def _as_array(self, raw: list[Any]) -> NumpyArray:
|
329
|
+
return np.asarray(raw)
|
330
|
+
|
331
|
+
def _one_hot_encode(self, value: int | list[int]) -> NumpyArray:
|
332
|
+
if isinstance(value, int):
|
333
|
+
encoded = np.zeros(len(self.index2label))
|
334
|
+
encoded[value] = 1
|
335
|
+
else:
|
336
|
+
encoded = np.zeros((len(value), len(self.index2label)))
|
337
|
+
encoded[np.arange(len(value)), value] = 1
|
338
|
+
return encoded
|
339
|
+
|
340
|
+
def _read_file(self, path: str) -> NumpyArray:
|
341
|
+
return np.array(Image.open(path)).transpose(2, 0, 1)
|
342
|
+
|
343
|
+
|
344
|
+
NumpyImageTransform = Callable[[NumpyArray], NumpyArray]
|
345
|
+
NumpyImageClassificationDatumTransform = Callable[
|
346
|
+
[tuple[NumpyArray, NumpyArray, DatumMetadata]],
|
347
|
+
tuple[NumpyArray, NumpyArray, DatumMetadata],
|
348
|
+
]
|
349
|
+
NumpyObjectDetectionDatumTransform = Callable[
|
350
|
+
[tuple[NumpyArray, ObjectDetectionTarget, DatumMetadata]],
|
351
|
+
tuple[NumpyArray, ObjectDetectionTarget, DatumMetadata],
|
352
|
+
]
|
353
|
+
NumpyImageClassificationTransform = NumpyImageTransform | NumpyImageClassificationDatumTransform
|
354
|
+
NumpyObjectDetectionTransform = NumpyImageTransform | NumpyObjectDetectionDatumTransform
|
maite_datasets/_builder.py
CHANGED
@@ -1,29 +1,24 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
|
-
import numpy as np
|
4
|
-
|
5
3
|
__all__ = []
|
6
4
|
|
5
|
+
from collections.abc import Iterable, Sequence
|
7
6
|
from typing import (
|
8
7
|
Any,
|
9
8
|
Generic,
|
10
|
-
Iterable,
|
11
9
|
Literal,
|
12
|
-
Sequence,
|
13
10
|
SupportsFloat,
|
14
11
|
SupportsInt,
|
15
12
|
TypeVar,
|
16
13
|
cast,
|
17
14
|
)
|
18
15
|
|
19
|
-
|
20
|
-
|
21
|
-
|
22
|
-
|
23
|
-
|
24
|
-
|
25
|
-
DatumMetadata,
|
26
|
-
)
|
16
|
+
import maite.protocols.image_classification as ic
|
17
|
+
import maite.protocols.object_detection as od
|
18
|
+
import numpy as np
|
19
|
+
from maite.protocols import ArrayLike, DatasetMetadata, DatumMetadata
|
20
|
+
|
21
|
+
from maite_datasets.protocols import Array
|
27
22
|
|
28
23
|
|
29
24
|
def _ensure_id(index: int, metadata: dict[str, Any]) -> DatumMetadata:
|
@@ -97,6 +92,8 @@ _TLabels = TypeVar("_TLabels", Sequence[int], Sequence[Sequence[int]])
|
|
97
92
|
|
98
93
|
|
99
94
|
class BaseAnnotatedDataset(Generic[_TLabels]):
|
95
|
+
metadata: DatasetMetadata
|
96
|
+
|
100
97
|
def __init__(
|
101
98
|
self,
|
102
99
|
datum_type: Literal["ic", "od"],
|
@@ -112,16 +109,13 @@ class BaseAnnotatedDataset(Generic[_TLabels]):
|
|
112
109
|
self._labels = labels
|
113
110
|
self._metadata = metadata
|
114
111
|
self._id = name or f"{len(self._images)}_image_{len(self._index2label)}_class_{datum_type}_dataset"
|
115
|
-
|
116
|
-
@property
|
117
|
-
def metadata(self) -> DatasetMetadata:
|
118
|
-
return DatasetMetadata(id=self._id, index2label=self._index2label)
|
112
|
+
self.metadata = DatasetMetadata(id=self._id, index2label=self._index2label)
|
119
113
|
|
120
114
|
def __len__(self) -> int:
|
121
115
|
return len(self._images)
|
122
116
|
|
123
117
|
|
124
|
-
class CustomImageClassificationDataset(BaseAnnotatedDataset[Sequence[int]],
|
118
|
+
class CustomImageClassificationDataset(BaseAnnotatedDataset[Sequence[int]], ic.Dataset):
|
125
119
|
def __init__(
|
126
120
|
self,
|
127
121
|
images: Array | Sequence[Array],
|
@@ -152,33 +146,34 @@ class CustomImageClassificationDataset(BaseAnnotatedDataset[Sequence[int]], Imag
|
|
152
146
|
)
|
153
147
|
|
154
148
|
|
155
|
-
class
|
156
|
-
|
157
|
-
|
158
|
-
|
159
|
-
|
160
|
-
|
161
|
-
|
162
|
-
|
163
|
-
|
164
|
-
|
165
|
-
|
166
|
-
|
167
|
-
|
168
|
-
|
169
|
-
|
170
|
-
|
171
|
-
|
172
|
-
|
173
|
-
|
174
|
-
|
175
|
-
|
176
|
-
return self._bboxes
|
177
|
-
|
178
|
-
@property
|
179
|
-
def scores(self) -> Sequence[Sequence[float]]:
|
180
|
-
return self._scores
|
149
|
+
class CustomObjectDetectionTarget:
|
150
|
+
def __init__(
|
151
|
+
self,
|
152
|
+
labels: Sequence[int],
|
153
|
+
bboxes: Sequence[Sequence[float]],
|
154
|
+
class_count: int,
|
155
|
+
) -> None:
|
156
|
+
self._labels = labels
|
157
|
+
self._bboxes = bboxes
|
158
|
+
one_hot = [[0.0] * class_count] * len(labels)
|
159
|
+
for i, label in enumerate(labels):
|
160
|
+
one_hot[i][label] = 1.0
|
161
|
+
self._scores = one_hot
|
162
|
+
|
163
|
+
@property
|
164
|
+
def labels(self) -> Sequence[int]:
|
165
|
+
return self._labels
|
166
|
+
|
167
|
+
@property
|
168
|
+
def boxes(self) -> Sequence[Sequence[float]]:
|
169
|
+
return self._bboxes
|
181
170
|
|
171
|
+
@property
|
172
|
+
def scores(self) -> Sequence[Sequence[float]]:
|
173
|
+
return self._scores
|
174
|
+
|
175
|
+
|
176
|
+
class CustomObjectDetectionDataset(BaseAnnotatedDataset[Sequence[Sequence[int]]], od.Dataset):
|
182
177
|
def __init__(
|
183
178
|
self,
|
184
179
|
images: Array | Sequence[Array],
|
@@ -203,14 +198,10 @@ class CustomObjectDetectionDataset(BaseAnnotatedDataset[Sequence[Sequence[int]]]
|
|
203
198
|
[np.asarray(box).tolist() if isinstance(box, Array) else box for box in bbox] for bbox in bboxes
|
204
199
|
]
|
205
200
|
|
206
|
-
|
207
|
-
def metadata(self) -> DatasetMetadata:
|
208
|
-
return DatasetMetadata(id=self._id, index2label=self._index2label)
|
209
|
-
|
210
|
-
def __getitem__(self, idx: int, /) -> tuple[Array, ObjectDetectionTarget, DatumMetadata]:
|
201
|
+
def __getitem__(self, idx: int, /) -> tuple[Array, CustomObjectDetectionTarget, DatumMetadata]:
|
211
202
|
return (
|
212
203
|
self._images[idx],
|
213
|
-
|
204
|
+
CustomObjectDetectionTarget(self._labels[idx], self._bboxes[idx], len(self._classes)),
|
214
205
|
_ensure_id(idx, self._metadata[idx] if self._metadata is not None else {}),
|
215
206
|
)
|
216
207
|
|
@@ -221,9 +212,9 @@ def to_image_classification_dataset(
|
|
221
212
|
metadata: Sequence[dict[str, Any]] | dict[str, Sequence[Any]] | None,
|
222
213
|
classes: Sequence[str] | None,
|
223
214
|
name: str | None = None,
|
224
|
-
) ->
|
215
|
+
) -> ic.Dataset:
|
225
216
|
"""
|
226
|
-
Helper function to create custom
|
217
|
+
Helper function to create custom image classification Dataset classes.
|
227
218
|
|
228
219
|
Parameters
|
229
220
|
----------
|
@@ -238,7 +229,7 @@ def to_image_classification_dataset(
|
|
238
229
|
|
239
230
|
Returns
|
240
231
|
-------
|
241
|
-
|
232
|
+
Dataset
|
242
233
|
"""
|
243
234
|
_validate_data("ic", images, labels, None, metadata)
|
244
235
|
return CustomImageClassificationDataset(images, labels, _listify_metadata(metadata), classes, name)
|
@@ -251,9 +242,9 @@ def to_object_detection_dataset(
|
|
251
242
|
metadata: Sequence[dict[str, Any]] | dict[str, Sequence[Any]] | None,
|
252
243
|
classes: Sequence[str] | None,
|
253
244
|
name: str | None = None,
|
254
|
-
) ->
|
245
|
+
) -> od.Dataset:
|
255
246
|
"""
|
256
|
-
Helper function to create custom
|
247
|
+
Helper function to create custom object detection Dataset classes.
|
257
248
|
|
258
249
|
Parameters
|
259
250
|
----------
|
@@ -270,7 +261,7 @@ def to_object_detection_dataset(
|
|
270
261
|
|
271
262
|
Returns
|
272
263
|
-------
|
273
|
-
|
264
|
+
Dataset
|
274
265
|
"""
|
275
266
|
_validate_data("od", images, labels, bboxes, metadata)
|
276
267
|
return CustomObjectDetectionDataset(images, labels, bboxes, _listify_metadata(metadata), classes, name)
|
maite_datasets/_collate.py
CHANGED
@@ -7,16 +7,15 @@ from __future__ import annotations
|
|
7
7
|
__all__ = []
|
8
8
|
|
9
9
|
from collections.abc import Iterable, Sequence
|
10
|
-
from typing import Any, TypeVar
|
10
|
+
from typing import TYPE_CHECKING, Any, TypeVar
|
11
11
|
|
12
12
|
import numpy as np
|
13
|
+
from maite.protocols import ArrayLike
|
13
14
|
from numpy.typing import NDArray
|
14
15
|
|
15
16
|
if TYPE_CHECKING:
|
16
17
|
import torch
|
17
18
|
|
18
|
-
from maite_datasets._protocols import ArrayLike
|
19
|
-
|
20
19
|
T_in = TypeVar("T_in")
|
21
20
|
T_tgt = TypeVar("T_tgt")
|
22
21
|
T_md = TypeVar("T_md")
|
@@ -1,39 +1,19 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
|
-
from abc import ABC, abstractmethod
|
4
3
|
import logging
|
4
|
+
from abc import ABC, abstractmethod
|
5
5
|
from pathlib import Path
|
6
|
-
from typing import Any
|
7
|
-
|
8
|
-
import numpy as np
|
6
|
+
from typing import Any, Generic, TypeVar
|
9
7
|
|
10
|
-
|
8
|
+
import maite.protocols.image_classification as ic
|
9
|
+
import maite.protocols.object_detection as od
|
11
10
|
|
12
11
|
_logger = logging.getLogger(__name__)
|
13
12
|
|
14
|
-
|
15
|
-
class _ObjectDetectionTarget:
|
16
|
-
"""Internal implementation of ObjectDetectionTarget protocol."""
|
17
|
-
|
18
|
-
def __init__(self, boxes: ArrayLike, labels: ArrayLike, scores: ArrayLike) -> None:
|
19
|
-
self._boxes = np.asarray(boxes)
|
20
|
-
self._labels = np.asarray(labels)
|
21
|
-
self._scores = np.asarray(scores)
|
22
|
-
|
23
|
-
@property
|
24
|
-
def boxes(self) -> ArrayLike:
|
25
|
-
return self._boxes
|
26
|
-
|
27
|
-
@property
|
28
|
-
def labels(self) -> ArrayLike:
|
29
|
-
return self._labels
|
30
|
-
|
31
|
-
@property
|
32
|
-
def scores(self) -> ArrayLike:
|
33
|
-
return self._scores
|
13
|
+
_TDataset = TypeVar("_TDataset", ic.Dataset, od.Dataset)
|
34
14
|
|
35
15
|
|
36
|
-
class BaseDatasetReader(ABC):
|
16
|
+
class BaseDatasetReader(Generic[_TDataset], ABC):
|
37
17
|
"""
|
38
18
|
Abstract base class for object detection dataset readers.
|
39
19
|
|
@@ -65,7 +45,7 @@ class BaseDatasetReader(ABC):
|
|
65
45
|
pass
|
66
46
|
|
67
47
|
@abstractmethod
|
68
|
-
def
|
48
|
+
def create_dataset(self) -> _TDataset:
|
69
49
|
"""Create the format-specific dataset implementation."""
|
70
50
|
pass
|
71
51
|
|
@@ -123,13 +103,59 @@ class BaseDatasetReader(ABC):
|
|
123
103
|
|
124
104
|
return {"is_valid": len(issues) == 0, "issues": issues, "stats": stats}
|
125
105
|
|
126
|
-
def get_dataset(self) -> ObjectDetectionDataset:
|
127
|
-
"""
|
128
|
-
Get dataset conforming to MAITE ObjectDetectionDataset protocol.
|
129
106
|
|
130
|
-
|
131
|
-
|
132
|
-
|
133
|
-
|
134
|
-
|
135
|
-
|
107
|
+
def create_dataset_reader(
|
108
|
+
dataset_path: str | Path, format_hint: str | None = None
|
109
|
+
) -> BaseDatasetReader[ic.Dataset] | BaseDatasetReader[od.Dataset]:
|
110
|
+
"""
|
111
|
+
Factory function to create appropriate dataset reader based on directory structure.
|
112
|
+
|
113
|
+
Parameters
|
114
|
+
----------
|
115
|
+
dataset_path : str or Path
|
116
|
+
Root directory containing dataset files
|
117
|
+
format_hint : str or None, default None
|
118
|
+
Format hint ("coco" or "yolo"). If None, auto-detects based on file structure
|
119
|
+
|
120
|
+
Returns
|
121
|
+
-------
|
122
|
+
BaseDatasetReader
|
123
|
+
Appropriate reader instance for the detected format
|
124
|
+
|
125
|
+
Raises
|
126
|
+
------
|
127
|
+
ValueError
|
128
|
+
If format cannot be determined or is unsupported
|
129
|
+
"""
|
130
|
+
from maite_datasets.object_detection._coco import COCODatasetReader
|
131
|
+
from maite_datasets.object_detection._yolo import YOLODatasetReader
|
132
|
+
|
133
|
+
dataset_path = Path(dataset_path)
|
134
|
+
|
135
|
+
if format_hint:
|
136
|
+
format_hint = format_hint.lower()
|
137
|
+
if format_hint == "coco":
|
138
|
+
return COCODatasetReader(dataset_path)
|
139
|
+
if format_hint == "yolo":
|
140
|
+
return YOLODatasetReader(dataset_path)
|
141
|
+
raise ValueError(f"Unsupported format hint: {format_hint}")
|
142
|
+
|
143
|
+
# Auto-detect format
|
144
|
+
has_annotations_json = (dataset_path / "annotations.json").exists()
|
145
|
+
has_labels_dir = (dataset_path / "labels").exists()
|
146
|
+
|
147
|
+
if has_annotations_json and not has_labels_dir:
|
148
|
+
_logger.info(f"Detected COCO format for {dataset_path}")
|
149
|
+
return COCODatasetReader(dataset_path)
|
150
|
+
if has_labels_dir and not has_annotations_json:
|
151
|
+
_logger.info(f"Detected YOLO format for {dataset_path}")
|
152
|
+
return YOLODatasetReader(dataset_path)
|
153
|
+
if has_annotations_json and has_labels_dir:
|
154
|
+
raise ValueError(
|
155
|
+
f"Ambiguous format in {dataset_path}: both annotations.json and labels/ exist. "
|
156
|
+
"Use format_hint parameter to specify format."
|
157
|
+
)
|
158
|
+
raise ValueError(
|
159
|
+
f"Cannot detect dataset format in {dataset_path}. "
|
160
|
+
"Expected either annotations.json (COCO) or labels/ directory (YOLO)."
|
161
|
+
)
|
maite_datasets/_validate.py
CHANGED
@@ -2,11 +2,13 @@ from __future__ import annotations
|
|
2
2
|
|
3
3
|
__all__ = []
|
4
4
|
|
5
|
-
import numpy as np
|
6
5
|
from collections.abc import Sequence, Sized
|
7
6
|
from typing import Any, Literal
|
8
7
|
|
9
|
-
|
8
|
+
import numpy as np
|
9
|
+
from maite.protocols.object_detection import ObjectDetectionTarget
|
10
|
+
|
11
|
+
from maite_datasets.protocols import Array
|
10
12
|
|
11
13
|
|
12
14
|
class ValidationMessages:
|