maite-datasets 0.0.5__py3-none-any.whl → 0.0.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- maite_datasets/__init__.py +2 -6
- maite_datasets/_base.py +169 -51
- maite_datasets/_builder.py +46 -55
- maite_datasets/_collate.py +2 -3
- maite_datasets/{_reader/_base.py → _reader.py} +62 -36
- maite_datasets/_validate.py +4 -2
- maite_datasets/image_classification/_cifar10.py +12 -7
- maite_datasets/image_classification/_mnist.py +15 -10
- maite_datasets/image_classification/_ships.py +12 -8
- maite_datasets/object_detection/__init__.py +4 -7
- maite_datasets/object_detection/_antiuav.py +11 -8
- maite_datasets/{_reader → object_detection}/_coco.py +29 -27
- maite_datasets/object_detection/_milco.py +11 -9
- maite_datasets/object_detection/_seadrone.py +11 -9
- maite_datasets/object_detection/_voc.py +11 -13
- maite_datasets/{_reader → object_detection}/_yolo.py +26 -21
- maite_datasets/protocols.py +23 -0
- maite_datasets/wrappers/__init__.py +8 -0
- maite_datasets/wrappers/_torch.py +111 -0
- {maite_datasets-0.0.5.dist-info → maite_datasets-0.0.6.dist-info}/METADATA +56 -3
- maite_datasets-0.0.6.dist-info/RECORD +26 -0
- maite_datasets/_mixin/__init__.py +0 -0
- maite_datasets/_mixin/_numpy.py +0 -28
- maite_datasets/_mixin/_torch.py +0 -28
- maite_datasets/_protocols.py +0 -217
- maite_datasets/_reader/__init__.py +0 -6
- maite_datasets/_reader/_factory.py +0 -64
- maite_datasets/_types.py +0 -50
- maite_datasets/object_detection/_voc_torch.py +0 -65
- maite_datasets-0.0.5.dist-info/RECORD +0 -31
- {maite_datasets-0.0.5.dist-info → maite_datasets-0.0.6.dist-info}/WHEEL +0 -0
- {maite_datasets-0.0.5.dist-info → maite_datasets-0.0.6.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,111 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
from typing import Any, Callable, Generic, TypeAlias, TypeVar, cast, overload
|
4
|
+
|
5
|
+
import numpy as np
|
6
|
+
import torch
|
7
|
+
from maite.protocols import DatasetMetadata, DatumMetadata
|
8
|
+
from maite.protocols.object_detection import ObjectDetectionTarget as _ObjectDetectionTarget
|
9
|
+
from torch import Tensor
|
10
|
+
from torchvision.tv_tensors import BoundingBoxes, Image
|
11
|
+
|
12
|
+
from maite_datasets._base import BaseDataset, ObjectDetectionTarget
|
13
|
+
from maite_datasets.protocols import Array
|
14
|
+
|
15
|
+
TArray = TypeVar("TArray", bound=Array)
|
16
|
+
TTarget = TypeVar("TTarget")
|
17
|
+
|
18
|
+
TorchvisionImageClassificationDatum: TypeAlias = tuple[Image, Tensor, DatumMetadata]
|
19
|
+
TorchvisionObjectDetectionDatum: TypeAlias = tuple[Image, ObjectDetectionTarget, DatumMetadata]
|
20
|
+
|
21
|
+
|
22
|
+
class TorchvisionWrapper(Generic[TArray, TTarget]):
|
23
|
+
"""
|
24
|
+
Lightweight wrapper converting numpy-based datasets to Torchvision tensors.
|
25
|
+
|
26
|
+
Converts images to tv_tensor.Image and targets to the appropriate torchvision format.
|
27
|
+
|
28
|
+
Parameters
|
29
|
+
----------
|
30
|
+
dataset : Dataset
|
31
|
+
Source dataset with numpy arrays
|
32
|
+
transforms : callable, optional
|
33
|
+
Torchvision v2 transform functions for targets
|
34
|
+
"""
|
35
|
+
|
36
|
+
def __init__(
|
37
|
+
self,
|
38
|
+
dataset: BaseDataset[TArray, TTarget],
|
39
|
+
transforms: Callable[[Any], Any] | None = None,
|
40
|
+
) -> None:
|
41
|
+
self._dataset = dataset
|
42
|
+
self.transforms = transforms
|
43
|
+
self.metadata: DatasetMetadata = {
|
44
|
+
"id": f"TorchvisionWrapper({dataset.metadata['id']})",
|
45
|
+
"index2label": dataset.metadata.get("index2label", {}),
|
46
|
+
}
|
47
|
+
|
48
|
+
def __getattr__(self, name: str) -> Any:
|
49
|
+
"""Forward unknown attributes to wrapped dataset."""
|
50
|
+
return getattr(self._dataset, name)
|
51
|
+
|
52
|
+
def __dir__(self) -> list[str]:
|
53
|
+
"""Include wrapped dataset attributes in dir() for IDE support."""
|
54
|
+
wrapper_attrs = set(super().__dir__())
|
55
|
+
dataset_attrs = set(dir(self._dataset))
|
56
|
+
return sorted(wrapper_attrs | dataset_attrs)
|
57
|
+
|
58
|
+
def _transform(self, datum: Any) -> Any:
|
59
|
+
return self.transforms(datum) if self.transforms else datum
|
60
|
+
|
61
|
+
@overload
|
62
|
+
def __getitem__(self: TorchvisionWrapper[TArray, TArray], index: int) -> tuple[Image, Tensor, DatumMetadata]: ...
|
63
|
+
@overload
|
64
|
+
def __getitem__(
|
65
|
+
self: TorchvisionWrapper[TArray, TTarget], index: int
|
66
|
+
) -> tuple[Image, ObjectDetectionTarget, DatumMetadata]: ...
|
67
|
+
|
68
|
+
def __getitem__(self, index: int) -> tuple[Image, Tensor | ObjectDetectionTarget, DatumMetadata]:
|
69
|
+
"""Get item with torch tensor conversion."""
|
70
|
+
image, target, metadata = self._dataset[index]
|
71
|
+
|
72
|
+
# Convert image to torch tensor
|
73
|
+
torch_image = torch.from_numpy(image) if isinstance(image, np.ndarray) else torch.as_tensor(image)
|
74
|
+
torch_image = Image(torch_image)
|
75
|
+
|
76
|
+
# Handle different target types
|
77
|
+
if isinstance(target, Array):
|
78
|
+
# Image classification case
|
79
|
+
torch_target = torch.as_tensor(target, dtype=torch.float32)
|
80
|
+
torch_datum = self._transform((torch_image, torch_target, metadata))
|
81
|
+
return cast(TorchvisionImageClassificationDatum, torch_datum)
|
82
|
+
|
83
|
+
if isinstance(target, _ObjectDetectionTarget):
|
84
|
+
# Object detection case
|
85
|
+
torch_boxes = BoundingBoxes(
|
86
|
+
torch.as_tensor(target.boxes), format="XYXY", canvas_size=(torch_image.shape[-2], torch_image.shape[-1])
|
87
|
+
) # type: ignore
|
88
|
+
torch_labels = torch.as_tensor(target.labels, dtype=torch.int64)
|
89
|
+
torch_scores = torch.as_tensor(target.scores, dtype=torch.float32)
|
90
|
+
torch_target = ObjectDetectionTarget(torch_boxes, torch_labels, torch_scores)
|
91
|
+
torch_datum = self._transform((torch_image, torch_target, metadata))
|
92
|
+
return cast(TorchvisionObjectDetectionDatum, torch_datum)
|
93
|
+
|
94
|
+
raise TypeError(f"Unsupported target type: {type(target)}")
|
95
|
+
|
96
|
+
def __str__(self) -> str:
|
97
|
+
"""String representation showing torch version."""
|
98
|
+
nt = "\n "
|
99
|
+
base_name = f"{self._dataset.__class__.__name__.replace('Dataset', '')} Dataset"
|
100
|
+
title = f"Torchvision Wrapped {base_name}" if not base_name.startswith("Torchvision") else base_name
|
101
|
+
sep = "-" * len(title)
|
102
|
+
attrs = [
|
103
|
+
f"{' '.join(w.capitalize() for w in k.split('_'))}: {v}"
|
104
|
+
for k, v in self.__dict__.items()
|
105
|
+
if not k.startswith("_")
|
106
|
+
]
|
107
|
+
wrapped = f"{title}\n{sep}{nt}{nt.join(attrs)}"
|
108
|
+
return f"{wrapped}\n\n{self._dataset}"
|
109
|
+
|
110
|
+
def __len__(self) -> int:
|
111
|
+
return self._dataset.__len__()
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: maite-datasets
|
3
|
-
Version: 0.0.
|
3
|
+
Version: 0.0.6
|
4
4
|
Summary: A collection of Image Classification and Object Detection task datasets conforming to the MAITE protocol.
|
5
5
|
Author-email: Andrew Weng <andrew.weng@ariacoustics.com>, Ryan Wood <ryan.wood@ariacoustics.com>, Shaun Jullens <shaun.jullens@ariacoustics.com>
|
6
6
|
License-Expression: MIT
|
@@ -16,6 +16,7 @@ Classifier: Programming Language :: Python :: 3.11
|
|
16
16
|
Classifier: Programming Language :: Python :: 3.12
|
17
17
|
Requires-Python: >=3.9
|
18
18
|
Requires-Dist: defusedxml>=0.7.1
|
19
|
+
Requires-Dist: maite<0.9,>=0.7
|
19
20
|
Requires-Dist: numpy>=1.24.2
|
20
21
|
Requires-Dist: pillow>=10.3.0
|
21
22
|
Requires-Dist: requests>=2.32.3
|
@@ -42,7 +43,7 @@ For status bar indicators when downloading, you can include the extra `tqdm` whe
|
|
42
43
|
pip install maite-datasets[tqdm]
|
43
44
|
```
|
44
45
|
|
45
|
-
## Available Datasets
|
46
|
+
## Available Downloadable Datasets
|
46
47
|
|
47
48
|
| Task | Dataset | Description |
|
48
49
|
|----------------|------------------|---------------------------------------------------------------------|
|
@@ -54,7 +55,7 @@ pip install maite-datasets[tqdm]
|
|
54
55
|
| Detection | Seadrone | A UAV dataset focused on open water object detection. |
|
55
56
|
| Detection | VOCDetection | [Pascal VOC](http://host.robots.ox.ac.uk/pascal/VOC/) dataset. |
|
56
57
|
|
57
|
-
|
58
|
+
### Usage
|
58
59
|
|
59
60
|
Here is an example of how to import MNIST for usage with your workflow.
|
60
61
|
|
@@ -76,6 +77,58 @@ MNIST Dataset
|
|
76
77
|
tuple(<class 'numpy.ndarray'>, <class 'numpy.ndarray'>, <class 'dict'>)
|
77
78
|
```
|
78
79
|
|
80
|
+
## Dataset Wrappers
|
81
|
+
|
82
|
+
Wrappers provide a way to convert datasets to allow usage of tools within specific backend frameworks.
|
83
|
+
|
84
|
+
`TorchvisionWrapper` is a convenience class that wraps any of the datasets and provides the capability to apply
|
85
|
+
`torchvision` transforms to the dataset.
|
86
|
+
|
87
|
+
**NOTE:** `TorchvisionWrapper` requires _torch_ and _torchvision_ to be installed.
|
88
|
+
|
89
|
+
```python
|
90
|
+
>>> from maite_datasets.object_detection import MILCO
|
91
|
+
|
92
|
+
>>> milco = MILCO(root="data", download=True)
|
93
|
+
>>> print(milco)
|
94
|
+
MILCO Dataset
|
95
|
+
-------------
|
96
|
+
Transforms: []
|
97
|
+
Image Set: train
|
98
|
+
Metadata: {'id': 'MILCO_train', 'index2label': {0: 'MILCO', 1: 'NOMBO'}, 'split': 'train'}
|
99
|
+
Path: /home/user/maite-datasets/data/milco
|
100
|
+
Size: 261
|
101
|
+
|
102
|
+
>>> print(f"type={milco[0][0].__class__.__name__}, shape={milco[0][0].shape}")
|
103
|
+
type=ndarray, shape=(3, 1024, 1024)
|
104
|
+
|
105
|
+
>>> print(milco[0][1].boxes[0])
|
106
|
+
[ 75. 217. 130. 247.]
|
107
|
+
|
108
|
+
>>> from maite_datasets.wrappers import TorchvisionWrapper
|
109
|
+
>>> from torchvision.transforms.v2 import Resize
|
110
|
+
|
111
|
+
>>> milco_torch = TorchvisionWrapper(milco, transforms=Resize(224))
|
112
|
+
>>> print(milco_torch)
|
113
|
+
Torchvision Wrapped MILCO Dataset
|
114
|
+
---------------------------
|
115
|
+
Transforms: Resize(size=[224], interpolation=InterpolationMode.BILINEAR, antialias=True)
|
116
|
+
|
117
|
+
MILCO Dataset
|
118
|
+
-------------
|
119
|
+
Transforms: []
|
120
|
+
Image Set: train
|
121
|
+
Metadata: {'id': 'MILCO_train', 'index2label': {0: 'MILCO', 1: 'NOMBO'}, 'split': 'train'}
|
122
|
+
Path: /home/user/maite-datasets/data/milco
|
123
|
+
Size: 261
|
124
|
+
|
125
|
+
>>> print(f"type={milco_torch[0][0].__class__.__name__}, shape={milco_torch[0][0].shape}")
|
126
|
+
type=Image, shape=torch.Size([3, 224, 224])
|
127
|
+
|
128
|
+
>>> print(milco_torch[0][1].boxes[0])
|
129
|
+
tensor([16.4062, 47.4688, 28.4375, 54.0312], dtype=torch.float64)
|
130
|
+
```
|
131
|
+
|
79
132
|
## Additional Information
|
80
133
|
|
81
134
|
For more information on the MAITE protocol, check out their [documentation](https://mit-ll-ai-technology.github.io/maite/).
|
@@ -0,0 +1,26 @@
|
|
1
|
+
maite_datasets/__init__.py,sha256=Z_HyAe08HaHMjzZS2afFumBXYFRFj0ny5ZAIp0hcj4w,569
|
2
|
+
maite_datasets/_base.py,sha256=VEd4ipHPAOCbz4Zm8zdI2yQwQ_x9O4Wq01xoZ2QvNYo,12366
|
3
|
+
maite_datasets/_builder.py,sha256=MnCh6z5hSINlzBnK_pdbgI5zSg5d1uq4UvXt3cjn9hs,9820
|
4
|
+
maite_datasets/_collate.py,sha256=pwUnmrbJH5olFjSwF-ZkGdfopTWUUlwmq0d5KzERcy8,4052
|
5
|
+
maite_datasets/_fileio.py,sha256=7S-hF3xU60AdcsPsfYR7rjbeGZUlv3JjGEZhGJOxGYU,5622
|
6
|
+
maite_datasets/_reader.py,sha256=tJqsjfXaK-mrs0Ed4BktombFMmNwCur35W7tuYCflKM,5569
|
7
|
+
maite_datasets/_validate.py,sha256=Uokbolmv1uSv98sph44HON0HEieeK3s2mqbPMP1d5xs,6948
|
8
|
+
maite_datasets/protocols.py,sha256=YGXb-WxlneXdIQBfBy5OdbylHSVfM-RBXeGvpiWwfLU,607
|
9
|
+
maite_datasets/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
10
|
+
maite_datasets/image_classification/__init__.py,sha256=pcZojkdsiMoLgY4mKjoQY6WyEwiGYHxNrAGpnvn3zsY,308
|
11
|
+
maite_datasets/image_classification/_cifar10.py,sha256=rrkJZ70NBYOSGxligXyakVxNOyGAglN6PaGQazuWNO4,8453
|
12
|
+
maite_datasets/image_classification/_mnist.py,sha256=9isgdi-YXgs6nXoh1j8uOgh4_sIhBIky72Vyl866rTE,8192
|
13
|
+
maite_datasets/image_classification/_ships.py,sha256=nWhte8592lpybhQCCdgT36LnuMQ0PRJWlDxT5-IPUtk,5137
|
14
|
+
maite_datasets/object_detection/__init__.py,sha256=171KT_X6I4YGy18G240N_-ZsKvXJ6YqcBDzkhTiBj2E,587
|
15
|
+
maite_datasets/object_detection/_antiuav.py,sha256=B20JrbouDM1o5f1ct9Zfbkks8NaVqYrxu5x-rBZvGx8,8265
|
16
|
+
maite_datasets/object_detection/_coco.py,sha256=3abRQJ9ATcZOeqK-4pnMfr-pv7aGcRum88SRlLLXTzk,10309
|
17
|
+
maite_datasets/object_detection/_milco.py,sha256=brxxYs5ak0vEpOSd2IW5AMMVkuadVmXCJBFPvXTmNlo,7928
|
18
|
+
maite_datasets/object_detection/_seadrone.py,sha256=JdHL0eRZoe7pXVInOq5Xpnz3-vgeBxbO25oTYgGZ44o,271213
|
19
|
+
maite_datasets/object_detection/_voc.py,sha256=vgRn-sa_r2-hxwpM3veRZQMcWyqJz9OGalABOccZeow,19589
|
20
|
+
maite_datasets/object_detection/_yolo.py,sha256=Luojzhanh6AK949910jN0yTpy8zwF5_At6nThj3Zw9Q,11867
|
21
|
+
maite_datasets/wrappers/__init__.py,sha256=6uI0ztOB2IlMWln9JkVke4OhU2HQ8i6YCaCNq_q5qb0,225
|
22
|
+
maite_datasets/wrappers/_torch.py,sha256=dmY6nSyLyVPOzpOE4BDTyOomWdFpN0x5dmH3XUzNetc,4588
|
23
|
+
maite_datasets-0.0.6.dist-info/METADATA,sha256=cmDnRwPTu1xWbFIpnNQ0jmhEI7XTu6CQ3cm18y0dMDk,5505
|
24
|
+
maite_datasets-0.0.6.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
25
|
+
maite_datasets-0.0.6.dist-info/licenses/LICENSE,sha256=6h3J3R-ajGHh_isDSftzS5_jJjB9HH4TaI0vU-VscaY,1082
|
26
|
+
maite_datasets-0.0.6.dist-info/RECORD,,
|
File without changes
|
maite_datasets/_mixin/_numpy.py
DELETED
@@ -1,28 +0,0 @@
|
|
1
|
-
from __future__ import annotations
|
2
|
-
|
3
|
-
__all__ = []
|
4
|
-
|
5
|
-
from typing import Any
|
6
|
-
|
7
|
-
import numpy as np
|
8
|
-
from numpy.typing import NDArray
|
9
|
-
from PIL import Image
|
10
|
-
|
11
|
-
from maite_datasets._base import BaseDatasetMixin
|
12
|
-
|
13
|
-
|
14
|
-
class BaseDatasetNumpyMixin(BaseDatasetMixin[NDArray[np.number[Any]]]):
|
15
|
-
def _as_array(self, raw: list[Any]) -> NDArray[np.number[Any]]:
|
16
|
-
return np.asarray(raw)
|
17
|
-
|
18
|
-
def _one_hot_encode(self, value: int | list[int]) -> NDArray[np.number[Any]]:
|
19
|
-
if isinstance(value, int):
|
20
|
-
encoded = np.zeros(len(self.index2label))
|
21
|
-
encoded[value] = 1
|
22
|
-
else:
|
23
|
-
encoded = np.zeros((len(value), len(self.index2label)))
|
24
|
-
encoded[np.arange(len(value)), value] = 1
|
25
|
-
return encoded
|
26
|
-
|
27
|
-
def _read_file(self, path: str) -> NDArray[np.number[Any]]:
|
28
|
-
return np.array(Image.open(path)).transpose(2, 0, 1)
|
maite_datasets/_mixin/_torch.py
DELETED
@@ -1,28 +0,0 @@
|
|
1
|
-
from __future__ import annotations
|
2
|
-
|
3
|
-
__all__ = []
|
4
|
-
|
5
|
-
from typing import Any
|
6
|
-
|
7
|
-
import numpy as np
|
8
|
-
import torch
|
9
|
-
from PIL import Image
|
10
|
-
|
11
|
-
from maite_datasets._base import BaseDatasetMixin
|
12
|
-
|
13
|
-
|
14
|
-
class BaseDatasetTorchMixin(BaseDatasetMixin[torch.Tensor]):
|
15
|
-
def _as_array(self, raw: list[Any]) -> torch.Tensor:
|
16
|
-
return torch.as_tensor(raw)
|
17
|
-
|
18
|
-
def _one_hot_encode(self, value: int | list[int]) -> torch.Tensor:
|
19
|
-
if isinstance(value, int):
|
20
|
-
encoded = torch.zeros(len(self.index2label))
|
21
|
-
encoded[value] = 1
|
22
|
-
else:
|
23
|
-
encoded = torch.zeros((len(value), len(self.index2label)))
|
24
|
-
encoded[torch.arange(len(value)), value] = 1
|
25
|
-
return encoded
|
26
|
-
|
27
|
-
def _read_file(self, path: str) -> torch.Tensor:
|
28
|
-
return torch.as_tensor(np.array(Image.open(path)).transpose(2, 0, 1))
|
maite_datasets/_protocols.py
DELETED
@@ -1,217 +0,0 @@
|
|
1
|
-
"""
|
2
|
-
Common type protocols used for interoperability.
|
3
|
-
"""
|
4
|
-
|
5
|
-
from collections.abc import Iterator
|
6
|
-
import sys
|
7
|
-
from typing import (
|
8
|
-
Any,
|
9
|
-
Generic,
|
10
|
-
Protocol,
|
11
|
-
TypeAlias,
|
12
|
-
TypedDict,
|
13
|
-
TypeVar,
|
14
|
-
runtime_checkable,
|
15
|
-
)
|
16
|
-
|
17
|
-
import numpy.typing
|
18
|
-
from typing_extensions import NotRequired, ReadOnly, Required
|
19
|
-
|
20
|
-
if sys.version_info >= (3, 10):
|
21
|
-
from typing import TypeAlias
|
22
|
-
else:
|
23
|
-
from typing_extensions import TypeAlias
|
24
|
-
|
25
|
-
|
26
|
-
ArrayLike: TypeAlias = numpy.typing.ArrayLike
|
27
|
-
"""
|
28
|
-
Type alias for a `Union` representing objects that can be coerced into an array.
|
29
|
-
|
30
|
-
See Also
|
31
|
-
--------
|
32
|
-
`NumPy ArrayLike <https://numpy.org/doc/stable/reference/typing.html#numpy.typing.ArrayLike>`_
|
33
|
-
"""
|
34
|
-
|
35
|
-
|
36
|
-
@runtime_checkable
|
37
|
-
class Array(Protocol):
|
38
|
-
"""
|
39
|
-
Protocol for interoperable array objects.
|
40
|
-
|
41
|
-
Supports common array representations with popular libraries like
|
42
|
-
PyTorch, Tensorflow and JAX, as well as NumPy arrays.
|
43
|
-
"""
|
44
|
-
|
45
|
-
@property
|
46
|
-
def shape(self) -> tuple[int, ...]: ...
|
47
|
-
def __array__(self) -> Any: ...
|
48
|
-
def __getitem__(self, key: Any, /) -> Any: ...
|
49
|
-
def __iter__(self) -> Iterator[Any]: ...
|
50
|
-
def __len__(self) -> int: ...
|
51
|
-
|
52
|
-
|
53
|
-
_T = TypeVar("_T")
|
54
|
-
_T_co = TypeVar("_T_co", covariant=True)
|
55
|
-
_T_cn = TypeVar("_T_cn", contravariant=True)
|
56
|
-
|
57
|
-
|
58
|
-
class DatasetMetadata(TypedDict, total=False):
|
59
|
-
"""
|
60
|
-
Dataset level metadata required for all `AnnotatedDataset` classes.
|
61
|
-
|
62
|
-
Attributes
|
63
|
-
----------
|
64
|
-
id : Required[str]
|
65
|
-
A unique identifier for the dataset
|
66
|
-
index2label : NotRequired[dict[int, str]]
|
67
|
-
A lookup table converting label value to class name
|
68
|
-
"""
|
69
|
-
|
70
|
-
id: Required[ReadOnly[str]]
|
71
|
-
index2label: NotRequired[ReadOnly[dict[int, str]]]
|
72
|
-
|
73
|
-
|
74
|
-
class DatumMetadata(TypedDict, total=False):
|
75
|
-
"""
|
76
|
-
Datum level metadata required for all `AnnotatedDataset` classes.
|
77
|
-
|
78
|
-
Attributes
|
79
|
-
----------
|
80
|
-
id : Required[str]
|
81
|
-
A unique identifier for the datum
|
82
|
-
"""
|
83
|
-
|
84
|
-
id: Required[ReadOnly[str]]
|
85
|
-
|
86
|
-
|
87
|
-
@runtime_checkable
|
88
|
-
class Dataset(Generic[_T_co], Protocol):
|
89
|
-
"""
|
90
|
-
Protocol for a generic `Dataset`.
|
91
|
-
|
92
|
-
Methods
|
93
|
-
-------
|
94
|
-
__getitem__(index: int)
|
95
|
-
Returns datum at specified index.
|
96
|
-
__len__()
|
97
|
-
Returns dataset length.
|
98
|
-
"""
|
99
|
-
|
100
|
-
def __getitem__(self, index: int, /) -> _T_co: ...
|
101
|
-
def __len__(self) -> int: ...
|
102
|
-
|
103
|
-
|
104
|
-
@runtime_checkable
|
105
|
-
class AnnotatedDataset(Dataset[_T_co], Generic[_T_co], Protocol):
|
106
|
-
"""
|
107
|
-
Protocol for a generic `AnnotatedDataset`.
|
108
|
-
|
109
|
-
Attributes
|
110
|
-
----------
|
111
|
-
metadata : :class:`.DatasetMetadata` or derivatives.
|
112
|
-
|
113
|
-
Methods
|
114
|
-
-------
|
115
|
-
__getitem__(index: int)
|
116
|
-
Returns datum at specified index.
|
117
|
-
__len__()
|
118
|
-
Returns dataset length.
|
119
|
-
|
120
|
-
Notes
|
121
|
-
-----
|
122
|
-
Inherits from :class:`.Dataset`.
|
123
|
-
"""
|
124
|
-
|
125
|
-
@property
|
126
|
-
def metadata(self) -> DatasetMetadata: ...
|
127
|
-
|
128
|
-
|
129
|
-
# ========== IMAGE CLASSIFICATION DATASETS ==========
|
130
|
-
|
131
|
-
|
132
|
-
ImageClassificationDatum: TypeAlias = tuple[ArrayLike, ArrayLike, DatumMetadata]
|
133
|
-
"""
|
134
|
-
Type alias for an image classification datum tuple.
|
135
|
-
|
136
|
-
- :class:`ArrayLike` of shape (C, H, W) - Image data in channel, height, width format.
|
137
|
-
- :class:`ArrayLike` of shape (N,) - Class label as one-hot encoded ground-truth or prediction confidences.
|
138
|
-
- dict[str, Any] - Datum level metadata.
|
139
|
-
"""
|
140
|
-
|
141
|
-
|
142
|
-
ImageClassificationDataset: TypeAlias = AnnotatedDataset[ImageClassificationDatum]
|
143
|
-
"""
|
144
|
-
Type alias for an :class:`AnnotatedDataset` of :class:`ImageClassificationDatum` elements.
|
145
|
-
"""
|
146
|
-
|
147
|
-
# ========== OBJECT DETECTION DATASETS ==========
|
148
|
-
|
149
|
-
|
150
|
-
@runtime_checkable
|
151
|
-
class ObjectDetectionTarget(Protocol):
|
152
|
-
"""
|
153
|
-
Protocol for targets in an Object Detection dataset.
|
154
|
-
|
155
|
-
Attributes
|
156
|
-
----------
|
157
|
-
boxes : :class:`ArrayLike` of shape (N, 4)
|
158
|
-
labels : :class:`ArrayLike` of shape (N,)
|
159
|
-
scores : :class:`ArrayLike` of shape (N, M)
|
160
|
-
"""
|
161
|
-
|
162
|
-
@property
|
163
|
-
def boxes(self) -> ArrayLike: ...
|
164
|
-
|
165
|
-
@property
|
166
|
-
def labels(self) -> ArrayLike: ...
|
167
|
-
|
168
|
-
@property
|
169
|
-
def scores(self) -> ArrayLike: ...
|
170
|
-
|
171
|
-
|
172
|
-
ObjectDetectionDatum: TypeAlias = tuple[ArrayLike, ObjectDetectionTarget, DatumMetadata]
|
173
|
-
"""
|
174
|
-
Type alias for an object detection datum tuple.
|
175
|
-
|
176
|
-
- :class:`ArrayLike` of shape (C, H, W) - Image data in channel, height, width format.
|
177
|
-
- :class:`ObjectDetectionTarget` - Object detection target information for the image.
|
178
|
-
- dict[str, Any] - Datum level metadata.
|
179
|
-
"""
|
180
|
-
|
181
|
-
|
182
|
-
ObjectDetectionDataset: TypeAlias = AnnotatedDataset[ObjectDetectionDatum]
|
183
|
-
"""
|
184
|
-
Type alias for an :class:`AnnotatedDataset` of :class:`ObjectDetectionDatum` elements.
|
185
|
-
"""
|
186
|
-
|
187
|
-
|
188
|
-
# ========== TRANSFORM ==========
|
189
|
-
|
190
|
-
|
191
|
-
@runtime_checkable
|
192
|
-
class Transform(Generic[_T], Protocol):
|
193
|
-
"""
|
194
|
-
Protocol defining a transform function.
|
195
|
-
|
196
|
-
Requires a `__call__` method that returns transformed data.
|
197
|
-
|
198
|
-
Example
|
199
|
-
-------
|
200
|
-
>>> from typing import Any
|
201
|
-
>>> from numpy.typing import NDArray
|
202
|
-
|
203
|
-
>>> class MyTransform:
|
204
|
-
... def __init__(self, divisor: float) -> None:
|
205
|
-
... self.divisor = divisor
|
206
|
-
...
|
207
|
-
... def __call__(self, data: NDArray[Any], /) -> NDArray[Any]:
|
208
|
-
... return data / self.divisor
|
209
|
-
|
210
|
-
>>> my_transform = MyTransform(divisor=255.0)
|
211
|
-
>>> isinstance(my_transform, Transform)
|
212
|
-
True
|
213
|
-
>>> my_transform(np.array([1, 2, 3]))
|
214
|
-
array([0.004, 0.008, 0.012])
|
215
|
-
"""
|
216
|
-
|
217
|
-
def __call__(self, data: _T, /) -> _T: ...
|
@@ -1,64 +0,0 @@
|
|
1
|
-
from __future__ import annotations
|
2
|
-
|
3
|
-
import logging
|
4
|
-
from pathlib import Path
|
5
|
-
|
6
|
-
from maite_datasets._reader._base import BaseDatasetReader
|
7
|
-
from maite_datasets._reader._yolo import YOLODatasetReader
|
8
|
-
from maite_datasets._reader._coco import COCODatasetReader
|
9
|
-
|
10
|
-
_logger = logging.getLogger(__name__)
|
11
|
-
|
12
|
-
|
13
|
-
def create_dataset_reader(dataset_path: str | Path, format_hint: str | None = None) -> BaseDatasetReader:
|
14
|
-
"""
|
15
|
-
Factory function to create appropriate dataset reader based on directory structure.
|
16
|
-
|
17
|
-
Parameters
|
18
|
-
----------
|
19
|
-
dataset_path : str or Path
|
20
|
-
Root directory containing dataset files
|
21
|
-
format_hint : str or None, default None
|
22
|
-
Format hint ("coco" or "yolo"). If None, auto-detects based on file structure
|
23
|
-
|
24
|
-
Returns
|
25
|
-
-------
|
26
|
-
BaseDatasetReader
|
27
|
-
Appropriate reader instance for the detected format
|
28
|
-
|
29
|
-
Raises
|
30
|
-
------
|
31
|
-
ValueError
|
32
|
-
If format cannot be determined or is unsupported
|
33
|
-
"""
|
34
|
-
dataset_path = Path(dataset_path)
|
35
|
-
|
36
|
-
if format_hint:
|
37
|
-
format_hint = format_hint.lower()
|
38
|
-
if format_hint == "coco":
|
39
|
-
return COCODatasetReader(dataset_path)
|
40
|
-
elif format_hint == "yolo":
|
41
|
-
return YOLODatasetReader(dataset_path)
|
42
|
-
else:
|
43
|
-
raise ValueError(f"Unsupported format hint: {format_hint}")
|
44
|
-
|
45
|
-
# Auto-detect format
|
46
|
-
has_annotations_json = (dataset_path / "annotations.json").exists()
|
47
|
-
has_labels_dir = (dataset_path / "labels").exists()
|
48
|
-
|
49
|
-
if has_annotations_json and not has_labels_dir:
|
50
|
-
_logger.info(f"Detected COCO format for {dataset_path}")
|
51
|
-
return COCODatasetReader(dataset_path)
|
52
|
-
elif has_labels_dir and not has_annotations_json:
|
53
|
-
_logger.info(f"Detected YOLO format for {dataset_path}")
|
54
|
-
return YOLODatasetReader(dataset_path)
|
55
|
-
elif has_annotations_json and has_labels_dir:
|
56
|
-
raise ValueError(
|
57
|
-
f"Ambiguous format in {dataset_path}: both annotations.json and labels/ exist. "
|
58
|
-
"Use format_hint parameter to specify format."
|
59
|
-
)
|
60
|
-
else:
|
61
|
-
raise ValueError(
|
62
|
-
f"Cannot detect dataset format in {dataset_path}. "
|
63
|
-
"Expected either annotations.json (COCO) or labels/ directory (YOLO)."
|
64
|
-
)
|
maite_datasets/_types.py
DELETED
@@ -1,50 +0,0 @@
|
|
1
|
-
from __future__ import annotations
|
2
|
-
|
3
|
-
__all__ = []
|
4
|
-
|
5
|
-
from dataclasses import dataclass
|
6
|
-
from typing import Generic, TypedDict, TypeVar
|
7
|
-
|
8
|
-
from typing_extensions import NotRequired, Required
|
9
|
-
|
10
|
-
_T_co = TypeVar("_T_co", covariant=True)
|
11
|
-
|
12
|
-
|
13
|
-
class Dataset(Generic[_T_co]):
|
14
|
-
"""Abstract generic base class for PyTorch style Dataset"""
|
15
|
-
|
16
|
-
def __getitem__(self, index: int) -> _T_co: ...
|
17
|
-
def __add__(self, other: Dataset[_T_co]) -> Dataset[_T_co]: ...
|
18
|
-
|
19
|
-
|
20
|
-
class DatasetMetadata(TypedDict):
|
21
|
-
id: Required[str]
|
22
|
-
index2label: NotRequired[dict[int, str]]
|
23
|
-
split: NotRequired[str]
|
24
|
-
|
25
|
-
|
26
|
-
class DatumMetadata(TypedDict, total=False):
|
27
|
-
id: Required[str]
|
28
|
-
|
29
|
-
|
30
|
-
_TDatum = TypeVar("_TDatum")
|
31
|
-
_TArray = TypeVar("_TArray")
|
32
|
-
|
33
|
-
|
34
|
-
class AnnotatedDataset(Dataset[_TDatum]):
|
35
|
-
metadata: DatasetMetadata
|
36
|
-
|
37
|
-
def __len__(self) -> int: ...
|
38
|
-
|
39
|
-
|
40
|
-
class ImageClassificationDataset(AnnotatedDataset[tuple[_TArray, _TArray, DatumMetadata]]): ...
|
41
|
-
|
42
|
-
|
43
|
-
@dataclass
|
44
|
-
class ObjectDetectionTarget(Generic[_TArray]):
|
45
|
-
boxes: _TArray
|
46
|
-
labels: _TArray
|
47
|
-
scores: _TArray
|
48
|
-
|
49
|
-
|
50
|
-
class ObjectDetectionDataset(AnnotatedDataset[tuple[_TArray, ObjectDetectionTarget[_TArray], DatumMetadata]]): ...
|