maite-datasets 0.0.2__py3-none-any.whl → 0.0.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,5 +1,14 @@
1
1
  """Module for MAITE compliant Computer Vision datasets."""
2
2
 
3
3
  from maite_datasets._builder import to_image_classification_dataset, to_object_detection_dataset
4
+ from maite_datasets._collate import collate_as_torch, collate_as_numpy, collate_as_list
5
+ from maite_datasets._validate import validate_dataset
4
6
 
5
- __all__ = ["to_image_classification_dataset", "to_object_detection_dataset"]
7
+ __all__ = [
8
+ "collate_as_list",
9
+ "collate_as_numpy",
10
+ "collate_as_torch",
11
+ "to_image_classification_dataset",
12
+ "to_object_detection_dataset",
13
+ "validate_dataset",
14
+ ]
@@ -0,0 +1,112 @@
1
+ """
2
+ Collate functions used with a PyTorch DataLoader to load data from MAITE compliant datasets.
3
+ """
4
+
5
+ from __future__ import annotations
6
+
7
+ __all__ = []
8
+
9
+ from collections.abc import Iterable, Sequence
10
+ from typing import Any, TypeVar, TYPE_CHECKING
11
+
12
+ import numpy as np
13
+ from numpy.typing import NDArray
14
+
15
+ if TYPE_CHECKING:
16
+ import torch
17
+
18
+ from maite_datasets._protocols import ArrayLike
19
+
20
+ T_in = TypeVar("T_in")
21
+ T_tgt = TypeVar("T_tgt")
22
+ T_md = TypeVar("T_md")
23
+
24
+
25
+ def collate_as_list(
26
+ batch_data_as_singles: Iterable[tuple[T_in, T_tgt, T_md]],
27
+ ) -> tuple[Sequence[T_in], Sequence[T_tgt], Sequence[T_md]]:
28
+ """
29
+ A collate function that takes a batch of individual data points in the format
30
+ (input, target, metadata) and returns three lists: the input batch, the target batch,
31
+ and the metadata batch. This is useful for loading data with torch.utils.data.DataLoader
32
+ when the target and metadata are not tensors.
33
+
34
+ Parameters
35
+ ----------
36
+ batch_data_as_singles : An iterable of (input, target, metadata) tuples.
37
+
38
+ Returns
39
+ -------
40
+ tuple[Sequence[T_in], Sequence[T_tgt], Sequence[T_md]]
41
+ A tuple of three lists: the input batch, the target batch, and the metadata batch.
42
+ """
43
+ input_batch: list[T_in] = []
44
+ target_batch: list[T_tgt] = []
45
+ metadata_batch: list[T_md] = []
46
+ for input_datum, target_datum, metadata_datum in batch_data_as_singles:
47
+ input_batch.append(input_datum)
48
+ target_batch.append(target_datum)
49
+ metadata_batch.append(metadata_datum)
50
+
51
+ return input_batch, target_batch, metadata_batch
52
+
53
+
54
+ def collate_as_numpy(
55
+ batch_data_as_singles: Iterable[tuple[ArrayLike, T_tgt, T_md]],
56
+ ) -> tuple[NDArray[Any], Sequence[T_tgt], Sequence[T_md]]:
57
+ """
58
+ A collate function that takes a batch of individual data points in the format
59
+ (input, target, metadata) and returns the batched input as a single NumPy array with two
60
+ lists: the target batch, and the metadata batch. The inputs must be homogeneous arrays.
61
+
62
+ Parameters
63
+ ----------
64
+ batch_data_as_singles : An iterable of (ArrayLike, target, metadata) tuples.
65
+
66
+ Returns
67
+ -------
68
+ tuple[NDArray[Any], Sequence[T_tgt], Sequence[T_md]]
69
+ A tuple of a NumPy array and two lists: the input batch, the target batch, and the metadata batch.
70
+ """
71
+ input_batch: list[NDArray[Any]] = []
72
+ target_batch: list[T_tgt] = []
73
+ metadata_batch: list[T_md] = []
74
+ for input_datum, target_datum, metadata_datum in batch_data_as_singles:
75
+ input_batch.append(np.asarray(input_datum))
76
+ target_batch.append(target_datum)
77
+ metadata_batch.append(metadata_datum)
78
+
79
+ return np.stack(input_batch) if input_batch else np.array([]), target_batch, metadata_batch
80
+
81
+
82
+ def collate_as_torch(
83
+ batch_data_as_singles: Iterable[tuple[ArrayLike, T_tgt, T_md]],
84
+ ) -> tuple[torch.Tensor, Sequence[T_tgt], Sequence[T_md]]:
85
+ """
86
+ A collate function that takes a batch of individual data points in the format
87
+ (input, target, metadata) and returns the batched input as a single torch Tensor with two
88
+ lists: the target batch, and the metadata batch. The inputs must be homogeneous arrays.
89
+
90
+ Parameters
91
+ ----------
92
+ batch_data_as_singles : An iterable of (ArrayLike, target, metadata) tuples.
93
+
94
+ Returns
95
+ -------
96
+ tuple[torch.Tensor, Sequence[T_tgt], Sequence[T_md]]
97
+ A tuple of a torch Tensor and two lists: the input batch, the target batch, and the metadata batch.
98
+ """
99
+ try:
100
+ import torch
101
+ except ImportError:
102
+ raise ImportError("PyTorch is not installed. Please install it to use this function.")
103
+
104
+ input_batch: list[torch.Tensor] = []
105
+ target_batch: list[T_tgt] = []
106
+ metadata_batch: list[T_md] = []
107
+ for input_datum, target_datum, metadata_datum in batch_data_as_singles:
108
+ input_batch.append(torch.as_tensor(input_datum))
109
+ target_batch.append(target_datum)
110
+ metadata_batch.append(metadata_datum)
111
+
112
+ return torch.stack(input_batch) if input_batch else torch.tensor([]), target_batch, metadata_batch
@@ -0,0 +1,169 @@
1
+ from __future__ import annotations
2
+
3
+ __all__ = []
4
+
5
+ import numpy as np
6
+ from collections.abc import Sequence, Sized
7
+ from typing import Any, Literal
8
+
9
+ from maite_datasets._protocols import Array, ObjectDetectionTarget
10
+
11
+
12
+ class ValidationMessages:
13
+ DATASET_SIZED = "Dataset must be sized."
14
+ DATASET_INDEXABLE = "Dataset must be indexable."
15
+ DATASET_NONEMPTY = "Dataset must be non-empty."
16
+ DATASET_METADATA = "Dataset must have a 'metadata' attribute."
17
+ DATASET_METADATA_TYPE = "Dataset metadata must be a dictionary."
18
+ DATASET_METADATA_FORMAT = "Dataset metadata must contain an 'id' key."
19
+ DATUM_TYPE = "Dataset datum must be a tuple."
20
+ DATUM_FORMAT = "Dataset datum must contain 3 elements: image, target, metadata."
21
+ DATUM_IMAGE_TYPE = "Images must be 3-dimensional arrays."
22
+ DATUM_IMAGE_FORMAT = "Images must be in CHW format."
23
+ DATUM_TARGET_IC_TYPE = "ImageClassificationDataset targets must be one-dimensional arrays."
24
+ DATUM_TARGET_IC_FORMAT = "ImageClassificationDataset targets must be one-hot encoded or pseudo-probabilities."
25
+ DATUM_TARGET_OD_TYPE = "ObjectDetectionDataset targets must be have 'boxes', 'labels' and 'scores'."
26
+ DATUM_TARGET_OD_LABELS_TYPE = "ObjectDetectionTarget labels must be one-dimensional (N,) arrays."
27
+ DATUM_TARGET_OD_BOXES_TYPE = "ObjectDetectionTarget boxes must be two-dimensional (N, 4) arrays in xxyy format."
28
+ DATUM_TARGET_OD_SCORES_TYPE = "ObjectDetectionTarget scores must be one (N,) or two-dimensional (N, M) arrays."
29
+ DATUM_TARGET_TYPE = "Target is not a valid ImageClassification or ObjectDetection target type."
30
+ DATUM_METADATA_TYPE = "Datum metadata must be a dictionary."
31
+ DATUM_METADATA_FORMAT = "Datum metadata must contain an 'id' key."
32
+
33
+
34
+ def _validate_dataset_type(dataset: Any) -> list[str]:
35
+ issues = []
36
+ is_sized = isinstance(dataset, Sized)
37
+ is_indexable = hasattr(dataset, "__getitem__")
38
+ if not is_sized:
39
+ issues.append(ValidationMessages.DATASET_SIZED)
40
+ if not is_indexable:
41
+ issues.append(ValidationMessages.DATASET_INDEXABLE)
42
+ if is_sized and len(dataset) == 0:
43
+ issues.append(ValidationMessages.DATASET_NONEMPTY)
44
+ return issues
45
+
46
+
47
+ def _validate_dataset_metadata(dataset: Any) -> list[str]:
48
+ issues = []
49
+ if not hasattr(dataset, "metadata"):
50
+ issues.append(ValidationMessages.DATASET_METADATA)
51
+ metadata = getattr(dataset, "metadata", None)
52
+ if not isinstance(metadata, dict):
53
+ issues.append(ValidationMessages.DATASET_METADATA_TYPE)
54
+ if not isinstance(metadata, dict) or "id" not in metadata:
55
+ issues.append(ValidationMessages.DATASET_METADATA_FORMAT)
56
+ return issues
57
+
58
+
59
+ def _validate_datum_type(datum: Any) -> list[str]:
60
+ issues = []
61
+ if not isinstance(datum, tuple):
62
+ issues.append(ValidationMessages.DATUM_TYPE)
63
+ if datum is None or isinstance(datum, Sized) and len(datum) != 3:
64
+ issues.append(ValidationMessages.DATUM_FORMAT)
65
+ return issues
66
+
67
+
68
+ def _validate_datum_image(image: Any) -> list[str]:
69
+ issues = []
70
+ if not isinstance(image, Array) or len(image.shape) != 3:
71
+ issues.append(ValidationMessages.DATUM_IMAGE_TYPE)
72
+ if (
73
+ not isinstance(image, Array)
74
+ or len(image.shape) == 3
75
+ and (image.shape[0] > image.shape[1] or image.shape[0] > image.shape[2])
76
+ ):
77
+ issues.append(ValidationMessages.DATUM_IMAGE_FORMAT)
78
+ return issues
79
+
80
+
81
+ def _validate_datum_target_ic(target: Any) -> list[str]:
82
+ issues = []
83
+ if not isinstance(target, Array) or len(target.shape) != 1:
84
+ issues.append(ValidationMessages.DATUM_TARGET_IC_TYPE)
85
+ if target is None or sum(target) > 1 + 1e-6 or sum(target) < 1 - 1e-6:
86
+ issues.append(ValidationMessages.DATUM_TARGET_IC_FORMAT)
87
+ return issues
88
+
89
+
90
+ def _validate_datum_target_od(target: Any) -> list[str]:
91
+ issues = []
92
+ if not isinstance(target, ObjectDetectionTarget):
93
+ issues.append(ValidationMessages.DATUM_TARGET_OD_TYPE)
94
+ od_target: ObjectDetectionTarget | None = target if isinstance(target, ObjectDetectionTarget) else None
95
+ if od_target is None or len(np.asarray(od_target.labels).shape) != 1:
96
+ issues.append(ValidationMessages.DATUM_TARGET_OD_LABELS_TYPE)
97
+ if (
98
+ od_target is None
99
+ or len(np.asarray(od_target.boxes).shape) != 2
100
+ or (len(np.asarray(od_target.boxes).shape) == 2 and np.asarray(od_target.boxes).shape[1] != 4)
101
+ ):
102
+ issues.append(ValidationMessages.DATUM_TARGET_OD_BOXES_TYPE)
103
+ if od_target is None or len(np.asarray(od_target.scores).shape) not in (1, 2):
104
+ issues.append(ValidationMessages.DATUM_TARGET_OD_SCORES_TYPE)
105
+ return issues
106
+
107
+
108
+ def _detect_target_type(target: Any) -> Literal["ic", "od", "auto"]:
109
+ if isinstance(target, Array):
110
+ return "ic"
111
+ if isinstance(target, ObjectDetectionTarget):
112
+ return "od"
113
+ return "auto"
114
+
115
+
116
+ def _validate_datum_target(target: Any, target_type: Literal["ic", "od", "auto"]) -> list[str]:
117
+ issues = []
118
+ target_type = _detect_target_type(target) if target_type == "auto" else target_type
119
+ if target_type == "ic":
120
+ issues.extend(_validate_datum_target_ic(target))
121
+ elif target_type == "od":
122
+ issues.extend(_validate_datum_target_od(target))
123
+ else:
124
+ issues.append(ValidationMessages.DATUM_TARGET_TYPE)
125
+ return issues
126
+
127
+
128
+ def _validate_datum_metadata(metadata: Any) -> list[str]:
129
+ issues = []
130
+ if metadata is None or not isinstance(metadata, dict):
131
+ issues.append(ValidationMessages.DATUM_METADATA_TYPE)
132
+ if metadata is None or isinstance(metadata, dict) and "id" not in metadata:
133
+ issues.append(ValidationMessages.DATUM_METADATA_FORMAT)
134
+ return issues
135
+
136
+
137
+ def validate_dataset(dataset: Any, dataset_type: Literal["ic", "od", "auto"] = "auto") -> None:
138
+ """
139
+ Validate a dataset for compliance with MAITE protocol.
140
+
141
+ Parameters
142
+ ----------
143
+ dataset: Any
144
+ Dataset to validate.
145
+ dataset_type: "ic", "od", or "auto", default "auto"
146
+ Dataset type, if known.
147
+
148
+ Raises
149
+ ------
150
+ ValueError
151
+ Raises exception if dataset is invalid with a list of validation issues.
152
+ """
153
+ issues = []
154
+ issues.extend(_validate_dataset_type(dataset))
155
+ datum = None if issues else dataset[0] # type: ignore
156
+ issues.extend(_validate_dataset_metadata(dataset))
157
+ issues.extend(_validate_datum_type(datum))
158
+
159
+ is_seq = isinstance(datum, Sequence)
160
+ datum_len = len(datum) if is_seq else 0
161
+ image = datum[0] if is_seq and datum_len > 0 else None
162
+ target = datum[1] if is_seq and datum_len > 1 else None
163
+ metadata = datum[2] if is_seq and datum_len > 2 else None
164
+ issues.extend(_validate_datum_image(image))
165
+ issues.extend(_validate_datum_target(target, dataset_type))
166
+ issues.extend(_validate_datum_metadata(metadata))
167
+
168
+ if issues:
169
+ raise ValueError("Dataset validation issues found:\n - " + "\n - ".join(issues))
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: maite-datasets
3
- Version: 0.0.2
3
+ Version: 0.0.3
4
4
  Summary: A collection of Image Classification and Object Detection task datasets conforming to the MAITE protocol.
5
5
  Author-email: Andrew Weng <andrew.weng@ariacoustics.com>, Ryan Wood <ryan.wood@ariacoustics.com>, Shaun Jullens <shaun.jullens@ariacoustics.com>
6
6
  License-Expression: MIT
@@ -1,9 +1,11 @@
1
- maite_datasets/__init__.py,sha256=aM16hWPYR5WF0nx2AqTYHbGmibNTBCrYilcDKUs_yPo,235
1
+ maite_datasets/__init__.py,sha256=81LNxx03O7FzWNZQbIrSovDrdpO_x74WkLPKBJy91gU,483
2
2
  maite_datasets/_base.py,sha256=BiWB_xvL4AtV0jxVjzpcZHuRTb52dTD0CQtu08DzoXA,8195
3
3
  maite_datasets/_builder.py,sha256=URhRCedvuqsy88N4lzQrwI-uL1kS1_kavP9fS402sPw,10036
4
+ maite_datasets/_collate.py,sha256=-XuKeeMmOnSB0RgQbz8BjsoqQar9Tsf_qALZxijQ498,4063
4
5
  maite_datasets/_fileio.py,sha256=7S-hF3xU60AdcsPsfYR7rjbeGZUlv3JjGEZhGJOxGYU,5622
5
6
  maite_datasets/_protocols.py,sha256=uwnI2P-zJnpEHJ0eOJ7dO_7KehwHEtEqR4pYcJiEXNk,5312
6
7
  maite_datasets/_types.py,sha256=S5DMyiUrkUjV9uM0ysKqxVoi7z5P7B3EPiLI4Fyq9Jc,1147
8
+ maite_datasets/_validate.py,sha256=sP-5lYXkmkiTadJcy_LtEMiZ0m82xR0yELoxWORrZDQ,6904
7
9
  maite_datasets/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
8
10
  maite_datasets/_mixin/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
9
11
  maite_datasets/_mixin/_numpy.py,sha256=GEuRyeprH-STh-_zktAp0Tg6NNyMdh1ThyhjW558NOo,860
@@ -18,7 +20,7 @@ maite_datasets/object_detection/_milco.py,sha256=KEU4JFvCxfyMAb4RFMnxTMk_MggdEAV
18
20
  maite_datasets/object_detection/_seadrone.py,sha256=w_pSojLzgwdKrUSxaz8r7dPJVKGND6JSYl0S_BKOLH0,271282
19
21
  maite_datasets/object_detection/_voc.py,sha256=VuokKaOzI1wSfgG5DC7ufMbRDlG-b6Se3hg4eQzNQbE,19731
20
22
  maite_datasets/object_detection/_voc_torch.py,sha256=bjeawnNit7Llcf_cZY_9lcJYoUoAU-Wen6MMT-7QX3k,2917
21
- maite_datasets-0.0.2.dist-info/METADATA,sha256=O3RGToBWSFhEyi_iAdnc8pqYSVzNRXo_XjIQBOEIEWA,3747
22
- maite_datasets-0.0.2.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
23
- maite_datasets-0.0.2.dist-info/licenses/LICENSE,sha256=6h3J3R-ajGHh_isDSftzS5_jJjB9HH4TaI0vU-VscaY,1082
24
- maite_datasets-0.0.2.dist-info/RECORD,,
23
+ maite_datasets-0.0.3.dist-info/METADATA,sha256=hoOvbKjGriS10siM8HsRvepA3nfi-QgUcrpjGsHr1lM,3747
24
+ maite_datasets-0.0.3.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
25
+ maite_datasets-0.0.3.dist-info/licenses/LICENSE,sha256=6h3J3R-ajGHh_isDSftzS5_jJjB9HH4TaI0vU-VscaY,1082
26
+ maite_datasets-0.0.3.dist-info/RECORD,,