maite-datasets 0.0.5__py3-none-any.whl → 0.0.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- maite_datasets/__init__.py +2 -6
- maite_datasets/_base.py +169 -51
- maite_datasets/_builder.py +46 -55
- maite_datasets/_collate.py +2 -3
- maite_datasets/{_reader/_base.py → _reader.py} +62 -36
- maite_datasets/_validate.py +4 -2
- maite_datasets/adapters/__init__.py +3 -0
- maite_datasets/adapters/_huggingface.py +391 -0
- maite_datasets/image_classification/_cifar10.py +12 -7
- maite_datasets/image_classification/_mnist.py +15 -10
- maite_datasets/image_classification/_ships.py +12 -8
- maite_datasets/object_detection/__init__.py +4 -7
- maite_datasets/object_detection/_antiuav.py +11 -8
- maite_datasets/{_reader → object_detection}/_coco.py +29 -27
- maite_datasets/object_detection/_milco.py +11 -9
- maite_datasets/object_detection/_seadrone.py +11 -9
- maite_datasets/object_detection/_voc.py +11 -13
- maite_datasets/{_reader → object_detection}/_yolo.py +26 -21
- maite_datasets/protocols.py +94 -0
- maite_datasets/wrappers/__init__.py +8 -0
- maite_datasets/wrappers/_torch.py +109 -0
- maite_datasets-0.0.7.dist-info/METADATA +181 -0
- maite_datasets-0.0.7.dist-info/RECORD +28 -0
- maite_datasets/_mixin/__init__.py +0 -0
- maite_datasets/_mixin/_numpy.py +0 -28
- maite_datasets/_mixin/_torch.py +0 -28
- maite_datasets/_protocols.py +0 -217
- maite_datasets/_reader/__init__.py +0 -6
- maite_datasets/_reader/_factory.py +0 -64
- maite_datasets/_types.py +0 -50
- maite_datasets/object_detection/_voc_torch.py +0 -65
- maite_datasets-0.0.5.dist-info/METADATA +0 -91
- maite_datasets-0.0.5.dist-info/RECORD +0 -31
- {maite_datasets-0.0.5.dist-info → maite_datasets-0.0.7.dist-info}/WHEEL +0 -0
- {maite_datasets-0.0.5.dist-info → maite_datasets-0.0.7.dist-info}/licenses/LICENSE +0 -0
@@ -3,20 +3,19 @@
|
|
3
3
|
from __future__ import annotations
|
4
4
|
|
5
5
|
import json
|
6
|
-
import logging
|
7
6
|
from pathlib import Path
|
8
7
|
from typing import Any
|
9
8
|
|
9
|
+
import maite.protocols.object_detection as od
|
10
10
|
import numpy as np
|
11
|
+
from maite.protocols import DatasetMetadata, DatumMetadata
|
11
12
|
from PIL import Image
|
12
13
|
|
13
|
-
from maite_datasets.
|
14
|
-
from maite_datasets._reader
|
14
|
+
from maite_datasets._base import BaseDataset, ObjectDetectionTarget
|
15
|
+
from maite_datasets._reader import BaseDatasetReader
|
15
16
|
|
16
|
-
_logger = logging.getLogger(__name__)
|
17
17
|
|
18
|
-
|
19
|
-
class COCODatasetReader(BaseDatasetReader):
|
18
|
+
class COCODatasetReader(BaseDatasetReader[od.Dataset]):
|
20
19
|
"""
|
21
20
|
COCO format dataset reader conforming to MAITE protocols.
|
22
21
|
|
@@ -132,9 +131,9 @@ class COCODatasetReader(BaseDatasetReader):
|
|
132
131
|
"""Mapping from class index to class name."""
|
133
132
|
return self._index2label
|
134
133
|
|
135
|
-
def
|
134
|
+
def create_dataset(self) -> od.Dataset:
|
136
135
|
"""Create COCO dataset implementation."""
|
137
|
-
return
|
136
|
+
return COCODataset(self)
|
138
137
|
|
139
138
|
def _validate_format_specific(self) -> tuple[list[str], dict[str, Any]]:
|
140
139
|
"""Validate COCO format specific files and structure."""
|
@@ -198,37 +197,40 @@ class COCODatasetReader(BaseDatasetReader):
|
|
198
197
|
else:
|
199
198
|
class_names = [cat["name"] for cat in self._coco_data["categories"]]
|
200
199
|
|
201
|
-
self._index2label =
|
200
|
+
self._index2label = dict(enumerate(class_names))
|
202
201
|
|
203
202
|
|
204
|
-
class
|
203
|
+
class COCODataset(BaseDataset):
|
205
204
|
"""Internal COCO dataset implementation."""
|
206
205
|
|
207
206
|
def __init__(self, reader: COCODatasetReader) -> None:
|
208
|
-
self.
|
209
|
-
self.
|
210
|
-
|
211
|
-
|
212
|
-
|
213
|
-
|
214
|
-
|
215
|
-
|
207
|
+
self._reader = reader
|
208
|
+
self._image_ids = list(reader._image_id_to_info.keys())
|
209
|
+
|
210
|
+
self.root = reader.dataset_path
|
211
|
+
self.images_path = reader._images_path
|
212
|
+
self.annotation_path = reader._annotation_path
|
213
|
+
self.size = len(reader._image_id_to_info)
|
214
|
+
self.classes = reader.index2label
|
215
|
+
self.metadata = DatasetMetadata(
|
216
|
+
id=self._reader.dataset_id,
|
217
|
+
index2label=self._reader.index2label,
|
216
218
|
)
|
217
219
|
|
218
220
|
def __len__(self) -> int:
|
219
|
-
return len(self.
|
221
|
+
return len(self._image_ids)
|
220
222
|
|
221
|
-
def __getitem__(self, index: int) ->
|
222
|
-
image_id = self.
|
223
|
-
image_info = self.
|
223
|
+
def __getitem__(self, index: int) -> tuple[od.InputType, od.ObjectDetectionTarget, DatumMetadata]:
|
224
|
+
image_id = self._image_ids[index]
|
225
|
+
image_info = self._reader._image_id_to_info[image_id]
|
224
226
|
|
225
227
|
# Load image
|
226
|
-
image_path = self.
|
228
|
+
image_path = self._reader._images_path / image_info["file_name"]
|
227
229
|
image = np.array(Image.open(image_path).convert("RGB"))
|
228
230
|
image = np.transpose(image, (2, 0, 1)) # Convert to CHW format
|
229
231
|
|
230
232
|
# Get annotations for this image
|
231
|
-
annotations = self.
|
233
|
+
annotations = self._reader.image_id_to_annotations.get(image_id, [])
|
232
234
|
|
233
235
|
if annotations:
|
234
236
|
boxes = []
|
@@ -241,7 +243,7 @@ class _COCODataset:
|
|
241
243
|
boxes.append([x, y, x + w, y + h])
|
242
244
|
|
243
245
|
# Map category_id to class index
|
244
|
-
cat_idx = self.
|
246
|
+
cat_idx = self._reader._category_id_to_idx[ann["category_id"]]
|
245
247
|
labels.append(cat_idx)
|
246
248
|
|
247
249
|
# Collect annotation metadata
|
@@ -267,12 +269,12 @@ class _COCODataset:
|
|
267
269
|
scores = np.empty(0, dtype=np.float32)
|
268
270
|
annotation_metadata = []
|
269
271
|
|
270
|
-
target =
|
272
|
+
target = ObjectDetectionTarget(boxes, labels, scores)
|
271
273
|
|
272
274
|
# Create comprehensive datum metadata
|
273
275
|
datum_metadata = DatumMetadata(
|
274
276
|
**{
|
275
|
-
"id": f"{self.
|
277
|
+
"id": f"{self._reader.dataset_id}_{image_id}",
|
276
278
|
# Image-level metadata
|
277
279
|
"coco_image_id": image_id,
|
278
280
|
"file_name": image_info["file_name"],
|
@@ -2,18 +2,20 @@ from __future__ import annotations
|
|
2
2
|
|
3
3
|
__all__ = []
|
4
4
|
|
5
|
+
from collections.abc import Sequence
|
5
6
|
from pathlib import Path
|
6
|
-
from typing import Any, Literal
|
7
|
+
from typing import Any, Literal
|
7
8
|
|
8
|
-
|
9
|
-
|
9
|
+
from maite_datasets._base import (
|
10
|
+
BaseDatasetNumpyMixin,
|
11
|
+
BaseODDataset,
|
12
|
+
DataLocation,
|
13
|
+
NumpyArray,
|
14
|
+
NumpyObjectDetectionTransform,
|
15
|
+
)
|
10
16
|
|
11
|
-
from maite_datasets._base import BaseODDataset, DataLocation
|
12
|
-
from maite_datasets._mixin._numpy import BaseDatasetNumpyMixin
|
13
|
-
from maite_datasets._protocols import Transform
|
14
17
|
|
15
|
-
|
16
|
-
class MILCO(BaseODDataset[NDArray[np.number[Any]], list[str], str], BaseDatasetNumpyMixin):
|
18
|
+
class MILCO(BaseODDataset[NumpyArray, list[str], str], BaseDatasetNumpyMixin):
|
17
19
|
"""
|
18
20
|
A side-scan sonar dataset focused on mine-like object detection.
|
19
21
|
|
@@ -116,7 +118,7 @@ class MILCO(BaseODDataset[NDArray[np.number[Any]], list[str], str], BaseDatasetN
|
|
116
118
|
self,
|
117
119
|
root: str | Path,
|
118
120
|
image_set: Literal["train", "operational", "base"] = "train",
|
119
|
-
transforms:
|
121
|
+
transforms: NumpyObjectDetectionTransform | Sequence[NumpyObjectDetectionTransform] | None = None,
|
120
122
|
download: bool = False,
|
121
123
|
verbose: bool = False,
|
122
124
|
) -> None:
|
@@ -3,21 +3,23 @@ from __future__ import annotations
|
|
3
3
|
__all__ = []
|
4
4
|
|
5
5
|
import json
|
6
|
+
from collections.abc import Sequence
|
6
7
|
from pathlib import Path
|
7
|
-
from typing import Any, Literal
|
8
|
+
from typing import Any, Literal
|
8
9
|
|
9
|
-
|
10
|
-
|
11
|
-
|
12
|
-
|
10
|
+
from maite_datasets._base import (
|
11
|
+
BaseDatasetNumpyMixin,
|
12
|
+
BaseODDataset,
|
13
|
+
DataLocation,
|
14
|
+
NumpyArray,
|
15
|
+
NumpyObjectDetectionTransform,
|
16
|
+
)
|
13
17
|
from maite_datasets._fileio import _ensure_exists
|
14
|
-
from maite_datasets._mixin._numpy import BaseDatasetNumpyMixin
|
15
|
-
from maite_datasets._protocols import Transform
|
16
18
|
|
17
19
|
|
18
20
|
class SeaDrone(
|
19
21
|
BaseODDataset[
|
20
|
-
|
22
|
+
NumpyArray,
|
21
23
|
list[tuple[list[int], list[list[float]]]],
|
22
24
|
tuple[list[int], list[list[float]]],
|
23
25
|
],
|
@@ -313,7 +315,7 @@ class SeaDrone(
|
|
313
315
|
self,
|
314
316
|
root: str | Path,
|
315
317
|
image_set: Literal["train", "val", "test", "base"] = "train",
|
316
|
-
transforms:
|
318
|
+
transforms: NumpyObjectDetectionTransform | Sequence[NumpyObjectDetectionTransform] | None = None,
|
317
319
|
download: bool = False,
|
318
320
|
verbose: bool = False,
|
319
321
|
) -> None:
|
@@ -4,24 +4,22 @@ __all__ = []
|
|
4
4
|
|
5
5
|
import os
|
6
6
|
import shutil
|
7
|
+
from collections.abc import Sequence
|
7
8
|
from pathlib import Path
|
8
|
-
from typing import Any, Literal,
|
9
|
+
from typing import Any, Literal, TypeVar
|
9
10
|
|
10
|
-
import numpy as np
|
11
11
|
from defusedxml.ElementTree import parse
|
12
|
-
from numpy.typing import NDArray
|
13
12
|
|
14
13
|
from maite_datasets._base import (
|
15
|
-
|
14
|
+
BaseDatasetNumpyMixin,
|
15
|
+
BaseDownloadedDataset,
|
16
16
|
BaseODDataset,
|
17
17
|
DataLocation,
|
18
|
+
NumpyArray,
|
19
|
+
NumpyObjectDetectionTransform,
|
20
|
+
ObjectDetectionTarget,
|
18
21
|
_ensure_exists,
|
19
|
-
_TArray,
|
20
|
-
_TTarget,
|
21
22
|
)
|
22
|
-
from maite_datasets._mixin._numpy import BaseDatasetNumpyMixin
|
23
|
-
from maite_datasets._protocols import Transform
|
24
|
-
from maite_datasets._types import ObjectDetectionTarget
|
25
23
|
|
26
24
|
VOCClassStringMap = Literal[
|
27
25
|
"aeroplane",
|
@@ -48,7 +46,7 @@ VOCClassStringMap = Literal[
|
|
48
46
|
TVOCClassMap = TypeVar("TVOCClassMap", VOCClassStringMap, int, list[VOCClassStringMap], list[int])
|
49
47
|
|
50
48
|
|
51
|
-
class BaseVOCDataset(
|
49
|
+
class BaseVOCDataset(BaseDownloadedDataset[NumpyArray, ObjectDetectionTarget, list[str], str]):
|
52
50
|
_resources = [
|
53
51
|
DataLocation(
|
54
52
|
url="https://data.brainchip.com/dataset-mirror/voc/VOCtrainval_11-May-2012.tar",
|
@@ -130,7 +128,7 @@ class BaseVOCDataset(BaseDataset[_TArray, _TTarget, list[str], str]):
|
|
130
128
|
root: str | Path,
|
131
129
|
image_set: Literal["train", "val", "test", "base"] = "train",
|
132
130
|
year: Literal["2007", "2008", "2009", "2010", "2011", "2012"] = "2012",
|
133
|
-
transforms:
|
131
|
+
transforms: NumpyObjectDetectionTransform | Sequence[NumpyObjectDetectionTransform] | None = None,
|
134
132
|
download: bool = False,
|
135
133
|
verbose: bool = False,
|
136
134
|
) -> None:
|
@@ -432,8 +430,8 @@ class BaseVOCDataset(BaseDataset[_TArray, _TTarget, list[str], str]):
|
|
432
430
|
|
433
431
|
|
434
432
|
class VOCDetection(
|
435
|
-
BaseVOCDataset
|
436
|
-
BaseODDataset[
|
433
|
+
BaseVOCDataset,
|
434
|
+
BaseODDataset[NumpyArray, list[str], str],
|
437
435
|
BaseDatasetNumpyMixin,
|
438
436
|
):
|
439
437
|
"""
|
@@ -7,14 +7,16 @@ __all__ = []
|
|
7
7
|
from pathlib import Path
|
8
8
|
from typing import Any
|
9
9
|
|
10
|
+
import maite.protocols.object_detection as od
|
10
11
|
import numpy as np
|
12
|
+
from maite.protocols import DatasetMetadata, DatumMetadata
|
11
13
|
from PIL import Image
|
12
14
|
|
13
|
-
from maite_datasets.
|
14
|
-
from maite_datasets._reader
|
15
|
+
from maite_datasets._base import BaseDataset, ObjectDetectionTarget
|
16
|
+
from maite_datasets._reader import BaseDatasetReader
|
15
17
|
|
16
18
|
|
17
|
-
class YOLODatasetReader(BaseDatasetReader):
|
19
|
+
class YOLODatasetReader(BaseDatasetReader[od.Dataset]):
|
18
20
|
"""
|
19
21
|
YOLO format dataset reader conforming to MAITE protocols.
|
20
22
|
|
@@ -120,9 +122,9 @@ class YOLODatasetReader(BaseDatasetReader):
|
|
120
122
|
"""Mapping from class index to class name."""
|
121
123
|
return self._index2label
|
122
124
|
|
123
|
-
def
|
125
|
+
def create_dataset(self) -> od.Dataset:
|
124
126
|
"""Create YOLO dataset implementation."""
|
125
|
-
return
|
127
|
+
return YOLODataset(self)
|
126
128
|
|
127
129
|
def _validate_format_specific(self) -> tuple[list[str], dict[str, Any]]:
|
128
130
|
"""Validate YOLO format specific files and structure."""
|
@@ -200,7 +202,7 @@ class YOLODatasetReader(BaseDatasetReader):
|
|
200
202
|
"""Load class names from classes file."""
|
201
203
|
with open(self._classes_path) as f:
|
202
204
|
class_names = [line.strip() for line in f if line.strip()]
|
203
|
-
self._index2label =
|
205
|
+
self._index2label = dict(enumerate(class_names))
|
204
206
|
|
205
207
|
def _find_image_files(self) -> None:
|
206
208
|
"""Find all valid image files."""
|
@@ -213,32 +215,35 @@ class YOLODatasetReader(BaseDatasetReader):
|
|
213
215
|
raise ValueError(f"No image files found in {self._images_path}")
|
214
216
|
|
215
217
|
|
216
|
-
class
|
218
|
+
class YOLODataset(BaseDataset):
|
217
219
|
"""Internal YOLO dataset implementation."""
|
218
220
|
|
219
221
|
def __init__(self, reader: YOLODatasetReader) -> None:
|
220
|
-
self.
|
221
|
-
|
222
|
-
|
223
|
-
|
224
|
-
|
225
|
-
|
226
|
-
|
222
|
+
self._reader = reader
|
223
|
+
|
224
|
+
self.root = reader.dataset_path
|
225
|
+
self.images_path = reader._images_path
|
226
|
+
self.annotation_path = reader._labels_path
|
227
|
+
self.size = len(reader._image_files)
|
228
|
+
self.classes = reader.index2label
|
229
|
+
self.metadata = DatasetMetadata(
|
230
|
+
id=self._reader.dataset_id,
|
231
|
+
index2label=self._reader.index2label,
|
227
232
|
)
|
228
233
|
|
229
234
|
def __len__(self) -> int:
|
230
|
-
return len(self.
|
235
|
+
return len(self._reader._image_files)
|
231
236
|
|
232
|
-
def __getitem__(self, index: int) ->
|
233
|
-
image_path = self.
|
237
|
+
def __getitem__(self, index: int) -> tuple[od.InputType, od.ObjectDetectionTarget, DatumMetadata]:
|
238
|
+
image_path = self._reader._image_files[index]
|
234
239
|
|
235
240
|
# Load image
|
236
|
-
image = np.
|
241
|
+
image = np.asarray(Image.open(image_path).convert("RGB"), dtype=np.uint8)
|
237
242
|
img_height, img_width = image.shape[:2]
|
238
243
|
image = np.transpose(image, (2, 0, 1)) # Convert to CHW format
|
239
244
|
|
240
245
|
# Load corresponding label file
|
241
|
-
label_path = self.
|
246
|
+
label_path = self._reader._labels_path / f"{image_path.stem}.txt"
|
242
247
|
|
243
248
|
annotation_metadata = []
|
244
249
|
if label_path.exists():
|
@@ -292,12 +297,12 @@ class _YOLODataset:
|
|
292
297
|
labels = np.empty(0, dtype=np.int64)
|
293
298
|
scores = np.empty(0, dtype=np.float32)
|
294
299
|
|
295
|
-
target =
|
300
|
+
target = ObjectDetectionTarget(boxes, labels, scores)
|
296
301
|
|
297
302
|
# Create comprehensive datum metadata
|
298
303
|
datum_metadata = DatumMetadata(
|
299
304
|
**{
|
300
|
-
"id": f"{self.
|
305
|
+
"id": f"{self._reader.dataset_id}_{image_path.stem}",
|
301
306
|
# Image-level metadata
|
302
307
|
"file_name": image_path.name,
|
303
308
|
"file_path": str(image_path),
|
@@ -0,0 +1,94 @@
|
|
1
|
+
"""
|
2
|
+
Common type protocols used for interoperability.
|
3
|
+
"""
|
4
|
+
|
5
|
+
from collections.abc import Iterable, Iterator, Mapping, Sequence
|
6
|
+
from typing import Any, Protocol, overload, runtime_checkable
|
7
|
+
|
8
|
+
|
9
|
+
@runtime_checkable
|
10
|
+
class Array(Protocol):
|
11
|
+
"""
|
12
|
+
Protocol for interoperable array objects.
|
13
|
+
|
14
|
+
Supports common array representations with popular libraries like
|
15
|
+
PyTorch, Tensorflow and JAX, as well as NumPy arrays.
|
16
|
+
"""
|
17
|
+
|
18
|
+
@property
|
19
|
+
def shape(self) -> tuple[int, ...]: ...
|
20
|
+
def __array__(self) -> Any: ...
|
21
|
+
def __getitem__(self, key: Any, /) -> Any: ...
|
22
|
+
def __iter__(self) -> Iterator[Any]: ...
|
23
|
+
def __len__(self) -> int: ...
|
24
|
+
|
25
|
+
|
26
|
+
@runtime_checkable
|
27
|
+
class HFDatasetInfo(Protocol):
|
28
|
+
@property
|
29
|
+
def dataset_name(self) -> str: ...
|
30
|
+
|
31
|
+
|
32
|
+
@runtime_checkable
|
33
|
+
class HFDataset(Protocol):
|
34
|
+
@property
|
35
|
+
def features(self) -> Mapping[str, Any]: ...
|
36
|
+
|
37
|
+
@property
|
38
|
+
def builder_name(self) -> str | None: ...
|
39
|
+
|
40
|
+
@property
|
41
|
+
def info(self) -> HFDatasetInfo: ...
|
42
|
+
|
43
|
+
@overload
|
44
|
+
def __getitem__(self, key: int | slice | Iterable[int]) -> dict[str, Any]: ...
|
45
|
+
@overload
|
46
|
+
def __getitem__(self, key: str) -> Sequence[int]: ...
|
47
|
+
def __getitem__(self, key: str | int | slice | Iterable[int]) -> dict[str, Any] | Sequence[int]: ...
|
48
|
+
|
49
|
+
def __len__(self) -> int: ...
|
50
|
+
|
51
|
+
|
52
|
+
@runtime_checkable
|
53
|
+
class HFFeature(Protocol):
|
54
|
+
@property
|
55
|
+
def _type(self) -> str: ...
|
56
|
+
|
57
|
+
|
58
|
+
@runtime_checkable
|
59
|
+
class HFClassLabel(HFFeature, Protocol):
|
60
|
+
@property
|
61
|
+
def names(self) -> list[str]: ...
|
62
|
+
|
63
|
+
@property
|
64
|
+
def num_classes(self) -> int: ...
|
65
|
+
|
66
|
+
|
67
|
+
@runtime_checkable
|
68
|
+
class HFImage(HFFeature, Protocol):
|
69
|
+
@property
|
70
|
+
def decode(self) -> bool: ...
|
71
|
+
|
72
|
+
|
73
|
+
@runtime_checkable
|
74
|
+
class HFArray(HFFeature, Protocol):
|
75
|
+
@property
|
76
|
+
def shape(self) -> tuple[int, ...]: ...
|
77
|
+
@property
|
78
|
+
def dtype(self) -> str: ...
|
79
|
+
|
80
|
+
|
81
|
+
@runtime_checkable
|
82
|
+
class HFList(HFFeature, Protocol):
|
83
|
+
@property
|
84
|
+
def feature(self) -> Any: ...
|
85
|
+
@property
|
86
|
+
def length(self) -> int: ...
|
87
|
+
|
88
|
+
|
89
|
+
@runtime_checkable
|
90
|
+
class HFValue(HFFeature, Protocol):
|
91
|
+
@property
|
92
|
+
def pa_type(self) -> Any: ... # pyarrow type ... not documented
|
93
|
+
@property
|
94
|
+
def dtype(self) -> str: ...
|
@@ -0,0 +1,109 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
from typing import Any, Callable, Generic, TypeAlias, TypeVar, cast, overload
|
4
|
+
|
5
|
+
import torch
|
6
|
+
from maite.protocols import DatasetMetadata, DatumMetadata
|
7
|
+
from maite.protocols.object_detection import ObjectDetectionTarget as _ObjectDetectionTarget
|
8
|
+
from torch import Tensor
|
9
|
+
from torchvision.tv_tensors import BoundingBoxes, Image
|
10
|
+
|
11
|
+
from maite_datasets._base import BaseDataset, ObjectDetectionTarget
|
12
|
+
from maite_datasets.protocols import Array
|
13
|
+
|
14
|
+
TArray = TypeVar("TArray", bound=Array)
|
15
|
+
TTarget = TypeVar("TTarget")
|
16
|
+
|
17
|
+
TorchvisionImageClassificationDatum: TypeAlias = tuple[Image, Tensor, DatumMetadata]
|
18
|
+
TorchvisionObjectDetectionDatum: TypeAlias = tuple[Image, ObjectDetectionTarget, DatumMetadata]
|
19
|
+
|
20
|
+
|
21
|
+
class TorchvisionWrapper(Generic[TArray, TTarget]):
|
22
|
+
"""
|
23
|
+
Lightweight wrapper converting numpy-based datasets to Torchvision tensors.
|
24
|
+
|
25
|
+
Converts images to tv_tensor.Image and targets to the appropriate torchvision format.
|
26
|
+
|
27
|
+
Parameters
|
28
|
+
----------
|
29
|
+
dataset : Dataset
|
30
|
+
Source dataset with numpy arrays
|
31
|
+
transforms : callable, optional
|
32
|
+
Torchvision v2 transform functions for targets
|
33
|
+
"""
|
34
|
+
|
35
|
+
def __init__(
|
36
|
+
self,
|
37
|
+
dataset: BaseDataset[TArray, TTarget],
|
38
|
+
transforms: Callable[[Any], Any] | None = None,
|
39
|
+
) -> None:
|
40
|
+
self._dataset = dataset
|
41
|
+
self.transforms = transforms
|
42
|
+
self.metadata: DatasetMetadata = {
|
43
|
+
"id": f"TorchvisionWrapper({dataset.metadata['id']})",
|
44
|
+
"index2label": dataset.metadata.get("index2label", {}),
|
45
|
+
}
|
46
|
+
|
47
|
+
def __getattr__(self, name: str) -> Any:
|
48
|
+
"""Forward unknown attributes to wrapped dataset."""
|
49
|
+
return getattr(self._dataset, name)
|
50
|
+
|
51
|
+
def __dir__(self) -> list[str]:
|
52
|
+
"""Include wrapped dataset attributes in dir() for IDE support."""
|
53
|
+
wrapper_attrs = set(super().__dir__())
|
54
|
+
dataset_attrs = set(dir(self._dataset))
|
55
|
+
return sorted(wrapper_attrs | dataset_attrs)
|
56
|
+
|
57
|
+
def _transform(self, datum: Any) -> Any:
|
58
|
+
return self.transforms(datum) if self.transforms else datum
|
59
|
+
|
60
|
+
@overload
|
61
|
+
def __getitem__(self: TorchvisionWrapper[TArray, TArray], index: int) -> tuple[Image, Tensor, DatumMetadata]: ...
|
62
|
+
@overload
|
63
|
+
def __getitem__(
|
64
|
+
self: TorchvisionWrapper[TArray, TTarget], index: int
|
65
|
+
) -> tuple[Image, ObjectDetectionTarget, DatumMetadata]: ...
|
66
|
+
|
67
|
+
def __getitem__(self, index: int) -> tuple[Image, Tensor | ObjectDetectionTarget, DatumMetadata]:
|
68
|
+
"""Get item with torch tensor conversion."""
|
69
|
+
image, target, metadata = self._dataset[index]
|
70
|
+
|
71
|
+
# Convert image to torch tensor
|
72
|
+
torch_image = Image(torch.tensor(image))
|
73
|
+
|
74
|
+
# Handle different target types
|
75
|
+
if isinstance(target, Array):
|
76
|
+
# Image classification case
|
77
|
+
torch_target = torch.tensor(target, dtype=torch.float32)
|
78
|
+
torch_datum = self._transform((torch_image, torch_target, metadata))
|
79
|
+
return cast(TorchvisionImageClassificationDatum, torch_datum)
|
80
|
+
|
81
|
+
if isinstance(target, _ObjectDetectionTarget):
|
82
|
+
# Object detection case
|
83
|
+
torch_boxes = BoundingBoxes(
|
84
|
+
torch.tensor(target.boxes), format="XYXY", canvas_size=(torch_image.shape[-2], torch_image.shape[-1])
|
85
|
+
) # type: ignore
|
86
|
+
torch_labels = torch.tensor(target.labels, dtype=torch.int64)
|
87
|
+
torch_scores = torch.tensor(target.scores, dtype=torch.float32)
|
88
|
+
torch_target = ObjectDetectionTarget(torch_boxes, torch_labels, torch_scores)
|
89
|
+
torch_datum = self._transform((torch_image, torch_target, metadata))
|
90
|
+
return cast(TorchvisionObjectDetectionDatum, torch_datum)
|
91
|
+
|
92
|
+
raise TypeError(f"Unsupported target type: {type(target)}")
|
93
|
+
|
94
|
+
def __str__(self) -> str:
|
95
|
+
"""String representation showing torch version."""
|
96
|
+
nt = "\n "
|
97
|
+
base_name = f"{self._dataset.__class__.__name__.replace('Dataset', '')} Dataset"
|
98
|
+
title = f"Torchvision Wrapped {base_name}" if not base_name.startswith("Torchvision") else base_name
|
99
|
+
sep = "-" * len(title)
|
100
|
+
attrs = [
|
101
|
+
f"{' '.join(w.capitalize() for w in k.split('_'))}: {v}"
|
102
|
+
for k, v in self.__dict__.items()
|
103
|
+
if not k.startswith("_")
|
104
|
+
]
|
105
|
+
wrapped = f"{title}\n{sep}{nt}{nt.join(attrs)}"
|
106
|
+
return f"{wrapped}\n\n{self._dataset}"
|
107
|
+
|
108
|
+
def __len__(self) -> int:
|
109
|
+
return self._dataset.__len__()
|