maite-datasets 0.0.4a0__py3-none-any.whl → 0.0.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (32) hide show
  1. maite_datasets/__init__.py +2 -6
  2. maite_datasets/_base.py +169 -51
  3. maite_datasets/_builder.py +46 -55
  4. maite_datasets/_collate.py +2 -3
  5. maite_datasets/{_reader/_base.py → _reader.py} +62 -36
  6. maite_datasets/_validate.py +4 -2
  7. maite_datasets/image_classification/_cifar10.py +12 -7
  8. maite_datasets/image_classification/_mnist.py +15 -10
  9. maite_datasets/image_classification/_ships.py +12 -8
  10. maite_datasets/object_detection/__init__.py +4 -7
  11. maite_datasets/object_detection/_antiuav.py +11 -8
  12. maite_datasets/{_reader → object_detection}/_coco.py +29 -27
  13. maite_datasets/object_detection/_milco.py +11 -9
  14. maite_datasets/object_detection/_seadrone.py +11 -9
  15. maite_datasets/object_detection/_voc.py +11 -13
  16. maite_datasets/{_reader → object_detection}/_yolo.py +26 -21
  17. maite_datasets/protocols.py +23 -0
  18. maite_datasets/wrappers/__init__.py +8 -0
  19. maite_datasets/wrappers/_torch.py +111 -0
  20. {maite_datasets-0.0.4a0.dist-info → maite_datasets-0.0.6.dist-info}/METADATA +56 -3
  21. maite_datasets-0.0.6.dist-info/RECORD +26 -0
  22. maite_datasets/_mixin/__init__.py +0 -0
  23. maite_datasets/_mixin/_numpy.py +0 -28
  24. maite_datasets/_mixin/_torch.py +0 -28
  25. maite_datasets/_protocols.py +0 -217
  26. maite_datasets/_reader/__init__.py +0 -6
  27. maite_datasets/_reader/_factory.py +0 -64
  28. maite_datasets/_types.py +0 -50
  29. maite_datasets/object_detection/_voc_torch.py +0 -65
  30. maite_datasets-0.0.4a0.dist-info/RECORD +0 -31
  31. {maite_datasets-0.0.4a0.dist-info → maite_datasets-0.0.6.dist-info}/WHEEL +0 -0
  32. {maite_datasets-0.0.4a0.dist-info → maite_datasets-0.0.6.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,111 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import Any, Callable, Generic, TypeAlias, TypeVar, cast, overload
4
+
5
+ import numpy as np
6
+ import torch
7
+ from maite.protocols import DatasetMetadata, DatumMetadata
8
+ from maite.protocols.object_detection import ObjectDetectionTarget as _ObjectDetectionTarget
9
+ from torch import Tensor
10
+ from torchvision.tv_tensors import BoundingBoxes, Image
11
+
12
+ from maite_datasets._base import BaseDataset, ObjectDetectionTarget
13
+ from maite_datasets.protocols import Array
14
+
15
+ TArray = TypeVar("TArray", bound=Array)
16
+ TTarget = TypeVar("TTarget")
17
+
18
+ TorchvisionImageClassificationDatum: TypeAlias = tuple[Image, Tensor, DatumMetadata]
19
+ TorchvisionObjectDetectionDatum: TypeAlias = tuple[Image, ObjectDetectionTarget, DatumMetadata]
20
+
21
+
22
+ class TorchvisionWrapper(Generic[TArray, TTarget]):
23
+ """
24
+ Lightweight wrapper converting numpy-based datasets to Torchvision tensors.
25
+
26
+ Converts images to tv_tensor.Image and targets to the appropriate torchvision format.
27
+
28
+ Parameters
29
+ ----------
30
+ dataset : Dataset
31
+ Source dataset with numpy arrays
32
+ transforms : callable, optional
33
+ Torchvision v2 transform functions for targets
34
+ """
35
+
36
+ def __init__(
37
+ self,
38
+ dataset: BaseDataset[TArray, TTarget],
39
+ transforms: Callable[[Any], Any] | None = None,
40
+ ) -> None:
41
+ self._dataset = dataset
42
+ self.transforms = transforms
43
+ self.metadata: DatasetMetadata = {
44
+ "id": f"TorchvisionWrapper({dataset.metadata['id']})",
45
+ "index2label": dataset.metadata.get("index2label", {}),
46
+ }
47
+
48
+ def __getattr__(self, name: str) -> Any:
49
+ """Forward unknown attributes to wrapped dataset."""
50
+ return getattr(self._dataset, name)
51
+
52
+ def __dir__(self) -> list[str]:
53
+ """Include wrapped dataset attributes in dir() for IDE support."""
54
+ wrapper_attrs = set(super().__dir__())
55
+ dataset_attrs = set(dir(self._dataset))
56
+ return sorted(wrapper_attrs | dataset_attrs)
57
+
58
+ def _transform(self, datum: Any) -> Any:
59
+ return self.transforms(datum) if self.transforms else datum
60
+
61
+ @overload
62
+ def __getitem__(self: TorchvisionWrapper[TArray, TArray], index: int) -> tuple[Image, Tensor, DatumMetadata]: ...
63
+ @overload
64
+ def __getitem__(
65
+ self: TorchvisionWrapper[TArray, TTarget], index: int
66
+ ) -> tuple[Image, ObjectDetectionTarget, DatumMetadata]: ...
67
+
68
+ def __getitem__(self, index: int) -> tuple[Image, Tensor | ObjectDetectionTarget, DatumMetadata]:
69
+ """Get item with torch tensor conversion."""
70
+ image, target, metadata = self._dataset[index]
71
+
72
+ # Convert image to torch tensor
73
+ torch_image = torch.from_numpy(image) if isinstance(image, np.ndarray) else torch.as_tensor(image)
74
+ torch_image = Image(torch_image)
75
+
76
+ # Handle different target types
77
+ if isinstance(target, Array):
78
+ # Image classification case
79
+ torch_target = torch.as_tensor(target, dtype=torch.float32)
80
+ torch_datum = self._transform((torch_image, torch_target, metadata))
81
+ return cast(TorchvisionImageClassificationDatum, torch_datum)
82
+
83
+ if isinstance(target, _ObjectDetectionTarget):
84
+ # Object detection case
85
+ torch_boxes = BoundingBoxes(
86
+ torch.as_tensor(target.boxes), format="XYXY", canvas_size=(torch_image.shape[-2], torch_image.shape[-1])
87
+ ) # type: ignore
88
+ torch_labels = torch.as_tensor(target.labels, dtype=torch.int64)
89
+ torch_scores = torch.as_tensor(target.scores, dtype=torch.float32)
90
+ torch_target = ObjectDetectionTarget(torch_boxes, torch_labels, torch_scores)
91
+ torch_datum = self._transform((torch_image, torch_target, metadata))
92
+ return cast(TorchvisionObjectDetectionDatum, torch_datum)
93
+
94
+ raise TypeError(f"Unsupported target type: {type(target)}")
95
+
96
+ def __str__(self) -> str:
97
+ """String representation showing torch version."""
98
+ nt = "\n "
99
+ base_name = f"{self._dataset.__class__.__name__.replace('Dataset', '')} Dataset"
100
+ title = f"Torchvision Wrapped {base_name}" if not base_name.startswith("Torchvision") else base_name
101
+ sep = "-" * len(title)
102
+ attrs = [
103
+ f"{' '.join(w.capitalize() for w in k.split('_'))}: {v}"
104
+ for k, v in self.__dict__.items()
105
+ if not k.startswith("_")
106
+ ]
107
+ wrapped = f"{title}\n{sep}{nt}{nt.join(attrs)}"
108
+ return f"{wrapped}\n\n{self._dataset}"
109
+
110
+ def __len__(self) -> int:
111
+ return self._dataset.__len__()
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: maite-datasets
3
- Version: 0.0.4a0
3
+ Version: 0.0.6
4
4
  Summary: A collection of Image Classification and Object Detection task datasets conforming to the MAITE protocol.
5
5
  Author-email: Andrew Weng <andrew.weng@ariacoustics.com>, Ryan Wood <ryan.wood@ariacoustics.com>, Shaun Jullens <shaun.jullens@ariacoustics.com>
6
6
  License-Expression: MIT
@@ -16,6 +16,7 @@ Classifier: Programming Language :: Python :: 3.11
16
16
  Classifier: Programming Language :: Python :: 3.12
17
17
  Requires-Python: >=3.9
18
18
  Requires-Dist: defusedxml>=0.7.1
19
+ Requires-Dist: maite<0.9,>=0.7
19
20
  Requires-Dist: numpy>=1.24.2
20
21
  Requires-Dist: pillow>=10.3.0
21
22
  Requires-Dist: requests>=2.32.3
@@ -42,7 +43,7 @@ For status bar indicators when downloading, you can include the extra `tqdm` whe
42
43
  pip install maite-datasets[tqdm]
43
44
  ```
44
45
 
45
- ## Available Datasets
46
+ ## Available Downloadable Datasets
46
47
 
47
48
  | Task | Dataset | Description |
48
49
  |----------------|------------------|---------------------------------------------------------------------|
@@ -54,7 +55,7 @@ pip install maite-datasets[tqdm]
54
55
  | Detection | Seadrone | A UAV dataset focused on open water object detection. |
55
56
  | Detection | VOCDetection | [Pascal VOC](http://host.robots.ox.ac.uk/pascal/VOC/) dataset. |
56
57
 
57
- ## Usage
58
+ ### Usage
58
59
 
59
60
  Here is an example of how to import MNIST for usage with your workflow.
60
61
 
@@ -76,6 +77,58 @@ MNIST Dataset
76
77
  tuple(<class 'numpy.ndarray'>, <class 'numpy.ndarray'>, <class 'dict'>)
77
78
  ```
78
79
 
80
+ ## Dataset Wrappers
81
+
82
+ Wrappers provide a way to convert datasets to allow usage of tools within specific backend frameworks.
83
+
84
+ `TorchvisionWrapper` is a convenience class that wraps any of the datasets and provides the capability to apply
85
+ `torchvision` transforms to the dataset.
86
+
87
+ **NOTE:** `TorchvisionWrapper` requires _torch_ and _torchvision_ to be installed.
88
+
89
+ ```python
90
+ >>> from maite_datasets.object_detection import MILCO
91
+
92
+ >>> milco = MILCO(root="data", download=True)
93
+ >>> print(milco)
94
+ MILCO Dataset
95
+ -------------
96
+ Transforms: []
97
+ Image Set: train
98
+ Metadata: {'id': 'MILCO_train', 'index2label': {0: 'MILCO', 1: 'NOMBO'}, 'split': 'train'}
99
+ Path: /home/user/maite-datasets/data/milco
100
+ Size: 261
101
+
102
+ >>> print(f"type={milco[0][0].__class__.__name__}, shape={milco[0][0].shape}")
103
+ type=ndarray, shape=(3, 1024, 1024)
104
+
105
+ >>> print(milco[0][1].boxes[0])
106
+ [ 75. 217. 130. 247.]
107
+
108
+ >>> from maite_datasets.wrappers import TorchvisionWrapper
109
+ >>> from torchvision.transforms.v2 import Resize
110
+
111
+ >>> milco_torch = TorchvisionWrapper(milco, transforms=Resize(224))
112
+ >>> print(milco_torch)
113
+ Torchvision Wrapped MILCO Dataset
114
+ ---------------------------
115
+ Transforms: Resize(size=[224], interpolation=InterpolationMode.BILINEAR, antialias=True)
116
+
117
+ MILCO Dataset
118
+ -------------
119
+ Transforms: []
120
+ Image Set: train
121
+ Metadata: {'id': 'MILCO_train', 'index2label': {0: 'MILCO', 1: 'NOMBO'}, 'split': 'train'}
122
+ Path: /home/user/maite-datasets/data/milco
123
+ Size: 261
124
+
125
+ >>> print(f"type={milco_torch[0][0].__class__.__name__}, shape={milco_torch[0][0].shape}")
126
+ type=Image, shape=torch.Size([3, 224, 224])
127
+
128
+ >>> print(milco_torch[0][1].boxes[0])
129
+ tensor([16.4062, 47.4688, 28.4375, 54.0312], dtype=torch.float64)
130
+ ```
131
+
79
132
  ## Additional Information
80
133
 
81
134
  For more information on the MAITE protocol, check out their [documentation](https://mit-ll-ai-technology.github.io/maite/).
@@ -0,0 +1,26 @@
1
+ maite_datasets/__init__.py,sha256=Z_HyAe08HaHMjzZS2afFumBXYFRFj0ny5ZAIp0hcj4w,569
2
+ maite_datasets/_base.py,sha256=VEd4ipHPAOCbz4Zm8zdI2yQwQ_x9O4Wq01xoZ2QvNYo,12366
3
+ maite_datasets/_builder.py,sha256=MnCh6z5hSINlzBnK_pdbgI5zSg5d1uq4UvXt3cjn9hs,9820
4
+ maite_datasets/_collate.py,sha256=pwUnmrbJH5olFjSwF-ZkGdfopTWUUlwmq0d5KzERcy8,4052
5
+ maite_datasets/_fileio.py,sha256=7S-hF3xU60AdcsPsfYR7rjbeGZUlv3JjGEZhGJOxGYU,5622
6
+ maite_datasets/_reader.py,sha256=tJqsjfXaK-mrs0Ed4BktombFMmNwCur35W7tuYCflKM,5569
7
+ maite_datasets/_validate.py,sha256=Uokbolmv1uSv98sph44HON0HEieeK3s2mqbPMP1d5xs,6948
8
+ maite_datasets/protocols.py,sha256=YGXb-WxlneXdIQBfBy5OdbylHSVfM-RBXeGvpiWwfLU,607
9
+ maite_datasets/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
10
+ maite_datasets/image_classification/__init__.py,sha256=pcZojkdsiMoLgY4mKjoQY6WyEwiGYHxNrAGpnvn3zsY,308
11
+ maite_datasets/image_classification/_cifar10.py,sha256=rrkJZ70NBYOSGxligXyakVxNOyGAglN6PaGQazuWNO4,8453
12
+ maite_datasets/image_classification/_mnist.py,sha256=9isgdi-YXgs6nXoh1j8uOgh4_sIhBIky72Vyl866rTE,8192
13
+ maite_datasets/image_classification/_ships.py,sha256=nWhte8592lpybhQCCdgT36LnuMQ0PRJWlDxT5-IPUtk,5137
14
+ maite_datasets/object_detection/__init__.py,sha256=171KT_X6I4YGy18G240N_-ZsKvXJ6YqcBDzkhTiBj2E,587
15
+ maite_datasets/object_detection/_antiuav.py,sha256=B20JrbouDM1o5f1ct9Zfbkks8NaVqYrxu5x-rBZvGx8,8265
16
+ maite_datasets/object_detection/_coco.py,sha256=3abRQJ9ATcZOeqK-4pnMfr-pv7aGcRum88SRlLLXTzk,10309
17
+ maite_datasets/object_detection/_milco.py,sha256=brxxYs5ak0vEpOSd2IW5AMMVkuadVmXCJBFPvXTmNlo,7928
18
+ maite_datasets/object_detection/_seadrone.py,sha256=JdHL0eRZoe7pXVInOq5Xpnz3-vgeBxbO25oTYgGZ44o,271213
19
+ maite_datasets/object_detection/_voc.py,sha256=vgRn-sa_r2-hxwpM3veRZQMcWyqJz9OGalABOccZeow,19589
20
+ maite_datasets/object_detection/_yolo.py,sha256=Luojzhanh6AK949910jN0yTpy8zwF5_At6nThj3Zw9Q,11867
21
+ maite_datasets/wrappers/__init__.py,sha256=6uI0ztOB2IlMWln9JkVke4OhU2HQ8i6YCaCNq_q5qb0,225
22
+ maite_datasets/wrappers/_torch.py,sha256=dmY6nSyLyVPOzpOE4BDTyOomWdFpN0x5dmH3XUzNetc,4588
23
+ maite_datasets-0.0.6.dist-info/METADATA,sha256=cmDnRwPTu1xWbFIpnNQ0jmhEI7XTu6CQ3cm18y0dMDk,5505
24
+ maite_datasets-0.0.6.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
25
+ maite_datasets-0.0.6.dist-info/licenses/LICENSE,sha256=6h3J3R-ajGHh_isDSftzS5_jJjB9HH4TaI0vU-VscaY,1082
26
+ maite_datasets-0.0.6.dist-info/RECORD,,
File without changes
@@ -1,28 +0,0 @@
1
- from __future__ import annotations
2
-
3
- __all__ = []
4
-
5
- from typing import Any
6
-
7
- import numpy as np
8
- from numpy.typing import NDArray
9
- from PIL import Image
10
-
11
- from maite_datasets._base import BaseDatasetMixin
12
-
13
-
14
- class BaseDatasetNumpyMixin(BaseDatasetMixin[NDArray[np.number[Any]]]):
15
- def _as_array(self, raw: list[Any]) -> NDArray[np.number[Any]]:
16
- return np.asarray(raw)
17
-
18
- def _one_hot_encode(self, value: int | list[int]) -> NDArray[np.number[Any]]:
19
- if isinstance(value, int):
20
- encoded = np.zeros(len(self.index2label))
21
- encoded[value] = 1
22
- else:
23
- encoded = np.zeros((len(value), len(self.index2label)))
24
- encoded[np.arange(len(value)), value] = 1
25
- return encoded
26
-
27
- def _read_file(self, path: str) -> NDArray[np.number[Any]]:
28
- return np.array(Image.open(path)).transpose(2, 0, 1)
@@ -1,28 +0,0 @@
1
- from __future__ import annotations
2
-
3
- __all__ = []
4
-
5
- from typing import Any
6
-
7
- import numpy as np
8
- import torch
9
- from PIL import Image
10
-
11
- from maite_datasets._base import BaseDatasetMixin
12
-
13
-
14
- class BaseDatasetTorchMixin(BaseDatasetMixin[torch.Tensor]):
15
- def _as_array(self, raw: list[Any]) -> torch.Tensor:
16
- return torch.as_tensor(raw)
17
-
18
- def _one_hot_encode(self, value: int | list[int]) -> torch.Tensor:
19
- if isinstance(value, int):
20
- encoded = torch.zeros(len(self.index2label))
21
- encoded[value] = 1
22
- else:
23
- encoded = torch.zeros((len(value), len(self.index2label)))
24
- encoded[torch.arange(len(value)), value] = 1
25
- return encoded
26
-
27
- def _read_file(self, path: str) -> torch.Tensor:
28
- return torch.as_tensor(np.array(Image.open(path)).transpose(2, 0, 1))
@@ -1,217 +0,0 @@
1
- """
2
- Common type protocols used for interoperability.
3
- """
4
-
5
- from collections.abc import Iterator
6
- import sys
7
- from typing import (
8
- Any,
9
- Generic,
10
- Protocol,
11
- TypeAlias,
12
- TypedDict,
13
- TypeVar,
14
- runtime_checkable,
15
- )
16
-
17
- import numpy.typing
18
- from typing_extensions import NotRequired, ReadOnly, Required
19
-
20
- if sys.version_info >= (3, 10):
21
- from typing import TypeAlias
22
- else:
23
- from typing_extensions import TypeAlias
24
-
25
-
26
- ArrayLike: TypeAlias = numpy.typing.ArrayLike
27
- """
28
- Type alias for a `Union` representing objects that can be coerced into an array.
29
-
30
- See Also
31
- --------
32
- `NumPy ArrayLike <https://numpy.org/doc/stable/reference/typing.html#numpy.typing.ArrayLike>`_
33
- """
34
-
35
-
36
- @runtime_checkable
37
- class Array(Protocol):
38
- """
39
- Protocol for interoperable array objects.
40
-
41
- Supports common array representations with popular libraries like
42
- PyTorch, Tensorflow and JAX, as well as NumPy arrays.
43
- """
44
-
45
- @property
46
- def shape(self) -> tuple[int, ...]: ...
47
- def __array__(self) -> Any: ...
48
- def __getitem__(self, key: Any, /) -> Any: ...
49
- def __iter__(self) -> Iterator[Any]: ...
50
- def __len__(self) -> int: ...
51
-
52
-
53
- _T = TypeVar("_T")
54
- _T_co = TypeVar("_T_co", covariant=True)
55
- _T_cn = TypeVar("_T_cn", contravariant=True)
56
-
57
-
58
- class DatasetMetadata(TypedDict, total=False):
59
- """
60
- Dataset level metadata required for all `AnnotatedDataset` classes.
61
-
62
- Attributes
63
- ----------
64
- id : Required[str]
65
- A unique identifier for the dataset
66
- index2label : NotRequired[dict[int, str]]
67
- A lookup table converting label value to class name
68
- """
69
-
70
- id: Required[ReadOnly[str]]
71
- index2label: NotRequired[ReadOnly[dict[int, str]]]
72
-
73
-
74
- class DatumMetadata(TypedDict, total=False):
75
- """
76
- Datum level metadata required for all `AnnotatedDataset` classes.
77
-
78
- Attributes
79
- ----------
80
- id : Required[str]
81
- A unique identifier for the datum
82
- """
83
-
84
- id: Required[ReadOnly[str]]
85
-
86
-
87
- @runtime_checkable
88
- class Dataset(Generic[_T_co], Protocol):
89
- """
90
- Protocol for a generic `Dataset`.
91
-
92
- Methods
93
- -------
94
- __getitem__(index: int)
95
- Returns datum at specified index.
96
- __len__()
97
- Returns dataset length.
98
- """
99
-
100
- def __getitem__(self, index: int, /) -> _T_co: ...
101
- def __len__(self) -> int: ...
102
-
103
-
104
- @runtime_checkable
105
- class AnnotatedDataset(Dataset[_T_co], Generic[_T_co], Protocol):
106
- """
107
- Protocol for a generic `AnnotatedDataset`.
108
-
109
- Attributes
110
- ----------
111
- metadata : :class:`.DatasetMetadata` or derivatives.
112
-
113
- Methods
114
- -------
115
- __getitem__(index: int)
116
- Returns datum at specified index.
117
- __len__()
118
- Returns dataset length.
119
-
120
- Notes
121
- -----
122
- Inherits from :class:`.Dataset`.
123
- """
124
-
125
- @property
126
- def metadata(self) -> DatasetMetadata: ...
127
-
128
-
129
- # ========== IMAGE CLASSIFICATION DATASETS ==========
130
-
131
-
132
- ImageClassificationDatum: TypeAlias = tuple[ArrayLike, ArrayLike, DatumMetadata]
133
- """
134
- Type alias for an image classification datum tuple.
135
-
136
- - :class:`ArrayLike` of shape (C, H, W) - Image data in channel, height, width format.
137
- - :class:`ArrayLike` of shape (N,) - Class label as one-hot encoded ground-truth or prediction confidences.
138
- - dict[str, Any] - Datum level metadata.
139
- """
140
-
141
-
142
- ImageClassificationDataset: TypeAlias = AnnotatedDataset[ImageClassificationDatum]
143
- """
144
- Type alias for an :class:`AnnotatedDataset` of :class:`ImageClassificationDatum` elements.
145
- """
146
-
147
- # ========== OBJECT DETECTION DATASETS ==========
148
-
149
-
150
- @runtime_checkable
151
- class ObjectDetectionTarget(Protocol):
152
- """
153
- Protocol for targets in an Object Detection dataset.
154
-
155
- Attributes
156
- ----------
157
- boxes : :class:`ArrayLike` of shape (N, 4)
158
- labels : :class:`ArrayLike` of shape (N,)
159
- scores : :class:`ArrayLike` of shape (N, M)
160
- """
161
-
162
- @property
163
- def boxes(self) -> ArrayLike: ...
164
-
165
- @property
166
- def labels(self) -> ArrayLike: ...
167
-
168
- @property
169
- def scores(self) -> ArrayLike: ...
170
-
171
-
172
- ObjectDetectionDatum: TypeAlias = tuple[ArrayLike, ObjectDetectionTarget, DatumMetadata]
173
- """
174
- Type alias for an object detection datum tuple.
175
-
176
- - :class:`ArrayLike` of shape (C, H, W) - Image data in channel, height, width format.
177
- - :class:`ObjectDetectionTarget` - Object detection target information for the image.
178
- - dict[str, Any] - Datum level metadata.
179
- """
180
-
181
-
182
- ObjectDetectionDataset: TypeAlias = AnnotatedDataset[ObjectDetectionDatum]
183
- """
184
- Type alias for an :class:`AnnotatedDataset` of :class:`ObjectDetectionDatum` elements.
185
- """
186
-
187
-
188
- # ========== TRANSFORM ==========
189
-
190
-
191
- @runtime_checkable
192
- class Transform(Generic[_T], Protocol):
193
- """
194
- Protocol defining a transform function.
195
-
196
- Requires a `__call__` method that returns transformed data.
197
-
198
- Example
199
- -------
200
- >>> from typing import Any
201
- >>> from numpy.typing import NDArray
202
-
203
- >>> class MyTransform:
204
- ... def __init__(self, divisor: float) -> None:
205
- ... self.divisor = divisor
206
- ...
207
- ... def __call__(self, data: NDArray[Any], /) -> NDArray[Any]:
208
- ... return data / self.divisor
209
-
210
- >>> my_transform = MyTransform(divisor=255.0)
211
- >>> isinstance(my_transform, Transform)
212
- True
213
- >>> my_transform(np.array([1, 2, 3]))
214
- array([0.004, 0.008, 0.012])
215
- """
216
-
217
- def __call__(self, data: _T, /) -> _T: ...
@@ -1,6 +0,0 @@
1
- """
2
- Dataset readers for common computer vision dataset formats.
3
-
4
- This module provides standardized readers that for loading datasets
5
- from directory structures.
6
- """
@@ -1,64 +0,0 @@
1
- from __future__ import annotations
2
-
3
- import logging
4
- from pathlib import Path
5
-
6
- from maite_datasets._reader._base import BaseDatasetReader
7
- from maite_datasets._reader._yolo import YOLODatasetReader
8
- from maite_datasets._reader._coco import COCODatasetReader
9
-
10
- _logger = logging.getLogger(__name__)
11
-
12
-
13
- def create_dataset_reader(dataset_path: str | Path, format_hint: str | None = None) -> BaseDatasetReader:
14
- """
15
- Factory function to create appropriate dataset reader based on directory structure.
16
-
17
- Parameters
18
- ----------
19
- dataset_path : str or Path
20
- Root directory containing dataset files
21
- format_hint : str or None, default None
22
- Format hint ("coco" or "yolo"). If None, auto-detects based on file structure
23
-
24
- Returns
25
- -------
26
- BaseDatasetReader
27
- Appropriate reader instance for the detected format
28
-
29
- Raises
30
- ------
31
- ValueError
32
- If format cannot be determined or is unsupported
33
- """
34
- dataset_path = Path(dataset_path)
35
-
36
- if format_hint:
37
- format_hint = format_hint.lower()
38
- if format_hint == "coco":
39
- return COCODatasetReader(dataset_path)
40
- elif format_hint == "yolo":
41
- return YOLODatasetReader(dataset_path)
42
- else:
43
- raise ValueError(f"Unsupported format hint: {format_hint}")
44
-
45
- # Auto-detect format
46
- has_annotations_json = (dataset_path / "annotations.json").exists()
47
- has_labels_dir = (dataset_path / "labels").exists()
48
-
49
- if has_annotations_json and not has_labels_dir:
50
- _logger.info(f"Detected COCO format for {dataset_path}")
51
- return COCODatasetReader(dataset_path)
52
- elif has_labels_dir and not has_annotations_json:
53
- _logger.info(f"Detected YOLO format for {dataset_path}")
54
- return YOLODatasetReader(dataset_path)
55
- elif has_annotations_json and has_labels_dir:
56
- raise ValueError(
57
- f"Ambiguous format in {dataset_path}: both annotations.json and labels/ exist. "
58
- "Use format_hint parameter to specify format."
59
- )
60
- else:
61
- raise ValueError(
62
- f"Cannot detect dataset format in {dataset_path}. "
63
- "Expected either annotations.json (COCO) or labels/ directory (YOLO)."
64
- )
maite_datasets/_types.py DELETED
@@ -1,50 +0,0 @@
1
- from __future__ import annotations
2
-
3
- __all__ = []
4
-
5
- from dataclasses import dataclass
6
- from typing import Generic, TypedDict, TypeVar
7
-
8
- from typing_extensions import NotRequired, Required
9
-
10
- _T_co = TypeVar("_T_co", covariant=True)
11
-
12
-
13
- class Dataset(Generic[_T_co]):
14
- """Abstract generic base class for PyTorch style Dataset"""
15
-
16
- def __getitem__(self, index: int) -> _T_co: ...
17
- def __add__(self, other: Dataset[_T_co]) -> Dataset[_T_co]: ...
18
-
19
-
20
- class DatasetMetadata(TypedDict):
21
- id: Required[str]
22
- index2label: NotRequired[dict[int, str]]
23
- split: NotRequired[str]
24
-
25
-
26
- class DatumMetadata(TypedDict, total=False):
27
- id: Required[str]
28
-
29
-
30
- _TDatum = TypeVar("_TDatum")
31
- _TArray = TypeVar("_TArray")
32
-
33
-
34
- class AnnotatedDataset(Dataset[_TDatum]):
35
- metadata: DatasetMetadata
36
-
37
- def __len__(self) -> int: ...
38
-
39
-
40
- class ImageClassificationDataset(AnnotatedDataset[tuple[_TArray, _TArray, DatumMetadata]]): ...
41
-
42
-
43
- @dataclass
44
- class ObjectDetectionTarget(Generic[_TArray]):
45
- boxes: _TArray
46
- labels: _TArray
47
- scores: _TArray
48
-
49
-
50
- class ObjectDetectionDataset(AnnotatedDataset[tuple[_TArray, ObjectDetectionTarget[_TArray], DatumMetadata]]): ...