maite-datasets 0.0.5__tar.gz → 0.0.6__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {maite_datasets-0.0.5 → maite_datasets-0.0.6}/.gitignore +5 -1
- {maite_datasets-0.0.5 → maite_datasets-0.0.6}/PKG-INFO +56 -3
- {maite_datasets-0.0.5 → maite_datasets-0.0.6}/README.md +54 -2
- {maite_datasets-0.0.5 → maite_datasets-0.0.6}/pyproject.toml +15 -4
- {maite_datasets-0.0.5 → maite_datasets-0.0.6}/src/maite_datasets/__init__.py +2 -6
- {maite_datasets-0.0.5 → maite_datasets-0.0.6}/src/maite_datasets/_base.py +169 -51
- {maite_datasets-0.0.5 → maite_datasets-0.0.6}/src/maite_datasets/_builder.py +46 -55
- {maite_datasets-0.0.5 → maite_datasets-0.0.6}/src/maite_datasets/_collate.py +2 -3
- maite_datasets-0.0.5/src/maite_datasets/_reader/_base.py → maite_datasets-0.0.6/src/maite_datasets/_reader.py +62 -36
- {maite_datasets-0.0.5 → maite_datasets-0.0.6}/src/maite_datasets/_validate.py +4 -2
- {maite_datasets-0.0.5 → maite_datasets-0.0.6}/src/maite_datasets/image_classification/_cifar10.py +12 -7
- {maite_datasets-0.0.5 → maite_datasets-0.0.6}/src/maite_datasets/image_classification/_mnist.py +15 -10
- {maite_datasets-0.0.5 → maite_datasets-0.0.6}/src/maite_datasets/image_classification/_ships.py +12 -8
- {maite_datasets-0.0.5 → maite_datasets-0.0.6}/src/maite_datasets/object_detection/__init__.py +4 -7
- {maite_datasets-0.0.5 → maite_datasets-0.0.6}/src/maite_datasets/object_detection/_antiuav.py +11 -8
- {maite_datasets-0.0.5/src/maite_datasets/_reader → maite_datasets-0.0.6/src/maite_datasets/object_detection}/_coco.py +29 -27
- {maite_datasets-0.0.5 → maite_datasets-0.0.6}/src/maite_datasets/object_detection/_milco.py +11 -9
- {maite_datasets-0.0.5 → maite_datasets-0.0.6}/src/maite_datasets/object_detection/_seadrone.py +11 -9
- {maite_datasets-0.0.5 → maite_datasets-0.0.6}/src/maite_datasets/object_detection/_voc.py +11 -13
- {maite_datasets-0.0.5/src/maite_datasets/_reader → maite_datasets-0.0.6/src/maite_datasets/object_detection}/_yolo.py +26 -21
- maite_datasets-0.0.6/src/maite_datasets/protocols.py +23 -0
- maite_datasets-0.0.6/src/maite_datasets/wrappers/__init__.py +8 -0
- maite_datasets-0.0.6/src/maite_datasets/wrappers/_torch.py +111 -0
- maite_datasets-0.0.5/src/maite_datasets/_mixin/__init__.py +0 -0
- maite_datasets-0.0.5/src/maite_datasets/_mixin/_numpy.py +0 -28
- maite_datasets-0.0.5/src/maite_datasets/_mixin/_torch.py +0 -28
- maite_datasets-0.0.5/src/maite_datasets/_protocols.py +0 -217
- maite_datasets-0.0.5/src/maite_datasets/_reader/__init__.py +0 -6
- maite_datasets-0.0.5/src/maite_datasets/_reader/_factory.py +0 -64
- maite_datasets-0.0.5/src/maite_datasets/_types.py +0 -50
- maite_datasets-0.0.5/src/maite_datasets/object_detection/_voc_torch.py +0 -65
- {maite_datasets-0.0.5 → maite_datasets-0.0.6}/LICENSE +0 -0
- {maite_datasets-0.0.5 → maite_datasets-0.0.6}/src/maite_datasets/_fileio.py +0 -0
- {maite_datasets-0.0.5 → maite_datasets-0.0.6}/src/maite_datasets/image_classification/__init__.py +0 -0
- {maite_datasets-0.0.5 → maite_datasets-0.0.6}/src/maite_datasets/py.typed +0 -0
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: maite-datasets
|
3
|
-
Version: 0.0.
|
3
|
+
Version: 0.0.6
|
4
4
|
Summary: A collection of Image Classification and Object Detection task datasets conforming to the MAITE protocol.
|
5
5
|
Author-email: Andrew Weng <andrew.weng@ariacoustics.com>, Ryan Wood <ryan.wood@ariacoustics.com>, Shaun Jullens <shaun.jullens@ariacoustics.com>
|
6
6
|
License-Expression: MIT
|
@@ -16,6 +16,7 @@ Classifier: Programming Language :: Python :: 3.11
|
|
16
16
|
Classifier: Programming Language :: Python :: 3.12
|
17
17
|
Requires-Python: >=3.9
|
18
18
|
Requires-Dist: defusedxml>=0.7.1
|
19
|
+
Requires-Dist: maite<0.9,>=0.7
|
19
20
|
Requires-Dist: numpy>=1.24.2
|
20
21
|
Requires-Dist: pillow>=10.3.0
|
21
22
|
Requires-Dist: requests>=2.32.3
|
@@ -42,7 +43,7 @@ For status bar indicators when downloading, you can include the extra `tqdm` whe
|
|
42
43
|
pip install maite-datasets[tqdm]
|
43
44
|
```
|
44
45
|
|
45
|
-
## Available Datasets
|
46
|
+
## Available Downloadable Datasets
|
46
47
|
|
47
48
|
| Task | Dataset | Description |
|
48
49
|
|----------------|------------------|---------------------------------------------------------------------|
|
@@ -54,7 +55,7 @@ pip install maite-datasets[tqdm]
|
|
54
55
|
| Detection | Seadrone | A UAV dataset focused on open water object detection. |
|
55
56
|
| Detection | VOCDetection | [Pascal VOC](http://host.robots.ox.ac.uk/pascal/VOC/) dataset. |
|
56
57
|
|
57
|
-
|
58
|
+
### Usage
|
58
59
|
|
59
60
|
Here is an example of how to import MNIST for usage with your workflow.
|
60
61
|
|
@@ -76,6 +77,58 @@ MNIST Dataset
|
|
76
77
|
tuple(<class 'numpy.ndarray'>, <class 'numpy.ndarray'>, <class 'dict'>)
|
77
78
|
```
|
78
79
|
|
80
|
+
## Dataset Wrappers
|
81
|
+
|
82
|
+
Wrappers provide a way to convert datasets to allow usage of tools within specific backend frameworks.
|
83
|
+
|
84
|
+
`TorchvisionWrapper` is a convenience class that wraps any of the datasets and provides the capability to apply
|
85
|
+
`torchvision` transforms to the dataset.
|
86
|
+
|
87
|
+
**NOTE:** `TorchvisionWrapper` requires _torch_ and _torchvision_ to be installed.
|
88
|
+
|
89
|
+
```python
|
90
|
+
>>> from maite_datasets.object_detection import MILCO
|
91
|
+
|
92
|
+
>>> milco = MILCO(root="data", download=True)
|
93
|
+
>>> print(milco)
|
94
|
+
MILCO Dataset
|
95
|
+
-------------
|
96
|
+
Transforms: []
|
97
|
+
Image Set: train
|
98
|
+
Metadata: {'id': 'MILCO_train', 'index2label': {0: 'MILCO', 1: 'NOMBO'}, 'split': 'train'}
|
99
|
+
Path: /home/user/maite-datasets/data/milco
|
100
|
+
Size: 261
|
101
|
+
|
102
|
+
>>> print(f"type={milco[0][0].__class__.__name__}, shape={milco[0][0].shape}")
|
103
|
+
type=ndarray, shape=(3, 1024, 1024)
|
104
|
+
|
105
|
+
>>> print(milco[0][1].boxes[0])
|
106
|
+
[ 75. 217. 130. 247.]
|
107
|
+
|
108
|
+
>>> from maite_datasets.wrappers import TorchvisionWrapper
|
109
|
+
>>> from torchvision.transforms.v2 import Resize
|
110
|
+
|
111
|
+
>>> milco_torch = TorchvisionWrapper(milco, transforms=Resize(224))
|
112
|
+
>>> print(milco_torch)
|
113
|
+
Torchvision Wrapped MILCO Dataset
|
114
|
+
---------------------------
|
115
|
+
Transforms: Resize(size=[224], interpolation=InterpolationMode.BILINEAR, antialias=True)
|
116
|
+
|
117
|
+
MILCO Dataset
|
118
|
+
-------------
|
119
|
+
Transforms: []
|
120
|
+
Image Set: train
|
121
|
+
Metadata: {'id': 'MILCO_train', 'index2label': {0: 'MILCO', 1: 'NOMBO'}, 'split': 'train'}
|
122
|
+
Path: /home/user/maite-datasets/data/milco
|
123
|
+
Size: 261
|
124
|
+
|
125
|
+
>>> print(f"type={milco_torch[0][0].__class__.__name__}, shape={milco_torch[0][0].shape}")
|
126
|
+
type=Image, shape=torch.Size([3, 224, 224])
|
127
|
+
|
128
|
+
>>> print(milco_torch[0][1].boxes[0])
|
129
|
+
tensor([16.4062, 47.4688, 28.4375, 54.0312], dtype=torch.float64)
|
130
|
+
```
|
131
|
+
|
79
132
|
## Additional Information
|
80
133
|
|
81
134
|
For more information on the MAITE protocol, check out their [documentation](https://mit-ll-ai-technology.github.io/maite/).
|
@@ -16,7 +16,7 @@ For status bar indicators when downloading, you can include the extra `tqdm` whe
|
|
16
16
|
pip install maite-datasets[tqdm]
|
17
17
|
```
|
18
18
|
|
19
|
-
## Available Datasets
|
19
|
+
## Available Downloadable Datasets
|
20
20
|
|
21
21
|
| Task | Dataset | Description |
|
22
22
|
|----------------|------------------|---------------------------------------------------------------------|
|
@@ -28,7 +28,7 @@ pip install maite-datasets[tqdm]
|
|
28
28
|
| Detection | Seadrone | A UAV dataset focused on open water object detection. |
|
29
29
|
| Detection | VOCDetection | [Pascal VOC](http://host.robots.ox.ac.uk/pascal/VOC/) dataset. |
|
30
30
|
|
31
|
-
|
31
|
+
### Usage
|
32
32
|
|
33
33
|
Here is an example of how to import MNIST for usage with your workflow.
|
34
34
|
|
@@ -50,6 +50,58 @@ MNIST Dataset
|
|
50
50
|
tuple(<class 'numpy.ndarray'>, <class 'numpy.ndarray'>, <class 'dict'>)
|
51
51
|
```
|
52
52
|
|
53
|
+
## Dataset Wrappers
|
54
|
+
|
55
|
+
Wrappers provide a way to convert datasets to allow usage of tools within specific backend frameworks.
|
56
|
+
|
57
|
+
`TorchvisionWrapper` is a convenience class that wraps any of the datasets and provides the capability to apply
|
58
|
+
`torchvision` transforms to the dataset.
|
59
|
+
|
60
|
+
**NOTE:** `TorchvisionWrapper` requires _torch_ and _torchvision_ to be installed.
|
61
|
+
|
62
|
+
```python
|
63
|
+
>>> from maite_datasets.object_detection import MILCO
|
64
|
+
|
65
|
+
>>> milco = MILCO(root="data", download=True)
|
66
|
+
>>> print(milco)
|
67
|
+
MILCO Dataset
|
68
|
+
-------------
|
69
|
+
Transforms: []
|
70
|
+
Image Set: train
|
71
|
+
Metadata: {'id': 'MILCO_train', 'index2label': {0: 'MILCO', 1: 'NOMBO'}, 'split': 'train'}
|
72
|
+
Path: /home/user/maite-datasets/data/milco
|
73
|
+
Size: 261
|
74
|
+
|
75
|
+
>>> print(f"type={milco[0][0].__class__.__name__}, shape={milco[0][0].shape}")
|
76
|
+
type=ndarray, shape=(3, 1024, 1024)
|
77
|
+
|
78
|
+
>>> print(milco[0][1].boxes[0])
|
79
|
+
[ 75. 217. 130. 247.]
|
80
|
+
|
81
|
+
>>> from maite_datasets.wrappers import TorchvisionWrapper
|
82
|
+
>>> from torchvision.transforms.v2 import Resize
|
83
|
+
|
84
|
+
>>> milco_torch = TorchvisionWrapper(milco, transforms=Resize(224))
|
85
|
+
>>> print(milco_torch)
|
86
|
+
Torchvision Wrapped MILCO Dataset
|
87
|
+
---------------------------
|
88
|
+
Transforms: Resize(size=[224], interpolation=InterpolationMode.BILINEAR, antialias=True)
|
89
|
+
|
90
|
+
MILCO Dataset
|
91
|
+
-------------
|
92
|
+
Transforms: []
|
93
|
+
Image Set: train
|
94
|
+
Metadata: {'id': 'MILCO_train', 'index2label': {0: 'MILCO', 1: 'NOMBO'}, 'split': 'train'}
|
95
|
+
Path: /home/user/maite-datasets/data/milco
|
96
|
+
Size: 261
|
97
|
+
|
98
|
+
>>> print(f"type={milco_torch[0][0].__class__.__name__}, shape={milco_torch[0][0].shape}")
|
99
|
+
type=Image, shape=torch.Size([3, 224, 224])
|
100
|
+
|
101
|
+
>>> print(milco_torch[0][1].boxes[0])
|
102
|
+
tensor([16.4062, 47.4688, 28.4375, 54.0312], dtype=torch.float64)
|
103
|
+
```
|
104
|
+
|
53
105
|
## Additional Information
|
54
106
|
|
55
107
|
For more information on the MAITE protocol, check out their [documentation](https://mit-ll-ai-technology.github.io/maite/).
|
@@ -6,6 +6,7 @@ requires-python = ">=3.9"
|
|
6
6
|
dynamic = ["version"]
|
7
7
|
dependencies = [
|
8
8
|
"defusedxml>=0.7.1",
|
9
|
+
"maite>=0.7,<0.9",
|
9
10
|
"numpy>=1.24.2",
|
10
11
|
"pillow>=10.3.0",
|
11
12
|
"requests>=2.32.3",
|
@@ -38,7 +39,9 @@ tqdm = [
|
|
38
39
|
base = [
|
39
40
|
"nox[uv]>=2025.5.1",
|
40
41
|
"torch>=2.2.0",
|
41
|
-
"
|
42
|
+
"torchvision>=0.17.0",
|
43
|
+
"tqdm>=4.66",
|
44
|
+
"uv>=0.8.0",
|
42
45
|
]
|
43
46
|
lint = [
|
44
47
|
"ruff>=0.11",
|
@@ -59,12 +62,12 @@ dev = [
|
|
59
62
|
{ include-group = "lint" },
|
60
63
|
{ include-group = "test" },
|
61
64
|
{ include-group = "type" },
|
65
|
+
"ipykernel>=6.30.0",
|
62
66
|
]
|
63
67
|
|
64
68
|
[tool.uv.sources]
|
65
|
-
torch = [
|
66
|
-
|
67
|
-
]
|
69
|
+
torch = [{ index = "pytorch-cpu" }]
|
70
|
+
torchvision = [{ index = "pytorch-cpu" }]
|
68
71
|
|
69
72
|
[[tool.uv.index]]
|
70
73
|
name = "pytorch-cpu"
|
@@ -108,6 +111,14 @@ line-length = 120
|
|
108
111
|
indent-width = 4
|
109
112
|
target-version = "py39"
|
110
113
|
|
114
|
+
[tool.ruff.lint]
|
115
|
+
select = ["A", "ANN", "C4", "C90", "E", "F", "I", "NPY", "S", "SIM", "RET", "RUF100", "UP"]
|
116
|
+
ignore = ["ANN401", "NPY002"]
|
117
|
+
fixable = ["ALL"]
|
118
|
+
unfixable = []
|
119
|
+
dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$"
|
120
|
+
per-file-ignores = { "!src/*" = ["ANN", "S", "RET"]}
|
121
|
+
|
111
122
|
[tool.ruff.lint.isort]
|
112
123
|
known-first-party = ["maite_datasets"]
|
113
124
|
|
@@ -1,11 +1,9 @@
|
|
1
1
|
"""Module for MAITE compliant Computer Vision datasets."""
|
2
2
|
|
3
3
|
from maite_datasets._builder import to_image_classification_dataset, to_object_detection_dataset
|
4
|
-
from maite_datasets._collate import
|
4
|
+
from maite_datasets._collate import collate_as_list, collate_as_numpy, collate_as_torch
|
5
|
+
from maite_datasets._reader import create_dataset_reader
|
5
6
|
from maite_datasets._validate import validate_dataset
|
6
|
-
from maite_datasets._reader._factory import create_dataset_reader
|
7
|
-
from maite_datasets._reader._coco import COCODatasetReader
|
8
|
-
from maite_datasets._reader._yolo import YOLODatasetReader
|
9
7
|
|
10
8
|
__all__ = [
|
11
9
|
"collate_as_list",
|
@@ -15,6 +13,4 @@ __all__ = [
|
|
15
13
|
"to_image_classification_dataset",
|
16
14
|
"to_object_detection_dataset",
|
17
15
|
"validate_dataset",
|
18
|
-
"COCODatasetReader",
|
19
|
-
"YOLODatasetReader",
|
20
16
|
]
|
@@ -2,23 +2,24 @@ from __future__ import annotations
|
|
2
2
|
|
3
3
|
__all__ = []
|
4
4
|
|
5
|
+
import inspect
|
6
|
+
import warnings
|
5
7
|
from abc import abstractmethod
|
8
|
+
from collections import namedtuple
|
9
|
+
from collections.abc import Iterator, Sequence
|
6
10
|
from pathlib import Path
|
7
|
-
from typing import Any,
|
11
|
+
from typing import Any, Callable, Generic, Literal, NamedTuple, TypeVar, cast
|
8
12
|
|
9
13
|
import numpy as np
|
14
|
+
from maite.protocols import DatasetMetadata, DatumMetadata
|
15
|
+
from numpy.typing import NDArray
|
16
|
+
from PIL import Image
|
10
17
|
|
11
18
|
from maite_datasets._fileio import _ensure_exists
|
12
|
-
from maite_datasets.
|
13
|
-
from maite_datasets._types import (
|
14
|
-
AnnotatedDataset,
|
15
|
-
DatasetMetadata,
|
16
|
-
DatumMetadata,
|
17
|
-
ImageClassificationDataset,
|
18
|
-
ObjectDetectionDataset,
|
19
|
-
ObjectDetectionTarget,
|
20
|
-
)
|
19
|
+
from maite_datasets.protocols import Array
|
21
20
|
|
21
|
+
_T = TypeVar("_T")
|
22
|
+
_T_co = TypeVar("_T_co", covariant=True)
|
22
23
|
_TArray = TypeVar("_TArray", bound=Array)
|
23
24
|
_TTarget = TypeVar("_TTarget")
|
24
25
|
_TRawTarget = TypeVar(
|
@@ -30,16 +31,7 @@ _TRawTarget = TypeVar(
|
|
30
31
|
_TAnnotation = TypeVar("_TAnnotation", int, str, tuple[list[int], list[list[float]]])
|
31
32
|
|
32
33
|
|
33
|
-
|
34
|
-
_id = metadata.pop("id", index)
|
35
|
-
return DatumMetadata(id=_id, **metadata)
|
36
|
-
|
37
|
-
|
38
|
-
class DataLocation(NamedTuple):
|
39
|
-
url: str
|
40
|
-
filename: str
|
41
|
-
md5: bool
|
42
|
-
checksum: str
|
34
|
+
ObjectDetectionTarget = namedtuple("ObjectDetectionTarget", ["boxes", "labels", "scores"])
|
43
35
|
|
44
36
|
|
45
37
|
class BaseDatasetMixin(Generic[_TArray]):
|
@@ -50,8 +42,99 @@ class BaseDatasetMixin(Generic[_TArray]):
|
|
50
42
|
def _read_file(self, path: str) -> _TArray: ...
|
51
43
|
|
52
44
|
|
53
|
-
class
|
54
|
-
|
45
|
+
class Dataset(Generic[_T_co]):
|
46
|
+
"""Abstract generic base class for PyTorch style Dataset"""
|
47
|
+
|
48
|
+
def __getitem__(self, index: int) -> _T_co: ...
|
49
|
+
def __add__(self, other: Dataset[_T_co]) -> Dataset[_T_co]: ...
|
50
|
+
|
51
|
+
|
52
|
+
class BaseDataset(Dataset[tuple[_TArray, _TTarget, DatumMetadata]]):
|
53
|
+
metadata: DatasetMetadata
|
54
|
+
|
55
|
+
def __init__(
|
56
|
+
self,
|
57
|
+
transforms: Callable[[_TArray], _TArray]
|
58
|
+
| Callable[
|
59
|
+
[tuple[_TArray, _TTarget, DatumMetadata]],
|
60
|
+
tuple[_TArray, _TTarget, DatumMetadata],
|
61
|
+
]
|
62
|
+
| Sequence[
|
63
|
+
Callable[[_TArray], _TArray]
|
64
|
+
| Callable[
|
65
|
+
[tuple[_TArray, _TTarget, DatumMetadata]],
|
66
|
+
tuple[_TArray, _TTarget, DatumMetadata],
|
67
|
+
]
|
68
|
+
]
|
69
|
+
| None,
|
70
|
+
) -> None:
|
71
|
+
self.transforms: Sequence[
|
72
|
+
Callable[
|
73
|
+
[tuple[_TArray, _TTarget, DatumMetadata]],
|
74
|
+
tuple[_TArray, _TTarget, DatumMetadata],
|
75
|
+
]
|
76
|
+
] = []
|
77
|
+
transforms = transforms if isinstance(transforms, Sequence) else [transforms] if transforms else []
|
78
|
+
for transform in transforms:
|
79
|
+
sig = inspect.signature(transform)
|
80
|
+
if len(sig.parameters) != 1:
|
81
|
+
warnings.warn(f"Dropping unrecognized transform: {str(transform)}")
|
82
|
+
elif "tuple" in str(sig.parameters.values()):
|
83
|
+
transform = cast(
|
84
|
+
Callable[
|
85
|
+
[tuple[_TArray, _TTarget, DatumMetadata]],
|
86
|
+
tuple[_TArray, _TTarget, DatumMetadata],
|
87
|
+
],
|
88
|
+
transform,
|
89
|
+
)
|
90
|
+
self.transforms.append(transform)
|
91
|
+
else:
|
92
|
+
transform = cast(Callable[[_TArray], _TArray], transform)
|
93
|
+
self.transforms.append(self._wrap_transform(transform))
|
94
|
+
|
95
|
+
def _wrap_transform(
|
96
|
+
self, transform: Callable[[_TArray], _TArray]
|
97
|
+
) -> Callable[
|
98
|
+
[tuple[_TArray, _TTarget, DatumMetadata]],
|
99
|
+
tuple[_TArray, _TTarget, DatumMetadata],
|
100
|
+
]:
|
101
|
+
def wrapper(
|
102
|
+
datum: tuple[_TArray, _TTarget, DatumMetadata],
|
103
|
+
) -> tuple[_TArray, _TTarget, DatumMetadata]:
|
104
|
+
image, target, metadata = datum
|
105
|
+
return (transform(image), target, metadata)
|
106
|
+
|
107
|
+
return wrapper
|
108
|
+
|
109
|
+
def _transform(self, datum: tuple[_TArray, _TTarget, DatumMetadata]) -> tuple[_TArray, _TTarget, DatumMetadata]:
|
110
|
+
"""Function to transform the image prior to returning based on parameters passed in."""
|
111
|
+
for transform in self.transforms:
|
112
|
+
datum = transform(datum)
|
113
|
+
return datum
|
114
|
+
|
115
|
+
def __len__(self) -> int: ...
|
116
|
+
|
117
|
+
def __str__(self) -> str:
|
118
|
+
nt = "\n "
|
119
|
+
title = f"{self.__class__.__name__.replace('Dataset', '')} Dataset"
|
120
|
+
sep = "-" * len(title)
|
121
|
+
attrs = [
|
122
|
+
f"{' '.join(w.capitalize() for w in k.split('_'))}: {v}"
|
123
|
+
for k, v in self.__dict__.items()
|
124
|
+
if not k.startswith("_")
|
125
|
+
]
|
126
|
+
return f"{title}\n{sep}{nt}{nt.join(attrs)}"
|
127
|
+
|
128
|
+
|
129
|
+
class DataLocation(NamedTuple):
|
130
|
+
url: str
|
131
|
+
filename: str
|
132
|
+
md5: bool
|
133
|
+
checksum: str
|
134
|
+
|
135
|
+
|
136
|
+
class BaseDownloadedDataset(
|
137
|
+
BaseDataset[_TArray, _TTarget],
|
55
138
|
Generic[_TArray, _TTarget, _TRawTarget, _TAnnotation],
|
56
139
|
):
|
57
140
|
"""
|
@@ -72,13 +155,24 @@ class BaseDataset(
|
|
72
155
|
self,
|
73
156
|
root: str | Path,
|
74
157
|
image_set: Literal["train", "val", "test", "operational", "base"] = "train",
|
75
|
-
transforms:
|
158
|
+
transforms: Callable[[_TArray], _TArray]
|
159
|
+
| Callable[
|
160
|
+
[tuple[_TArray, _TTarget, DatumMetadata]],
|
161
|
+
tuple[_TArray, _TTarget, DatumMetadata],
|
162
|
+
]
|
163
|
+
| Sequence[
|
164
|
+
Callable[[_TArray], _TArray]
|
165
|
+
| Callable[
|
166
|
+
[tuple[_TArray, _TTarget, DatumMetadata]],
|
167
|
+
tuple[_TArray, _TTarget, DatumMetadata],
|
168
|
+
]
|
169
|
+
]
|
170
|
+
| None = None,
|
76
171
|
download: bool = False,
|
77
172
|
verbose: bool = False,
|
78
173
|
) -> None:
|
174
|
+
super().__init__(transforms)
|
79
175
|
self._root: Path = root.absolute() if isinstance(root, Path) else Path(root).absolute()
|
80
|
-
transforms = transforms if transforms is not None else []
|
81
|
-
self.transforms: Sequence[Transform[_TArray]] = transforms if isinstance(transforms, Sequence) else [transforms]
|
82
176
|
self.image_set = image_set
|
83
177
|
self._verbose = verbose
|
84
178
|
|
@@ -91,9 +185,11 @@ class BaseDataset(
|
|
91
185
|
self._label2index = {v: k for k, v in self.index2label.items()}
|
92
186
|
|
93
187
|
self.metadata: DatasetMetadata = DatasetMetadata(
|
94
|
-
|
95
|
-
|
96
|
-
|
188
|
+
**{
|
189
|
+
"id": self._unique_id(),
|
190
|
+
"index2label": self.index2label,
|
191
|
+
"split": self.image_set,
|
192
|
+
}
|
97
193
|
)
|
98
194
|
|
99
195
|
# Load the data
|
@@ -101,13 +197,6 @@ class BaseDataset(
|
|
101
197
|
self._filepaths, self._targets, self._datum_metadata = self._load_data()
|
102
198
|
self.size: int = len(self._filepaths)
|
103
199
|
|
104
|
-
def __str__(self) -> str:
|
105
|
-
nt = "\n "
|
106
|
-
title = f"{self.__class__.__name__} Dataset"
|
107
|
-
sep = "-" * len(title)
|
108
|
-
attrs = [f"{k.capitalize()}: {v}" for k, v in self.__dict__.items() if not k.startswith("_")]
|
109
|
-
return f"{title}\n{sep}{nt}{nt.join(attrs)}"
|
110
|
-
|
111
200
|
@property
|
112
201
|
def label2index(self) -> dict[str, int]:
|
113
202
|
return self._label2index
|
@@ -148,20 +237,18 @@ class BaseDataset(
|
|
148
237
|
@abstractmethod
|
149
238
|
def _load_data_inner(self) -> tuple[list[str], _TRawTarget, dict[str, Any]]: ...
|
150
239
|
|
151
|
-
def
|
152
|
-
|
153
|
-
|
154
|
-
image = transform(image)
|
155
|
-
return image
|
240
|
+
def _to_datum_metadata(self, index: int, metadata: dict[str, Any]) -> DatumMetadata:
|
241
|
+
_id = metadata.pop("id", index)
|
242
|
+
return DatumMetadata(id=_id, **metadata)
|
156
243
|
|
157
244
|
def __len__(self) -> int:
|
158
245
|
return self.size
|
159
246
|
|
160
247
|
|
161
248
|
class BaseICDataset(
|
162
|
-
|
249
|
+
BaseDownloadedDataset[_TArray, _TArray, list[int], int],
|
163
250
|
BaseDatasetMixin[_TArray],
|
164
|
-
|
251
|
+
BaseDataset[_TArray, _TArray],
|
165
252
|
):
|
166
253
|
"""
|
167
254
|
Base class for image classification datasets.
|
@@ -184,17 +271,16 @@ class BaseICDataset(
|
|
184
271
|
score = self._one_hot_encode(label)
|
185
272
|
# Get the image
|
186
273
|
img = self._read_file(self._filepaths[index])
|
187
|
-
img = self._transform(img)
|
188
274
|
|
189
275
|
img_metadata = {key: val[index] for key, val in self._datum_metadata.items()}
|
190
276
|
|
191
|
-
return img, score, _to_datum_metadata(index, img_metadata)
|
277
|
+
return self._transform((img, score, self._to_datum_metadata(index, img_metadata)))
|
192
278
|
|
193
279
|
|
194
280
|
class BaseODDataset(
|
195
|
-
|
281
|
+
BaseDownloadedDataset[_TArray, ObjectDetectionTarget, _TRawTarget, _TAnnotation],
|
196
282
|
BaseDatasetMixin[_TArray],
|
197
|
-
|
283
|
+
BaseDataset[_TArray, ObjectDetectionTarget],
|
198
284
|
):
|
199
285
|
"""
|
200
286
|
Base class for object detection datasets.
|
@@ -202,7 +288,7 @@ class BaseODDataset(
|
|
202
288
|
|
203
289
|
_bboxes_per_size: bool = False
|
204
290
|
|
205
|
-
def __getitem__(self, index: int) -> tuple[_TArray, ObjectDetectionTarget
|
291
|
+
def __getitem__(self, index: int) -> tuple[_TArray, ObjectDetectionTarget, DatumMetadata]:
|
206
292
|
"""
|
207
293
|
Args
|
208
294
|
----
|
@@ -211,7 +297,7 @@ class BaseODDataset(
|
|
211
297
|
|
212
298
|
Returns
|
213
299
|
-------
|
214
|
-
tuple[TArray, ObjectDetectionTarget
|
300
|
+
tuple[TArray, ObjectDetectionTarget, DatumMetadata]
|
215
301
|
Image, target, datum_metadata - target.boxes returns boxes in x0, y0, x1, y1 format
|
216
302
|
"""
|
217
303
|
# Grab the bounding boxes and labels from the annotations
|
@@ -220,17 +306,49 @@ class BaseODDataset(
|
|
220
306
|
# Get the image
|
221
307
|
img = self._read_file(self._filepaths[index])
|
222
308
|
img_size = img.shape
|
223
|
-
img = self._transform(img)
|
224
309
|
# Adjust labels if necessary
|
225
310
|
if self._bboxes_per_size and boxes:
|
226
|
-
boxes = boxes * np.
|
311
|
+
boxes = boxes * np.asarray([[img_size[1], img_size[2], img_size[1], img_size[2]]])
|
227
312
|
# Create the Object Detection Target
|
228
313
|
target = ObjectDetectionTarget(self._as_array(boxes), self._as_array(labels), self._one_hot_encode(labels))
|
229
314
|
|
230
315
|
img_metadata = {key: val[index] for key, val in self._datum_metadata.items()}
|
231
316
|
img_metadata = img_metadata | additional_metadata
|
232
317
|
|
233
|
-
return img, target, _to_datum_metadata(index, img_metadata)
|
318
|
+
return self._transform((img, target, self._to_datum_metadata(index, img_metadata)))
|
234
319
|
|
235
320
|
@abstractmethod
|
236
321
|
def _read_annotations(self, annotation: _TAnnotation) -> tuple[list[list[float]], list[int], dict[str, Any]]: ...
|
322
|
+
|
323
|
+
|
324
|
+
NumpyArray = NDArray[np.floating[Any]] | NDArray[np.integer[Any]]
|
325
|
+
|
326
|
+
|
327
|
+
class BaseDatasetNumpyMixin(BaseDatasetMixin[NumpyArray]):
|
328
|
+
def _as_array(self, raw: list[Any]) -> NumpyArray:
|
329
|
+
return np.asarray(raw)
|
330
|
+
|
331
|
+
def _one_hot_encode(self, value: int | list[int]) -> NumpyArray:
|
332
|
+
if isinstance(value, int):
|
333
|
+
encoded = np.zeros(len(self.index2label))
|
334
|
+
encoded[value] = 1
|
335
|
+
else:
|
336
|
+
encoded = np.zeros((len(value), len(self.index2label)))
|
337
|
+
encoded[np.arange(len(value)), value] = 1
|
338
|
+
return encoded
|
339
|
+
|
340
|
+
def _read_file(self, path: str) -> NumpyArray:
|
341
|
+
return np.array(Image.open(path)).transpose(2, 0, 1)
|
342
|
+
|
343
|
+
|
344
|
+
NumpyImageTransform = Callable[[NumpyArray], NumpyArray]
|
345
|
+
NumpyImageClassificationDatumTransform = Callable[
|
346
|
+
[tuple[NumpyArray, NumpyArray, DatumMetadata]],
|
347
|
+
tuple[NumpyArray, NumpyArray, DatumMetadata],
|
348
|
+
]
|
349
|
+
NumpyObjectDetectionDatumTransform = Callable[
|
350
|
+
[tuple[NumpyArray, ObjectDetectionTarget, DatumMetadata]],
|
351
|
+
tuple[NumpyArray, ObjectDetectionTarget, DatumMetadata],
|
352
|
+
]
|
353
|
+
NumpyImageClassificationTransform = NumpyImageTransform | NumpyImageClassificationDatumTransform
|
354
|
+
NumpyObjectDetectionTransform = NumpyImageTransform | NumpyObjectDetectionDatumTransform
|