maite-datasets 0.0.1__tar.gz → 0.0.3__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {maite_datasets-0.0.1 → maite_datasets-0.0.3}/.gitignore +1 -0
- {maite_datasets-0.0.1 → maite_datasets-0.0.3}/PKG-INFO +1 -1
- {maite_datasets-0.0.1 → maite_datasets-0.0.3}/pyproject.toml +14 -0
- maite_datasets-0.0.3/src/maite_datasets/__init__.py +14 -0
- {maite_datasets-0.0.1 → maite_datasets-0.0.3}/src/maite_datasets/_base.py +8 -26
- maite_datasets-0.0.3/src/maite_datasets/_builder.py +275 -0
- maite_datasets-0.0.3/src/maite_datasets/_collate.py +112 -0
- {maite_datasets-0.0.1 → maite_datasets-0.0.3}/src/maite_datasets/_fileio.py +9 -31
- {maite_datasets-0.0.1 → maite_datasets-0.0.3}/src/maite_datasets/_protocols.py +1 -3
- {maite_datasets-0.0.1 → maite_datasets-0.0.3}/src/maite_datasets/_types.py +2 -6
- maite_datasets-0.0.3/src/maite_datasets/_validate.py +169 -0
- {maite_datasets-0.0.1 → maite_datasets-0.0.3}/src/maite_datasets/image_classification/_cifar10.py +5 -15
- {maite_datasets-0.0.1 → maite_datasets-0.0.3}/src/maite_datasets/image_classification/_mnist.py +6 -18
- {maite_datasets-0.0.1 → maite_datasets-0.0.3}/src/maite_datasets/image_classification/_ships.py +1 -3
- {maite_datasets-0.0.1 → maite_datasets-0.0.3}/src/maite_datasets/object_detection/_antiuav.py +6 -18
- {maite_datasets-0.0.1 → maite_datasets-0.0.3}/src/maite_datasets/object_detection/_milco.py +3 -9
- {maite_datasets-0.0.1 → maite_datasets-0.0.3}/src/maite_datasets/object_detection/_seadrone.py +10 -30
- {maite_datasets-0.0.1 → maite_datasets-0.0.3}/src/maite_datasets/object_detection/_voc.py +12 -36
- maite_datasets-0.0.1/src/maite_datasets/__init__.py +0 -1
- {maite_datasets-0.0.1 → maite_datasets-0.0.3}/LICENSE +0 -0
- {maite_datasets-0.0.1 → maite_datasets-0.0.3}/README.md +0 -0
- {maite_datasets-0.0.1 → maite_datasets-0.0.3}/src/maite_datasets/_mixin/__init__.py +0 -0
- {maite_datasets-0.0.1 → maite_datasets-0.0.3}/src/maite_datasets/_mixin/_numpy.py +0 -0
- {maite_datasets-0.0.1 → maite_datasets-0.0.3}/src/maite_datasets/_mixin/_torch.py +0 -0
- {maite_datasets-0.0.1 → maite_datasets-0.0.3}/src/maite_datasets/image_classification/__init__.py +0 -0
- {maite_datasets-0.0.1 → maite_datasets-0.0.3}/src/maite_datasets/object_detection/__init__.py +0 -0
- {maite_datasets-0.0.1 → maite_datasets-0.0.3}/src/maite_datasets/object_detection/_voc_torch.py +0 -0
- {maite_datasets-0.0.1 → maite_datasets-0.0.3}/src/maite_datasets/py.typed +0 -0
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: maite-datasets
|
3
|
-
Version: 0.0.
|
3
|
+
Version: 0.0.3
|
4
4
|
Summary: A collection of Image Classification and Object Detection task datasets conforming to the MAITE protocol.
|
5
5
|
Author-email: Andrew Weng <andrew.weng@ariacoustics.com>, Ryan Wood <ryan.wood@ariacoustics.com>, Shaun Jullens <shaun.jullens@ariacoustics.com>
|
6
6
|
License-Expression: MIT
|
@@ -97,6 +97,20 @@ fail_under = 90
|
|
97
97
|
[tool.codespell]
|
98
98
|
skip = './*env*,./output,uv.lock'
|
99
99
|
|
100
|
+
[tool.ruff]
|
101
|
+
exclude = [
|
102
|
+
".github",
|
103
|
+
".vscode",
|
104
|
+
"*env*",
|
105
|
+
".nox",
|
106
|
+
]
|
107
|
+
line-length = 120
|
108
|
+
indent-width = 4
|
109
|
+
target-version = "py39"
|
110
|
+
|
111
|
+
[tool.ruff.lint.isort]
|
112
|
+
known-first-party = ["maite_datasets"]
|
113
|
+
|
100
114
|
[tool.hatch.build.targets.wheel]
|
101
115
|
packages = ["src/maite_datasets"]
|
102
116
|
|
@@ -0,0 +1,14 @@
|
|
1
|
+
"""Module for MAITE compliant Computer Vision datasets."""
|
2
|
+
|
3
|
+
from maite_datasets._builder import to_image_classification_dataset, to_object_detection_dataset
|
4
|
+
from maite_datasets._collate import collate_as_torch, collate_as_numpy, collate_as_list
|
5
|
+
from maite_datasets._validate import validate_dataset
|
6
|
+
|
7
|
+
__all__ = [
|
8
|
+
"collate_as_list",
|
9
|
+
"collate_as_numpy",
|
10
|
+
"collate_as_torch",
|
11
|
+
"to_image_classification_dataset",
|
12
|
+
"to_object_detection_dataset",
|
13
|
+
"validate_dataset",
|
14
|
+
]
|
@@ -76,13 +76,9 @@ class BaseDataset(
|
|
76
76
|
download: bool = False,
|
77
77
|
verbose: bool = False,
|
78
78
|
) -> None:
|
79
|
-
self._root: Path = (
|
80
|
-
root.absolute() if isinstance(root, Path) else Path(root).absolute()
|
81
|
-
)
|
79
|
+
self._root: Path = root.absolute() if isinstance(root, Path) else Path(root).absolute()
|
82
80
|
transforms = transforms if transforms is not None else []
|
83
|
-
self.transforms: Sequence[Transform[_TArray]] = (
|
84
|
-
transforms if isinstance(transforms, Sequence) else [transforms]
|
85
|
-
)
|
81
|
+
self.transforms: Sequence[Transform[_TArray]] = transforms if isinstance(transforms, Sequence) else [transforms]
|
86
82
|
self.image_set = image_set
|
87
83
|
self._verbose = verbose
|
88
84
|
|
@@ -109,11 +105,7 @@ class BaseDataset(
|
|
109
105
|
nt = "\n "
|
110
106
|
title = f"{self.__class__.__name__} Dataset"
|
111
107
|
sep = "-" * len(title)
|
112
|
-
attrs = [
|
113
|
-
f"{k.capitalize()}: {v}"
|
114
|
-
for k, v in self.__dict__.items()
|
115
|
-
if not k.startswith("_")
|
116
|
-
]
|
108
|
+
attrs = [f"{k.capitalize()}: {v}" for k, v in self.__dict__.items() if not k.startswith("_")]
|
117
109
|
return f"{title}\n{sep}{nt}{nt.join(attrs)}"
|
118
110
|
|
119
111
|
@property
|
@@ -149,9 +141,7 @@ class BaseDataset(
|
|
149
141
|
if self._verbose:
|
150
142
|
print("No download needed, loaded data successfully.")
|
151
143
|
except FileNotFoundError:
|
152
|
-
_ensure_exists(
|
153
|
-
*self._resource, self.path, self._root, self._download, self._verbose
|
154
|
-
)
|
144
|
+
_ensure_exists(*self._resource, self.path, self._root, self._download, self._verbose)
|
155
145
|
result = self._load_data_inner()
|
156
146
|
return result
|
157
147
|
|
@@ -212,9 +202,7 @@ class BaseODDataset(
|
|
212
202
|
|
213
203
|
_bboxes_per_size: bool = False
|
214
204
|
|
215
|
-
def __getitem__(
|
216
|
-
self, index: int
|
217
|
-
) -> tuple[_TArray, ObjectDetectionTarget[_TArray], DatumMetadata]:
|
205
|
+
def __getitem__(self, index: int) -> tuple[_TArray, ObjectDetectionTarget[_TArray], DatumMetadata]:
|
218
206
|
"""
|
219
207
|
Args
|
220
208
|
----
|
@@ -235,13 +223,9 @@ class BaseODDataset(
|
|
235
223
|
img = self._transform(img)
|
236
224
|
# Adjust labels if necessary
|
237
225
|
if self._bboxes_per_size and boxes:
|
238
|
-
boxes = boxes * np.array(
|
239
|
-
[[img_size[1], img_size[2], img_size[1], img_size[2]]]
|
240
|
-
)
|
226
|
+
boxes = boxes * np.array([[img_size[1], img_size[2], img_size[1], img_size[2]]])
|
241
227
|
# Create the Object Detection Target
|
242
|
-
target = ObjectDetectionTarget(
|
243
|
-
self._as_array(boxes), self._as_array(labels), self._one_hot_encode(labels)
|
244
|
-
)
|
228
|
+
target = ObjectDetectionTarget(self._as_array(boxes), self._as_array(labels), self._one_hot_encode(labels))
|
245
229
|
|
246
230
|
img_metadata = {key: val[index] for key, val in self._datum_metadata.items()}
|
247
231
|
img_metadata = img_metadata | additional_metadata
|
@@ -249,6 +233,4 @@ class BaseODDataset(
|
|
249
233
|
return img, target, _to_datum_metadata(index, img_metadata)
|
250
234
|
|
251
235
|
@abstractmethod
|
252
|
-
def _read_annotations(
|
253
|
-
self, annotation: _TAnnotation
|
254
|
-
) -> tuple[list[list[float]], list[int], dict[str, Any]]: ...
|
236
|
+
def _read_annotations(self, annotation: _TAnnotation) -> tuple[list[list[float]], list[int], dict[str, Any]]: ...
|
@@ -0,0 +1,275 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
import numpy as np
|
4
|
+
|
5
|
+
__all__ = []
|
6
|
+
|
7
|
+
from typing import (
|
8
|
+
Any,
|
9
|
+
Generic,
|
10
|
+
Iterable,
|
11
|
+
Literal,
|
12
|
+
Sequence,
|
13
|
+
SupportsFloat,
|
14
|
+
SupportsInt,
|
15
|
+
TypeVar,
|
16
|
+
cast,
|
17
|
+
)
|
18
|
+
|
19
|
+
from maite_datasets._protocols import (
|
20
|
+
Array,
|
21
|
+
ArrayLike,
|
22
|
+
DatasetMetadata,
|
23
|
+
ImageClassificationDataset,
|
24
|
+
ObjectDetectionDataset,
|
25
|
+
)
|
26
|
+
|
27
|
+
|
28
|
+
def _ensure_id(index: int, metadata: dict[str, Any]) -> dict[str, Any]:
|
29
|
+
return {"id": index, **metadata} if "id" not in metadata else metadata
|
30
|
+
|
31
|
+
|
32
|
+
def _validate_data(
|
33
|
+
datum_type: Literal["ic", "od"],
|
34
|
+
images: Array | Sequence[Array],
|
35
|
+
labels: Array | Sequence[int] | Sequence[Array] | Sequence[Sequence[int]],
|
36
|
+
bboxes: Array | Sequence[Array] | Sequence[Sequence[Array]] | Sequence[Sequence[Sequence[float]]] | None,
|
37
|
+
metadata: Sequence[dict[str, Any]] | dict[str, Sequence[Any]] | None,
|
38
|
+
) -> None:
|
39
|
+
# Validate inputs
|
40
|
+
dataset_len = len(images)
|
41
|
+
|
42
|
+
if not isinstance(images, (Sequence, Array)) or len(images[0].shape) != 3:
|
43
|
+
raise ValueError("Images must be a sequence or array of 3 dimensional arrays (H, W, C).")
|
44
|
+
if len(labels) != dataset_len:
|
45
|
+
raise ValueError(f"Number of labels ({len(labels)}) does not match number of images ({dataset_len}).")
|
46
|
+
if bboxes is not None and len(bboxes) != dataset_len:
|
47
|
+
raise ValueError(f"Number of bboxes ({len(bboxes)}) does not match number of images ({dataset_len}).")
|
48
|
+
if metadata is not None and (
|
49
|
+
len(metadata) != dataset_len
|
50
|
+
if isinstance(metadata, Sequence)
|
51
|
+
else any(
|
52
|
+
not isinstance(metadatum, Sequence) or len(metadatum) != dataset_len for metadatum in metadata.values()
|
53
|
+
)
|
54
|
+
):
|
55
|
+
raise ValueError(f"Number of metadata ({len(metadata)}) does not match number of images ({dataset_len}).")
|
56
|
+
|
57
|
+
if datum_type == "ic":
|
58
|
+
if not isinstance(labels, (Sequence, Array)) or not isinstance(labels[0], (int, SupportsInt)):
|
59
|
+
raise TypeError("Labels must be a sequence of integers for image classification.")
|
60
|
+
elif datum_type == "od":
|
61
|
+
if (
|
62
|
+
not isinstance(labels, (Sequence, Array))
|
63
|
+
or not isinstance(labels[0], (Sequence, Array))
|
64
|
+
or not isinstance(cast(Sequence[Any], labels[0])[0], (int, SupportsInt))
|
65
|
+
):
|
66
|
+
raise TypeError("Labels must be a sequence of sequences of integers for object detection.")
|
67
|
+
if (
|
68
|
+
bboxes is None
|
69
|
+
or not isinstance(bboxes, (Sequence, Array))
|
70
|
+
or not isinstance(bboxes[0], (Sequence, Array))
|
71
|
+
or not isinstance(bboxes[0][0], (Sequence, Array))
|
72
|
+
or not isinstance(bboxes[0][0][0], (float, SupportsFloat))
|
73
|
+
or not len(bboxes[0][0]) == 4
|
74
|
+
):
|
75
|
+
raise TypeError("Boxes must be a sequence of sequences of (x0, y0, x1, y1) for object detection.")
|
76
|
+
else:
|
77
|
+
raise ValueError(f"Unknown datum type '{datum_type}'. Must be 'ic' or 'od'.")
|
78
|
+
|
79
|
+
|
80
|
+
def _listify_metadata(
|
81
|
+
metadata: Sequence[dict[str, Any]] | dict[str, Sequence[Any]] | None,
|
82
|
+
) -> Sequence[dict[str, Any]] | None:
|
83
|
+
if isinstance(metadata, dict):
|
84
|
+
return [{k: v[i] for k, v in metadata.items()} for i in range(len(next(iter(metadata.values()))))]
|
85
|
+
return metadata
|
86
|
+
|
87
|
+
|
88
|
+
def _find_max(arr: ArrayLike) -> Any:
|
89
|
+
if not isinstance(arr, (bytes, str)) and isinstance(arr, (Iterable, Sequence, Array)):
|
90
|
+
nested = [x for x in [_find_max(x) for x in arr] if x is not None]
|
91
|
+
return max(nested) if len(nested) > 0 else None
|
92
|
+
return arr
|
93
|
+
|
94
|
+
|
95
|
+
_TLabels = TypeVar("_TLabels", Sequence[int], Sequence[Sequence[int]])
|
96
|
+
|
97
|
+
|
98
|
+
class BaseAnnotatedDataset(Generic[_TLabels]):
|
99
|
+
def __init__(
|
100
|
+
self,
|
101
|
+
datum_type: Literal["ic", "od"],
|
102
|
+
images: Array | Sequence[Array],
|
103
|
+
labels: _TLabels,
|
104
|
+
metadata: Sequence[dict[str, Any]] | None,
|
105
|
+
classes: Sequence[str] | None,
|
106
|
+
name: str | None = None,
|
107
|
+
) -> None:
|
108
|
+
self._classes = classes if classes is not None else [str(i) for i in range(_find_max(labels) + 1)]
|
109
|
+
self._index2label = dict(enumerate(self._classes))
|
110
|
+
self._images = images
|
111
|
+
self._labels = labels
|
112
|
+
self._metadata = metadata
|
113
|
+
self._id = name or f"{len(self._images)}_image_{len(self._index2label)}_class_{datum_type}_dataset"
|
114
|
+
|
115
|
+
@property
|
116
|
+
def metadata(self) -> DatasetMetadata:
|
117
|
+
return DatasetMetadata(id=self._id, index2label=self._index2label)
|
118
|
+
|
119
|
+
def __len__(self) -> int:
|
120
|
+
return len(self._images)
|
121
|
+
|
122
|
+
|
123
|
+
class CustomImageClassificationDataset(BaseAnnotatedDataset[Sequence[int]], ImageClassificationDataset):
|
124
|
+
def __init__(
|
125
|
+
self,
|
126
|
+
images: Array | Sequence[Array],
|
127
|
+
labels: Array | Sequence[int],
|
128
|
+
metadata: Sequence[dict[str, Any]] | None,
|
129
|
+
classes: Sequence[str] | None,
|
130
|
+
name: str | None = None,
|
131
|
+
) -> None:
|
132
|
+
super().__init__(
|
133
|
+
"ic",
|
134
|
+
images,
|
135
|
+
np.asarray(labels).tolist() if isinstance(labels, Array) else labels,
|
136
|
+
metadata,
|
137
|
+
classes,
|
138
|
+
)
|
139
|
+
if name is not None:
|
140
|
+
self.__name__ = name
|
141
|
+
self.__class__.__name__ = name
|
142
|
+
self.__class__.__qualname__ = name
|
143
|
+
|
144
|
+
def __getitem__(self, idx: int, /) -> tuple[Array, Array, dict[str, Any]]:
|
145
|
+
one_hot = [0.0] * len(self._index2label)
|
146
|
+
one_hot[self._labels[idx]] = 1.0
|
147
|
+
return (
|
148
|
+
self._images[idx],
|
149
|
+
np.asarray(one_hot),
|
150
|
+
_ensure_id(idx, self._metadata[idx] if self._metadata is not None else {}),
|
151
|
+
)
|
152
|
+
|
153
|
+
|
154
|
+
class CustomObjectDetectionDataset(BaseAnnotatedDataset[Sequence[Sequence[int]]], ObjectDetectionDataset):
|
155
|
+
class ObjectDetectionTarget:
|
156
|
+
def __init__(
|
157
|
+
self,
|
158
|
+
labels: Sequence[int],
|
159
|
+
bboxes: Sequence[Sequence[float]],
|
160
|
+
class_count: int,
|
161
|
+
) -> None:
|
162
|
+
self._labels = labels
|
163
|
+
self._bboxes = bboxes
|
164
|
+
one_hot = [[0.0] * class_count] * len(labels)
|
165
|
+
for i, label in enumerate(labels):
|
166
|
+
one_hot[i][label] = 1.0
|
167
|
+
self._scores = one_hot
|
168
|
+
|
169
|
+
@property
|
170
|
+
def labels(self) -> Sequence[int]:
|
171
|
+
return self._labels
|
172
|
+
|
173
|
+
@property
|
174
|
+
def boxes(self) -> Sequence[Sequence[float]]:
|
175
|
+
return self._bboxes
|
176
|
+
|
177
|
+
@property
|
178
|
+
def scores(self) -> Sequence[Sequence[float]]:
|
179
|
+
return self._scores
|
180
|
+
|
181
|
+
def __init__(
|
182
|
+
self,
|
183
|
+
images: Array | Sequence[Array],
|
184
|
+
labels: Array | Sequence[Array] | Sequence[Sequence[int]],
|
185
|
+
bboxes: Array | Sequence[Array] | Sequence[Sequence[Array]] | Sequence[Sequence[Sequence[float]]],
|
186
|
+
metadata: Sequence[dict[str, Any]] | None,
|
187
|
+
classes: Sequence[str] | None,
|
188
|
+
name: str | None = None,
|
189
|
+
) -> None:
|
190
|
+
super().__init__(
|
191
|
+
"od",
|
192
|
+
images,
|
193
|
+
[np.asarray(label).tolist() if isinstance(label, Array) else label for label in labels],
|
194
|
+
metadata,
|
195
|
+
classes,
|
196
|
+
)
|
197
|
+
if name is not None:
|
198
|
+
self.__name__ = name
|
199
|
+
self.__class__.__name__ = name
|
200
|
+
self.__class__.__qualname__ = name
|
201
|
+
self._bboxes = [
|
202
|
+
[np.asarray(box).tolist() if isinstance(box, Array) else box for box in bbox] for bbox in bboxes
|
203
|
+
]
|
204
|
+
|
205
|
+
@property
|
206
|
+
def metadata(self) -> DatasetMetadata:
|
207
|
+
return DatasetMetadata(id=self._id, index2label=self._index2label)
|
208
|
+
|
209
|
+
def __getitem__(self, idx: int, /) -> tuple[Array, ObjectDetectionTarget, dict[str, Any]]:
|
210
|
+
return (
|
211
|
+
self._images[idx],
|
212
|
+
self.ObjectDetectionTarget(self._labels[idx], self._bboxes[idx], len(self._classes)),
|
213
|
+
_ensure_id(idx, self._metadata[idx] if self._metadata is not None else {}),
|
214
|
+
)
|
215
|
+
|
216
|
+
|
217
|
+
def to_image_classification_dataset(
|
218
|
+
images: Array | Sequence[Array],
|
219
|
+
labels: Array | Sequence[int],
|
220
|
+
metadata: Sequence[dict[str, Any]] | dict[str, Sequence[Any]] | None,
|
221
|
+
classes: Sequence[str] | None,
|
222
|
+
name: str | None = None,
|
223
|
+
) -> ImageClassificationDataset:
|
224
|
+
"""
|
225
|
+
Helper function to create custom ImageClassificationDataset classes.
|
226
|
+
|
227
|
+
Parameters
|
228
|
+
----------
|
229
|
+
images : Array | Sequence[Array]
|
230
|
+
The images to use in the dataset.
|
231
|
+
labels : Array | Sequence[int]
|
232
|
+
The labels to use in the dataset.
|
233
|
+
metadata : Sequence[dict[str, Any]] | dict[str, Sequence[Any]] | None
|
234
|
+
The metadata to use in the dataset.
|
235
|
+
classes : Sequence[str] | None
|
236
|
+
The classes to use in the dataset.
|
237
|
+
|
238
|
+
Returns
|
239
|
+
-------
|
240
|
+
ImageClassificationDataset
|
241
|
+
"""
|
242
|
+
_validate_data("ic", images, labels, None, metadata)
|
243
|
+
return CustomImageClassificationDataset(images, labels, _listify_metadata(metadata), classes, name)
|
244
|
+
|
245
|
+
|
246
|
+
def to_object_detection_dataset(
|
247
|
+
images: Array | Sequence[Array],
|
248
|
+
labels: Array | Sequence[Array] | Sequence[Sequence[int]],
|
249
|
+
bboxes: Array | Sequence[Array] | Sequence[Sequence[Array]] | Sequence[Sequence[Sequence[float]]],
|
250
|
+
metadata: Sequence[dict[str, Any]] | dict[str, Sequence[Any]] | None,
|
251
|
+
classes: Sequence[str] | None,
|
252
|
+
name: str | None = None,
|
253
|
+
) -> ObjectDetectionDataset:
|
254
|
+
"""
|
255
|
+
Helper function to create custom ObjectDetectionDataset classes.
|
256
|
+
|
257
|
+
Parameters
|
258
|
+
----------
|
259
|
+
images : Array | Sequence[Array]
|
260
|
+
The images to use in the dataset.
|
261
|
+
labels : Array | Sequence[Array] | Sequence[Sequence[int]]
|
262
|
+
The labels to use in the dataset.
|
263
|
+
bboxes : Array | Sequence[Array] | Sequence[Sequence[Array]] | Sequence[Sequence[Sequence[float]]]
|
264
|
+
The bounding boxes (x0,y0,x1,y0) to use in the dataset.
|
265
|
+
metadata : Sequence[dict[str, Any]] | dict[str, Sequence[Any]] | None
|
266
|
+
The metadata to use in the dataset.
|
267
|
+
classes : Sequence[str] | None
|
268
|
+
The classes to use in the dataset.
|
269
|
+
|
270
|
+
Returns
|
271
|
+
-------
|
272
|
+
ObjectDetectionDataset
|
273
|
+
"""
|
274
|
+
_validate_data("od", images, labels, bboxes, metadata)
|
275
|
+
return CustomObjectDetectionDataset(images, labels, bboxes, _listify_metadata(metadata), classes, name)
|
@@ -0,0 +1,112 @@
|
|
1
|
+
"""
|
2
|
+
Collate functions used with a PyTorch DataLoader to load data from MAITE compliant datasets.
|
3
|
+
"""
|
4
|
+
|
5
|
+
from __future__ import annotations
|
6
|
+
|
7
|
+
__all__ = []
|
8
|
+
|
9
|
+
from collections.abc import Iterable, Sequence
|
10
|
+
from typing import Any, TypeVar, TYPE_CHECKING
|
11
|
+
|
12
|
+
import numpy as np
|
13
|
+
from numpy.typing import NDArray
|
14
|
+
|
15
|
+
if TYPE_CHECKING:
|
16
|
+
import torch
|
17
|
+
|
18
|
+
from maite_datasets._protocols import ArrayLike
|
19
|
+
|
20
|
+
T_in = TypeVar("T_in")
|
21
|
+
T_tgt = TypeVar("T_tgt")
|
22
|
+
T_md = TypeVar("T_md")
|
23
|
+
|
24
|
+
|
25
|
+
def collate_as_list(
|
26
|
+
batch_data_as_singles: Iterable[tuple[T_in, T_tgt, T_md]],
|
27
|
+
) -> tuple[Sequence[T_in], Sequence[T_tgt], Sequence[T_md]]:
|
28
|
+
"""
|
29
|
+
A collate function that takes a batch of individual data points in the format
|
30
|
+
(input, target, metadata) and returns three lists: the input batch, the target batch,
|
31
|
+
and the metadata batch. This is useful for loading data with torch.utils.data.DataLoader
|
32
|
+
when the target and metadata are not tensors.
|
33
|
+
|
34
|
+
Parameters
|
35
|
+
----------
|
36
|
+
batch_data_as_singles : An iterable of (input, target, metadata) tuples.
|
37
|
+
|
38
|
+
Returns
|
39
|
+
-------
|
40
|
+
tuple[Sequence[T_in], Sequence[T_tgt], Sequence[T_md]]
|
41
|
+
A tuple of three lists: the input batch, the target batch, and the metadata batch.
|
42
|
+
"""
|
43
|
+
input_batch: list[T_in] = []
|
44
|
+
target_batch: list[T_tgt] = []
|
45
|
+
metadata_batch: list[T_md] = []
|
46
|
+
for input_datum, target_datum, metadata_datum in batch_data_as_singles:
|
47
|
+
input_batch.append(input_datum)
|
48
|
+
target_batch.append(target_datum)
|
49
|
+
metadata_batch.append(metadata_datum)
|
50
|
+
|
51
|
+
return input_batch, target_batch, metadata_batch
|
52
|
+
|
53
|
+
|
54
|
+
def collate_as_numpy(
|
55
|
+
batch_data_as_singles: Iterable[tuple[ArrayLike, T_tgt, T_md]],
|
56
|
+
) -> tuple[NDArray[Any], Sequence[T_tgt], Sequence[T_md]]:
|
57
|
+
"""
|
58
|
+
A collate function that takes a batch of individual data points in the format
|
59
|
+
(input, target, metadata) and returns the batched input as a single NumPy array with two
|
60
|
+
lists: the target batch, and the metadata batch. The inputs must be homogeneous arrays.
|
61
|
+
|
62
|
+
Parameters
|
63
|
+
----------
|
64
|
+
batch_data_as_singles : An iterable of (ArrayLike, target, metadata) tuples.
|
65
|
+
|
66
|
+
Returns
|
67
|
+
-------
|
68
|
+
tuple[NDArray[Any], Sequence[T_tgt], Sequence[T_md]]
|
69
|
+
A tuple of a NumPy array and two lists: the input batch, the target batch, and the metadata batch.
|
70
|
+
"""
|
71
|
+
input_batch: list[NDArray[Any]] = []
|
72
|
+
target_batch: list[T_tgt] = []
|
73
|
+
metadata_batch: list[T_md] = []
|
74
|
+
for input_datum, target_datum, metadata_datum in batch_data_as_singles:
|
75
|
+
input_batch.append(np.asarray(input_datum))
|
76
|
+
target_batch.append(target_datum)
|
77
|
+
metadata_batch.append(metadata_datum)
|
78
|
+
|
79
|
+
return np.stack(input_batch) if input_batch else np.array([]), target_batch, metadata_batch
|
80
|
+
|
81
|
+
|
82
|
+
def collate_as_torch(
|
83
|
+
batch_data_as_singles: Iterable[tuple[ArrayLike, T_tgt, T_md]],
|
84
|
+
) -> tuple[torch.Tensor, Sequence[T_tgt], Sequence[T_md]]:
|
85
|
+
"""
|
86
|
+
A collate function that takes a batch of individual data points in the format
|
87
|
+
(input, target, metadata) and returns the batched input as a single torch Tensor with two
|
88
|
+
lists: the target batch, and the metadata batch. The inputs must be homogeneous arrays.
|
89
|
+
|
90
|
+
Parameters
|
91
|
+
----------
|
92
|
+
batch_data_as_singles : An iterable of (ArrayLike, target, metadata) tuples.
|
93
|
+
|
94
|
+
Returns
|
95
|
+
-------
|
96
|
+
tuple[torch.Tensor, Sequence[T_tgt], Sequence[T_md]]
|
97
|
+
A tuple of a torch Tensor and two lists: the input batch, the target batch, and the metadata batch.
|
98
|
+
"""
|
99
|
+
try:
|
100
|
+
import torch
|
101
|
+
except ImportError:
|
102
|
+
raise ImportError("PyTorch is not installed. Please install it to use this function.")
|
103
|
+
|
104
|
+
input_batch: list[torch.Tensor] = []
|
105
|
+
target_batch: list[T_tgt] = []
|
106
|
+
metadata_batch: list[T_md] = []
|
107
|
+
for input_datum, target_datum, metadata_datum in batch_data_as_singles:
|
108
|
+
input_batch.append(torch.as_tensor(input_datum))
|
109
|
+
target_batch.append(target_datum)
|
110
|
+
metadata_batch.append(metadata_datum)
|
111
|
+
|
112
|
+
return torch.stack(input_batch) if input_batch else torch.tensor([]), target_batch, metadata_batch
|
@@ -23,9 +23,7 @@ def _print(text: str, verbose: bool) -> None:
|
|
23
23
|
print(text)
|
24
24
|
|
25
25
|
|
26
|
-
def _validate_file(
|
27
|
-
fpath: Path | str, file_md5: str, md5: bool = False, chunk_size: int = 65535
|
28
|
-
) -> bool:
|
26
|
+
def _validate_file(fpath: Path | str, file_md5: str, md5: bool = False, chunk_size: int = 65535) -> bool:
|
29
27
|
hasher = hashlib.md5(usedforsecurity=False) if md5 else hashlib.sha256()
|
30
28
|
with open(fpath, "rb") as fpath_file:
|
31
29
|
while chunk := fpath_file.read(chunk_size):
|
@@ -33,28 +31,20 @@ def _validate_file(
|
|
33
31
|
return hasher.hexdigest() == file_md5
|
34
32
|
|
35
33
|
|
36
|
-
def _download_dataset(
|
37
|
-
url: str, file_path: Path, timeout: int = 60, verbose: bool = False
|
38
|
-
) -> None:
|
34
|
+
def _download_dataset(url: str, file_path: Path, timeout: int = 60, verbose: bool = False) -> None:
|
39
35
|
"""Download a single resource from its URL to the `data_folder`."""
|
40
36
|
error_msg = "URL fetch failure on {}: {} -- {}"
|
41
37
|
try:
|
42
38
|
response = requests.get(url, stream=True, timeout=timeout)
|
43
39
|
response.raise_for_status()
|
44
40
|
except requests.exceptions.HTTPError as e:
|
45
|
-
raise RuntimeError(
|
46
|
-
f"{error_msg.format(url, e.response.status_code, e.response.reason)}"
|
47
|
-
) from e
|
41
|
+
raise RuntimeError(f"{error_msg.format(url, e.response.status_code, e.response.reason)}") from e
|
48
42
|
except requests.exceptions.RequestException as e:
|
49
43
|
raise ValueError(f"{error_msg.format(url, 'Unknown error', str(e))}") from e
|
50
44
|
|
51
45
|
total_size = int(response.headers.get("content-length", 0))
|
52
46
|
block_size = 8192 # 8 KB
|
53
|
-
progress_bar = (
|
54
|
-
None
|
55
|
-
if tqdm is None
|
56
|
-
else tqdm(total=total_size, unit="iB", unit_scale=True, disable=not verbose)
|
57
|
-
)
|
47
|
+
progress_bar = None if tqdm is None else tqdm(total=total_size, unit="iB", unit_scale=True, disable=not verbose)
|
58
48
|
|
59
49
|
with open(file_path, "wb") as f:
|
60
50
|
for chunk in response.iter_content(block_size):
|
@@ -72,9 +62,7 @@ def _extract_zip_archive(file_path: Path, extract_to: Path) -> None:
|
|
72
62
|
zip_ref.extractall(extract_to) # noqa: S202
|
73
63
|
file_path.unlink()
|
74
64
|
except zipfile.BadZipFile:
|
75
|
-
raise FileNotFoundError(
|
76
|
-
f"{file_path.name} is not a valid zip file, skipping extraction."
|
77
|
-
)
|
65
|
+
raise FileNotFoundError(f"{file_path.name} is not a valid zip file, skipping extraction.")
|
78
66
|
|
79
67
|
|
80
68
|
def _extract_tar_archive(file_path: Path, extract_to: Path) -> None:
|
@@ -84,9 +72,7 @@ def _extract_tar_archive(file_path: Path, extract_to: Path) -> None:
|
|
84
72
|
tar_ref.extractall(extract_to) # noqa: S202
|
85
73
|
file_path.unlink()
|
86
74
|
except tarfile.TarError:
|
87
|
-
raise FileNotFoundError(
|
88
|
-
f"{file_path.name} is not a valid tar file, skipping extraction."
|
89
|
-
)
|
75
|
+
raise FileNotFoundError(f"{file_path.name} is not a valid tar file, skipping extraction.")
|
90
76
|
|
91
77
|
|
92
78
|
def _extract_archive(
|
@@ -135,11 +121,7 @@ def _ensure_exists(
|
|
135
121
|
file_ext = file_path.suffixes[0]
|
136
122
|
compression = True
|
137
123
|
|
138
|
-
check_path = (
|
139
|
-
alternate_path
|
140
|
-
if alternate_path.exists() and not file_path.exists()
|
141
|
-
else file_path
|
142
|
-
)
|
124
|
+
check_path = alternate_path if alternate_path.exists() and not file_path.exists() else file_path
|
143
125
|
|
144
126
|
# Download file if it doesn't exist.
|
145
127
|
if not check_path.exists() and download:
|
@@ -147,9 +129,7 @@ def _ensure_exists(
|
|
147
129
|
_download_dataset(url, check_path, verbose=verbose)
|
148
130
|
|
149
131
|
if not _validate_file(check_path, checksum, md5):
|
150
|
-
raise Exception(
|
151
|
-
"File checksum mismatch. Remove current file and retry download."
|
152
|
-
)
|
132
|
+
raise Exception("File checksum mismatch. Remove current file and retry download.")
|
153
133
|
|
154
134
|
# If the file is a zip, tar or tgz extract it into the designated folder.
|
155
135
|
if file_ext in ARCHIVE_ENDINGS:
|
@@ -164,9 +144,7 @@ def _ensure_exists(
|
|
164
144
|
)
|
165
145
|
else:
|
166
146
|
if not _validate_file(check_path, checksum, md5):
|
167
|
-
raise Exception(
|
168
|
-
"File checksum mismatch. Remove current file and retry download."
|
169
|
-
)
|
147
|
+
raise Exception("File checksum mismatch. Remove current file and retry download.")
|
170
148
|
_print(f"{filename} already exists, skipping download.", verbose)
|
171
149
|
|
172
150
|
if file_ext in ARCHIVE_ENDINGS:
|
@@ -174,9 +174,7 @@ class ObjectDetectionTarget(Protocol):
|
|
174
174
|
def scores(self) -> ArrayLike: ...
|
175
175
|
|
176
176
|
|
177
|
-
ObjectDetectionDatum: TypeAlias = tuple[
|
178
|
-
ArrayLike, ObjectDetectionTarget, Mapping[str, Any]
|
179
|
-
]
|
177
|
+
ObjectDetectionDatum: TypeAlias = tuple[ArrayLike, ObjectDetectionTarget, Mapping[str, Any]]
|
180
178
|
"""
|
181
179
|
Type alias for an object detection datum tuple.
|
182
180
|
|
@@ -37,9 +37,7 @@ class AnnotatedDataset(Dataset[_TDatum]):
|
|
37
37
|
def __len__(self) -> int: ...
|
38
38
|
|
39
39
|
|
40
|
-
class ImageClassificationDataset(
|
41
|
-
AnnotatedDataset[tuple[_TArray, _TArray, DatumMetadata]]
|
42
|
-
): ...
|
40
|
+
class ImageClassificationDataset(AnnotatedDataset[tuple[_TArray, _TArray, DatumMetadata]]): ...
|
43
41
|
|
44
42
|
|
45
43
|
@dataclass
|
@@ -49,6 +47,4 @@ class ObjectDetectionTarget(Generic[_TArray]):
|
|
49
47
|
scores: _TArray
|
50
48
|
|
51
49
|
|
52
|
-
class ObjectDetectionDataset(
|
53
|
-
AnnotatedDataset[tuple[_TArray, ObjectDetectionTarget[_TArray], DatumMetadata]]
|
54
|
-
): ...
|
50
|
+
class ObjectDetectionDataset(AnnotatedDataset[tuple[_TArray, ObjectDetectionTarget[_TArray], DatumMetadata]]): ...
|
@@ -0,0 +1,169 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
__all__ = []
|
4
|
+
|
5
|
+
import numpy as np
|
6
|
+
from collections.abc import Sequence, Sized
|
7
|
+
from typing import Any, Literal
|
8
|
+
|
9
|
+
from maite_datasets._protocols import Array, ObjectDetectionTarget
|
10
|
+
|
11
|
+
|
12
|
+
class ValidationMessages:
|
13
|
+
DATASET_SIZED = "Dataset must be sized."
|
14
|
+
DATASET_INDEXABLE = "Dataset must be indexable."
|
15
|
+
DATASET_NONEMPTY = "Dataset must be non-empty."
|
16
|
+
DATASET_METADATA = "Dataset must have a 'metadata' attribute."
|
17
|
+
DATASET_METADATA_TYPE = "Dataset metadata must be a dictionary."
|
18
|
+
DATASET_METADATA_FORMAT = "Dataset metadata must contain an 'id' key."
|
19
|
+
DATUM_TYPE = "Dataset datum must be a tuple."
|
20
|
+
DATUM_FORMAT = "Dataset datum must contain 3 elements: image, target, metadata."
|
21
|
+
DATUM_IMAGE_TYPE = "Images must be 3-dimensional arrays."
|
22
|
+
DATUM_IMAGE_FORMAT = "Images must be in CHW format."
|
23
|
+
DATUM_TARGET_IC_TYPE = "ImageClassificationDataset targets must be one-dimensional arrays."
|
24
|
+
DATUM_TARGET_IC_FORMAT = "ImageClassificationDataset targets must be one-hot encoded or pseudo-probabilities."
|
25
|
+
DATUM_TARGET_OD_TYPE = "ObjectDetectionDataset targets must be have 'boxes', 'labels' and 'scores'."
|
26
|
+
DATUM_TARGET_OD_LABELS_TYPE = "ObjectDetectionTarget labels must be one-dimensional (N,) arrays."
|
27
|
+
DATUM_TARGET_OD_BOXES_TYPE = "ObjectDetectionTarget boxes must be two-dimensional (N, 4) arrays in xxyy format."
|
28
|
+
DATUM_TARGET_OD_SCORES_TYPE = "ObjectDetectionTarget scores must be one (N,) or two-dimensional (N, M) arrays."
|
29
|
+
DATUM_TARGET_TYPE = "Target is not a valid ImageClassification or ObjectDetection target type."
|
30
|
+
DATUM_METADATA_TYPE = "Datum metadata must be a dictionary."
|
31
|
+
DATUM_METADATA_FORMAT = "Datum metadata must contain an 'id' key."
|
32
|
+
|
33
|
+
|
34
|
+
def _validate_dataset_type(dataset: Any) -> list[str]:
|
35
|
+
issues = []
|
36
|
+
is_sized = isinstance(dataset, Sized)
|
37
|
+
is_indexable = hasattr(dataset, "__getitem__")
|
38
|
+
if not is_sized:
|
39
|
+
issues.append(ValidationMessages.DATASET_SIZED)
|
40
|
+
if not is_indexable:
|
41
|
+
issues.append(ValidationMessages.DATASET_INDEXABLE)
|
42
|
+
if is_sized and len(dataset) == 0:
|
43
|
+
issues.append(ValidationMessages.DATASET_NONEMPTY)
|
44
|
+
return issues
|
45
|
+
|
46
|
+
|
47
|
+
def _validate_dataset_metadata(dataset: Any) -> list[str]:
|
48
|
+
issues = []
|
49
|
+
if not hasattr(dataset, "metadata"):
|
50
|
+
issues.append(ValidationMessages.DATASET_METADATA)
|
51
|
+
metadata = getattr(dataset, "metadata", None)
|
52
|
+
if not isinstance(metadata, dict):
|
53
|
+
issues.append(ValidationMessages.DATASET_METADATA_TYPE)
|
54
|
+
if not isinstance(metadata, dict) or "id" not in metadata:
|
55
|
+
issues.append(ValidationMessages.DATASET_METADATA_FORMAT)
|
56
|
+
return issues
|
57
|
+
|
58
|
+
|
59
|
+
def _validate_datum_type(datum: Any) -> list[str]:
|
60
|
+
issues = []
|
61
|
+
if not isinstance(datum, tuple):
|
62
|
+
issues.append(ValidationMessages.DATUM_TYPE)
|
63
|
+
if datum is None or isinstance(datum, Sized) and len(datum) != 3:
|
64
|
+
issues.append(ValidationMessages.DATUM_FORMAT)
|
65
|
+
return issues
|
66
|
+
|
67
|
+
|
68
|
+
def _validate_datum_image(image: Any) -> list[str]:
|
69
|
+
issues = []
|
70
|
+
if not isinstance(image, Array) or len(image.shape) != 3:
|
71
|
+
issues.append(ValidationMessages.DATUM_IMAGE_TYPE)
|
72
|
+
if (
|
73
|
+
not isinstance(image, Array)
|
74
|
+
or len(image.shape) == 3
|
75
|
+
and (image.shape[0] > image.shape[1] or image.shape[0] > image.shape[2])
|
76
|
+
):
|
77
|
+
issues.append(ValidationMessages.DATUM_IMAGE_FORMAT)
|
78
|
+
return issues
|
79
|
+
|
80
|
+
|
81
|
+
def _validate_datum_target_ic(target: Any) -> list[str]:
|
82
|
+
issues = []
|
83
|
+
if not isinstance(target, Array) or len(target.shape) != 1:
|
84
|
+
issues.append(ValidationMessages.DATUM_TARGET_IC_TYPE)
|
85
|
+
if target is None or sum(target) > 1 + 1e-6 or sum(target) < 1 - 1e-6:
|
86
|
+
issues.append(ValidationMessages.DATUM_TARGET_IC_FORMAT)
|
87
|
+
return issues
|
88
|
+
|
89
|
+
|
90
|
+
def _validate_datum_target_od(target: Any) -> list[str]:
|
91
|
+
issues = []
|
92
|
+
if not isinstance(target, ObjectDetectionTarget):
|
93
|
+
issues.append(ValidationMessages.DATUM_TARGET_OD_TYPE)
|
94
|
+
od_target: ObjectDetectionTarget | None = target if isinstance(target, ObjectDetectionTarget) else None
|
95
|
+
if od_target is None or len(np.asarray(od_target.labels).shape) != 1:
|
96
|
+
issues.append(ValidationMessages.DATUM_TARGET_OD_LABELS_TYPE)
|
97
|
+
if (
|
98
|
+
od_target is None
|
99
|
+
or len(np.asarray(od_target.boxes).shape) != 2
|
100
|
+
or (len(np.asarray(od_target.boxes).shape) == 2 and np.asarray(od_target.boxes).shape[1] != 4)
|
101
|
+
):
|
102
|
+
issues.append(ValidationMessages.DATUM_TARGET_OD_BOXES_TYPE)
|
103
|
+
if od_target is None or len(np.asarray(od_target.scores).shape) not in (1, 2):
|
104
|
+
issues.append(ValidationMessages.DATUM_TARGET_OD_SCORES_TYPE)
|
105
|
+
return issues
|
106
|
+
|
107
|
+
|
108
|
+
def _detect_target_type(target: Any) -> Literal["ic", "od", "auto"]:
|
109
|
+
if isinstance(target, Array):
|
110
|
+
return "ic"
|
111
|
+
if isinstance(target, ObjectDetectionTarget):
|
112
|
+
return "od"
|
113
|
+
return "auto"
|
114
|
+
|
115
|
+
|
116
|
+
def _validate_datum_target(target: Any, target_type: Literal["ic", "od", "auto"]) -> list[str]:
|
117
|
+
issues = []
|
118
|
+
target_type = _detect_target_type(target) if target_type == "auto" else target_type
|
119
|
+
if target_type == "ic":
|
120
|
+
issues.extend(_validate_datum_target_ic(target))
|
121
|
+
elif target_type == "od":
|
122
|
+
issues.extend(_validate_datum_target_od(target))
|
123
|
+
else:
|
124
|
+
issues.append(ValidationMessages.DATUM_TARGET_TYPE)
|
125
|
+
return issues
|
126
|
+
|
127
|
+
|
128
|
+
def _validate_datum_metadata(metadata: Any) -> list[str]:
|
129
|
+
issues = []
|
130
|
+
if metadata is None or not isinstance(metadata, dict):
|
131
|
+
issues.append(ValidationMessages.DATUM_METADATA_TYPE)
|
132
|
+
if metadata is None or isinstance(metadata, dict) and "id" not in metadata:
|
133
|
+
issues.append(ValidationMessages.DATUM_METADATA_FORMAT)
|
134
|
+
return issues
|
135
|
+
|
136
|
+
|
137
|
+
def validate_dataset(dataset: Any, dataset_type: Literal["ic", "od", "auto"] = "auto") -> None:
|
138
|
+
"""
|
139
|
+
Validate a dataset for compliance with MAITE protocol.
|
140
|
+
|
141
|
+
Parameters
|
142
|
+
----------
|
143
|
+
dataset: Any
|
144
|
+
Dataset to validate.
|
145
|
+
dataset_type: "ic", "od", or "auto", default "auto"
|
146
|
+
Dataset type, if known.
|
147
|
+
|
148
|
+
Raises
|
149
|
+
------
|
150
|
+
ValueError
|
151
|
+
Raises exception if dataset is invalid with a list of validation issues.
|
152
|
+
"""
|
153
|
+
issues = []
|
154
|
+
issues.extend(_validate_dataset_type(dataset))
|
155
|
+
datum = None if issues else dataset[0] # type: ignore
|
156
|
+
issues.extend(_validate_dataset_metadata(dataset))
|
157
|
+
issues.extend(_validate_datum_type(datum))
|
158
|
+
|
159
|
+
is_seq = isinstance(datum, Sequence)
|
160
|
+
datum_len = len(datum) if is_seq else 0
|
161
|
+
image = datum[0] if is_seq and datum_len > 0 else None
|
162
|
+
target = datum[1] if is_seq and datum_len > 1 else None
|
163
|
+
metadata = datum[2] if is_seq and datum_len > 2 else None
|
164
|
+
issues.extend(_validate_datum_image(image))
|
165
|
+
issues.extend(_validate_datum_target(target, dataset_type))
|
166
|
+
issues.extend(_validate_datum_metadata(metadata))
|
167
|
+
|
168
|
+
if issues:
|
169
|
+
raise ValueError("Dataset validation issues found:\n - " + "\n - ".join(issues))
|
{maite_datasets-0.0.1 → maite_datasets-0.0.3}/src/maite_datasets/image_classification/_cifar10.py
RENAMED
@@ -24,9 +24,7 @@ CIFARClassStringMap = Literal[
|
|
24
24
|
"ship",
|
25
25
|
"truck",
|
26
26
|
]
|
27
|
-
TCIFARClassMap = TypeVar(
|
28
|
-
"TCIFARClassMap", CIFARClassStringMap, int, list[CIFARClassStringMap], list[int]
|
29
|
-
)
|
27
|
+
TCIFARClassMap = TypeVar("TCIFARClassMap", CIFARClassStringMap, int, list[CIFARClassStringMap], list[int])
|
30
28
|
|
31
29
|
|
32
30
|
class CIFAR10(BaseICDataset[NDArray[np.number[Any]]], BaseDatasetNumpyMixin):
|
@@ -91,9 +89,7 @@ class CIFAR10(BaseICDataset[NDArray[np.number[Any]]], BaseDatasetNumpyMixin):
|
|
91
89
|
self,
|
92
90
|
root: str | Path,
|
93
91
|
image_set: Literal["train", "test", "base"] = "train",
|
94
|
-
transforms: Transform[NDArray[np.number[Any]]]
|
95
|
-
| Sequence[Transform[NDArray[np.number[Any]]]]
|
96
|
-
| None = None,
|
92
|
+
transforms: Transform[NDArray[np.number[Any]]] | Sequence[Transform[NDArray[np.number[Any]]]] | None = None,
|
97
93
|
download: bool = False,
|
98
94
|
verbose: bool = False,
|
99
95
|
) -> None:
|
@@ -105,9 +101,7 @@ class CIFAR10(BaseICDataset[NDArray[np.number[Any]]], BaseDatasetNumpyMixin):
|
|
105
101
|
verbose,
|
106
102
|
)
|
107
103
|
|
108
|
-
def _load_bin_data(
|
109
|
-
self, data_folder: list[Path]
|
110
|
-
) -> tuple[list[str], list[int], dict[str, Any]]:
|
104
|
+
def _load_bin_data(self, data_folder: list[Path]) -> tuple[list[str], list[int], dict[str, Any]]:
|
111
105
|
batch_nums = np.zeros(60000, dtype=np.uint8)
|
112
106
|
all_labels = np.zeros(60000, dtype=np.uint8)
|
113
107
|
all_images = np.zeros((60000, 3, 32, 32), dtype=np.uint8)
|
@@ -115,9 +109,7 @@ class CIFAR10(BaseICDataset[NDArray[np.number[Any]]], BaseDatasetNumpyMixin):
|
|
115
109
|
for batch_file in data_folder:
|
116
110
|
# Get batch parameters
|
117
111
|
batch_type = "test" if "test" in batch_file.stem else "train"
|
118
|
-
batch_num = (
|
119
|
-
5 if batch_type == "test" else int(batch_file.stem.split("_")[-1]) - 1
|
120
|
-
)
|
112
|
+
batch_num = 5 if batch_type == "test" else int(batch_file.stem.split("_")[-1]) - 1
|
121
113
|
|
122
114
|
# Load data
|
123
115
|
batch_images, batch_labels = self._unpack_batch_files(batch_file)
|
@@ -193,9 +185,7 @@ class CIFAR10(BaseICDataset[NDArray[np.number[Any]]], BaseDatasetNumpyMixin):
|
|
193
185
|
{"batch_num": batch_nums.tolist()},
|
194
186
|
)
|
195
187
|
|
196
|
-
def _unpack_batch_files(
|
197
|
-
self, file_path: Path
|
198
|
-
) -> tuple[NDArray[np.uint8], NDArray[np.uint8]]:
|
188
|
+
def _unpack_batch_files(self, file_path: Path) -> tuple[NDArray[np.uint8], NDArray[np.uint8]]:
|
199
189
|
# Load pickle data with latin1 encoding
|
200
190
|
with file_path.open("rb") as f:
|
201
191
|
buffer = np.frombuffer(f.read(), dtype=np.uint8)
|
{maite_datasets-0.0.1 → maite_datasets-0.0.3}/src/maite_datasets/image_classification/_mnist.py
RENAMED
@@ -12,12 +12,8 @@ from maite_datasets._base import BaseICDataset, DataLocation
|
|
12
12
|
from maite_datasets._mixin._numpy import BaseDatasetNumpyMixin
|
13
13
|
from maite_datasets._protocols import Transform
|
14
14
|
|
15
|
-
MNISTClassStringMap = Literal[
|
16
|
-
|
17
|
-
]
|
18
|
-
TMNISTClassMap = TypeVar(
|
19
|
-
"TMNISTClassMap", MNISTClassStringMap, int, list[MNISTClassStringMap], list[int]
|
20
|
-
)
|
15
|
+
MNISTClassStringMap = Literal["zero", "one", "two", "three", "four", "five", "six", "seven", "eight", "nine"]
|
16
|
+
TMNISTClassMap = TypeVar("TMNISTClassMap", MNISTClassStringMap, int, list[MNISTClassStringMap], list[int])
|
21
17
|
CorruptionStringMap = Literal[
|
22
18
|
"identity",
|
23
19
|
"shot_noise",
|
@@ -122,9 +118,7 @@ class MNIST(BaseICDataset[NDArray[np.number[Any]]], BaseDatasetNumpyMixin):
|
|
122
118
|
root: str | Path,
|
123
119
|
image_set: Literal["train", "test", "base"] = "train",
|
124
120
|
corruption: CorruptionStringMap | None = None,
|
125
|
-
transforms: Transform[NDArray[np.number[Any]]]
|
126
|
-
| Sequence[Transform[NDArray[np.number[Any]]]]
|
127
|
-
| None = None,
|
121
|
+
transforms: Transform[NDArray[np.number[Any]]] | Sequence[Transform[NDArray[np.number[Any]]]] | None = None,
|
128
122
|
download: bool = False,
|
129
123
|
verbose: bool = False,
|
130
124
|
) -> None:
|
@@ -182,18 +176,12 @@ class MNIST(BaseICDataset[NDArray[np.number[Any]]], BaseDatasetNumpyMixin):
|
|
182
176
|
|
183
177
|
return data, labels
|
184
178
|
|
185
|
-
def _grab_data(
|
186
|
-
self, path: Path
|
187
|
-
) -> tuple[NDArray[np.number[Any]], NDArray[np.uintp]]:
|
179
|
+
def _grab_data(self, path: Path) -> tuple[NDArray[np.number[Any]], NDArray[np.uintp]]:
|
188
180
|
"""Function to load in the data numpy array"""
|
189
181
|
with np.load(path, allow_pickle=True) as data_array:
|
190
182
|
if self.image_set == "base":
|
191
|
-
data = np.concatenate(
|
192
|
-
|
193
|
-
)
|
194
|
-
labels = np.concatenate(
|
195
|
-
[data_array["y_train"], data_array["y_test"]], axis=0
|
196
|
-
).astype(np.uintp)
|
183
|
+
data = np.concatenate([data_array["x_train"], data_array["x_test"]], axis=0)
|
184
|
+
labels = np.concatenate([data_array["y_train"], data_array["y_test"]], axis=0).astype(np.uintp)
|
197
185
|
else:
|
198
186
|
data, labels = (
|
199
187
|
data_array[f"x_{self.image_set}"],
|
{maite_datasets-0.0.1 → maite_datasets-0.0.3}/src/maite_datasets/image_classification/_ships.py
RENAMED
@@ -76,9 +76,7 @@ class Ships(BaseICDataset[NDArray[np.number[Any]]], BaseDatasetNumpyMixin):
|
|
76
76
|
def __init__(
|
77
77
|
self,
|
78
78
|
root: str | Path,
|
79
|
-
transforms: Transform[NDArray[np.number[Any]]]
|
80
|
-
| Sequence[Transform[NDArray[np.number[Any]]]]
|
81
|
-
| None = None,
|
79
|
+
transforms: Transform[NDArray[np.number[Any]]] | Sequence[Transform[NDArray[np.number[Any]]]] | None = None,
|
82
80
|
download: bool = False,
|
83
81
|
verbose: bool = False,
|
84
82
|
) -> None:
|
{maite_datasets-0.0.1 → maite_datasets-0.0.3}/src/maite_datasets/object_detection/_antiuav.py
RENAMED
@@ -14,9 +14,7 @@ from maite_datasets._mixin._numpy import BaseDatasetNumpyMixin
|
|
14
14
|
from maite_datasets._protocols import Transform
|
15
15
|
|
16
16
|
|
17
|
-
class AntiUAVDetection(
|
18
|
-
BaseODDataset[NDArray[np.number[Any]], list[str], str], BaseDatasetNumpyMixin
|
19
|
-
):
|
17
|
+
class AntiUAVDetection(BaseODDataset[NDArray[np.number[Any]], list[str], str], BaseDatasetNumpyMixin):
|
20
18
|
"""
|
21
19
|
A UAV detection dataset focused on detecting UAVs in natural images against large variation in backgrounds.
|
22
20
|
|
@@ -103,9 +101,7 @@ class AntiUAVDetection(
|
|
103
101
|
self,
|
104
102
|
root: str | Path,
|
105
103
|
image_set: Literal["train", "val", "test", "base"] = "train",
|
106
|
-
transforms: Transform[NDArray[np.number[Any]]]
|
107
|
-
| Sequence[Transform[NDArray[np.number[Any]]]]
|
108
|
-
| None = None,
|
104
|
+
transforms: Transform[NDArray[np.number[Any]]] | Sequence[Transform[NDArray[np.number[Any]]]] | None = None,
|
109
105
|
download: bool = False,
|
110
106
|
verbose: bool = False,
|
111
107
|
) -> None:
|
@@ -128,9 +124,7 @@ class AntiUAVDetection(
|
|
128
124
|
|
129
125
|
for resource in self._resources:
|
130
126
|
self._resource = resource
|
131
|
-
resource_filepaths, resource_targets, resource_metadata = (
|
132
|
-
super()._load_data()
|
133
|
-
)
|
127
|
+
resource_filepaths, resource_targets, resource_metadata = super()._load_data()
|
134
128
|
filepaths.extend(resource_filepaths)
|
135
129
|
targets.extend(resource_targets)
|
136
130
|
metadata_list.append(resource_metadata)
|
@@ -148,9 +142,7 @@ class AntiUAVDetection(
|
|
148
142
|
for resource in self._resources:
|
149
143
|
if self.image_set in resource.filename:
|
150
144
|
self._resource = resource
|
151
|
-
resource_filepaths, resource_targets, resource_metadata = (
|
152
|
-
super()._load_data()
|
153
|
-
)
|
145
|
+
resource_filepaths, resource_targets, resource_metadata = super()._load_data()
|
154
146
|
filepaths.extend(resource_filepaths)
|
155
147
|
targets.extend(resource_targets)
|
156
148
|
datum_metadata.update(resource_metadata)
|
@@ -164,17 +156,13 @@ class AntiUAVDetection(
|
|
164
156
|
if not data_folder:
|
165
157
|
raise FileNotFoundError
|
166
158
|
|
167
|
-
file_data = {
|
168
|
-
"image_id": [f"{resource_name}_{entry.name}" for entry in data_folder]
|
169
|
-
}
|
159
|
+
file_data = {"image_id": [f"{resource_name}_{entry.name}" for entry in data_folder]}
|
170
160
|
data = [str(entry) for entry in data_folder]
|
171
161
|
annotations = sorted(str(entry) for entry in (base_dir / "xml").glob("*.xml"))
|
172
162
|
|
173
163
|
return data, annotations, file_data
|
174
164
|
|
175
|
-
def _read_annotations(
|
176
|
-
self, annotation: str
|
177
|
-
) -> tuple[list[list[float]], list[int], dict[str, Any]]:
|
165
|
+
def _read_annotations(self, annotation: str) -> tuple[list[list[float]], list[int], dict[str, Any]]:
|
178
166
|
"""Function for extracting the info for the label and boxes"""
|
179
167
|
boxes: list[list[float]] = []
|
180
168
|
labels = []
|
@@ -13,9 +13,7 @@ from maite_datasets._mixin._numpy import BaseDatasetNumpyMixin
|
|
13
13
|
from maite_datasets._protocols import Transform
|
14
14
|
|
15
15
|
|
16
|
-
class MILCO(
|
17
|
-
BaseODDataset[NDArray[np.number[Any]], list[str], str], BaseDatasetNumpyMixin
|
18
|
-
):
|
16
|
+
class MILCO(BaseODDataset[NDArray[np.number[Any]], list[str], str], BaseDatasetNumpyMixin):
|
19
17
|
"""
|
20
18
|
A side-scan sonar dataset focused on mine-like object detection.
|
21
19
|
|
@@ -118,9 +116,7 @@ class MILCO(
|
|
118
116
|
self,
|
119
117
|
root: str | Path,
|
120
118
|
image_set: Literal["train", "operational", "base"] = "train",
|
121
|
-
transforms: Transform[NDArray[np.number[Any]]]
|
122
|
-
| Sequence[Transform[NDArray[np.number[Any]]]]
|
123
|
-
| None = None,
|
119
|
+
transforms: Transform[NDArray[np.number[Any]]] | Sequence[Transform[NDArray[np.number[Any]]]] | None = None,
|
124
120
|
download: bool = False,
|
125
121
|
verbose: bool = False,
|
126
122
|
) -> None:
|
@@ -180,9 +176,7 @@ class MILCO(
|
|
180
176
|
|
181
177
|
return data, annotations, file_data
|
182
178
|
|
183
|
-
def _read_annotations(
|
184
|
-
self, annotation: str
|
185
|
-
) -> tuple[list[list[float]], list[int], dict[str, Any]]:
|
179
|
+
def _read_annotations(self, annotation: str) -> tuple[list[list[float]], list[int], dict[str, Any]]:
|
186
180
|
"""Function for extracting the info out of the text files"""
|
187
181
|
labels: list[int] = []
|
188
182
|
boxes: list[list[float]] = []
|
{maite_datasets-0.0.1 → maite_datasets-0.0.3}/src/maite_datasets/object_detection/_seadrone.py
RENAMED
@@ -313,9 +313,7 @@ class SeaDrone(
|
|
313
313
|
self,
|
314
314
|
root: str | Path,
|
315
315
|
image_set: Literal["train", "val", "test", "base"] = "train",
|
316
|
-
transforms: Transform[NDArray[np.number[Any]]]
|
317
|
-
| Sequence[Transform[NDArray[np.number[Any]]]]
|
318
|
-
| None = None,
|
316
|
+
transforms: Transform[NDArray[np.number[Any]]] | Sequence[Transform[NDArray[np.number[Any]]]] | None = None,
|
319
317
|
download: bool = False,
|
320
318
|
verbose: bool = False,
|
321
319
|
) -> None:
|
@@ -365,9 +363,7 @@ class SeaDrone(
|
|
365
363
|
|
366
364
|
def _load_data(
|
367
365
|
self,
|
368
|
-
) -> tuple[
|
369
|
-
list[str], list[tuple[list[int], list[list[float]]]], dict[str, list[Any]]
|
370
|
-
]:
|
366
|
+
) -> tuple[list[str], list[tuple[list[int], list[list[float]]]], dict[str, list[Any]]]:
|
371
367
|
image_sets: dict[str, list[int]] = {
|
372
368
|
"train": list(range(20)),
|
373
369
|
"val": list(range(20, 24)),
|
@@ -390,9 +386,7 @@ class SeaDrone(
|
|
390
386
|
|
391
387
|
return filepaths, list(targets), datum_metadata
|
392
388
|
|
393
|
-
def _load_images(
|
394
|
-
self, data_folder: Path, file_data: dict[int, dict[str, Any]]
|
395
|
-
) -> dict[int, dict[str, Any]]:
|
389
|
+
def _load_images(self, data_folder: Path, file_data: dict[int, dict[str, Any]]) -> dict[int, dict[str, Any]]:
|
396
390
|
for entry in data_folder.iterdir():
|
397
391
|
if entry.is_file() and entry.suffix == ".jpg":
|
398
392
|
if int(entry.stem) not in file_data:
|
@@ -441,14 +435,10 @@ class SeaDrone(
|
|
441
435
|
current_file["storage"] = source.get("folder_name", "")
|
442
436
|
|
443
437
|
# Handle non-standard file metadata
|
444
|
-
current_file["date_time"] = (
|
445
|
-
file_meta.get("date_time") or meta.get("date_time") or ""
|
446
|
-
)
|
438
|
+
current_file["date_time"] = file_meta.get("date_time") or meta.get("date_time") or ""
|
447
439
|
if "frame" in file_meta:
|
448
440
|
frame = file_meta["frame"][:-4]
|
449
|
-
current_file["frame"] = (
|
450
|
-
int(frame.split("_")[-1]) if "IMG_" in frame else int(frame[3:])
|
451
|
-
)
|
441
|
+
current_file["frame"] = int(frame.split("_")[-1]) if "IMG_" in frame else int(frame[3:])
|
452
442
|
elif "frame_no" in source:
|
453
443
|
current_file["frame"] = source["frame_no"]
|
454
444
|
else:
|
@@ -456,9 +446,7 @@ class SeaDrone(
|
|
456
446
|
|
457
447
|
# Grab additional metadata if available
|
458
448
|
for output_key, (possible_keys, default) in mappings.items():
|
459
|
-
current_file[output_key] = next(
|
460
|
-
(meta.get(key) for key in possible_keys if key in meta), default
|
461
|
-
)
|
449
|
+
current_file[output_key] = next((meta.get(key) for key in possible_keys if key in meta), default)
|
462
450
|
|
463
451
|
# Retrieve the label and bounding box
|
464
452
|
for annotation in result["annotations"]:
|
@@ -482,9 +470,7 @@ class SeaDrone(
|
|
482
470
|
|
483
471
|
return file_data
|
484
472
|
|
485
|
-
def _restructure_file_data(
|
486
|
-
self, file_data: dict[int, dict[str, Any]]
|
487
|
-
) -> dict[str, list[Any]]:
|
473
|
+
def _restructure_file_data(self, file_data: dict[int, dict[str, Any]]) -> dict[str, list[Any]]:
|
488
474
|
"""Restructure file_data from dictionary of dictionaries to a dictionary of lists"""
|
489
475
|
# Get the keys from the dictionary
|
490
476
|
all_keys = set()
|
@@ -501,9 +487,7 @@ class SeaDrone(
|
|
501
487
|
# Create the lists
|
502
488
|
for file_id, file_dict in file_data.items():
|
503
489
|
restructured_data["image_id"].append(file_id)
|
504
|
-
restructured_data["label_box"].append(
|
505
|
-
(file_dict.get("label", []), file_dict.get("box", []))
|
506
|
-
)
|
490
|
+
restructured_data["label_box"].append((file_dict.get("label", []), file_dict.get("box", [])))
|
507
491
|
for key in all_keys:
|
508
492
|
restructured_data[key].append(file_dict.get(key, None))
|
509
493
|
|
@@ -528,12 +512,8 @@ class SeaDrone(
|
|
528
512
|
json_name = folder
|
529
513
|
if json_name == "test":
|
530
514
|
json_name += "_nogt"
|
531
|
-
annotation_file =
|
532
|
-
|
533
|
-
)
|
534
|
-
file_data = self._create_per_image_annotations(
|
535
|
-
annotation_file, file_data
|
536
|
-
)
|
515
|
+
annotation_file = self.path / "annotations" / f"instances_{json_name}.json"
|
516
|
+
file_data = self._create_per_image_annotations(annotation_file, file_data)
|
537
517
|
|
538
518
|
meta_data = self._restructure_file_data(file_data)
|
539
519
|
data = meta_data.pop("data_path")
|
@@ -45,9 +45,7 @@ VOCClassStringMap = Literal[
|
|
45
45
|
"train",
|
46
46
|
"tvmonitor",
|
47
47
|
]
|
48
|
-
TVOCClassMap = TypeVar(
|
49
|
-
"TVOCClassMap", VOCClassStringMap, int, list[VOCClassStringMap], list[int]
|
50
|
-
)
|
48
|
+
TVOCClassMap = TypeVar("TVOCClassMap", VOCClassStringMap, int, list[VOCClassStringMap], list[int])
|
51
49
|
|
52
50
|
|
53
51
|
class BaseVOCDataset(BaseDataset[_TArray, _TTarget, list[str], str]):
|
@@ -170,13 +168,9 @@ class BaseVOCDataset(BaseDataset[_TArray, _TTarget, list[str], str]):
|
|
170
168
|
base if base.stem == f"VOC{self.year}" else None,
|
171
169
|
base / f"VOC{self.year}" if base.stem == "VOCdevkit" else None,
|
172
170
|
base / "VOCdevkit" / f"VOC{self.year}",
|
173
|
-
base / "TrainVal" / "VOCdevkit" / f"VOC{self.year}"
|
174
|
-
if self.year == "2011"
|
175
|
-
else None,
|
171
|
+
base / "TrainVal" / "VOCdevkit" / f"VOC{self.year}" if self.year == "2011" else None,
|
176
172
|
dataset_dir / "VOCdevkit" / f"VOC{self.year}",
|
177
|
-
dataset_dir / "TrainVal" / "VOCdevkit" / f"VOC{self.year}"
|
178
|
-
if self.year == "2011"
|
179
|
-
else None,
|
173
|
+
dataset_dir / "TrainVal" / "VOCdevkit" / f"VOC{self.year}" if self.year == "2011" else None,
|
180
174
|
]
|
181
175
|
|
182
176
|
# Filter out None values and check each path
|
@@ -269,9 +263,7 @@ class BaseVOCDataset(BaseDataset[_TArray, _TTarget, list[str], str]):
|
|
269
263
|
|
270
264
|
for img_set in ["test", "base"]:
|
271
265
|
self.image_set = img_set
|
272
|
-
resource_filepaths, resource_targets, resource_metadata = (
|
273
|
-
self._load_data_inner()
|
274
|
-
)
|
266
|
+
resource_filepaths, resource_targets, resource_metadata = self._load_data_inner()
|
275
267
|
filepaths.extend(resource_filepaths)
|
276
268
|
targets.extend(resource_targets)
|
277
269
|
metadata_list.append(resource_metadata)
|
@@ -288,14 +280,10 @@ class BaseVOCDataset(BaseDataset[_TArray, _TTarget, list[str], str]):
|
|
288
280
|
self._resource = self._resources[resource_idx[1]]
|
289
281
|
|
290
282
|
if train_exists and not test_exists:
|
291
|
-
_ensure_exists(
|
292
|
-
*self._resource, tmp_path, self._root, self._download, self._verbose
|
293
|
-
)
|
283
|
+
_ensure_exists(*self._resource, tmp_path, self._root, self._download, self._verbose)
|
294
284
|
self._merge_voc_directories(tmp_path)
|
295
285
|
|
296
|
-
resource_filepaths, resource_targets, resource_metadata = (
|
297
|
-
self._load_try_and_update()
|
298
|
-
)
|
286
|
+
resource_filepaths, resource_targets, resource_metadata = self._load_try_and_update()
|
299
287
|
filepaths.extend(resource_filepaths)
|
300
288
|
targets.extend(resource_targets)
|
301
289
|
datum_metadata.update(resource_metadata)
|
@@ -341,9 +329,7 @@ class BaseVOCDataset(BaseDataset[_TArray, _TTarget, list[str], str]):
|
|
341
329
|
if self._verbose:
|
342
330
|
print("No download needed, loaded data successfully.")
|
343
331
|
except FileNotFoundError:
|
344
|
-
_ensure_exists(
|
345
|
-
*self._resource, self.path, self._root, self._download, self._verbose
|
346
|
-
)
|
332
|
+
_ensure_exists(*self._resource, self.path, self._root, self._download, self._verbose)
|
347
333
|
self._update_path()
|
348
334
|
result = self._load_data_inner()
|
349
335
|
return result
|
@@ -364,9 +350,7 @@ class BaseVOCDataset(BaseDataset[_TArray, _TTarget, list[str], str]):
|
|
364
350
|
def _get_image_sets(self) -> dict[str, list[str]]:
|
365
351
|
"""Function to create the list of images in each image set"""
|
366
352
|
image_folder = self.path / "JPEGImages"
|
367
|
-
image_set_list =
|
368
|
-
["train", "val", "trainval"] if self.image_set != "test" else ["test"]
|
369
|
-
)
|
353
|
+
image_set_list = ["train", "val", "trainval"] if self.image_set != "test" else ["test"]
|
370
354
|
image_sets = {}
|
371
355
|
for image_set in image_set_list:
|
372
356
|
text_file = self.path / "ImageSets" / "Main" / (image_set + ".txt")
|
@@ -408,9 +392,7 @@ class BaseVOCDataset(BaseDataset[_TArray, _TTarget, list[str], str]):
|
|
408
392
|
|
409
393
|
return data, annotations, file_meta
|
410
394
|
|
411
|
-
def _read_annotations(
|
412
|
-
self, annotation: str
|
413
|
-
) -> tuple[list[list[float]], list[int], dict[str, Any]]:
|
395
|
+
def _read_annotations(self, annotation: str) -> tuple[list[list[float]], list[int], dict[str, Any]]:
|
414
396
|
boxes: list[list[float]] = []
|
415
397
|
label_str = []
|
416
398
|
if not Path(annotation).exists():
|
@@ -435,12 +417,8 @@ class BaseVOCDataset(BaseDataset[_TArray, _TTarget, list[str], str]):
|
|
435
417
|
for obj in root.findall("object"):
|
436
418
|
label_str.append(obj.findtext("name", default=""))
|
437
419
|
additional_meta["pose"].append(obj.findtext("pose", default=""))
|
438
|
-
additional_meta["truncated"].append(
|
439
|
-
|
440
|
-
)
|
441
|
-
additional_meta["difficult"].append(
|
442
|
-
int(obj.findtext("difficult", default="-1"))
|
443
|
-
)
|
420
|
+
additional_meta["truncated"].append(int(obj.findtext("truncated", default="-1")))
|
421
|
+
additional_meta["difficult"].append(int(obj.findtext("difficult", default="-1")))
|
444
422
|
boxes.append(
|
445
423
|
[
|
446
424
|
float(obj.findtext("bndbox/xmin", default="0")),
|
@@ -454,9 +432,7 @@ class BaseVOCDataset(BaseDataset[_TArray, _TTarget, list[str], str]):
|
|
454
432
|
|
455
433
|
|
456
434
|
class VOCDetection(
|
457
|
-
BaseVOCDataset[
|
458
|
-
NDArray[np.number[Any]], ObjectDetectionTarget[NDArray[np.number[Any]]]
|
459
|
-
],
|
435
|
+
BaseVOCDataset[NDArray[np.number[Any]], ObjectDetectionTarget[NDArray[np.number[Any]]]],
|
460
436
|
BaseODDataset[NDArray[np.number[Any]], list[str], str],
|
461
437
|
BaseDatasetNumpyMixin,
|
462
438
|
):
|
@@ -1 +0,0 @@
|
|
1
|
-
"""Module for MAITE compliant Computer Vision datasets."""
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
{maite_datasets-0.0.1 → maite_datasets-0.0.3}/src/maite_datasets/image_classification/__init__.py
RENAMED
File without changes
|
{maite_datasets-0.0.1 → maite_datasets-0.0.3}/src/maite_datasets/object_detection/__init__.py
RENAMED
File without changes
|
{maite_datasets-0.0.1 → maite_datasets-0.0.3}/src/maite_datasets/object_detection/_voc_torch.py
RENAMED
File without changes
|
File without changes
|