maite-datasets 0.0.2__py3-none-any.whl → 0.0.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- maite_datasets/__init__.py +10 -1
- maite_datasets/_collate.py +112 -0
- maite_datasets/_validate.py +169 -0
- {maite_datasets-0.0.2.dist-info → maite_datasets-0.0.3.dist-info}/METADATA +1 -1
- {maite_datasets-0.0.2.dist-info → maite_datasets-0.0.3.dist-info}/RECORD +7 -5
- {maite_datasets-0.0.2.dist-info → maite_datasets-0.0.3.dist-info}/WHEEL +0 -0
- {maite_datasets-0.0.2.dist-info → maite_datasets-0.0.3.dist-info}/licenses/LICENSE +0 -0
maite_datasets/__init__.py
CHANGED
@@ -1,5 +1,14 @@
|
|
1
1
|
"""Module for MAITE compliant Computer Vision datasets."""
|
2
2
|
|
3
3
|
from maite_datasets._builder import to_image_classification_dataset, to_object_detection_dataset
|
4
|
+
from maite_datasets._collate import collate_as_torch, collate_as_numpy, collate_as_list
|
5
|
+
from maite_datasets._validate import validate_dataset
|
4
6
|
|
5
|
-
__all__ = [
|
7
|
+
__all__ = [
|
8
|
+
"collate_as_list",
|
9
|
+
"collate_as_numpy",
|
10
|
+
"collate_as_torch",
|
11
|
+
"to_image_classification_dataset",
|
12
|
+
"to_object_detection_dataset",
|
13
|
+
"validate_dataset",
|
14
|
+
]
|
@@ -0,0 +1,112 @@
|
|
1
|
+
"""
|
2
|
+
Collate functions used with a PyTorch DataLoader to load data from MAITE compliant datasets.
|
3
|
+
"""
|
4
|
+
|
5
|
+
from __future__ import annotations
|
6
|
+
|
7
|
+
__all__ = []
|
8
|
+
|
9
|
+
from collections.abc import Iterable, Sequence
|
10
|
+
from typing import Any, TypeVar, TYPE_CHECKING
|
11
|
+
|
12
|
+
import numpy as np
|
13
|
+
from numpy.typing import NDArray
|
14
|
+
|
15
|
+
if TYPE_CHECKING:
|
16
|
+
import torch
|
17
|
+
|
18
|
+
from maite_datasets._protocols import ArrayLike
|
19
|
+
|
20
|
+
T_in = TypeVar("T_in")
|
21
|
+
T_tgt = TypeVar("T_tgt")
|
22
|
+
T_md = TypeVar("T_md")
|
23
|
+
|
24
|
+
|
25
|
+
def collate_as_list(
|
26
|
+
batch_data_as_singles: Iterable[tuple[T_in, T_tgt, T_md]],
|
27
|
+
) -> tuple[Sequence[T_in], Sequence[T_tgt], Sequence[T_md]]:
|
28
|
+
"""
|
29
|
+
A collate function that takes a batch of individual data points in the format
|
30
|
+
(input, target, metadata) and returns three lists: the input batch, the target batch,
|
31
|
+
and the metadata batch. This is useful for loading data with torch.utils.data.DataLoader
|
32
|
+
when the target and metadata are not tensors.
|
33
|
+
|
34
|
+
Parameters
|
35
|
+
----------
|
36
|
+
batch_data_as_singles : An iterable of (input, target, metadata) tuples.
|
37
|
+
|
38
|
+
Returns
|
39
|
+
-------
|
40
|
+
tuple[Sequence[T_in], Sequence[T_tgt], Sequence[T_md]]
|
41
|
+
A tuple of three lists: the input batch, the target batch, and the metadata batch.
|
42
|
+
"""
|
43
|
+
input_batch: list[T_in] = []
|
44
|
+
target_batch: list[T_tgt] = []
|
45
|
+
metadata_batch: list[T_md] = []
|
46
|
+
for input_datum, target_datum, metadata_datum in batch_data_as_singles:
|
47
|
+
input_batch.append(input_datum)
|
48
|
+
target_batch.append(target_datum)
|
49
|
+
metadata_batch.append(metadata_datum)
|
50
|
+
|
51
|
+
return input_batch, target_batch, metadata_batch
|
52
|
+
|
53
|
+
|
54
|
+
def collate_as_numpy(
|
55
|
+
batch_data_as_singles: Iterable[tuple[ArrayLike, T_tgt, T_md]],
|
56
|
+
) -> tuple[NDArray[Any], Sequence[T_tgt], Sequence[T_md]]:
|
57
|
+
"""
|
58
|
+
A collate function that takes a batch of individual data points in the format
|
59
|
+
(input, target, metadata) and returns the batched input as a single NumPy array with two
|
60
|
+
lists: the target batch, and the metadata batch. The inputs must be homogeneous arrays.
|
61
|
+
|
62
|
+
Parameters
|
63
|
+
----------
|
64
|
+
batch_data_as_singles : An iterable of (ArrayLike, target, metadata) tuples.
|
65
|
+
|
66
|
+
Returns
|
67
|
+
-------
|
68
|
+
tuple[NDArray[Any], Sequence[T_tgt], Sequence[T_md]]
|
69
|
+
A tuple of a NumPy array and two lists: the input batch, the target batch, and the metadata batch.
|
70
|
+
"""
|
71
|
+
input_batch: list[NDArray[Any]] = []
|
72
|
+
target_batch: list[T_tgt] = []
|
73
|
+
metadata_batch: list[T_md] = []
|
74
|
+
for input_datum, target_datum, metadata_datum in batch_data_as_singles:
|
75
|
+
input_batch.append(np.asarray(input_datum))
|
76
|
+
target_batch.append(target_datum)
|
77
|
+
metadata_batch.append(metadata_datum)
|
78
|
+
|
79
|
+
return np.stack(input_batch) if input_batch else np.array([]), target_batch, metadata_batch
|
80
|
+
|
81
|
+
|
82
|
+
def collate_as_torch(
|
83
|
+
batch_data_as_singles: Iterable[tuple[ArrayLike, T_tgt, T_md]],
|
84
|
+
) -> tuple[torch.Tensor, Sequence[T_tgt], Sequence[T_md]]:
|
85
|
+
"""
|
86
|
+
A collate function that takes a batch of individual data points in the format
|
87
|
+
(input, target, metadata) and returns the batched input as a single torch Tensor with two
|
88
|
+
lists: the target batch, and the metadata batch. The inputs must be homogeneous arrays.
|
89
|
+
|
90
|
+
Parameters
|
91
|
+
----------
|
92
|
+
batch_data_as_singles : An iterable of (ArrayLike, target, metadata) tuples.
|
93
|
+
|
94
|
+
Returns
|
95
|
+
-------
|
96
|
+
tuple[torch.Tensor, Sequence[T_tgt], Sequence[T_md]]
|
97
|
+
A tuple of a torch Tensor and two lists: the input batch, the target batch, and the metadata batch.
|
98
|
+
"""
|
99
|
+
try:
|
100
|
+
import torch
|
101
|
+
except ImportError:
|
102
|
+
raise ImportError("PyTorch is not installed. Please install it to use this function.")
|
103
|
+
|
104
|
+
input_batch: list[torch.Tensor] = []
|
105
|
+
target_batch: list[T_tgt] = []
|
106
|
+
metadata_batch: list[T_md] = []
|
107
|
+
for input_datum, target_datum, metadata_datum in batch_data_as_singles:
|
108
|
+
input_batch.append(torch.as_tensor(input_datum))
|
109
|
+
target_batch.append(target_datum)
|
110
|
+
metadata_batch.append(metadata_datum)
|
111
|
+
|
112
|
+
return torch.stack(input_batch) if input_batch else torch.tensor([]), target_batch, metadata_batch
|
@@ -0,0 +1,169 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
__all__ = []
|
4
|
+
|
5
|
+
import numpy as np
|
6
|
+
from collections.abc import Sequence, Sized
|
7
|
+
from typing import Any, Literal
|
8
|
+
|
9
|
+
from maite_datasets._protocols import Array, ObjectDetectionTarget
|
10
|
+
|
11
|
+
|
12
|
+
class ValidationMessages:
|
13
|
+
DATASET_SIZED = "Dataset must be sized."
|
14
|
+
DATASET_INDEXABLE = "Dataset must be indexable."
|
15
|
+
DATASET_NONEMPTY = "Dataset must be non-empty."
|
16
|
+
DATASET_METADATA = "Dataset must have a 'metadata' attribute."
|
17
|
+
DATASET_METADATA_TYPE = "Dataset metadata must be a dictionary."
|
18
|
+
DATASET_METADATA_FORMAT = "Dataset metadata must contain an 'id' key."
|
19
|
+
DATUM_TYPE = "Dataset datum must be a tuple."
|
20
|
+
DATUM_FORMAT = "Dataset datum must contain 3 elements: image, target, metadata."
|
21
|
+
DATUM_IMAGE_TYPE = "Images must be 3-dimensional arrays."
|
22
|
+
DATUM_IMAGE_FORMAT = "Images must be in CHW format."
|
23
|
+
DATUM_TARGET_IC_TYPE = "ImageClassificationDataset targets must be one-dimensional arrays."
|
24
|
+
DATUM_TARGET_IC_FORMAT = "ImageClassificationDataset targets must be one-hot encoded or pseudo-probabilities."
|
25
|
+
DATUM_TARGET_OD_TYPE = "ObjectDetectionDataset targets must be have 'boxes', 'labels' and 'scores'."
|
26
|
+
DATUM_TARGET_OD_LABELS_TYPE = "ObjectDetectionTarget labels must be one-dimensional (N,) arrays."
|
27
|
+
DATUM_TARGET_OD_BOXES_TYPE = "ObjectDetectionTarget boxes must be two-dimensional (N, 4) arrays in xxyy format."
|
28
|
+
DATUM_TARGET_OD_SCORES_TYPE = "ObjectDetectionTarget scores must be one (N,) or two-dimensional (N, M) arrays."
|
29
|
+
DATUM_TARGET_TYPE = "Target is not a valid ImageClassification or ObjectDetection target type."
|
30
|
+
DATUM_METADATA_TYPE = "Datum metadata must be a dictionary."
|
31
|
+
DATUM_METADATA_FORMAT = "Datum metadata must contain an 'id' key."
|
32
|
+
|
33
|
+
|
34
|
+
def _validate_dataset_type(dataset: Any) -> list[str]:
|
35
|
+
issues = []
|
36
|
+
is_sized = isinstance(dataset, Sized)
|
37
|
+
is_indexable = hasattr(dataset, "__getitem__")
|
38
|
+
if not is_sized:
|
39
|
+
issues.append(ValidationMessages.DATASET_SIZED)
|
40
|
+
if not is_indexable:
|
41
|
+
issues.append(ValidationMessages.DATASET_INDEXABLE)
|
42
|
+
if is_sized and len(dataset) == 0:
|
43
|
+
issues.append(ValidationMessages.DATASET_NONEMPTY)
|
44
|
+
return issues
|
45
|
+
|
46
|
+
|
47
|
+
def _validate_dataset_metadata(dataset: Any) -> list[str]:
|
48
|
+
issues = []
|
49
|
+
if not hasattr(dataset, "metadata"):
|
50
|
+
issues.append(ValidationMessages.DATASET_METADATA)
|
51
|
+
metadata = getattr(dataset, "metadata", None)
|
52
|
+
if not isinstance(metadata, dict):
|
53
|
+
issues.append(ValidationMessages.DATASET_METADATA_TYPE)
|
54
|
+
if not isinstance(metadata, dict) or "id" not in metadata:
|
55
|
+
issues.append(ValidationMessages.DATASET_METADATA_FORMAT)
|
56
|
+
return issues
|
57
|
+
|
58
|
+
|
59
|
+
def _validate_datum_type(datum: Any) -> list[str]:
|
60
|
+
issues = []
|
61
|
+
if not isinstance(datum, tuple):
|
62
|
+
issues.append(ValidationMessages.DATUM_TYPE)
|
63
|
+
if datum is None or isinstance(datum, Sized) and len(datum) != 3:
|
64
|
+
issues.append(ValidationMessages.DATUM_FORMAT)
|
65
|
+
return issues
|
66
|
+
|
67
|
+
|
68
|
+
def _validate_datum_image(image: Any) -> list[str]:
|
69
|
+
issues = []
|
70
|
+
if not isinstance(image, Array) or len(image.shape) != 3:
|
71
|
+
issues.append(ValidationMessages.DATUM_IMAGE_TYPE)
|
72
|
+
if (
|
73
|
+
not isinstance(image, Array)
|
74
|
+
or len(image.shape) == 3
|
75
|
+
and (image.shape[0] > image.shape[1] or image.shape[0] > image.shape[2])
|
76
|
+
):
|
77
|
+
issues.append(ValidationMessages.DATUM_IMAGE_FORMAT)
|
78
|
+
return issues
|
79
|
+
|
80
|
+
|
81
|
+
def _validate_datum_target_ic(target: Any) -> list[str]:
|
82
|
+
issues = []
|
83
|
+
if not isinstance(target, Array) or len(target.shape) != 1:
|
84
|
+
issues.append(ValidationMessages.DATUM_TARGET_IC_TYPE)
|
85
|
+
if target is None or sum(target) > 1 + 1e-6 or sum(target) < 1 - 1e-6:
|
86
|
+
issues.append(ValidationMessages.DATUM_TARGET_IC_FORMAT)
|
87
|
+
return issues
|
88
|
+
|
89
|
+
|
90
|
+
def _validate_datum_target_od(target: Any) -> list[str]:
|
91
|
+
issues = []
|
92
|
+
if not isinstance(target, ObjectDetectionTarget):
|
93
|
+
issues.append(ValidationMessages.DATUM_TARGET_OD_TYPE)
|
94
|
+
od_target: ObjectDetectionTarget | None = target if isinstance(target, ObjectDetectionTarget) else None
|
95
|
+
if od_target is None or len(np.asarray(od_target.labels).shape) != 1:
|
96
|
+
issues.append(ValidationMessages.DATUM_TARGET_OD_LABELS_TYPE)
|
97
|
+
if (
|
98
|
+
od_target is None
|
99
|
+
or len(np.asarray(od_target.boxes).shape) != 2
|
100
|
+
or (len(np.asarray(od_target.boxes).shape) == 2 and np.asarray(od_target.boxes).shape[1] != 4)
|
101
|
+
):
|
102
|
+
issues.append(ValidationMessages.DATUM_TARGET_OD_BOXES_TYPE)
|
103
|
+
if od_target is None or len(np.asarray(od_target.scores).shape) not in (1, 2):
|
104
|
+
issues.append(ValidationMessages.DATUM_TARGET_OD_SCORES_TYPE)
|
105
|
+
return issues
|
106
|
+
|
107
|
+
|
108
|
+
def _detect_target_type(target: Any) -> Literal["ic", "od", "auto"]:
|
109
|
+
if isinstance(target, Array):
|
110
|
+
return "ic"
|
111
|
+
if isinstance(target, ObjectDetectionTarget):
|
112
|
+
return "od"
|
113
|
+
return "auto"
|
114
|
+
|
115
|
+
|
116
|
+
def _validate_datum_target(target: Any, target_type: Literal["ic", "od", "auto"]) -> list[str]:
|
117
|
+
issues = []
|
118
|
+
target_type = _detect_target_type(target) if target_type == "auto" else target_type
|
119
|
+
if target_type == "ic":
|
120
|
+
issues.extend(_validate_datum_target_ic(target))
|
121
|
+
elif target_type == "od":
|
122
|
+
issues.extend(_validate_datum_target_od(target))
|
123
|
+
else:
|
124
|
+
issues.append(ValidationMessages.DATUM_TARGET_TYPE)
|
125
|
+
return issues
|
126
|
+
|
127
|
+
|
128
|
+
def _validate_datum_metadata(metadata: Any) -> list[str]:
|
129
|
+
issues = []
|
130
|
+
if metadata is None or not isinstance(metadata, dict):
|
131
|
+
issues.append(ValidationMessages.DATUM_METADATA_TYPE)
|
132
|
+
if metadata is None or isinstance(metadata, dict) and "id" not in metadata:
|
133
|
+
issues.append(ValidationMessages.DATUM_METADATA_FORMAT)
|
134
|
+
return issues
|
135
|
+
|
136
|
+
|
137
|
+
def validate_dataset(dataset: Any, dataset_type: Literal["ic", "od", "auto"] = "auto") -> None:
|
138
|
+
"""
|
139
|
+
Validate a dataset for compliance with MAITE protocol.
|
140
|
+
|
141
|
+
Parameters
|
142
|
+
----------
|
143
|
+
dataset: Any
|
144
|
+
Dataset to validate.
|
145
|
+
dataset_type: "ic", "od", or "auto", default "auto"
|
146
|
+
Dataset type, if known.
|
147
|
+
|
148
|
+
Raises
|
149
|
+
------
|
150
|
+
ValueError
|
151
|
+
Raises exception if dataset is invalid with a list of validation issues.
|
152
|
+
"""
|
153
|
+
issues = []
|
154
|
+
issues.extend(_validate_dataset_type(dataset))
|
155
|
+
datum = None if issues else dataset[0] # type: ignore
|
156
|
+
issues.extend(_validate_dataset_metadata(dataset))
|
157
|
+
issues.extend(_validate_datum_type(datum))
|
158
|
+
|
159
|
+
is_seq = isinstance(datum, Sequence)
|
160
|
+
datum_len = len(datum) if is_seq else 0
|
161
|
+
image = datum[0] if is_seq and datum_len > 0 else None
|
162
|
+
target = datum[1] if is_seq and datum_len > 1 else None
|
163
|
+
metadata = datum[2] if is_seq and datum_len > 2 else None
|
164
|
+
issues.extend(_validate_datum_image(image))
|
165
|
+
issues.extend(_validate_datum_target(target, dataset_type))
|
166
|
+
issues.extend(_validate_datum_metadata(metadata))
|
167
|
+
|
168
|
+
if issues:
|
169
|
+
raise ValueError("Dataset validation issues found:\n - " + "\n - ".join(issues))
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: maite-datasets
|
3
|
-
Version: 0.0.
|
3
|
+
Version: 0.0.3
|
4
4
|
Summary: A collection of Image Classification and Object Detection task datasets conforming to the MAITE protocol.
|
5
5
|
Author-email: Andrew Weng <andrew.weng@ariacoustics.com>, Ryan Wood <ryan.wood@ariacoustics.com>, Shaun Jullens <shaun.jullens@ariacoustics.com>
|
6
6
|
License-Expression: MIT
|
@@ -1,9 +1,11 @@
|
|
1
|
-
maite_datasets/__init__.py,sha256=
|
1
|
+
maite_datasets/__init__.py,sha256=81LNxx03O7FzWNZQbIrSovDrdpO_x74WkLPKBJy91gU,483
|
2
2
|
maite_datasets/_base.py,sha256=BiWB_xvL4AtV0jxVjzpcZHuRTb52dTD0CQtu08DzoXA,8195
|
3
3
|
maite_datasets/_builder.py,sha256=URhRCedvuqsy88N4lzQrwI-uL1kS1_kavP9fS402sPw,10036
|
4
|
+
maite_datasets/_collate.py,sha256=-XuKeeMmOnSB0RgQbz8BjsoqQar9Tsf_qALZxijQ498,4063
|
4
5
|
maite_datasets/_fileio.py,sha256=7S-hF3xU60AdcsPsfYR7rjbeGZUlv3JjGEZhGJOxGYU,5622
|
5
6
|
maite_datasets/_protocols.py,sha256=uwnI2P-zJnpEHJ0eOJ7dO_7KehwHEtEqR4pYcJiEXNk,5312
|
6
7
|
maite_datasets/_types.py,sha256=S5DMyiUrkUjV9uM0ysKqxVoi7z5P7B3EPiLI4Fyq9Jc,1147
|
8
|
+
maite_datasets/_validate.py,sha256=sP-5lYXkmkiTadJcy_LtEMiZ0m82xR0yELoxWORrZDQ,6904
|
7
9
|
maite_datasets/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
8
10
|
maite_datasets/_mixin/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
9
11
|
maite_datasets/_mixin/_numpy.py,sha256=GEuRyeprH-STh-_zktAp0Tg6NNyMdh1ThyhjW558NOo,860
|
@@ -18,7 +20,7 @@ maite_datasets/object_detection/_milco.py,sha256=KEU4JFvCxfyMAb4RFMnxTMk_MggdEAV
|
|
18
20
|
maite_datasets/object_detection/_seadrone.py,sha256=w_pSojLzgwdKrUSxaz8r7dPJVKGND6JSYl0S_BKOLH0,271282
|
19
21
|
maite_datasets/object_detection/_voc.py,sha256=VuokKaOzI1wSfgG5DC7ufMbRDlG-b6Se3hg4eQzNQbE,19731
|
20
22
|
maite_datasets/object_detection/_voc_torch.py,sha256=bjeawnNit7Llcf_cZY_9lcJYoUoAU-Wen6MMT-7QX3k,2917
|
21
|
-
maite_datasets-0.0.
|
22
|
-
maite_datasets-0.0.
|
23
|
-
maite_datasets-0.0.
|
24
|
-
maite_datasets-0.0.
|
23
|
+
maite_datasets-0.0.3.dist-info/METADATA,sha256=hoOvbKjGriS10siM8HsRvepA3nfi-QgUcrpjGsHr1lM,3747
|
24
|
+
maite_datasets-0.0.3.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
25
|
+
maite_datasets-0.0.3.dist-info/licenses/LICENSE,sha256=6h3J3R-ajGHh_isDSftzS5_jJjB9HH4TaI0vU-VscaY,1082
|
26
|
+
maite_datasets-0.0.3.dist-info/RECORD,,
|
File without changes
|
File without changes
|