maite-datasets 0.0.5__tar.gz → 0.0.6__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (35) hide show
  1. {maite_datasets-0.0.5 → maite_datasets-0.0.6}/.gitignore +5 -1
  2. {maite_datasets-0.0.5 → maite_datasets-0.0.6}/PKG-INFO +56 -3
  3. {maite_datasets-0.0.5 → maite_datasets-0.0.6}/README.md +54 -2
  4. {maite_datasets-0.0.5 → maite_datasets-0.0.6}/pyproject.toml +15 -4
  5. {maite_datasets-0.0.5 → maite_datasets-0.0.6}/src/maite_datasets/__init__.py +2 -6
  6. {maite_datasets-0.0.5 → maite_datasets-0.0.6}/src/maite_datasets/_base.py +169 -51
  7. {maite_datasets-0.0.5 → maite_datasets-0.0.6}/src/maite_datasets/_builder.py +46 -55
  8. {maite_datasets-0.0.5 → maite_datasets-0.0.6}/src/maite_datasets/_collate.py +2 -3
  9. maite_datasets-0.0.5/src/maite_datasets/_reader/_base.py → maite_datasets-0.0.6/src/maite_datasets/_reader.py +62 -36
  10. {maite_datasets-0.0.5 → maite_datasets-0.0.6}/src/maite_datasets/_validate.py +4 -2
  11. {maite_datasets-0.0.5 → maite_datasets-0.0.6}/src/maite_datasets/image_classification/_cifar10.py +12 -7
  12. {maite_datasets-0.0.5 → maite_datasets-0.0.6}/src/maite_datasets/image_classification/_mnist.py +15 -10
  13. {maite_datasets-0.0.5 → maite_datasets-0.0.6}/src/maite_datasets/image_classification/_ships.py +12 -8
  14. {maite_datasets-0.0.5 → maite_datasets-0.0.6}/src/maite_datasets/object_detection/__init__.py +4 -7
  15. {maite_datasets-0.0.5 → maite_datasets-0.0.6}/src/maite_datasets/object_detection/_antiuav.py +11 -8
  16. {maite_datasets-0.0.5/src/maite_datasets/_reader → maite_datasets-0.0.6/src/maite_datasets/object_detection}/_coco.py +29 -27
  17. {maite_datasets-0.0.5 → maite_datasets-0.0.6}/src/maite_datasets/object_detection/_milco.py +11 -9
  18. {maite_datasets-0.0.5 → maite_datasets-0.0.6}/src/maite_datasets/object_detection/_seadrone.py +11 -9
  19. {maite_datasets-0.0.5 → maite_datasets-0.0.6}/src/maite_datasets/object_detection/_voc.py +11 -13
  20. {maite_datasets-0.0.5/src/maite_datasets/_reader → maite_datasets-0.0.6/src/maite_datasets/object_detection}/_yolo.py +26 -21
  21. maite_datasets-0.0.6/src/maite_datasets/protocols.py +23 -0
  22. maite_datasets-0.0.6/src/maite_datasets/wrappers/__init__.py +8 -0
  23. maite_datasets-0.0.6/src/maite_datasets/wrappers/_torch.py +111 -0
  24. maite_datasets-0.0.5/src/maite_datasets/_mixin/__init__.py +0 -0
  25. maite_datasets-0.0.5/src/maite_datasets/_mixin/_numpy.py +0 -28
  26. maite_datasets-0.0.5/src/maite_datasets/_mixin/_torch.py +0 -28
  27. maite_datasets-0.0.5/src/maite_datasets/_protocols.py +0 -217
  28. maite_datasets-0.0.5/src/maite_datasets/_reader/__init__.py +0 -6
  29. maite_datasets-0.0.5/src/maite_datasets/_reader/_factory.py +0 -64
  30. maite_datasets-0.0.5/src/maite_datasets/_types.py +0 -50
  31. maite_datasets-0.0.5/src/maite_datasets/object_detection/_voc_torch.py +0 -65
  32. {maite_datasets-0.0.5 → maite_datasets-0.0.6}/LICENSE +0 -0
  33. {maite_datasets-0.0.5 → maite_datasets-0.0.6}/src/maite_datasets/_fileio.py +0 -0
  34. {maite_datasets-0.0.5 → maite_datasets-0.0.6}/src/maite_datasets/image_classification/__init__.py +0 -0
  35. {maite_datasets-0.0.5 → maite_datasets-0.0.6}/src/maite_datasets/py.typed +0 -0
@@ -9,6 +9,10 @@ wheels/
9
9
  # Virtual environments
10
10
  .venv
11
11
 
12
+ # Downloaded data
13
+ .data
14
+
12
15
  # Test output
13
16
  .nox/
14
- output/
17
+ output/
18
+ .coverage
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: maite-datasets
3
- Version: 0.0.5
3
+ Version: 0.0.6
4
4
  Summary: A collection of Image Classification and Object Detection task datasets conforming to the MAITE protocol.
5
5
  Author-email: Andrew Weng <andrew.weng@ariacoustics.com>, Ryan Wood <ryan.wood@ariacoustics.com>, Shaun Jullens <shaun.jullens@ariacoustics.com>
6
6
  License-Expression: MIT
@@ -16,6 +16,7 @@ Classifier: Programming Language :: Python :: 3.11
16
16
  Classifier: Programming Language :: Python :: 3.12
17
17
  Requires-Python: >=3.9
18
18
  Requires-Dist: defusedxml>=0.7.1
19
+ Requires-Dist: maite<0.9,>=0.7
19
20
  Requires-Dist: numpy>=1.24.2
20
21
  Requires-Dist: pillow>=10.3.0
21
22
  Requires-Dist: requests>=2.32.3
@@ -42,7 +43,7 @@ For status bar indicators when downloading, you can include the extra `tqdm` whe
42
43
  pip install maite-datasets[tqdm]
43
44
  ```
44
45
 
45
- ## Available Datasets
46
+ ## Available Downloadable Datasets
46
47
 
47
48
  | Task | Dataset | Description |
48
49
  |----------------|------------------|---------------------------------------------------------------------|
@@ -54,7 +55,7 @@ pip install maite-datasets[tqdm]
54
55
  | Detection | Seadrone | A UAV dataset focused on open water object detection. |
55
56
  | Detection | VOCDetection | [Pascal VOC](http://host.robots.ox.ac.uk/pascal/VOC/) dataset. |
56
57
 
57
- ## Usage
58
+ ### Usage
58
59
 
59
60
  Here is an example of how to import MNIST for usage with your workflow.
60
61
 
@@ -76,6 +77,58 @@ MNIST Dataset
76
77
  tuple(<class 'numpy.ndarray'>, <class 'numpy.ndarray'>, <class 'dict'>)
77
78
  ```
78
79
 
80
+ ## Dataset Wrappers
81
+
82
+ Wrappers provide a way to convert datasets to allow usage of tools within specific backend frameworks.
83
+
84
+ `TorchvisionWrapper` is a convenience class that wraps any of the datasets and provides the capability to apply
85
+ `torchvision` transforms to the dataset.
86
+
87
+ **NOTE:** `TorchvisionWrapper` requires _torch_ and _torchvision_ to be installed.
88
+
89
+ ```python
90
+ >>> from maite_datasets.object_detection import MILCO
91
+
92
+ >>> milco = MILCO(root="data", download=True)
93
+ >>> print(milco)
94
+ MILCO Dataset
95
+ -------------
96
+ Transforms: []
97
+ Image Set: train
98
+ Metadata: {'id': 'MILCO_train', 'index2label': {0: 'MILCO', 1: 'NOMBO'}, 'split': 'train'}
99
+ Path: /home/user/maite-datasets/data/milco
100
+ Size: 261
101
+
102
+ >>> print(f"type={milco[0][0].__class__.__name__}, shape={milco[0][0].shape}")
103
+ type=ndarray, shape=(3, 1024, 1024)
104
+
105
+ >>> print(milco[0][1].boxes[0])
106
+ [ 75. 217. 130. 247.]
107
+
108
+ >>> from maite_datasets.wrappers import TorchvisionWrapper
109
+ >>> from torchvision.transforms.v2 import Resize
110
+
111
+ >>> milco_torch = TorchvisionWrapper(milco, transforms=Resize(224))
112
+ >>> print(milco_torch)
113
+ Torchvision Wrapped MILCO Dataset
114
+ ---------------------------
115
+ Transforms: Resize(size=[224], interpolation=InterpolationMode.BILINEAR, antialias=True)
116
+
117
+ MILCO Dataset
118
+ -------------
119
+ Transforms: []
120
+ Image Set: train
121
+ Metadata: {'id': 'MILCO_train', 'index2label': {0: 'MILCO', 1: 'NOMBO'}, 'split': 'train'}
122
+ Path: /home/user/maite-datasets/data/milco
123
+ Size: 261
124
+
125
+ >>> print(f"type={milco_torch[0][0].__class__.__name__}, shape={milco_torch[0][0].shape}")
126
+ type=Image, shape=torch.Size([3, 224, 224])
127
+
128
+ >>> print(milco_torch[0][1].boxes[0])
129
+ tensor([16.4062, 47.4688, 28.4375, 54.0312], dtype=torch.float64)
130
+ ```
131
+
79
132
  ## Additional Information
80
133
 
81
134
  For more information on the MAITE protocol, check out their [documentation](https://mit-ll-ai-technology.github.io/maite/).
@@ -16,7 +16,7 @@ For status bar indicators when downloading, you can include the extra `tqdm` whe
16
16
  pip install maite-datasets[tqdm]
17
17
  ```
18
18
 
19
- ## Available Datasets
19
+ ## Available Downloadable Datasets
20
20
 
21
21
  | Task | Dataset | Description |
22
22
  |----------------|------------------|---------------------------------------------------------------------|
@@ -28,7 +28,7 @@ pip install maite-datasets[tqdm]
28
28
  | Detection | Seadrone | A UAV dataset focused on open water object detection. |
29
29
  | Detection | VOCDetection | [Pascal VOC](http://host.robots.ox.ac.uk/pascal/VOC/) dataset. |
30
30
 
31
- ## Usage
31
+ ### Usage
32
32
 
33
33
  Here is an example of how to import MNIST for usage with your workflow.
34
34
 
@@ -50,6 +50,58 @@ MNIST Dataset
50
50
  tuple(<class 'numpy.ndarray'>, <class 'numpy.ndarray'>, <class 'dict'>)
51
51
  ```
52
52
 
53
+ ## Dataset Wrappers
54
+
55
+ Wrappers provide a way to convert datasets to allow usage of tools within specific backend frameworks.
56
+
57
+ `TorchvisionWrapper` is a convenience class that wraps any of the datasets and provides the capability to apply
58
+ `torchvision` transforms to the dataset.
59
+
60
+ **NOTE:** `TorchvisionWrapper` requires _torch_ and _torchvision_ to be installed.
61
+
62
+ ```python
63
+ >>> from maite_datasets.object_detection import MILCO
64
+
65
+ >>> milco = MILCO(root="data", download=True)
66
+ >>> print(milco)
67
+ MILCO Dataset
68
+ -------------
69
+ Transforms: []
70
+ Image Set: train
71
+ Metadata: {'id': 'MILCO_train', 'index2label': {0: 'MILCO', 1: 'NOMBO'}, 'split': 'train'}
72
+ Path: /home/user/maite-datasets/data/milco
73
+ Size: 261
74
+
75
+ >>> print(f"type={milco[0][0].__class__.__name__}, shape={milco[0][0].shape}")
76
+ type=ndarray, shape=(3, 1024, 1024)
77
+
78
+ >>> print(milco[0][1].boxes[0])
79
+ [ 75. 217. 130. 247.]
80
+
81
+ >>> from maite_datasets.wrappers import TorchvisionWrapper
82
+ >>> from torchvision.transforms.v2 import Resize
83
+
84
+ >>> milco_torch = TorchvisionWrapper(milco, transforms=Resize(224))
85
+ >>> print(milco_torch)
86
+ Torchvision Wrapped MILCO Dataset
87
+ ---------------------------
88
+ Transforms: Resize(size=[224], interpolation=InterpolationMode.BILINEAR, antialias=True)
89
+
90
+ MILCO Dataset
91
+ -------------
92
+ Transforms: []
93
+ Image Set: train
94
+ Metadata: {'id': 'MILCO_train', 'index2label': {0: 'MILCO', 1: 'NOMBO'}, 'split': 'train'}
95
+ Path: /home/user/maite-datasets/data/milco
96
+ Size: 261
97
+
98
+ >>> print(f"type={milco_torch[0][0].__class__.__name__}, shape={milco_torch[0][0].shape}")
99
+ type=Image, shape=torch.Size([3, 224, 224])
100
+
101
+ >>> print(milco_torch[0][1].boxes[0])
102
+ tensor([16.4062, 47.4688, 28.4375, 54.0312], dtype=torch.float64)
103
+ ```
104
+
53
105
  ## Additional Information
54
106
 
55
107
  For more information on the MAITE protocol, check out their [documentation](https://mit-ll-ai-technology.github.io/maite/).
@@ -6,6 +6,7 @@ requires-python = ">=3.9"
6
6
  dynamic = ["version"]
7
7
  dependencies = [
8
8
  "defusedxml>=0.7.1",
9
+ "maite>=0.7,<0.9",
9
10
  "numpy>=1.24.2",
10
11
  "pillow>=10.3.0",
11
12
  "requests>=2.32.3",
@@ -38,7 +39,9 @@ tqdm = [
38
39
  base = [
39
40
  "nox[uv]>=2025.5.1",
40
41
  "torch>=2.2.0",
41
- "uv>=0.7.8",
42
+ "torchvision>=0.17.0",
43
+ "tqdm>=4.66",
44
+ "uv>=0.8.0",
42
45
  ]
43
46
  lint = [
44
47
  "ruff>=0.11",
@@ -59,12 +62,12 @@ dev = [
59
62
  { include-group = "lint" },
60
63
  { include-group = "test" },
61
64
  { include-group = "type" },
65
+ "ipykernel>=6.30.0",
62
66
  ]
63
67
 
64
68
  [tool.uv.sources]
65
- torch = [
66
- { index = "pytorch-cpu" },
67
- ]
69
+ torch = [{ index = "pytorch-cpu" }]
70
+ torchvision = [{ index = "pytorch-cpu" }]
68
71
 
69
72
  [[tool.uv.index]]
70
73
  name = "pytorch-cpu"
@@ -108,6 +111,14 @@ line-length = 120
108
111
  indent-width = 4
109
112
  target-version = "py39"
110
113
 
114
+ [tool.ruff.lint]
115
+ select = ["A", "ANN", "C4", "C90", "E", "F", "I", "NPY", "S", "SIM", "RET", "RUF100", "UP"]
116
+ ignore = ["ANN401", "NPY002"]
117
+ fixable = ["ALL"]
118
+ unfixable = []
119
+ dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$"
120
+ per-file-ignores = { "!src/*" = ["ANN", "S", "RET"]}
121
+
111
122
  [tool.ruff.lint.isort]
112
123
  known-first-party = ["maite_datasets"]
113
124
 
@@ -1,11 +1,9 @@
1
1
  """Module for MAITE compliant Computer Vision datasets."""
2
2
 
3
3
  from maite_datasets._builder import to_image_classification_dataset, to_object_detection_dataset
4
- from maite_datasets._collate import collate_as_torch, collate_as_numpy, collate_as_list
4
+ from maite_datasets._collate import collate_as_list, collate_as_numpy, collate_as_torch
5
+ from maite_datasets._reader import create_dataset_reader
5
6
  from maite_datasets._validate import validate_dataset
6
- from maite_datasets._reader._factory import create_dataset_reader
7
- from maite_datasets._reader._coco import COCODatasetReader
8
- from maite_datasets._reader._yolo import YOLODatasetReader
9
7
 
10
8
  __all__ = [
11
9
  "collate_as_list",
@@ -15,6 +13,4 @@ __all__ = [
15
13
  "to_image_classification_dataset",
16
14
  "to_object_detection_dataset",
17
15
  "validate_dataset",
18
- "COCODatasetReader",
19
- "YOLODatasetReader",
20
16
  ]
@@ -2,23 +2,24 @@ from __future__ import annotations
2
2
 
3
3
  __all__ = []
4
4
 
5
+ import inspect
6
+ import warnings
5
7
  from abc import abstractmethod
8
+ from collections import namedtuple
9
+ from collections.abc import Iterator, Sequence
6
10
  from pathlib import Path
7
- from typing import Any, Generic, Iterator, Literal, NamedTuple, Sequence, TypeVar, cast
11
+ from typing import Any, Callable, Generic, Literal, NamedTuple, TypeVar, cast
8
12
 
9
13
  import numpy as np
14
+ from maite.protocols import DatasetMetadata, DatumMetadata
15
+ from numpy.typing import NDArray
16
+ from PIL import Image
10
17
 
11
18
  from maite_datasets._fileio import _ensure_exists
12
- from maite_datasets._protocols import Array, Transform
13
- from maite_datasets._types import (
14
- AnnotatedDataset,
15
- DatasetMetadata,
16
- DatumMetadata,
17
- ImageClassificationDataset,
18
- ObjectDetectionDataset,
19
- ObjectDetectionTarget,
20
- )
19
+ from maite_datasets.protocols import Array
21
20
 
21
+ _T = TypeVar("_T")
22
+ _T_co = TypeVar("_T_co", covariant=True)
22
23
  _TArray = TypeVar("_TArray", bound=Array)
23
24
  _TTarget = TypeVar("_TTarget")
24
25
  _TRawTarget = TypeVar(
@@ -30,16 +31,7 @@ _TRawTarget = TypeVar(
30
31
  _TAnnotation = TypeVar("_TAnnotation", int, str, tuple[list[int], list[list[float]]])
31
32
 
32
33
 
33
- def _to_datum_metadata(index: int, metadata: dict[str, Any]) -> DatumMetadata:
34
- _id = metadata.pop("id", index)
35
- return DatumMetadata(id=_id, **metadata)
36
-
37
-
38
- class DataLocation(NamedTuple):
39
- url: str
40
- filename: str
41
- md5: bool
42
- checksum: str
34
+ ObjectDetectionTarget = namedtuple("ObjectDetectionTarget", ["boxes", "labels", "scores"])
43
35
 
44
36
 
45
37
  class BaseDatasetMixin(Generic[_TArray]):
@@ -50,8 +42,99 @@ class BaseDatasetMixin(Generic[_TArray]):
50
42
  def _read_file(self, path: str) -> _TArray: ...
51
43
 
52
44
 
53
- class BaseDataset(
54
- AnnotatedDataset[tuple[_TArray, _TTarget, DatumMetadata]],
45
+ class Dataset(Generic[_T_co]):
46
+ """Abstract generic base class for PyTorch style Dataset"""
47
+
48
+ def __getitem__(self, index: int) -> _T_co: ...
49
+ def __add__(self, other: Dataset[_T_co]) -> Dataset[_T_co]: ...
50
+
51
+
52
+ class BaseDataset(Dataset[tuple[_TArray, _TTarget, DatumMetadata]]):
53
+ metadata: DatasetMetadata
54
+
55
+ def __init__(
56
+ self,
57
+ transforms: Callable[[_TArray], _TArray]
58
+ | Callable[
59
+ [tuple[_TArray, _TTarget, DatumMetadata]],
60
+ tuple[_TArray, _TTarget, DatumMetadata],
61
+ ]
62
+ | Sequence[
63
+ Callable[[_TArray], _TArray]
64
+ | Callable[
65
+ [tuple[_TArray, _TTarget, DatumMetadata]],
66
+ tuple[_TArray, _TTarget, DatumMetadata],
67
+ ]
68
+ ]
69
+ | None,
70
+ ) -> None:
71
+ self.transforms: Sequence[
72
+ Callable[
73
+ [tuple[_TArray, _TTarget, DatumMetadata]],
74
+ tuple[_TArray, _TTarget, DatumMetadata],
75
+ ]
76
+ ] = []
77
+ transforms = transforms if isinstance(transforms, Sequence) else [transforms] if transforms else []
78
+ for transform in transforms:
79
+ sig = inspect.signature(transform)
80
+ if len(sig.parameters) != 1:
81
+ warnings.warn(f"Dropping unrecognized transform: {str(transform)}")
82
+ elif "tuple" in str(sig.parameters.values()):
83
+ transform = cast(
84
+ Callable[
85
+ [tuple[_TArray, _TTarget, DatumMetadata]],
86
+ tuple[_TArray, _TTarget, DatumMetadata],
87
+ ],
88
+ transform,
89
+ )
90
+ self.transforms.append(transform)
91
+ else:
92
+ transform = cast(Callable[[_TArray], _TArray], transform)
93
+ self.transforms.append(self._wrap_transform(transform))
94
+
95
+ def _wrap_transform(
96
+ self, transform: Callable[[_TArray], _TArray]
97
+ ) -> Callable[
98
+ [tuple[_TArray, _TTarget, DatumMetadata]],
99
+ tuple[_TArray, _TTarget, DatumMetadata],
100
+ ]:
101
+ def wrapper(
102
+ datum: tuple[_TArray, _TTarget, DatumMetadata],
103
+ ) -> tuple[_TArray, _TTarget, DatumMetadata]:
104
+ image, target, metadata = datum
105
+ return (transform(image), target, metadata)
106
+
107
+ return wrapper
108
+
109
+ def _transform(self, datum: tuple[_TArray, _TTarget, DatumMetadata]) -> tuple[_TArray, _TTarget, DatumMetadata]:
110
+ """Function to transform the image prior to returning based on parameters passed in."""
111
+ for transform in self.transforms:
112
+ datum = transform(datum)
113
+ return datum
114
+
115
+ def __len__(self) -> int: ...
116
+
117
+ def __str__(self) -> str:
118
+ nt = "\n "
119
+ title = f"{self.__class__.__name__.replace('Dataset', '')} Dataset"
120
+ sep = "-" * len(title)
121
+ attrs = [
122
+ f"{' '.join(w.capitalize() for w in k.split('_'))}: {v}"
123
+ for k, v in self.__dict__.items()
124
+ if not k.startswith("_")
125
+ ]
126
+ return f"{title}\n{sep}{nt}{nt.join(attrs)}"
127
+
128
+
129
+ class DataLocation(NamedTuple):
130
+ url: str
131
+ filename: str
132
+ md5: bool
133
+ checksum: str
134
+
135
+
136
+ class BaseDownloadedDataset(
137
+ BaseDataset[_TArray, _TTarget],
55
138
  Generic[_TArray, _TTarget, _TRawTarget, _TAnnotation],
56
139
  ):
57
140
  """
@@ -72,13 +155,24 @@ class BaseDataset(
72
155
  self,
73
156
  root: str | Path,
74
157
  image_set: Literal["train", "val", "test", "operational", "base"] = "train",
75
- transforms: Transform[_TArray] | Sequence[Transform[_TArray]] | None = None,
158
+ transforms: Callable[[_TArray], _TArray]
159
+ | Callable[
160
+ [tuple[_TArray, _TTarget, DatumMetadata]],
161
+ tuple[_TArray, _TTarget, DatumMetadata],
162
+ ]
163
+ | Sequence[
164
+ Callable[[_TArray], _TArray]
165
+ | Callable[
166
+ [tuple[_TArray, _TTarget, DatumMetadata]],
167
+ tuple[_TArray, _TTarget, DatumMetadata],
168
+ ]
169
+ ]
170
+ | None = None,
76
171
  download: bool = False,
77
172
  verbose: bool = False,
78
173
  ) -> None:
174
+ super().__init__(transforms)
79
175
  self._root: Path = root.absolute() if isinstance(root, Path) else Path(root).absolute()
80
- transforms = transforms if transforms is not None else []
81
- self.transforms: Sequence[Transform[_TArray]] = transforms if isinstance(transforms, Sequence) else [transforms]
82
176
  self.image_set = image_set
83
177
  self._verbose = verbose
84
178
 
@@ -91,9 +185,11 @@ class BaseDataset(
91
185
  self._label2index = {v: k for k, v in self.index2label.items()}
92
186
 
93
187
  self.metadata: DatasetMetadata = DatasetMetadata(
94
- id=self._unique_id(),
95
- index2label=self.index2label,
96
- split=self.image_set,
188
+ **{
189
+ "id": self._unique_id(),
190
+ "index2label": self.index2label,
191
+ "split": self.image_set,
192
+ }
97
193
  )
98
194
 
99
195
  # Load the data
@@ -101,13 +197,6 @@ class BaseDataset(
101
197
  self._filepaths, self._targets, self._datum_metadata = self._load_data()
102
198
  self.size: int = len(self._filepaths)
103
199
 
104
- def __str__(self) -> str:
105
- nt = "\n "
106
- title = f"{self.__class__.__name__} Dataset"
107
- sep = "-" * len(title)
108
- attrs = [f"{k.capitalize()}: {v}" for k, v in self.__dict__.items() if not k.startswith("_")]
109
- return f"{title}\n{sep}{nt}{nt.join(attrs)}"
110
-
111
200
  @property
112
201
  def label2index(self) -> dict[str, int]:
113
202
  return self._label2index
@@ -148,20 +237,18 @@ class BaseDataset(
148
237
  @abstractmethod
149
238
  def _load_data_inner(self) -> tuple[list[str], _TRawTarget, dict[str, Any]]: ...
150
239
 
151
- def _transform(self, image: _TArray) -> _TArray:
152
- """Function to transform the image prior to returning based on parameters passed in."""
153
- for transform in self.transforms:
154
- image = transform(image)
155
- return image
240
+ def _to_datum_metadata(self, index: int, metadata: dict[str, Any]) -> DatumMetadata:
241
+ _id = metadata.pop("id", index)
242
+ return DatumMetadata(id=_id, **metadata)
156
243
 
157
244
  def __len__(self) -> int:
158
245
  return self.size
159
246
 
160
247
 
161
248
  class BaseICDataset(
162
- BaseDataset[_TArray, _TArray, list[int], int],
249
+ BaseDownloadedDataset[_TArray, _TArray, list[int], int],
163
250
  BaseDatasetMixin[_TArray],
164
- ImageClassificationDataset[_TArray],
251
+ BaseDataset[_TArray, _TArray],
165
252
  ):
166
253
  """
167
254
  Base class for image classification datasets.
@@ -184,17 +271,16 @@ class BaseICDataset(
184
271
  score = self._one_hot_encode(label)
185
272
  # Get the image
186
273
  img = self._read_file(self._filepaths[index])
187
- img = self._transform(img)
188
274
 
189
275
  img_metadata = {key: val[index] for key, val in self._datum_metadata.items()}
190
276
 
191
- return img, score, _to_datum_metadata(index, img_metadata)
277
+ return self._transform((img, score, self._to_datum_metadata(index, img_metadata)))
192
278
 
193
279
 
194
280
  class BaseODDataset(
195
- BaseDataset[_TArray, ObjectDetectionTarget[_TArray], _TRawTarget, _TAnnotation],
281
+ BaseDownloadedDataset[_TArray, ObjectDetectionTarget, _TRawTarget, _TAnnotation],
196
282
  BaseDatasetMixin[_TArray],
197
- ObjectDetectionDataset[_TArray],
283
+ BaseDataset[_TArray, ObjectDetectionTarget],
198
284
  ):
199
285
  """
200
286
  Base class for object detection datasets.
@@ -202,7 +288,7 @@ class BaseODDataset(
202
288
 
203
289
  _bboxes_per_size: bool = False
204
290
 
205
- def __getitem__(self, index: int) -> tuple[_TArray, ObjectDetectionTarget[_TArray], DatumMetadata]:
291
+ def __getitem__(self, index: int) -> tuple[_TArray, ObjectDetectionTarget, DatumMetadata]:
206
292
  """
207
293
  Args
208
294
  ----
@@ -211,7 +297,7 @@ class BaseODDataset(
211
297
 
212
298
  Returns
213
299
  -------
214
- tuple[TArray, ObjectDetectionTarget[TArray], DatumMetadata]
300
+ tuple[TArray, ObjectDetectionTarget, DatumMetadata]
215
301
  Image, target, datum_metadata - target.boxes returns boxes in x0, y0, x1, y1 format
216
302
  """
217
303
  # Grab the bounding boxes and labels from the annotations
@@ -220,17 +306,49 @@ class BaseODDataset(
220
306
  # Get the image
221
307
  img = self._read_file(self._filepaths[index])
222
308
  img_size = img.shape
223
- img = self._transform(img)
224
309
  # Adjust labels if necessary
225
310
  if self._bboxes_per_size and boxes:
226
- boxes = boxes * np.array([[img_size[1], img_size[2], img_size[1], img_size[2]]])
311
+ boxes = boxes * np.asarray([[img_size[1], img_size[2], img_size[1], img_size[2]]])
227
312
  # Create the Object Detection Target
228
313
  target = ObjectDetectionTarget(self._as_array(boxes), self._as_array(labels), self._one_hot_encode(labels))
229
314
 
230
315
  img_metadata = {key: val[index] for key, val in self._datum_metadata.items()}
231
316
  img_metadata = img_metadata | additional_metadata
232
317
 
233
- return img, target, _to_datum_metadata(index, img_metadata)
318
+ return self._transform((img, target, self._to_datum_metadata(index, img_metadata)))
234
319
 
235
320
  @abstractmethod
236
321
  def _read_annotations(self, annotation: _TAnnotation) -> tuple[list[list[float]], list[int], dict[str, Any]]: ...
322
+
323
+
324
+ NumpyArray = NDArray[np.floating[Any]] | NDArray[np.integer[Any]]
325
+
326
+
327
+ class BaseDatasetNumpyMixin(BaseDatasetMixin[NumpyArray]):
328
+ def _as_array(self, raw: list[Any]) -> NumpyArray:
329
+ return np.asarray(raw)
330
+
331
+ def _one_hot_encode(self, value: int | list[int]) -> NumpyArray:
332
+ if isinstance(value, int):
333
+ encoded = np.zeros(len(self.index2label))
334
+ encoded[value] = 1
335
+ else:
336
+ encoded = np.zeros((len(value), len(self.index2label)))
337
+ encoded[np.arange(len(value)), value] = 1
338
+ return encoded
339
+
340
+ def _read_file(self, path: str) -> NumpyArray:
341
+ return np.array(Image.open(path)).transpose(2, 0, 1)
342
+
343
+
344
+ NumpyImageTransform = Callable[[NumpyArray], NumpyArray]
345
+ NumpyImageClassificationDatumTransform = Callable[
346
+ [tuple[NumpyArray, NumpyArray, DatumMetadata]],
347
+ tuple[NumpyArray, NumpyArray, DatumMetadata],
348
+ ]
349
+ NumpyObjectDetectionDatumTransform = Callable[
350
+ [tuple[NumpyArray, ObjectDetectionTarget, DatumMetadata]],
351
+ tuple[NumpyArray, ObjectDetectionTarget, DatumMetadata],
352
+ ]
353
+ NumpyImageClassificationTransform = NumpyImageTransform | NumpyImageClassificationDatumTransform
354
+ NumpyObjectDetectionTransform = NumpyImageTransform | NumpyObjectDetectionDatumTransform