maite-datasets 0.0.5__py3-none-any.whl → 0.0.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (35) hide show
  1. maite_datasets/__init__.py +2 -6
  2. maite_datasets/_base.py +169 -51
  3. maite_datasets/_builder.py +46 -55
  4. maite_datasets/_collate.py +2 -3
  5. maite_datasets/{_reader/_base.py → _reader.py} +62 -36
  6. maite_datasets/_validate.py +4 -2
  7. maite_datasets/adapters/__init__.py +3 -0
  8. maite_datasets/adapters/_huggingface.py +391 -0
  9. maite_datasets/image_classification/_cifar10.py +12 -7
  10. maite_datasets/image_classification/_mnist.py +15 -10
  11. maite_datasets/image_classification/_ships.py +12 -8
  12. maite_datasets/object_detection/__init__.py +4 -7
  13. maite_datasets/object_detection/_antiuav.py +11 -8
  14. maite_datasets/{_reader → object_detection}/_coco.py +29 -27
  15. maite_datasets/object_detection/_milco.py +11 -9
  16. maite_datasets/object_detection/_seadrone.py +11 -9
  17. maite_datasets/object_detection/_voc.py +11 -13
  18. maite_datasets/{_reader → object_detection}/_yolo.py +26 -21
  19. maite_datasets/protocols.py +94 -0
  20. maite_datasets/wrappers/__init__.py +8 -0
  21. maite_datasets/wrappers/_torch.py +109 -0
  22. maite_datasets-0.0.7.dist-info/METADATA +181 -0
  23. maite_datasets-0.0.7.dist-info/RECORD +28 -0
  24. maite_datasets/_mixin/__init__.py +0 -0
  25. maite_datasets/_mixin/_numpy.py +0 -28
  26. maite_datasets/_mixin/_torch.py +0 -28
  27. maite_datasets/_protocols.py +0 -217
  28. maite_datasets/_reader/__init__.py +0 -6
  29. maite_datasets/_reader/_factory.py +0 -64
  30. maite_datasets/_types.py +0 -50
  31. maite_datasets/object_detection/_voc_torch.py +0 -65
  32. maite_datasets-0.0.5.dist-info/METADATA +0 -91
  33. maite_datasets-0.0.5.dist-info/RECORD +0 -31
  34. {maite_datasets-0.0.5.dist-info → maite_datasets-0.0.7.dist-info}/WHEEL +0 -0
  35. {maite_datasets-0.0.5.dist-info → maite_datasets-0.0.7.dist-info}/licenses/LICENSE +0 -0
@@ -3,20 +3,19 @@
3
3
  from __future__ import annotations
4
4
 
5
5
  import json
6
- import logging
7
6
  from pathlib import Path
8
7
  from typing import Any
9
8
 
9
+ import maite.protocols.object_detection as od
10
10
  import numpy as np
11
+ from maite.protocols import DatasetMetadata, DatumMetadata
11
12
  from PIL import Image
12
13
 
13
- from maite_datasets._protocols import DatasetMetadata, DatumMetadata, ObjectDetectionDataset, ObjectDetectionDatum
14
- from maite_datasets._reader._base import _ObjectDetectionTarget, BaseDatasetReader
14
+ from maite_datasets._base import BaseDataset, ObjectDetectionTarget
15
+ from maite_datasets._reader import BaseDatasetReader
15
16
 
16
- _logger = logging.getLogger(__name__)
17
17
 
18
-
19
- class COCODatasetReader(BaseDatasetReader):
18
+ class COCODatasetReader(BaseDatasetReader[od.Dataset]):
20
19
  """
21
20
  COCO format dataset reader conforming to MAITE protocols.
22
21
 
@@ -132,9 +131,9 @@ class COCODatasetReader(BaseDatasetReader):
132
131
  """Mapping from class index to class name."""
133
132
  return self._index2label
134
133
 
135
- def _create_dataset_implementation(self) -> ObjectDetectionDataset:
134
+ def create_dataset(self) -> od.Dataset:
136
135
  """Create COCO dataset implementation."""
137
- return _COCODataset(self)
136
+ return COCODataset(self)
138
137
 
139
138
  def _validate_format_specific(self) -> tuple[list[str], dict[str, Any]]:
140
139
  """Validate COCO format specific files and structure."""
@@ -198,37 +197,40 @@ class COCODatasetReader(BaseDatasetReader):
198
197
  else:
199
198
  class_names = [cat["name"] for cat in self._coco_data["categories"]]
200
199
 
201
- self._index2label = {idx: name for idx, name in enumerate(class_names)}
200
+ self._index2label = dict(enumerate(class_names))
202
201
 
203
202
 
204
- class _COCODataset:
203
+ class COCODataset(BaseDataset):
205
204
  """Internal COCO dataset implementation."""
206
205
 
207
206
  def __init__(self, reader: COCODatasetReader) -> None:
208
- self.reader = reader
209
- self.image_ids = list(reader._image_id_to_info.keys())
210
-
211
- @property
212
- def metadata(self) -> DatasetMetadata:
213
- return DatasetMetadata(
214
- id=self.reader.dataset_id,
215
- index2label=self.reader.index2label,
207
+ self._reader = reader
208
+ self._image_ids = list(reader._image_id_to_info.keys())
209
+
210
+ self.root = reader.dataset_path
211
+ self.images_path = reader._images_path
212
+ self.annotation_path = reader._annotation_path
213
+ self.size = len(reader._image_id_to_info)
214
+ self.classes = reader.index2label
215
+ self.metadata = DatasetMetadata(
216
+ id=self._reader.dataset_id,
217
+ index2label=self._reader.index2label,
216
218
  )
217
219
 
218
220
  def __len__(self) -> int:
219
- return len(self.image_ids)
221
+ return len(self._image_ids)
220
222
 
221
- def __getitem__(self, index: int) -> ObjectDetectionDatum:
222
- image_id = self.image_ids[index]
223
- image_info = self.reader._image_id_to_info[image_id]
223
+ def __getitem__(self, index: int) -> tuple[od.InputType, od.ObjectDetectionTarget, DatumMetadata]:
224
+ image_id = self._image_ids[index]
225
+ image_info = self._reader._image_id_to_info[image_id]
224
226
 
225
227
  # Load image
226
- image_path = self.reader._images_path / image_info["file_name"]
228
+ image_path = self._reader._images_path / image_info["file_name"]
227
229
  image = np.array(Image.open(image_path).convert("RGB"))
228
230
  image = np.transpose(image, (2, 0, 1)) # Convert to CHW format
229
231
 
230
232
  # Get annotations for this image
231
- annotations = self.reader.image_id_to_annotations.get(image_id, [])
233
+ annotations = self._reader.image_id_to_annotations.get(image_id, [])
232
234
 
233
235
  if annotations:
234
236
  boxes = []
@@ -241,7 +243,7 @@ class _COCODataset:
241
243
  boxes.append([x, y, x + w, y + h])
242
244
 
243
245
  # Map category_id to class index
244
- cat_idx = self.reader._category_id_to_idx[ann["category_id"]]
246
+ cat_idx = self._reader._category_id_to_idx[ann["category_id"]]
245
247
  labels.append(cat_idx)
246
248
 
247
249
  # Collect annotation metadata
@@ -267,12 +269,12 @@ class _COCODataset:
267
269
  scores = np.empty(0, dtype=np.float32)
268
270
  annotation_metadata = []
269
271
 
270
- target = _ObjectDetectionTarget(boxes, labels, scores)
272
+ target = ObjectDetectionTarget(boxes, labels, scores)
271
273
 
272
274
  # Create comprehensive datum metadata
273
275
  datum_metadata = DatumMetadata(
274
276
  **{
275
- "id": f"{self.reader.dataset_id}_{image_id}",
277
+ "id": f"{self._reader.dataset_id}_{image_id}",
276
278
  # Image-level metadata
277
279
  "coco_image_id": image_id,
278
280
  "file_name": image_info["file_name"],
@@ -2,18 +2,20 @@ from __future__ import annotations
2
2
 
3
3
  __all__ = []
4
4
 
5
+ from collections.abc import Sequence
5
6
  from pathlib import Path
6
- from typing import Any, Literal, Sequence
7
+ from typing import Any, Literal
7
8
 
8
- import numpy as np
9
- from numpy.typing import NDArray
9
+ from maite_datasets._base import (
10
+ BaseDatasetNumpyMixin,
11
+ BaseODDataset,
12
+ DataLocation,
13
+ NumpyArray,
14
+ NumpyObjectDetectionTransform,
15
+ )
10
16
 
11
- from maite_datasets._base import BaseODDataset, DataLocation
12
- from maite_datasets._mixin._numpy import BaseDatasetNumpyMixin
13
- from maite_datasets._protocols import Transform
14
17
 
15
-
16
- class MILCO(BaseODDataset[NDArray[np.number[Any]], list[str], str], BaseDatasetNumpyMixin):
18
+ class MILCO(BaseODDataset[NumpyArray, list[str], str], BaseDatasetNumpyMixin):
17
19
  """
18
20
  A side-scan sonar dataset focused on mine-like object detection.
19
21
 
@@ -116,7 +118,7 @@ class MILCO(BaseODDataset[NDArray[np.number[Any]], list[str], str], BaseDatasetN
116
118
  self,
117
119
  root: str | Path,
118
120
  image_set: Literal["train", "operational", "base"] = "train",
119
- transforms: Transform[NDArray[np.number[Any]]] | Sequence[Transform[NDArray[np.number[Any]]]] | None = None,
121
+ transforms: NumpyObjectDetectionTransform | Sequence[NumpyObjectDetectionTransform] | None = None,
120
122
  download: bool = False,
121
123
  verbose: bool = False,
122
124
  ) -> None:
@@ -3,21 +3,23 @@ from __future__ import annotations
3
3
  __all__ = []
4
4
 
5
5
  import json
6
+ from collections.abc import Sequence
6
7
  from pathlib import Path
7
- from typing import Any, Literal, Sequence
8
+ from typing import Any, Literal
8
9
 
9
- import numpy as np
10
- from numpy.typing import NDArray
11
-
12
- from maite_datasets._base import BaseODDataset, DataLocation
10
+ from maite_datasets._base import (
11
+ BaseDatasetNumpyMixin,
12
+ BaseODDataset,
13
+ DataLocation,
14
+ NumpyArray,
15
+ NumpyObjectDetectionTransform,
16
+ )
13
17
  from maite_datasets._fileio import _ensure_exists
14
- from maite_datasets._mixin._numpy import BaseDatasetNumpyMixin
15
- from maite_datasets._protocols import Transform
16
18
 
17
19
 
18
20
  class SeaDrone(
19
21
  BaseODDataset[
20
- NDArray[np.number[Any]],
22
+ NumpyArray,
21
23
  list[tuple[list[int], list[list[float]]]],
22
24
  tuple[list[int], list[list[float]]],
23
25
  ],
@@ -313,7 +315,7 @@ class SeaDrone(
313
315
  self,
314
316
  root: str | Path,
315
317
  image_set: Literal["train", "val", "test", "base"] = "train",
316
- transforms: Transform[NDArray[np.number[Any]]] | Sequence[Transform[NDArray[np.number[Any]]]] | None = None,
318
+ transforms: NumpyObjectDetectionTransform | Sequence[NumpyObjectDetectionTransform] | None = None,
317
319
  download: bool = False,
318
320
  verbose: bool = False,
319
321
  ) -> None:
@@ -4,24 +4,22 @@ __all__ = []
4
4
 
5
5
  import os
6
6
  import shutil
7
+ from collections.abc import Sequence
7
8
  from pathlib import Path
8
- from typing import Any, Literal, Sequence, TypeVar
9
+ from typing import Any, Literal, TypeVar
9
10
 
10
- import numpy as np
11
11
  from defusedxml.ElementTree import parse
12
- from numpy.typing import NDArray
13
12
 
14
13
  from maite_datasets._base import (
15
- BaseDataset,
14
+ BaseDatasetNumpyMixin,
15
+ BaseDownloadedDataset,
16
16
  BaseODDataset,
17
17
  DataLocation,
18
+ NumpyArray,
19
+ NumpyObjectDetectionTransform,
20
+ ObjectDetectionTarget,
18
21
  _ensure_exists,
19
- _TArray,
20
- _TTarget,
21
22
  )
22
- from maite_datasets._mixin._numpy import BaseDatasetNumpyMixin
23
- from maite_datasets._protocols import Transform
24
- from maite_datasets._types import ObjectDetectionTarget
25
23
 
26
24
  VOCClassStringMap = Literal[
27
25
  "aeroplane",
@@ -48,7 +46,7 @@ VOCClassStringMap = Literal[
48
46
  TVOCClassMap = TypeVar("TVOCClassMap", VOCClassStringMap, int, list[VOCClassStringMap], list[int])
49
47
 
50
48
 
51
- class BaseVOCDataset(BaseDataset[_TArray, _TTarget, list[str], str]):
49
+ class BaseVOCDataset(BaseDownloadedDataset[NumpyArray, ObjectDetectionTarget, list[str], str]):
52
50
  _resources = [
53
51
  DataLocation(
54
52
  url="https://data.brainchip.com/dataset-mirror/voc/VOCtrainval_11-May-2012.tar",
@@ -130,7 +128,7 @@ class BaseVOCDataset(BaseDataset[_TArray, _TTarget, list[str], str]):
130
128
  root: str | Path,
131
129
  image_set: Literal["train", "val", "test", "base"] = "train",
132
130
  year: Literal["2007", "2008", "2009", "2010", "2011", "2012"] = "2012",
133
- transforms: Transform[_TArray] | Sequence[Transform[_TArray]] | None = None,
131
+ transforms: NumpyObjectDetectionTransform | Sequence[NumpyObjectDetectionTransform] | None = None,
134
132
  download: bool = False,
135
133
  verbose: bool = False,
136
134
  ) -> None:
@@ -432,8 +430,8 @@ class BaseVOCDataset(BaseDataset[_TArray, _TTarget, list[str], str]):
432
430
 
433
431
 
434
432
  class VOCDetection(
435
- BaseVOCDataset[NDArray[np.number[Any]], ObjectDetectionTarget[NDArray[np.number[Any]]]],
436
- BaseODDataset[NDArray[np.number[Any]], list[str], str],
433
+ BaseVOCDataset,
434
+ BaseODDataset[NumpyArray, list[str], str],
437
435
  BaseDatasetNumpyMixin,
438
436
  ):
439
437
  """
@@ -7,14 +7,16 @@ __all__ = []
7
7
  from pathlib import Path
8
8
  from typing import Any
9
9
 
10
+ import maite.protocols.object_detection as od
10
11
  import numpy as np
12
+ from maite.protocols import DatasetMetadata, DatumMetadata
11
13
  from PIL import Image
12
14
 
13
- from maite_datasets._protocols import DatasetMetadata, DatumMetadata, ObjectDetectionDataset, ObjectDetectionDatum
14
- from maite_datasets._reader._base import _ObjectDetectionTarget, BaseDatasetReader
15
+ from maite_datasets._base import BaseDataset, ObjectDetectionTarget
16
+ from maite_datasets._reader import BaseDatasetReader
15
17
 
16
18
 
17
- class YOLODatasetReader(BaseDatasetReader):
19
+ class YOLODatasetReader(BaseDatasetReader[od.Dataset]):
18
20
  """
19
21
  YOLO format dataset reader conforming to MAITE protocols.
20
22
 
@@ -120,9 +122,9 @@ class YOLODatasetReader(BaseDatasetReader):
120
122
  """Mapping from class index to class name."""
121
123
  return self._index2label
122
124
 
123
- def _create_dataset_implementation(self) -> ObjectDetectionDataset:
125
+ def create_dataset(self) -> od.Dataset:
124
126
  """Create YOLO dataset implementation."""
125
- return _YOLODataset(self)
127
+ return YOLODataset(self)
126
128
 
127
129
  def _validate_format_specific(self) -> tuple[list[str], dict[str, Any]]:
128
130
  """Validate YOLO format specific files and structure."""
@@ -200,7 +202,7 @@ class YOLODatasetReader(BaseDatasetReader):
200
202
  """Load class names from classes file."""
201
203
  with open(self._classes_path) as f:
202
204
  class_names = [line.strip() for line in f if line.strip()]
203
- self._index2label = {idx: name for idx, name in enumerate(class_names)}
205
+ self._index2label = dict(enumerate(class_names))
204
206
 
205
207
  def _find_image_files(self) -> None:
206
208
  """Find all valid image files."""
@@ -213,32 +215,35 @@ class YOLODatasetReader(BaseDatasetReader):
213
215
  raise ValueError(f"No image files found in {self._images_path}")
214
216
 
215
217
 
216
- class _YOLODataset:
218
+ class YOLODataset(BaseDataset):
217
219
  """Internal YOLO dataset implementation."""
218
220
 
219
221
  def __init__(self, reader: YOLODatasetReader) -> None:
220
- self.reader = reader
221
-
222
- @property
223
- def metadata(self) -> DatasetMetadata:
224
- return DatasetMetadata(
225
- id=self.reader.dataset_id,
226
- index2label=self.reader.index2label,
222
+ self._reader = reader
223
+
224
+ self.root = reader.dataset_path
225
+ self.images_path = reader._images_path
226
+ self.annotation_path = reader._labels_path
227
+ self.size = len(reader._image_files)
228
+ self.classes = reader.index2label
229
+ self.metadata = DatasetMetadata(
230
+ id=self._reader.dataset_id,
231
+ index2label=self._reader.index2label,
227
232
  )
228
233
 
229
234
  def __len__(self) -> int:
230
- return len(self.reader._image_files)
235
+ return len(self._reader._image_files)
231
236
 
232
- def __getitem__(self, index: int) -> ObjectDetectionDatum:
233
- image_path = self.reader._image_files[index]
237
+ def __getitem__(self, index: int) -> tuple[od.InputType, od.ObjectDetectionTarget, DatumMetadata]:
238
+ image_path = self._reader._image_files[index]
234
239
 
235
240
  # Load image
236
- image = np.array(Image.open(image_path).convert("RGB"))
241
+ image = np.asarray(Image.open(image_path).convert("RGB"), dtype=np.uint8)
237
242
  img_height, img_width = image.shape[:2]
238
243
  image = np.transpose(image, (2, 0, 1)) # Convert to CHW format
239
244
 
240
245
  # Load corresponding label file
241
- label_path = self.reader._labels_path / f"{image_path.stem}.txt"
246
+ label_path = self._reader._labels_path / f"{image_path.stem}.txt"
242
247
 
243
248
  annotation_metadata = []
244
249
  if label_path.exists():
@@ -292,12 +297,12 @@ class _YOLODataset:
292
297
  labels = np.empty(0, dtype=np.int64)
293
298
  scores = np.empty(0, dtype=np.float32)
294
299
 
295
- target = _ObjectDetectionTarget(boxes, labels, scores)
300
+ target = ObjectDetectionTarget(boxes, labels, scores)
296
301
 
297
302
  # Create comprehensive datum metadata
298
303
  datum_metadata = DatumMetadata(
299
304
  **{
300
- "id": f"{self.reader.dataset_id}_{image_path.stem}",
305
+ "id": f"{self._reader.dataset_id}_{image_path.stem}",
301
306
  # Image-level metadata
302
307
  "file_name": image_path.name,
303
308
  "file_path": str(image_path),
@@ -0,0 +1,94 @@
1
+ """
2
+ Common type protocols used for interoperability.
3
+ """
4
+
5
+ from collections.abc import Iterable, Iterator, Mapping, Sequence
6
+ from typing import Any, Protocol, overload, runtime_checkable
7
+
8
+
9
+ @runtime_checkable
10
+ class Array(Protocol):
11
+ """
12
+ Protocol for interoperable array objects.
13
+
14
+ Supports common array representations with popular libraries like
15
+ PyTorch, Tensorflow and JAX, as well as NumPy arrays.
16
+ """
17
+
18
+ @property
19
+ def shape(self) -> tuple[int, ...]: ...
20
+ def __array__(self) -> Any: ...
21
+ def __getitem__(self, key: Any, /) -> Any: ...
22
+ def __iter__(self) -> Iterator[Any]: ...
23
+ def __len__(self) -> int: ...
24
+
25
+
26
+ @runtime_checkable
27
+ class HFDatasetInfo(Protocol):
28
+ @property
29
+ def dataset_name(self) -> str: ...
30
+
31
+
32
+ @runtime_checkable
33
+ class HFDataset(Protocol):
34
+ @property
35
+ def features(self) -> Mapping[str, Any]: ...
36
+
37
+ @property
38
+ def builder_name(self) -> str | None: ...
39
+
40
+ @property
41
+ def info(self) -> HFDatasetInfo: ...
42
+
43
+ @overload
44
+ def __getitem__(self, key: int | slice | Iterable[int]) -> dict[str, Any]: ...
45
+ @overload
46
+ def __getitem__(self, key: str) -> Sequence[int]: ...
47
+ def __getitem__(self, key: str | int | slice | Iterable[int]) -> dict[str, Any] | Sequence[int]: ...
48
+
49
+ def __len__(self) -> int: ...
50
+
51
+
52
+ @runtime_checkable
53
+ class HFFeature(Protocol):
54
+ @property
55
+ def _type(self) -> str: ...
56
+
57
+
58
+ @runtime_checkable
59
+ class HFClassLabel(HFFeature, Protocol):
60
+ @property
61
+ def names(self) -> list[str]: ...
62
+
63
+ @property
64
+ def num_classes(self) -> int: ...
65
+
66
+
67
+ @runtime_checkable
68
+ class HFImage(HFFeature, Protocol):
69
+ @property
70
+ def decode(self) -> bool: ...
71
+
72
+
73
+ @runtime_checkable
74
+ class HFArray(HFFeature, Protocol):
75
+ @property
76
+ def shape(self) -> tuple[int, ...]: ...
77
+ @property
78
+ def dtype(self) -> str: ...
79
+
80
+
81
+ @runtime_checkable
82
+ class HFList(HFFeature, Protocol):
83
+ @property
84
+ def feature(self) -> Any: ...
85
+ @property
86
+ def length(self) -> int: ...
87
+
88
+
89
+ @runtime_checkable
90
+ class HFValue(HFFeature, Protocol):
91
+ @property
92
+ def pa_type(self) -> Any: ... # pyarrow type ... not documented
93
+ @property
94
+ def dtype(self) -> str: ...
@@ -0,0 +1,8 @@
1
+ import importlib.util
2
+
3
+ __all__ = []
4
+
5
+ if importlib.util.find_spec("torch") is not None and importlib.util.find_spec("torchvision") is not None:
6
+ from ._torch import TorchvisionWrapper
7
+
8
+ __all__ += ["TorchvisionWrapper"]
@@ -0,0 +1,109 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import Any, Callable, Generic, TypeAlias, TypeVar, cast, overload
4
+
5
+ import torch
6
+ from maite.protocols import DatasetMetadata, DatumMetadata
7
+ from maite.protocols.object_detection import ObjectDetectionTarget as _ObjectDetectionTarget
8
+ from torch import Tensor
9
+ from torchvision.tv_tensors import BoundingBoxes, Image
10
+
11
+ from maite_datasets._base import BaseDataset, ObjectDetectionTarget
12
+ from maite_datasets.protocols import Array
13
+
14
+ TArray = TypeVar("TArray", bound=Array)
15
+ TTarget = TypeVar("TTarget")
16
+
17
+ TorchvisionImageClassificationDatum: TypeAlias = tuple[Image, Tensor, DatumMetadata]
18
+ TorchvisionObjectDetectionDatum: TypeAlias = tuple[Image, ObjectDetectionTarget, DatumMetadata]
19
+
20
+
21
+ class TorchvisionWrapper(Generic[TArray, TTarget]):
22
+ """
23
+ Lightweight wrapper converting numpy-based datasets to Torchvision tensors.
24
+
25
+ Converts images to tv_tensor.Image and targets to the appropriate torchvision format.
26
+
27
+ Parameters
28
+ ----------
29
+ dataset : Dataset
30
+ Source dataset with numpy arrays
31
+ transforms : callable, optional
32
+ Torchvision v2 transform functions for targets
33
+ """
34
+
35
+ def __init__(
36
+ self,
37
+ dataset: BaseDataset[TArray, TTarget],
38
+ transforms: Callable[[Any], Any] | None = None,
39
+ ) -> None:
40
+ self._dataset = dataset
41
+ self.transforms = transforms
42
+ self.metadata: DatasetMetadata = {
43
+ "id": f"TorchvisionWrapper({dataset.metadata['id']})",
44
+ "index2label": dataset.metadata.get("index2label", {}),
45
+ }
46
+
47
+ def __getattr__(self, name: str) -> Any:
48
+ """Forward unknown attributes to wrapped dataset."""
49
+ return getattr(self._dataset, name)
50
+
51
+ def __dir__(self) -> list[str]:
52
+ """Include wrapped dataset attributes in dir() for IDE support."""
53
+ wrapper_attrs = set(super().__dir__())
54
+ dataset_attrs = set(dir(self._dataset))
55
+ return sorted(wrapper_attrs | dataset_attrs)
56
+
57
+ def _transform(self, datum: Any) -> Any:
58
+ return self.transforms(datum) if self.transforms else datum
59
+
60
+ @overload
61
+ def __getitem__(self: TorchvisionWrapper[TArray, TArray], index: int) -> tuple[Image, Tensor, DatumMetadata]: ...
62
+ @overload
63
+ def __getitem__(
64
+ self: TorchvisionWrapper[TArray, TTarget], index: int
65
+ ) -> tuple[Image, ObjectDetectionTarget, DatumMetadata]: ...
66
+
67
+ def __getitem__(self, index: int) -> tuple[Image, Tensor | ObjectDetectionTarget, DatumMetadata]:
68
+ """Get item with torch tensor conversion."""
69
+ image, target, metadata = self._dataset[index]
70
+
71
+ # Convert image to torch tensor
72
+ torch_image = Image(torch.tensor(image))
73
+
74
+ # Handle different target types
75
+ if isinstance(target, Array):
76
+ # Image classification case
77
+ torch_target = torch.tensor(target, dtype=torch.float32)
78
+ torch_datum = self._transform((torch_image, torch_target, metadata))
79
+ return cast(TorchvisionImageClassificationDatum, torch_datum)
80
+
81
+ if isinstance(target, _ObjectDetectionTarget):
82
+ # Object detection case
83
+ torch_boxes = BoundingBoxes(
84
+ torch.tensor(target.boxes), format="XYXY", canvas_size=(torch_image.shape[-2], torch_image.shape[-1])
85
+ ) # type: ignore
86
+ torch_labels = torch.tensor(target.labels, dtype=torch.int64)
87
+ torch_scores = torch.tensor(target.scores, dtype=torch.float32)
88
+ torch_target = ObjectDetectionTarget(torch_boxes, torch_labels, torch_scores)
89
+ torch_datum = self._transform((torch_image, torch_target, metadata))
90
+ return cast(TorchvisionObjectDetectionDatum, torch_datum)
91
+
92
+ raise TypeError(f"Unsupported target type: {type(target)}")
93
+
94
+ def __str__(self) -> str:
95
+ """String representation showing torch version."""
96
+ nt = "\n "
97
+ base_name = f"{self._dataset.__class__.__name__.replace('Dataset', '')} Dataset"
98
+ title = f"Torchvision Wrapped {base_name}" if not base_name.startswith("Torchvision") else base_name
99
+ sep = "-" * len(title)
100
+ attrs = [
101
+ f"{' '.join(w.capitalize() for w in k.split('_'))}: {v}"
102
+ for k, v in self.__dict__.items()
103
+ if not k.startswith("_")
104
+ ]
105
+ wrapped = f"{title}\n{sep}{nt}{nt.join(attrs)}"
106
+ return f"{wrapped}\n\n{self._dataset}"
107
+
108
+ def __len__(self) -> int:
109
+ return self._dataset.__len__()