maite-datasets 0.0.5__py3-none-any.whl → 0.0.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (32) hide show
  1. maite_datasets/__init__.py +2 -6
  2. maite_datasets/_base.py +169 -51
  3. maite_datasets/_builder.py +46 -55
  4. maite_datasets/_collate.py +2 -3
  5. maite_datasets/{_reader/_base.py → _reader.py} +62 -36
  6. maite_datasets/_validate.py +4 -2
  7. maite_datasets/image_classification/_cifar10.py +12 -7
  8. maite_datasets/image_classification/_mnist.py +15 -10
  9. maite_datasets/image_classification/_ships.py +12 -8
  10. maite_datasets/object_detection/__init__.py +4 -7
  11. maite_datasets/object_detection/_antiuav.py +11 -8
  12. maite_datasets/{_reader → object_detection}/_coco.py +29 -27
  13. maite_datasets/object_detection/_milco.py +11 -9
  14. maite_datasets/object_detection/_seadrone.py +11 -9
  15. maite_datasets/object_detection/_voc.py +11 -13
  16. maite_datasets/{_reader → object_detection}/_yolo.py +26 -21
  17. maite_datasets/protocols.py +23 -0
  18. maite_datasets/wrappers/__init__.py +8 -0
  19. maite_datasets/wrappers/_torch.py +111 -0
  20. {maite_datasets-0.0.5.dist-info → maite_datasets-0.0.6.dist-info}/METADATA +56 -3
  21. maite_datasets-0.0.6.dist-info/RECORD +26 -0
  22. maite_datasets/_mixin/__init__.py +0 -0
  23. maite_datasets/_mixin/_numpy.py +0 -28
  24. maite_datasets/_mixin/_torch.py +0 -28
  25. maite_datasets/_protocols.py +0 -217
  26. maite_datasets/_reader/__init__.py +0 -6
  27. maite_datasets/_reader/_factory.py +0 -64
  28. maite_datasets/_types.py +0 -50
  29. maite_datasets/object_detection/_voc_torch.py +0 -65
  30. maite_datasets-0.0.5.dist-info/RECORD +0 -31
  31. {maite_datasets-0.0.5.dist-info → maite_datasets-0.0.6.dist-info}/WHEEL +0 -0
  32. {maite_datasets-0.0.5.dist-info → maite_datasets-0.0.6.dist-info}/licenses/LICENSE +0 -0
@@ -2,15 +2,20 @@ from __future__ import annotations
2
2
 
3
3
  __all__ = []
4
4
 
5
+ from collections.abc import Sequence
5
6
  from pathlib import Path
6
- from typing import Any, Literal, Sequence, TypeVar
7
+ from typing import Any, Literal, TypeVar
7
8
 
8
9
  import numpy as np
9
10
  from numpy.typing import NDArray
10
11
 
11
- from maite_datasets._base import BaseICDataset, DataLocation
12
- from maite_datasets._mixin._numpy import BaseDatasetNumpyMixin
13
- from maite_datasets._protocols import Transform
12
+ from maite_datasets._base import (
13
+ BaseDatasetNumpyMixin,
14
+ BaseICDataset,
15
+ DataLocation,
16
+ NumpyArray,
17
+ NumpyImageClassificationTransform,
18
+ )
14
19
 
15
20
  CIFARClassStringMap = Literal[
16
21
  "airplane",
@@ -27,7 +32,7 @@ CIFARClassStringMap = Literal[
27
32
  TCIFARClassMap = TypeVar("TCIFARClassMap", CIFARClassStringMap, int, list[CIFARClassStringMap], list[int])
28
33
 
29
34
 
30
- class CIFAR10(BaseICDataset[NDArray[np.number[Any]]], BaseDatasetNumpyMixin):
35
+ class CIFAR10(BaseICDataset[NumpyArray], BaseDatasetNumpyMixin):
31
36
  """
32
37
  `CIFAR10 <https://www.cs.toronto.edu/~kriz/cifar.html>`_ Dataset as NumPy arrays.
33
38
 
@@ -89,7 +94,7 @@ class CIFAR10(BaseICDataset[NDArray[np.number[Any]]], BaseDatasetNumpyMixin):
89
94
  self,
90
95
  root: str | Path,
91
96
  image_set: Literal["train", "test", "base"] = "train",
92
- transforms: Transform[NDArray[np.number[Any]]] | Sequence[Transform[NDArray[np.number[Any]]]] | None = None,
97
+ transforms: NumpyImageClassificationTransform | Sequence[NumpyImageClassificationTransform] | None = None,
93
98
  download: bool = False,
94
99
  verbose: bool = False,
95
100
  ) -> None:
@@ -214,7 +219,7 @@ class CIFAR10(BaseICDataset[NDArray[np.number[Any]]], BaseDatasetNumpyMixin):
214
219
  images[i, 2] = blue_channel # Blue channel
215
220
  return images, labels
216
221
 
217
- def _read_file(self, path: str) -> NDArray[np.number[Any]]:
222
+ def _read_file(self, path: str) -> NumpyArray:
218
223
  """
219
224
  Function to grab the correct image from the loaded data.
220
225
  Overwrite of the base `_read_file` because data is an all or nothing load.
@@ -2,15 +2,20 @@ from __future__ import annotations
2
2
 
3
3
  __all__ = []
4
4
 
5
+ from collections.abc import Sequence
5
6
  from pathlib import Path
6
- from typing import Any, Literal, Sequence, TypeVar
7
+ from typing import Any, Literal, TypeVar
7
8
 
8
9
  import numpy as np
9
10
  from numpy.typing import NDArray
10
11
 
11
- from maite_datasets._base import BaseICDataset, DataLocation
12
- from maite_datasets._mixin._numpy import BaseDatasetNumpyMixin
13
- from maite_datasets._protocols import Transform
12
+ from maite_datasets._base import (
13
+ BaseDatasetNumpyMixin,
14
+ BaseICDataset,
15
+ DataLocation,
16
+ NumpyArray,
17
+ NumpyImageClassificationTransform,
18
+ )
14
19
 
15
20
  MNISTClassStringMap = Literal["zero", "one", "two", "three", "four", "five", "six", "seven", "eight", "nine"]
16
21
  TMNISTClassMap = TypeVar("TMNISTClassMap", MNISTClassStringMap, int, list[MNISTClassStringMap], list[int])
@@ -34,7 +39,7 @@ CorruptionStringMap = Literal[
34
39
  ]
35
40
 
36
41
 
37
- class MNIST(BaseICDataset[NDArray[np.number[Any]]], BaseDatasetNumpyMixin):
42
+ class MNIST(BaseICDataset[NumpyArray], BaseDatasetNumpyMixin):
38
43
  """`MNIST <https://en.wikipedia.org/wiki/MNIST_database>`_ Dataset and `Corruptions <https://arxiv.org/abs/1906.02337>`_.
39
44
 
40
45
  There are 15 different styles of corruptions. This class downloads differently depending on if you
@@ -118,7 +123,7 @@ class MNIST(BaseICDataset[NDArray[np.number[Any]]], BaseDatasetNumpyMixin):
118
123
  root: str | Path,
119
124
  image_set: Literal["train", "test", "base"] = "train",
120
125
  corruption: CorruptionStringMap | None = None,
121
- transforms: Transform[NDArray[np.number[Any]]] | Sequence[Transform[NDArray[np.number[Any]]]] | None = None,
126
+ transforms: NumpyImageClassificationTransform | Sequence[NumpyImageClassificationTransform] | None = None,
122
127
  download: bool = False,
123
128
  verbose: bool = False,
124
129
  ) -> None:
@@ -149,7 +154,7 @@ class MNIST(BaseICDataset[NDArray[np.number[Any]]], BaseDatasetNumpyMixin):
149
154
  index_strings = np.arange(self._loaded_data.shape[0]).astype(str).tolist()
150
155
  return index_strings, labels.tolist(), {}
151
156
 
152
- def _load_corruption(self) -> tuple[NDArray[np.number[Any]], NDArray[np.uintp]]:
157
+ def _load_corruption(self) -> tuple[NumpyArray, NDArray[np.uintp]]:
153
158
  """Function to load in the file paths for the data and labels for the different corrupt data formats"""
154
159
  corruption = self.corruption if self.corruption is not None else "identity"
155
160
  base_path = self.path / "mnist_c" / corruption
@@ -176,7 +181,7 @@ class MNIST(BaseICDataset[NDArray[np.number[Any]]], BaseDatasetNumpyMixin):
176
181
 
177
182
  return data, labels
178
183
 
179
- def _grab_data(self, path: Path) -> tuple[NDArray[np.number[Any]], NDArray[np.uintp]]:
184
+ def _grab_data(self, path: Path) -> tuple[NumpyArray, NDArray[np.uintp]]:
180
185
  """Function to load in the data numpy array"""
181
186
  with np.load(path, allow_pickle=True) as data_array:
182
187
  if self.image_set == "base":
@@ -190,11 +195,11 @@ class MNIST(BaseICDataset[NDArray[np.number[Any]]], BaseDatasetNumpyMixin):
190
195
  data = np.expand_dims(data, axis=1)
191
196
  return data, labels
192
197
 
193
- def _grab_corruption_data(self, path: Path) -> NDArray[np.number[Any]]:
198
+ def _grab_corruption_data(self, path: Path) -> NumpyArray:
194
199
  """Function to load in the data numpy array for the previously chosen corrupt format"""
195
200
  return np.load(path, allow_pickle=False)
196
201
 
197
- def _read_file(self, path: str) -> NDArray[np.number[Any]]:
202
+ def _read_file(self, path: str) -> NumpyArray:
198
203
  """
199
204
  Function to grab the correct image from the loaded data.
200
205
  Overwrite of the base `_read_file` because data is an all or nothing load.
@@ -2,18 +2,22 @@ from __future__ import annotations
2
2
 
3
3
  __all__ = []
4
4
 
5
+ from collections.abc import Sequence
5
6
  from pathlib import Path
6
- from typing import Any, Sequence
7
+ from typing import Any
7
8
 
8
9
  import numpy as np
9
- from numpy.typing import NDArray
10
10
 
11
- from maite_datasets._base import BaseICDataset, DataLocation
12
- from maite_datasets._mixin._numpy import BaseDatasetNumpyMixin
13
- from maite_datasets._protocols import Transform
11
+ from maite_datasets._base import (
12
+ BaseDatasetNumpyMixin,
13
+ BaseICDataset,
14
+ DataLocation,
15
+ NumpyArray,
16
+ NumpyImageClassificationTransform,
17
+ )
14
18
 
15
19
 
16
- class Ships(BaseICDataset[NDArray[np.number[Any]]], BaseDatasetNumpyMixin):
20
+ class Ships(BaseICDataset[NumpyArray], BaseDatasetNumpyMixin):
17
21
  """
18
22
  A dataset that focuses on identifying ships from satellite images.
19
23
 
@@ -76,7 +80,7 @@ class Ships(BaseICDataset[NDArray[np.number[Any]]], BaseDatasetNumpyMixin):
76
80
  def __init__(
77
81
  self,
78
82
  root: str | Path,
79
- transforms: Transform[NDArray[np.number[Any]]] | Sequence[Transform[NDArray[np.number[Any]]]] | None = None,
83
+ transforms: NumpyImageClassificationTransform | Sequence[NumpyImageClassificationTransform] | None = None,
80
84
  download: bool = False,
81
85
  verbose: bool = False,
82
86
  ) -> None:
@@ -125,7 +129,7 @@ class Ships(BaseICDataset[NDArray[np.number[Any]]], BaseDatasetNumpyMixin):
125
129
  """Function to load in the file paths for the scene images"""
126
130
  return sorted(str(entry) for entry in (self.path / "scenes").glob("*.png"))
127
131
 
128
- def get_scene(self, index: int) -> NDArray[np.number[Any]]:
132
+ def get_scene(self, index: int) -> NumpyArray:
129
133
  """
130
134
  Get the desired satellite image (scene) by passing in the index of the desired file.
131
135
 
@@ -1,20 +1,17 @@
1
1
  """Module for MAITE compliant Object Detection datasets."""
2
2
 
3
3
  from maite_datasets.object_detection._antiuav import AntiUAVDetection
4
+ from maite_datasets.object_detection._coco import COCODatasetReader
4
5
  from maite_datasets.object_detection._milco import MILCO
5
6
  from maite_datasets.object_detection._seadrone import SeaDrone
6
7
  from maite_datasets.object_detection._voc import VOCDetection
8
+ from maite_datasets.object_detection._yolo import YOLODatasetReader
7
9
 
8
10
  __all__ = [
9
11
  "AntiUAVDetection",
10
12
  "MILCO",
11
13
  "SeaDrone",
12
14
  "VOCDetection",
15
+ "COCODatasetReader",
16
+ "YOLODatasetReader",
13
17
  ]
14
-
15
- import importlib.util
16
-
17
- if importlib.util.find_spec("torch") is not None:
18
- from maite_datasets.object_detection._voc_torch import VOCDetectionTorch
19
-
20
- __all__ += ["VOCDetectionTorch"]
@@ -2,19 +2,22 @@ from __future__ import annotations
2
2
 
3
3
  __all__ = []
4
4
 
5
+ from collections.abc import Sequence
5
6
  from pathlib import Path
6
- from typing import Any, Literal, Sequence
7
+ from typing import Any, Literal
7
8
 
8
- import numpy as np
9
9
  from defusedxml.ElementTree import parse
10
- from numpy.typing import NDArray
11
10
 
12
- from maite_datasets._base import BaseODDataset, DataLocation
13
- from maite_datasets._mixin._numpy import BaseDatasetNumpyMixin
14
- from maite_datasets._protocols import Transform
11
+ from maite_datasets._base import (
12
+ BaseDatasetNumpyMixin,
13
+ BaseODDataset,
14
+ DataLocation,
15
+ NumpyArray,
16
+ NumpyObjectDetectionTransform,
17
+ )
15
18
 
16
19
 
17
- class AntiUAVDetection(BaseODDataset[NDArray[np.number[Any]], list[str], str], BaseDatasetNumpyMixin):
20
+ class AntiUAVDetection(BaseODDataset[NumpyArray, list[str], str], BaseDatasetNumpyMixin):
18
21
  """
19
22
  A UAV detection dataset focused on detecting UAVs in natural images against large variation in backgrounds.
20
23
 
@@ -101,7 +104,7 @@ class AntiUAVDetection(BaseODDataset[NDArray[np.number[Any]], list[str], str], B
101
104
  self,
102
105
  root: str | Path,
103
106
  image_set: Literal["train", "val", "test", "base"] = "train",
104
- transforms: Transform[NDArray[np.number[Any]]] | Sequence[Transform[NDArray[np.number[Any]]]] | None = None,
107
+ transforms: NumpyObjectDetectionTransform | Sequence[NumpyObjectDetectionTransform] | None = None,
105
108
  download: bool = False,
106
109
  verbose: bool = False,
107
110
  ) -> None:
@@ -3,20 +3,19 @@
3
3
  from __future__ import annotations
4
4
 
5
5
  import json
6
- import logging
7
6
  from pathlib import Path
8
7
  from typing import Any
9
8
 
9
+ import maite.protocols.object_detection as od
10
10
  import numpy as np
11
+ from maite.protocols import DatasetMetadata, DatumMetadata
11
12
  from PIL import Image
12
13
 
13
- from maite_datasets._protocols import DatasetMetadata, DatumMetadata, ObjectDetectionDataset, ObjectDetectionDatum
14
- from maite_datasets._reader._base import _ObjectDetectionTarget, BaseDatasetReader
14
+ from maite_datasets._base import BaseDataset, ObjectDetectionTarget
15
+ from maite_datasets._reader import BaseDatasetReader
15
16
 
16
- _logger = logging.getLogger(__name__)
17
17
 
18
-
19
- class COCODatasetReader(BaseDatasetReader):
18
+ class COCODatasetReader(BaseDatasetReader[od.Dataset]):
20
19
  """
21
20
  COCO format dataset reader conforming to MAITE protocols.
22
21
 
@@ -132,9 +131,9 @@ class COCODatasetReader(BaseDatasetReader):
132
131
  """Mapping from class index to class name."""
133
132
  return self._index2label
134
133
 
135
- def _create_dataset_implementation(self) -> ObjectDetectionDataset:
134
+ def create_dataset(self) -> od.Dataset:
136
135
  """Create COCO dataset implementation."""
137
- return _COCODataset(self)
136
+ return COCODataset(self)
138
137
 
139
138
  def _validate_format_specific(self) -> tuple[list[str], dict[str, Any]]:
140
139
  """Validate COCO format specific files and structure."""
@@ -198,37 +197,40 @@ class COCODatasetReader(BaseDatasetReader):
198
197
  else:
199
198
  class_names = [cat["name"] for cat in self._coco_data["categories"]]
200
199
 
201
- self._index2label = {idx: name for idx, name in enumerate(class_names)}
200
+ self._index2label = dict(enumerate(class_names))
202
201
 
203
202
 
204
- class _COCODataset:
203
+ class COCODataset(BaseDataset):
205
204
  """Internal COCO dataset implementation."""
206
205
 
207
206
  def __init__(self, reader: COCODatasetReader) -> None:
208
- self.reader = reader
209
- self.image_ids = list(reader._image_id_to_info.keys())
210
-
211
- @property
212
- def metadata(self) -> DatasetMetadata:
213
- return DatasetMetadata(
214
- id=self.reader.dataset_id,
215
- index2label=self.reader.index2label,
207
+ self._reader = reader
208
+ self._image_ids = list(reader._image_id_to_info.keys())
209
+
210
+ self.root = reader.dataset_path
211
+ self.images_path = reader._images_path
212
+ self.annotation_path = reader._annotation_path
213
+ self.size = len(reader._image_id_to_info)
214
+ self.classes = reader.index2label
215
+ self.metadata = DatasetMetadata(
216
+ id=self._reader.dataset_id,
217
+ index2label=self._reader.index2label,
216
218
  )
217
219
 
218
220
  def __len__(self) -> int:
219
- return len(self.image_ids)
221
+ return len(self._image_ids)
220
222
 
221
- def __getitem__(self, index: int) -> ObjectDetectionDatum:
222
- image_id = self.image_ids[index]
223
- image_info = self.reader._image_id_to_info[image_id]
223
+ def __getitem__(self, index: int) -> tuple[od.InputType, od.ObjectDetectionTarget, DatumMetadata]:
224
+ image_id = self._image_ids[index]
225
+ image_info = self._reader._image_id_to_info[image_id]
224
226
 
225
227
  # Load image
226
- image_path = self.reader._images_path / image_info["file_name"]
228
+ image_path = self._reader._images_path / image_info["file_name"]
227
229
  image = np.array(Image.open(image_path).convert("RGB"))
228
230
  image = np.transpose(image, (2, 0, 1)) # Convert to CHW format
229
231
 
230
232
  # Get annotations for this image
231
- annotations = self.reader.image_id_to_annotations.get(image_id, [])
233
+ annotations = self._reader.image_id_to_annotations.get(image_id, [])
232
234
 
233
235
  if annotations:
234
236
  boxes = []
@@ -241,7 +243,7 @@ class _COCODataset:
241
243
  boxes.append([x, y, x + w, y + h])
242
244
 
243
245
  # Map category_id to class index
244
- cat_idx = self.reader._category_id_to_idx[ann["category_id"]]
246
+ cat_idx = self._reader._category_id_to_idx[ann["category_id"]]
245
247
  labels.append(cat_idx)
246
248
 
247
249
  # Collect annotation metadata
@@ -267,12 +269,12 @@ class _COCODataset:
267
269
  scores = np.empty(0, dtype=np.float32)
268
270
  annotation_metadata = []
269
271
 
270
- target = _ObjectDetectionTarget(boxes, labels, scores)
272
+ target = ObjectDetectionTarget(boxes, labels, scores)
271
273
 
272
274
  # Create comprehensive datum metadata
273
275
  datum_metadata = DatumMetadata(
274
276
  **{
275
- "id": f"{self.reader.dataset_id}_{image_id}",
277
+ "id": f"{self._reader.dataset_id}_{image_id}",
276
278
  # Image-level metadata
277
279
  "coco_image_id": image_id,
278
280
  "file_name": image_info["file_name"],
@@ -2,18 +2,20 @@ from __future__ import annotations
2
2
 
3
3
  __all__ = []
4
4
 
5
+ from collections.abc import Sequence
5
6
  from pathlib import Path
6
- from typing import Any, Literal, Sequence
7
+ from typing import Any, Literal
7
8
 
8
- import numpy as np
9
- from numpy.typing import NDArray
9
+ from maite_datasets._base import (
10
+ BaseDatasetNumpyMixin,
11
+ BaseODDataset,
12
+ DataLocation,
13
+ NumpyArray,
14
+ NumpyObjectDetectionTransform,
15
+ )
10
16
 
11
- from maite_datasets._base import BaseODDataset, DataLocation
12
- from maite_datasets._mixin._numpy import BaseDatasetNumpyMixin
13
- from maite_datasets._protocols import Transform
14
17
 
15
-
16
- class MILCO(BaseODDataset[NDArray[np.number[Any]], list[str], str], BaseDatasetNumpyMixin):
18
+ class MILCO(BaseODDataset[NumpyArray, list[str], str], BaseDatasetNumpyMixin):
17
19
  """
18
20
  A side-scan sonar dataset focused on mine-like object detection.
19
21
 
@@ -116,7 +118,7 @@ class MILCO(BaseODDataset[NDArray[np.number[Any]], list[str], str], BaseDatasetN
116
118
  self,
117
119
  root: str | Path,
118
120
  image_set: Literal["train", "operational", "base"] = "train",
119
- transforms: Transform[NDArray[np.number[Any]]] | Sequence[Transform[NDArray[np.number[Any]]]] | None = None,
121
+ transforms: NumpyObjectDetectionTransform | Sequence[NumpyObjectDetectionTransform] | None = None,
120
122
  download: bool = False,
121
123
  verbose: bool = False,
122
124
  ) -> None:
@@ -3,21 +3,23 @@ from __future__ import annotations
3
3
  __all__ = []
4
4
 
5
5
  import json
6
+ from collections.abc import Sequence
6
7
  from pathlib import Path
7
- from typing import Any, Literal, Sequence
8
+ from typing import Any, Literal
8
9
 
9
- import numpy as np
10
- from numpy.typing import NDArray
11
-
12
- from maite_datasets._base import BaseODDataset, DataLocation
10
+ from maite_datasets._base import (
11
+ BaseDatasetNumpyMixin,
12
+ BaseODDataset,
13
+ DataLocation,
14
+ NumpyArray,
15
+ NumpyObjectDetectionTransform,
16
+ )
13
17
  from maite_datasets._fileio import _ensure_exists
14
- from maite_datasets._mixin._numpy import BaseDatasetNumpyMixin
15
- from maite_datasets._protocols import Transform
16
18
 
17
19
 
18
20
  class SeaDrone(
19
21
  BaseODDataset[
20
- NDArray[np.number[Any]],
22
+ NumpyArray,
21
23
  list[tuple[list[int], list[list[float]]]],
22
24
  tuple[list[int], list[list[float]]],
23
25
  ],
@@ -313,7 +315,7 @@ class SeaDrone(
313
315
  self,
314
316
  root: str | Path,
315
317
  image_set: Literal["train", "val", "test", "base"] = "train",
316
- transforms: Transform[NDArray[np.number[Any]]] | Sequence[Transform[NDArray[np.number[Any]]]] | None = None,
318
+ transforms: NumpyObjectDetectionTransform | Sequence[NumpyObjectDetectionTransform] | None = None,
317
319
  download: bool = False,
318
320
  verbose: bool = False,
319
321
  ) -> None:
@@ -4,24 +4,22 @@ __all__ = []
4
4
 
5
5
  import os
6
6
  import shutil
7
+ from collections.abc import Sequence
7
8
  from pathlib import Path
8
- from typing import Any, Literal, Sequence, TypeVar
9
+ from typing import Any, Literal, TypeVar
9
10
 
10
- import numpy as np
11
11
  from defusedxml.ElementTree import parse
12
- from numpy.typing import NDArray
13
12
 
14
13
  from maite_datasets._base import (
15
- BaseDataset,
14
+ BaseDatasetNumpyMixin,
15
+ BaseDownloadedDataset,
16
16
  BaseODDataset,
17
17
  DataLocation,
18
+ NumpyArray,
19
+ NumpyObjectDetectionTransform,
20
+ ObjectDetectionTarget,
18
21
  _ensure_exists,
19
- _TArray,
20
- _TTarget,
21
22
  )
22
- from maite_datasets._mixin._numpy import BaseDatasetNumpyMixin
23
- from maite_datasets._protocols import Transform
24
- from maite_datasets._types import ObjectDetectionTarget
25
23
 
26
24
  VOCClassStringMap = Literal[
27
25
  "aeroplane",
@@ -48,7 +46,7 @@ VOCClassStringMap = Literal[
48
46
  TVOCClassMap = TypeVar("TVOCClassMap", VOCClassStringMap, int, list[VOCClassStringMap], list[int])
49
47
 
50
48
 
51
- class BaseVOCDataset(BaseDataset[_TArray, _TTarget, list[str], str]):
49
+ class BaseVOCDataset(BaseDownloadedDataset[NumpyArray, ObjectDetectionTarget, list[str], str]):
52
50
  _resources = [
53
51
  DataLocation(
54
52
  url="https://data.brainchip.com/dataset-mirror/voc/VOCtrainval_11-May-2012.tar",
@@ -130,7 +128,7 @@ class BaseVOCDataset(BaseDataset[_TArray, _TTarget, list[str], str]):
130
128
  root: str | Path,
131
129
  image_set: Literal["train", "val", "test", "base"] = "train",
132
130
  year: Literal["2007", "2008", "2009", "2010", "2011", "2012"] = "2012",
133
- transforms: Transform[_TArray] | Sequence[Transform[_TArray]] | None = None,
131
+ transforms: NumpyObjectDetectionTransform | Sequence[NumpyObjectDetectionTransform] | None = None,
134
132
  download: bool = False,
135
133
  verbose: bool = False,
136
134
  ) -> None:
@@ -432,8 +430,8 @@ class BaseVOCDataset(BaseDataset[_TArray, _TTarget, list[str], str]):
432
430
 
433
431
 
434
432
  class VOCDetection(
435
- BaseVOCDataset[NDArray[np.number[Any]], ObjectDetectionTarget[NDArray[np.number[Any]]]],
436
- BaseODDataset[NDArray[np.number[Any]], list[str], str],
433
+ BaseVOCDataset,
434
+ BaseODDataset[NumpyArray, list[str], str],
437
435
  BaseDatasetNumpyMixin,
438
436
  ):
439
437
  """
@@ -7,14 +7,16 @@ __all__ = []
7
7
  from pathlib import Path
8
8
  from typing import Any
9
9
 
10
+ import maite.protocols.object_detection as od
10
11
  import numpy as np
12
+ from maite.protocols import DatasetMetadata, DatumMetadata
11
13
  from PIL import Image
12
14
 
13
- from maite_datasets._protocols import DatasetMetadata, DatumMetadata, ObjectDetectionDataset, ObjectDetectionDatum
14
- from maite_datasets._reader._base import _ObjectDetectionTarget, BaseDatasetReader
15
+ from maite_datasets._base import BaseDataset, ObjectDetectionTarget
16
+ from maite_datasets._reader import BaseDatasetReader
15
17
 
16
18
 
17
- class YOLODatasetReader(BaseDatasetReader):
19
+ class YOLODatasetReader(BaseDatasetReader[od.Dataset]):
18
20
  """
19
21
  YOLO format dataset reader conforming to MAITE protocols.
20
22
 
@@ -120,9 +122,9 @@ class YOLODatasetReader(BaseDatasetReader):
120
122
  """Mapping from class index to class name."""
121
123
  return self._index2label
122
124
 
123
- def _create_dataset_implementation(self) -> ObjectDetectionDataset:
125
+ def create_dataset(self) -> od.Dataset:
124
126
  """Create YOLO dataset implementation."""
125
- return _YOLODataset(self)
127
+ return YOLODataset(self)
126
128
 
127
129
  def _validate_format_specific(self) -> tuple[list[str], dict[str, Any]]:
128
130
  """Validate YOLO format specific files and structure."""
@@ -200,7 +202,7 @@ class YOLODatasetReader(BaseDatasetReader):
200
202
  """Load class names from classes file."""
201
203
  with open(self._classes_path) as f:
202
204
  class_names = [line.strip() for line in f if line.strip()]
203
- self._index2label = {idx: name for idx, name in enumerate(class_names)}
205
+ self._index2label = dict(enumerate(class_names))
204
206
 
205
207
  def _find_image_files(self) -> None:
206
208
  """Find all valid image files."""
@@ -213,32 +215,35 @@ class YOLODatasetReader(BaseDatasetReader):
213
215
  raise ValueError(f"No image files found in {self._images_path}")
214
216
 
215
217
 
216
- class _YOLODataset:
218
+ class YOLODataset(BaseDataset):
217
219
  """Internal YOLO dataset implementation."""
218
220
 
219
221
  def __init__(self, reader: YOLODatasetReader) -> None:
220
- self.reader = reader
221
-
222
- @property
223
- def metadata(self) -> DatasetMetadata:
224
- return DatasetMetadata(
225
- id=self.reader.dataset_id,
226
- index2label=self.reader.index2label,
222
+ self._reader = reader
223
+
224
+ self.root = reader.dataset_path
225
+ self.images_path = reader._images_path
226
+ self.annotation_path = reader._labels_path
227
+ self.size = len(reader._image_files)
228
+ self.classes = reader.index2label
229
+ self.metadata = DatasetMetadata(
230
+ id=self._reader.dataset_id,
231
+ index2label=self._reader.index2label,
227
232
  )
228
233
 
229
234
  def __len__(self) -> int:
230
- return len(self.reader._image_files)
235
+ return len(self._reader._image_files)
231
236
 
232
- def __getitem__(self, index: int) -> ObjectDetectionDatum:
233
- image_path = self.reader._image_files[index]
237
+ def __getitem__(self, index: int) -> tuple[od.InputType, od.ObjectDetectionTarget, DatumMetadata]:
238
+ image_path = self._reader._image_files[index]
234
239
 
235
240
  # Load image
236
- image = np.array(Image.open(image_path).convert("RGB"))
241
+ image = np.asarray(Image.open(image_path).convert("RGB"), dtype=np.uint8)
237
242
  img_height, img_width = image.shape[:2]
238
243
  image = np.transpose(image, (2, 0, 1)) # Convert to CHW format
239
244
 
240
245
  # Load corresponding label file
241
- label_path = self.reader._labels_path / f"{image_path.stem}.txt"
246
+ label_path = self._reader._labels_path / f"{image_path.stem}.txt"
242
247
 
243
248
  annotation_metadata = []
244
249
  if label_path.exists():
@@ -292,12 +297,12 @@ class _YOLODataset:
292
297
  labels = np.empty(0, dtype=np.int64)
293
298
  scores = np.empty(0, dtype=np.float32)
294
299
 
295
- target = _ObjectDetectionTarget(boxes, labels, scores)
300
+ target = ObjectDetectionTarget(boxes, labels, scores)
296
301
 
297
302
  # Create comprehensive datum metadata
298
303
  datum_metadata = DatumMetadata(
299
304
  **{
300
- "id": f"{self.reader.dataset_id}_{image_path.stem}",
305
+ "id": f"{self._reader.dataset_id}_{image_path.stem}",
301
306
  # Image-level metadata
302
307
  "file_name": image_path.name,
303
308
  "file_path": str(image_path),
@@ -0,0 +1,23 @@
1
+ """
2
+ Common type protocols used for interoperability.
3
+ """
4
+
5
+ from collections.abc import Iterator
6
+ from typing import Any, Protocol, runtime_checkable
7
+
8
+
9
+ @runtime_checkable
10
+ class Array(Protocol):
11
+ """
12
+ Protocol for interoperable array objects.
13
+
14
+ Supports common array representations with popular libraries like
15
+ PyTorch, Tensorflow and JAX, as well as NumPy arrays.
16
+ """
17
+
18
+ @property
19
+ def shape(self) -> tuple[int, ...]: ...
20
+ def __array__(self) -> Any: ...
21
+ def __getitem__(self, key: Any, /) -> Any: ...
22
+ def __iter__(self) -> Iterator[Any]: ...
23
+ def __len__(self) -> int: ...
@@ -0,0 +1,8 @@
1
+ import importlib.util
2
+
3
+ __all__ = []
4
+
5
+ if importlib.util.find_spec("torch") is not None and importlib.util.find_spec("torchvision") is not None:
6
+ from ._torch import TorchvisionWrapper
7
+
8
+ __all__ += ["TorchvisionWrapper"]