maite-datasets 0.0.5__py3-none-any.whl → 0.0.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (35) hide show
  1. maite_datasets/__init__.py +2 -6
  2. maite_datasets/_base.py +169 -51
  3. maite_datasets/_builder.py +46 -55
  4. maite_datasets/_collate.py +2 -3
  5. maite_datasets/{_reader/_base.py → _reader.py} +62 -36
  6. maite_datasets/_validate.py +4 -2
  7. maite_datasets/adapters/__init__.py +3 -0
  8. maite_datasets/adapters/_huggingface.py +391 -0
  9. maite_datasets/image_classification/_cifar10.py +12 -7
  10. maite_datasets/image_classification/_mnist.py +15 -10
  11. maite_datasets/image_classification/_ships.py +12 -8
  12. maite_datasets/object_detection/__init__.py +4 -7
  13. maite_datasets/object_detection/_antiuav.py +11 -8
  14. maite_datasets/{_reader → object_detection}/_coco.py +29 -27
  15. maite_datasets/object_detection/_milco.py +11 -9
  16. maite_datasets/object_detection/_seadrone.py +11 -9
  17. maite_datasets/object_detection/_voc.py +11 -13
  18. maite_datasets/{_reader → object_detection}/_yolo.py +26 -21
  19. maite_datasets/protocols.py +94 -0
  20. maite_datasets/wrappers/__init__.py +8 -0
  21. maite_datasets/wrappers/_torch.py +109 -0
  22. maite_datasets-0.0.7.dist-info/METADATA +181 -0
  23. maite_datasets-0.0.7.dist-info/RECORD +28 -0
  24. maite_datasets/_mixin/__init__.py +0 -0
  25. maite_datasets/_mixin/_numpy.py +0 -28
  26. maite_datasets/_mixin/_torch.py +0 -28
  27. maite_datasets/_protocols.py +0 -217
  28. maite_datasets/_reader/__init__.py +0 -6
  29. maite_datasets/_reader/_factory.py +0 -64
  30. maite_datasets/_types.py +0 -50
  31. maite_datasets/object_detection/_voc_torch.py +0 -65
  32. maite_datasets-0.0.5.dist-info/METADATA +0 -91
  33. maite_datasets-0.0.5.dist-info/RECORD +0 -31
  34. {maite_datasets-0.0.5.dist-info → maite_datasets-0.0.7.dist-info}/WHEEL +0 -0
  35. {maite_datasets-0.0.5.dist-info → maite_datasets-0.0.7.dist-info}/licenses/LICENSE +0 -0
@@ -1,11 +1,9 @@
1
1
  """Module for MAITE compliant Computer Vision datasets."""
2
2
 
3
3
  from maite_datasets._builder import to_image_classification_dataset, to_object_detection_dataset
4
- from maite_datasets._collate import collate_as_torch, collate_as_numpy, collate_as_list
4
+ from maite_datasets._collate import collate_as_list, collate_as_numpy, collate_as_torch
5
+ from maite_datasets._reader import create_dataset_reader
5
6
  from maite_datasets._validate import validate_dataset
6
- from maite_datasets._reader._factory import create_dataset_reader
7
- from maite_datasets._reader._coco import COCODatasetReader
8
- from maite_datasets._reader._yolo import YOLODatasetReader
9
7
 
10
8
  __all__ = [
11
9
  "collate_as_list",
@@ -15,6 +13,4 @@ __all__ = [
15
13
  "to_image_classification_dataset",
16
14
  "to_object_detection_dataset",
17
15
  "validate_dataset",
18
- "COCODatasetReader",
19
- "YOLODatasetReader",
20
16
  ]
maite_datasets/_base.py CHANGED
@@ -2,23 +2,24 @@ from __future__ import annotations
2
2
 
3
3
  __all__ = []
4
4
 
5
+ import inspect
6
+ import warnings
5
7
  from abc import abstractmethod
8
+ from collections import namedtuple
9
+ from collections.abc import Iterator, Sequence
6
10
  from pathlib import Path
7
- from typing import Any, Generic, Iterator, Literal, NamedTuple, Sequence, TypeVar, cast
11
+ from typing import Any, Callable, Generic, Literal, NamedTuple, TypeVar, cast
8
12
 
9
13
  import numpy as np
14
+ from maite.protocols import DatasetMetadata, DatumMetadata
15
+ from numpy.typing import NDArray
16
+ from PIL import Image
10
17
 
11
18
  from maite_datasets._fileio import _ensure_exists
12
- from maite_datasets._protocols import Array, Transform
13
- from maite_datasets._types import (
14
- AnnotatedDataset,
15
- DatasetMetadata,
16
- DatumMetadata,
17
- ImageClassificationDataset,
18
- ObjectDetectionDataset,
19
- ObjectDetectionTarget,
20
- )
19
+ from maite_datasets.protocols import Array
21
20
 
21
+ _T = TypeVar("_T")
22
+ _T_co = TypeVar("_T_co", covariant=True)
22
23
  _TArray = TypeVar("_TArray", bound=Array)
23
24
  _TTarget = TypeVar("_TTarget")
24
25
  _TRawTarget = TypeVar(
@@ -30,16 +31,7 @@ _TRawTarget = TypeVar(
30
31
  _TAnnotation = TypeVar("_TAnnotation", int, str, tuple[list[int], list[list[float]]])
31
32
 
32
33
 
33
- def _to_datum_metadata(index: int, metadata: dict[str, Any]) -> DatumMetadata:
34
- _id = metadata.pop("id", index)
35
- return DatumMetadata(id=_id, **metadata)
36
-
37
-
38
- class DataLocation(NamedTuple):
39
- url: str
40
- filename: str
41
- md5: bool
42
- checksum: str
34
+ ObjectDetectionTarget = namedtuple("ObjectDetectionTarget", ["boxes", "labels", "scores"])
43
35
 
44
36
 
45
37
  class BaseDatasetMixin(Generic[_TArray]):
@@ -50,8 +42,99 @@ class BaseDatasetMixin(Generic[_TArray]):
50
42
  def _read_file(self, path: str) -> _TArray: ...
51
43
 
52
44
 
53
- class BaseDataset(
54
- AnnotatedDataset[tuple[_TArray, _TTarget, DatumMetadata]],
45
+ class Dataset(Generic[_T_co]):
46
+ """Abstract generic base class for PyTorch style Dataset"""
47
+
48
+ def __getitem__(self, index: int) -> _T_co: ...
49
+ def __add__(self, other: Dataset[_T_co]) -> Dataset[_T_co]: ...
50
+
51
+
52
+ class BaseDataset(Dataset[tuple[_TArray, _TTarget, DatumMetadata]]):
53
+ metadata: DatasetMetadata
54
+
55
+ def __init__(
56
+ self,
57
+ transforms: Callable[[_TArray], _TArray]
58
+ | Callable[
59
+ [tuple[_TArray, _TTarget, DatumMetadata]],
60
+ tuple[_TArray, _TTarget, DatumMetadata],
61
+ ]
62
+ | Sequence[
63
+ Callable[[_TArray], _TArray]
64
+ | Callable[
65
+ [tuple[_TArray, _TTarget, DatumMetadata]],
66
+ tuple[_TArray, _TTarget, DatumMetadata],
67
+ ]
68
+ ]
69
+ | None,
70
+ ) -> None:
71
+ self.transforms: Sequence[
72
+ Callable[
73
+ [tuple[_TArray, _TTarget, DatumMetadata]],
74
+ tuple[_TArray, _TTarget, DatumMetadata],
75
+ ]
76
+ ] = []
77
+ transforms = transforms if isinstance(transforms, Sequence) else [transforms] if transforms else []
78
+ for transform in transforms:
79
+ sig = inspect.signature(transform)
80
+ if len(sig.parameters) != 1:
81
+ warnings.warn(f"Dropping unrecognized transform: {str(transform)}")
82
+ elif "tuple" in str(sig.parameters.values()):
83
+ transform = cast(
84
+ Callable[
85
+ [tuple[_TArray, _TTarget, DatumMetadata]],
86
+ tuple[_TArray, _TTarget, DatumMetadata],
87
+ ],
88
+ transform,
89
+ )
90
+ self.transforms.append(transform)
91
+ else:
92
+ transform = cast(Callable[[_TArray], _TArray], transform)
93
+ self.transforms.append(self._wrap_transform(transform))
94
+
95
+ def _wrap_transform(
96
+ self, transform: Callable[[_TArray], _TArray]
97
+ ) -> Callable[
98
+ [tuple[_TArray, _TTarget, DatumMetadata]],
99
+ tuple[_TArray, _TTarget, DatumMetadata],
100
+ ]:
101
+ def wrapper(
102
+ datum: tuple[_TArray, _TTarget, DatumMetadata],
103
+ ) -> tuple[_TArray, _TTarget, DatumMetadata]:
104
+ image, target, metadata = datum
105
+ return (transform(image), target, metadata)
106
+
107
+ return wrapper
108
+
109
+ def _transform(self, datum: tuple[_TArray, _TTarget, DatumMetadata]) -> tuple[_TArray, _TTarget, DatumMetadata]:
110
+ """Function to transform the image prior to returning based on parameters passed in."""
111
+ for transform in self.transforms:
112
+ datum = transform(datum)
113
+ return datum
114
+
115
+ def __len__(self) -> int: ...
116
+
117
+ def __str__(self) -> str:
118
+ nt = "\n "
119
+ title = f"{self.__class__.__name__.replace('Dataset', '')} Dataset"
120
+ sep = "-" * len(title)
121
+ attrs = [
122
+ f"{' '.join(w.capitalize() for w in k.split('_'))}: {v}"
123
+ for k, v in self.__dict__.items()
124
+ if not k.startswith("_")
125
+ ]
126
+ return f"{title}\n{sep}{nt}{nt.join(attrs)}"
127
+
128
+
129
+ class DataLocation(NamedTuple):
130
+ url: str
131
+ filename: str
132
+ md5: bool
133
+ checksum: str
134
+
135
+
136
+ class BaseDownloadedDataset(
137
+ BaseDataset[_TArray, _TTarget],
55
138
  Generic[_TArray, _TTarget, _TRawTarget, _TAnnotation],
56
139
  ):
57
140
  """
@@ -72,13 +155,24 @@ class BaseDataset(
72
155
  self,
73
156
  root: str | Path,
74
157
  image_set: Literal["train", "val", "test", "operational", "base"] = "train",
75
- transforms: Transform[_TArray] | Sequence[Transform[_TArray]] | None = None,
158
+ transforms: Callable[[_TArray], _TArray]
159
+ | Callable[
160
+ [tuple[_TArray, _TTarget, DatumMetadata]],
161
+ tuple[_TArray, _TTarget, DatumMetadata],
162
+ ]
163
+ | Sequence[
164
+ Callable[[_TArray], _TArray]
165
+ | Callable[
166
+ [tuple[_TArray, _TTarget, DatumMetadata]],
167
+ tuple[_TArray, _TTarget, DatumMetadata],
168
+ ]
169
+ ]
170
+ | None = None,
76
171
  download: bool = False,
77
172
  verbose: bool = False,
78
173
  ) -> None:
174
+ super().__init__(transforms)
79
175
  self._root: Path = root.absolute() if isinstance(root, Path) else Path(root).absolute()
80
- transforms = transforms if transforms is not None else []
81
- self.transforms: Sequence[Transform[_TArray]] = transforms if isinstance(transforms, Sequence) else [transforms]
82
176
  self.image_set = image_set
83
177
  self._verbose = verbose
84
178
 
@@ -91,9 +185,11 @@ class BaseDataset(
91
185
  self._label2index = {v: k for k, v in self.index2label.items()}
92
186
 
93
187
  self.metadata: DatasetMetadata = DatasetMetadata(
94
- id=self._unique_id(),
95
- index2label=self.index2label,
96
- split=self.image_set,
188
+ **{
189
+ "id": self._unique_id(),
190
+ "index2label": self.index2label,
191
+ "split": self.image_set,
192
+ }
97
193
  )
98
194
 
99
195
  # Load the data
@@ -101,13 +197,6 @@ class BaseDataset(
101
197
  self._filepaths, self._targets, self._datum_metadata = self._load_data()
102
198
  self.size: int = len(self._filepaths)
103
199
 
104
- def __str__(self) -> str:
105
- nt = "\n "
106
- title = f"{self.__class__.__name__} Dataset"
107
- sep = "-" * len(title)
108
- attrs = [f"{k.capitalize()}: {v}" for k, v in self.__dict__.items() if not k.startswith("_")]
109
- return f"{title}\n{sep}{nt}{nt.join(attrs)}"
110
-
111
200
  @property
112
201
  def label2index(self) -> dict[str, int]:
113
202
  return self._label2index
@@ -148,20 +237,18 @@ class BaseDataset(
148
237
  @abstractmethod
149
238
  def _load_data_inner(self) -> tuple[list[str], _TRawTarget, dict[str, Any]]: ...
150
239
 
151
- def _transform(self, image: _TArray) -> _TArray:
152
- """Function to transform the image prior to returning based on parameters passed in."""
153
- for transform in self.transforms:
154
- image = transform(image)
155
- return image
240
+ def _to_datum_metadata(self, index: int, metadata: dict[str, Any]) -> DatumMetadata:
241
+ _id = metadata.pop("id", index)
242
+ return DatumMetadata(id=_id, **metadata)
156
243
 
157
244
  def __len__(self) -> int:
158
245
  return self.size
159
246
 
160
247
 
161
248
  class BaseICDataset(
162
- BaseDataset[_TArray, _TArray, list[int], int],
249
+ BaseDownloadedDataset[_TArray, _TArray, list[int], int],
163
250
  BaseDatasetMixin[_TArray],
164
- ImageClassificationDataset[_TArray],
251
+ BaseDataset[_TArray, _TArray],
165
252
  ):
166
253
  """
167
254
  Base class for image classification datasets.
@@ -184,17 +271,16 @@ class BaseICDataset(
184
271
  score = self._one_hot_encode(label)
185
272
  # Get the image
186
273
  img = self._read_file(self._filepaths[index])
187
- img = self._transform(img)
188
274
 
189
275
  img_metadata = {key: val[index] for key, val in self._datum_metadata.items()}
190
276
 
191
- return img, score, _to_datum_metadata(index, img_metadata)
277
+ return self._transform((img, score, self._to_datum_metadata(index, img_metadata)))
192
278
 
193
279
 
194
280
  class BaseODDataset(
195
- BaseDataset[_TArray, ObjectDetectionTarget[_TArray], _TRawTarget, _TAnnotation],
281
+ BaseDownloadedDataset[_TArray, ObjectDetectionTarget, _TRawTarget, _TAnnotation],
196
282
  BaseDatasetMixin[_TArray],
197
- ObjectDetectionDataset[_TArray],
283
+ BaseDataset[_TArray, ObjectDetectionTarget],
198
284
  ):
199
285
  """
200
286
  Base class for object detection datasets.
@@ -202,7 +288,7 @@ class BaseODDataset(
202
288
 
203
289
  _bboxes_per_size: bool = False
204
290
 
205
- def __getitem__(self, index: int) -> tuple[_TArray, ObjectDetectionTarget[_TArray], DatumMetadata]:
291
+ def __getitem__(self, index: int) -> tuple[_TArray, ObjectDetectionTarget, DatumMetadata]:
206
292
  """
207
293
  Args
208
294
  ----
@@ -211,7 +297,7 @@ class BaseODDataset(
211
297
 
212
298
  Returns
213
299
  -------
214
- tuple[TArray, ObjectDetectionTarget[TArray], DatumMetadata]
300
+ tuple[TArray, ObjectDetectionTarget, DatumMetadata]
215
301
  Image, target, datum_metadata - target.boxes returns boxes in x0, y0, x1, y1 format
216
302
  """
217
303
  # Grab the bounding boxes and labels from the annotations
@@ -220,17 +306,49 @@ class BaseODDataset(
220
306
  # Get the image
221
307
  img = self._read_file(self._filepaths[index])
222
308
  img_size = img.shape
223
- img = self._transform(img)
224
309
  # Adjust labels if necessary
225
310
  if self._bboxes_per_size and boxes:
226
- boxes = boxes * np.array([[img_size[1], img_size[2], img_size[1], img_size[2]]])
311
+ boxes = boxes * np.asarray([[img_size[1], img_size[2], img_size[1], img_size[2]]])
227
312
  # Create the Object Detection Target
228
313
  target = ObjectDetectionTarget(self._as_array(boxes), self._as_array(labels), self._one_hot_encode(labels))
229
314
 
230
315
  img_metadata = {key: val[index] for key, val in self._datum_metadata.items()}
231
316
  img_metadata = img_metadata | additional_metadata
232
317
 
233
- return img, target, _to_datum_metadata(index, img_metadata)
318
+ return self._transform((img, target, self._to_datum_metadata(index, img_metadata)))
234
319
 
235
320
  @abstractmethod
236
321
  def _read_annotations(self, annotation: _TAnnotation) -> tuple[list[list[float]], list[int], dict[str, Any]]: ...
322
+
323
+
324
+ NumpyArray = NDArray[np.floating[Any]] | NDArray[np.integer[Any]]
325
+
326
+
327
+ class BaseDatasetNumpyMixin(BaseDatasetMixin[NumpyArray]):
328
+ def _as_array(self, raw: list[Any]) -> NumpyArray:
329
+ return np.asarray(raw)
330
+
331
+ def _one_hot_encode(self, value: int | list[int]) -> NumpyArray:
332
+ if isinstance(value, int):
333
+ encoded = np.zeros(len(self.index2label))
334
+ encoded[value] = 1
335
+ else:
336
+ encoded = np.zeros((len(value), len(self.index2label)))
337
+ encoded[np.arange(len(value)), value] = 1
338
+ return encoded
339
+
340
+ def _read_file(self, path: str) -> NumpyArray:
341
+ return np.array(Image.open(path)).transpose(2, 0, 1)
342
+
343
+
344
+ NumpyImageTransform = Callable[[NumpyArray], NumpyArray]
345
+ NumpyImageClassificationDatumTransform = Callable[
346
+ [tuple[NumpyArray, NumpyArray, DatumMetadata]],
347
+ tuple[NumpyArray, NumpyArray, DatumMetadata],
348
+ ]
349
+ NumpyObjectDetectionDatumTransform = Callable[
350
+ [tuple[NumpyArray, ObjectDetectionTarget, DatumMetadata]],
351
+ tuple[NumpyArray, ObjectDetectionTarget, DatumMetadata],
352
+ ]
353
+ NumpyImageClassificationTransform = NumpyImageTransform | NumpyImageClassificationDatumTransform
354
+ NumpyObjectDetectionTransform = NumpyImageTransform | NumpyObjectDetectionDatumTransform
@@ -1,29 +1,24 @@
1
1
  from __future__ import annotations
2
2
 
3
- import numpy as np
4
-
5
3
  __all__ = []
6
4
 
5
+ from collections.abc import Iterable, Sequence
7
6
  from typing import (
8
7
  Any,
9
8
  Generic,
10
- Iterable,
11
9
  Literal,
12
- Sequence,
13
10
  SupportsFloat,
14
11
  SupportsInt,
15
12
  TypeVar,
16
13
  cast,
17
14
  )
18
15
 
19
- from maite_datasets._protocols import (
20
- Array,
21
- ArrayLike,
22
- DatasetMetadata,
23
- ImageClassificationDataset,
24
- ObjectDetectionDataset,
25
- DatumMetadata,
26
- )
16
+ import maite.protocols.image_classification as ic
17
+ import maite.protocols.object_detection as od
18
+ import numpy as np
19
+ from maite.protocols import ArrayLike, DatasetMetadata, DatumMetadata
20
+
21
+ from maite_datasets.protocols import Array
27
22
 
28
23
 
29
24
  def _ensure_id(index: int, metadata: dict[str, Any]) -> DatumMetadata:
@@ -97,6 +92,8 @@ _TLabels = TypeVar("_TLabels", Sequence[int], Sequence[Sequence[int]])
97
92
 
98
93
 
99
94
  class BaseAnnotatedDataset(Generic[_TLabels]):
95
+ metadata: DatasetMetadata
96
+
100
97
  def __init__(
101
98
  self,
102
99
  datum_type: Literal["ic", "od"],
@@ -112,16 +109,13 @@ class BaseAnnotatedDataset(Generic[_TLabels]):
112
109
  self._labels = labels
113
110
  self._metadata = metadata
114
111
  self._id = name or f"{len(self._images)}_image_{len(self._index2label)}_class_{datum_type}_dataset"
115
-
116
- @property
117
- def metadata(self) -> DatasetMetadata:
118
- return DatasetMetadata(id=self._id, index2label=self._index2label)
112
+ self.metadata = DatasetMetadata(id=self._id, index2label=self._index2label)
119
113
 
120
114
  def __len__(self) -> int:
121
115
  return len(self._images)
122
116
 
123
117
 
124
- class CustomImageClassificationDataset(BaseAnnotatedDataset[Sequence[int]], ImageClassificationDataset):
118
+ class CustomImageClassificationDataset(BaseAnnotatedDataset[Sequence[int]], ic.Dataset):
125
119
  def __init__(
126
120
  self,
127
121
  images: Array | Sequence[Array],
@@ -152,33 +146,34 @@ class CustomImageClassificationDataset(BaseAnnotatedDataset[Sequence[int]], Imag
152
146
  )
153
147
 
154
148
 
155
- class CustomObjectDetectionDataset(BaseAnnotatedDataset[Sequence[Sequence[int]]], ObjectDetectionDataset):
156
- class ObjectDetectionTarget:
157
- def __init__(
158
- self,
159
- labels: Sequence[int],
160
- bboxes: Sequence[Sequence[float]],
161
- class_count: int,
162
- ) -> None:
163
- self._labels = labels
164
- self._bboxes = bboxes
165
- one_hot = [[0.0] * class_count] * len(labels)
166
- for i, label in enumerate(labels):
167
- one_hot[i][label] = 1.0
168
- self._scores = one_hot
169
-
170
- @property
171
- def labels(self) -> Sequence[int]:
172
- return self._labels
173
-
174
- @property
175
- def boxes(self) -> Sequence[Sequence[float]]:
176
- return self._bboxes
177
-
178
- @property
179
- def scores(self) -> Sequence[Sequence[float]]:
180
- return self._scores
149
+ class CustomObjectDetectionTarget:
150
+ def __init__(
151
+ self,
152
+ labels: Sequence[int],
153
+ bboxes: Sequence[Sequence[float]],
154
+ class_count: int,
155
+ ) -> None:
156
+ self._labels = labels
157
+ self._bboxes = bboxes
158
+ one_hot = [[0.0] * class_count] * len(labels)
159
+ for i, label in enumerate(labels):
160
+ one_hot[i][label] = 1.0
161
+ self._scores = one_hot
162
+
163
+ @property
164
+ def labels(self) -> Sequence[int]:
165
+ return self._labels
166
+
167
+ @property
168
+ def boxes(self) -> Sequence[Sequence[float]]:
169
+ return self._bboxes
181
170
 
171
+ @property
172
+ def scores(self) -> Sequence[Sequence[float]]:
173
+ return self._scores
174
+
175
+
176
+ class CustomObjectDetectionDataset(BaseAnnotatedDataset[Sequence[Sequence[int]]], od.Dataset):
182
177
  def __init__(
183
178
  self,
184
179
  images: Array | Sequence[Array],
@@ -203,14 +198,10 @@ class CustomObjectDetectionDataset(BaseAnnotatedDataset[Sequence[Sequence[int]]]
203
198
  [np.asarray(box).tolist() if isinstance(box, Array) else box for box in bbox] for bbox in bboxes
204
199
  ]
205
200
 
206
- @property
207
- def metadata(self) -> DatasetMetadata:
208
- return DatasetMetadata(id=self._id, index2label=self._index2label)
209
-
210
- def __getitem__(self, idx: int, /) -> tuple[Array, ObjectDetectionTarget, DatumMetadata]:
201
+ def __getitem__(self, idx: int, /) -> tuple[Array, CustomObjectDetectionTarget, DatumMetadata]:
211
202
  return (
212
203
  self._images[idx],
213
- self.ObjectDetectionTarget(self._labels[idx], self._bboxes[idx], len(self._classes)),
204
+ CustomObjectDetectionTarget(self._labels[idx], self._bboxes[idx], len(self._classes)),
214
205
  _ensure_id(idx, self._metadata[idx] if self._metadata is not None else {}),
215
206
  )
216
207
 
@@ -221,9 +212,9 @@ def to_image_classification_dataset(
221
212
  metadata: Sequence[dict[str, Any]] | dict[str, Sequence[Any]] | None,
222
213
  classes: Sequence[str] | None,
223
214
  name: str | None = None,
224
- ) -> ImageClassificationDataset:
215
+ ) -> ic.Dataset:
225
216
  """
226
- Helper function to create custom ImageClassificationDataset classes.
217
+ Helper function to create custom image classification Dataset classes.
227
218
 
228
219
  Parameters
229
220
  ----------
@@ -238,7 +229,7 @@ def to_image_classification_dataset(
238
229
 
239
230
  Returns
240
231
  -------
241
- ImageClassificationDataset
232
+ Dataset
242
233
  """
243
234
  _validate_data("ic", images, labels, None, metadata)
244
235
  return CustomImageClassificationDataset(images, labels, _listify_metadata(metadata), classes, name)
@@ -251,9 +242,9 @@ def to_object_detection_dataset(
251
242
  metadata: Sequence[dict[str, Any]] | dict[str, Sequence[Any]] | None,
252
243
  classes: Sequence[str] | None,
253
244
  name: str | None = None,
254
- ) -> ObjectDetectionDataset:
245
+ ) -> od.Dataset:
255
246
  """
256
- Helper function to create custom ObjectDetectionDataset classes.
247
+ Helper function to create custom object detection Dataset classes.
257
248
 
258
249
  Parameters
259
250
  ----------
@@ -270,7 +261,7 @@ def to_object_detection_dataset(
270
261
 
271
262
  Returns
272
263
  -------
273
- ObjectDetectionDataset
264
+ Dataset
274
265
  """
275
266
  _validate_data("od", images, labels, bboxes, metadata)
276
267
  return CustomObjectDetectionDataset(images, labels, bboxes, _listify_metadata(metadata), classes, name)
@@ -7,16 +7,15 @@ from __future__ import annotations
7
7
  __all__ = []
8
8
 
9
9
  from collections.abc import Iterable, Sequence
10
- from typing import Any, TypeVar, TYPE_CHECKING
10
+ from typing import TYPE_CHECKING, Any, TypeVar
11
11
 
12
12
  import numpy as np
13
+ from maite.protocols import ArrayLike
13
14
  from numpy.typing import NDArray
14
15
 
15
16
  if TYPE_CHECKING:
16
17
  import torch
17
18
 
18
- from maite_datasets._protocols import ArrayLike
19
-
20
19
  T_in = TypeVar("T_in")
21
20
  T_tgt = TypeVar("T_tgt")
22
21
  T_md = TypeVar("T_md")
@@ -1,39 +1,19 @@
1
1
  from __future__ import annotations
2
2
 
3
- from abc import ABC, abstractmethod
4
3
  import logging
4
+ from abc import ABC, abstractmethod
5
5
  from pathlib import Path
6
- from typing import Any
7
-
8
- import numpy as np
6
+ from typing import Any, Generic, TypeVar
9
7
 
10
- from maite_datasets._protocols import ArrayLike, ObjectDetectionDataset
8
+ import maite.protocols.image_classification as ic
9
+ import maite.protocols.object_detection as od
11
10
 
12
11
  _logger = logging.getLogger(__name__)
13
12
 
14
-
15
- class _ObjectDetectionTarget:
16
- """Internal implementation of ObjectDetectionTarget protocol."""
17
-
18
- def __init__(self, boxes: ArrayLike, labels: ArrayLike, scores: ArrayLike) -> None:
19
- self._boxes = np.asarray(boxes)
20
- self._labels = np.asarray(labels)
21
- self._scores = np.asarray(scores)
22
-
23
- @property
24
- def boxes(self) -> ArrayLike:
25
- return self._boxes
26
-
27
- @property
28
- def labels(self) -> ArrayLike:
29
- return self._labels
30
-
31
- @property
32
- def scores(self) -> ArrayLike:
33
- return self._scores
13
+ _TDataset = TypeVar("_TDataset", ic.Dataset, od.Dataset)
34
14
 
35
15
 
36
- class BaseDatasetReader(ABC):
16
+ class BaseDatasetReader(Generic[_TDataset], ABC):
37
17
  """
38
18
  Abstract base class for object detection dataset readers.
39
19
 
@@ -65,7 +45,7 @@ class BaseDatasetReader(ABC):
65
45
  pass
66
46
 
67
47
  @abstractmethod
68
- def _create_dataset_implementation(self) -> ObjectDetectionDataset:
48
+ def create_dataset(self) -> _TDataset:
69
49
  """Create the format-specific dataset implementation."""
70
50
  pass
71
51
 
@@ -123,13 +103,59 @@ class BaseDatasetReader(ABC):
123
103
 
124
104
  return {"is_valid": len(issues) == 0, "issues": issues, "stats": stats}
125
105
 
126
- def get_dataset(self) -> ObjectDetectionDataset:
127
- """
128
- Get dataset conforming to MAITE ObjectDetectionDataset protocol.
129
106
 
130
- Returns
131
- -------
132
- ObjectDetectionDataset
133
- Dataset instance with MAITE-compatible interface
134
- """
135
- return self._create_dataset_implementation()
107
+ def create_dataset_reader(
108
+ dataset_path: str | Path, format_hint: str | None = None
109
+ ) -> BaseDatasetReader[ic.Dataset] | BaseDatasetReader[od.Dataset]:
110
+ """
111
+ Factory function to create appropriate dataset reader based on directory structure.
112
+
113
+ Parameters
114
+ ----------
115
+ dataset_path : str or Path
116
+ Root directory containing dataset files
117
+ format_hint : str or None, default None
118
+ Format hint ("coco" or "yolo"). If None, auto-detects based on file structure
119
+
120
+ Returns
121
+ -------
122
+ BaseDatasetReader
123
+ Appropriate reader instance for the detected format
124
+
125
+ Raises
126
+ ------
127
+ ValueError
128
+ If format cannot be determined or is unsupported
129
+ """
130
+ from maite_datasets.object_detection._coco import COCODatasetReader
131
+ from maite_datasets.object_detection._yolo import YOLODatasetReader
132
+
133
+ dataset_path = Path(dataset_path)
134
+
135
+ if format_hint:
136
+ format_hint = format_hint.lower()
137
+ if format_hint == "coco":
138
+ return COCODatasetReader(dataset_path)
139
+ if format_hint == "yolo":
140
+ return YOLODatasetReader(dataset_path)
141
+ raise ValueError(f"Unsupported format hint: {format_hint}")
142
+
143
+ # Auto-detect format
144
+ has_annotations_json = (dataset_path / "annotations.json").exists()
145
+ has_labels_dir = (dataset_path / "labels").exists()
146
+
147
+ if has_annotations_json and not has_labels_dir:
148
+ _logger.info(f"Detected COCO format for {dataset_path}")
149
+ return COCODatasetReader(dataset_path)
150
+ if has_labels_dir and not has_annotations_json:
151
+ _logger.info(f"Detected YOLO format for {dataset_path}")
152
+ return YOLODatasetReader(dataset_path)
153
+ if has_annotations_json and has_labels_dir:
154
+ raise ValueError(
155
+ f"Ambiguous format in {dataset_path}: both annotations.json and labels/ exist. "
156
+ "Use format_hint parameter to specify format."
157
+ )
158
+ raise ValueError(
159
+ f"Cannot detect dataset format in {dataset_path}. "
160
+ "Expected either annotations.json (COCO) or labels/ directory (YOLO)."
161
+ )
@@ -2,11 +2,13 @@ from __future__ import annotations
2
2
 
3
3
  __all__ = []
4
4
 
5
- import numpy as np
6
5
  from collections.abc import Sequence, Sized
7
6
  from typing import Any, Literal
8
7
 
9
- from maite_datasets._protocols import Array, ObjectDetectionTarget
8
+ import numpy as np
9
+ from maite.protocols.object_detection import ObjectDetectionTarget
10
+
11
+ from maite_datasets.protocols import Array
10
12
 
11
13
 
12
14
  class ValidationMessages: