maite-datasets 0.0.1__py3-none-any.whl → 0.0.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1 +1,14 @@
1
1
  """Module for MAITE compliant Computer Vision datasets."""
2
+
3
+ from maite_datasets._builder import to_image_classification_dataset, to_object_detection_dataset
4
+ from maite_datasets._collate import collate_as_torch, collate_as_numpy, collate_as_list
5
+ from maite_datasets._validate import validate_dataset
6
+
7
+ __all__ = [
8
+ "collate_as_list",
9
+ "collate_as_numpy",
10
+ "collate_as_torch",
11
+ "to_image_classification_dataset",
12
+ "to_object_detection_dataset",
13
+ "validate_dataset",
14
+ ]
maite_datasets/_base.py CHANGED
@@ -76,13 +76,9 @@ class BaseDataset(
76
76
  download: bool = False,
77
77
  verbose: bool = False,
78
78
  ) -> None:
79
- self._root: Path = (
80
- root.absolute() if isinstance(root, Path) else Path(root).absolute()
81
- )
79
+ self._root: Path = root.absolute() if isinstance(root, Path) else Path(root).absolute()
82
80
  transforms = transforms if transforms is not None else []
83
- self.transforms: Sequence[Transform[_TArray]] = (
84
- transforms if isinstance(transforms, Sequence) else [transforms]
85
- )
81
+ self.transforms: Sequence[Transform[_TArray]] = transforms if isinstance(transforms, Sequence) else [transforms]
86
82
  self.image_set = image_set
87
83
  self._verbose = verbose
88
84
 
@@ -109,11 +105,7 @@ class BaseDataset(
109
105
  nt = "\n "
110
106
  title = f"{self.__class__.__name__} Dataset"
111
107
  sep = "-" * len(title)
112
- attrs = [
113
- f"{k.capitalize()}: {v}"
114
- for k, v in self.__dict__.items()
115
- if not k.startswith("_")
116
- ]
108
+ attrs = [f"{k.capitalize()}: {v}" for k, v in self.__dict__.items() if not k.startswith("_")]
117
109
  return f"{title}\n{sep}{nt}{nt.join(attrs)}"
118
110
 
119
111
  @property
@@ -149,9 +141,7 @@ class BaseDataset(
149
141
  if self._verbose:
150
142
  print("No download needed, loaded data successfully.")
151
143
  except FileNotFoundError:
152
- _ensure_exists(
153
- *self._resource, self.path, self._root, self._download, self._verbose
154
- )
144
+ _ensure_exists(*self._resource, self.path, self._root, self._download, self._verbose)
155
145
  result = self._load_data_inner()
156
146
  return result
157
147
 
@@ -212,9 +202,7 @@ class BaseODDataset(
212
202
 
213
203
  _bboxes_per_size: bool = False
214
204
 
215
- def __getitem__(
216
- self, index: int
217
- ) -> tuple[_TArray, ObjectDetectionTarget[_TArray], DatumMetadata]:
205
+ def __getitem__(self, index: int) -> tuple[_TArray, ObjectDetectionTarget[_TArray], DatumMetadata]:
218
206
  """
219
207
  Args
220
208
  ----
@@ -235,13 +223,9 @@ class BaseODDataset(
235
223
  img = self._transform(img)
236
224
  # Adjust labels if necessary
237
225
  if self._bboxes_per_size and boxes:
238
- boxes = boxes * np.array(
239
- [[img_size[1], img_size[2], img_size[1], img_size[2]]]
240
- )
226
+ boxes = boxes * np.array([[img_size[1], img_size[2], img_size[1], img_size[2]]])
241
227
  # Create the Object Detection Target
242
- target = ObjectDetectionTarget(
243
- self._as_array(boxes), self._as_array(labels), self._one_hot_encode(labels)
244
- )
228
+ target = ObjectDetectionTarget(self._as_array(boxes), self._as_array(labels), self._one_hot_encode(labels))
245
229
 
246
230
  img_metadata = {key: val[index] for key, val in self._datum_metadata.items()}
247
231
  img_metadata = img_metadata | additional_metadata
@@ -249,6 +233,4 @@ class BaseODDataset(
249
233
  return img, target, _to_datum_metadata(index, img_metadata)
250
234
 
251
235
  @abstractmethod
252
- def _read_annotations(
253
- self, annotation: _TAnnotation
254
- ) -> tuple[list[list[float]], list[int], dict[str, Any]]: ...
236
+ def _read_annotations(self, annotation: _TAnnotation) -> tuple[list[list[float]], list[int], dict[str, Any]]: ...
@@ -0,0 +1,275 @@
1
+ from __future__ import annotations
2
+
3
+ import numpy as np
4
+
5
+ __all__ = []
6
+
7
+ from typing import (
8
+ Any,
9
+ Generic,
10
+ Iterable,
11
+ Literal,
12
+ Sequence,
13
+ SupportsFloat,
14
+ SupportsInt,
15
+ TypeVar,
16
+ cast,
17
+ )
18
+
19
+ from maite_datasets._protocols import (
20
+ Array,
21
+ ArrayLike,
22
+ DatasetMetadata,
23
+ ImageClassificationDataset,
24
+ ObjectDetectionDataset,
25
+ )
26
+
27
+
28
+ def _ensure_id(index: int, metadata: dict[str, Any]) -> dict[str, Any]:
29
+ return {"id": index, **metadata} if "id" not in metadata else metadata
30
+
31
+
32
+ def _validate_data(
33
+ datum_type: Literal["ic", "od"],
34
+ images: Array | Sequence[Array],
35
+ labels: Array | Sequence[int] | Sequence[Array] | Sequence[Sequence[int]],
36
+ bboxes: Array | Sequence[Array] | Sequence[Sequence[Array]] | Sequence[Sequence[Sequence[float]]] | None,
37
+ metadata: Sequence[dict[str, Any]] | dict[str, Sequence[Any]] | None,
38
+ ) -> None:
39
+ # Validate inputs
40
+ dataset_len = len(images)
41
+
42
+ if not isinstance(images, (Sequence, Array)) or len(images[0].shape) != 3:
43
+ raise ValueError("Images must be a sequence or array of 3 dimensional arrays (H, W, C).")
44
+ if len(labels) != dataset_len:
45
+ raise ValueError(f"Number of labels ({len(labels)}) does not match number of images ({dataset_len}).")
46
+ if bboxes is not None and len(bboxes) != dataset_len:
47
+ raise ValueError(f"Number of bboxes ({len(bboxes)}) does not match number of images ({dataset_len}).")
48
+ if metadata is not None and (
49
+ len(metadata) != dataset_len
50
+ if isinstance(metadata, Sequence)
51
+ else any(
52
+ not isinstance(metadatum, Sequence) or len(metadatum) != dataset_len for metadatum in metadata.values()
53
+ )
54
+ ):
55
+ raise ValueError(f"Number of metadata ({len(metadata)}) does not match number of images ({dataset_len}).")
56
+
57
+ if datum_type == "ic":
58
+ if not isinstance(labels, (Sequence, Array)) or not isinstance(labels[0], (int, SupportsInt)):
59
+ raise TypeError("Labels must be a sequence of integers for image classification.")
60
+ elif datum_type == "od":
61
+ if (
62
+ not isinstance(labels, (Sequence, Array))
63
+ or not isinstance(labels[0], (Sequence, Array))
64
+ or not isinstance(cast(Sequence[Any], labels[0])[0], (int, SupportsInt))
65
+ ):
66
+ raise TypeError("Labels must be a sequence of sequences of integers for object detection.")
67
+ if (
68
+ bboxes is None
69
+ or not isinstance(bboxes, (Sequence, Array))
70
+ or not isinstance(bboxes[0], (Sequence, Array))
71
+ or not isinstance(bboxes[0][0], (Sequence, Array))
72
+ or not isinstance(bboxes[0][0][0], (float, SupportsFloat))
73
+ or not len(bboxes[0][0]) == 4
74
+ ):
75
+ raise TypeError("Boxes must be a sequence of sequences of (x0, y0, x1, y1) for object detection.")
76
+ else:
77
+ raise ValueError(f"Unknown datum type '{datum_type}'. Must be 'ic' or 'od'.")
78
+
79
+
80
+ def _listify_metadata(
81
+ metadata: Sequence[dict[str, Any]] | dict[str, Sequence[Any]] | None,
82
+ ) -> Sequence[dict[str, Any]] | None:
83
+ if isinstance(metadata, dict):
84
+ return [{k: v[i] for k, v in metadata.items()} for i in range(len(next(iter(metadata.values()))))]
85
+ return metadata
86
+
87
+
88
+ def _find_max(arr: ArrayLike) -> Any:
89
+ if not isinstance(arr, (bytes, str)) and isinstance(arr, (Iterable, Sequence, Array)):
90
+ nested = [x for x in [_find_max(x) for x in arr] if x is not None]
91
+ return max(nested) if len(nested) > 0 else None
92
+ return arr
93
+
94
+
95
+ _TLabels = TypeVar("_TLabels", Sequence[int], Sequence[Sequence[int]])
96
+
97
+
98
+ class BaseAnnotatedDataset(Generic[_TLabels]):
99
+ def __init__(
100
+ self,
101
+ datum_type: Literal["ic", "od"],
102
+ images: Array | Sequence[Array],
103
+ labels: _TLabels,
104
+ metadata: Sequence[dict[str, Any]] | None,
105
+ classes: Sequence[str] | None,
106
+ name: str | None = None,
107
+ ) -> None:
108
+ self._classes = classes if classes is not None else [str(i) for i in range(_find_max(labels) + 1)]
109
+ self._index2label = dict(enumerate(self._classes))
110
+ self._images = images
111
+ self._labels = labels
112
+ self._metadata = metadata
113
+ self._id = name or f"{len(self._images)}_image_{len(self._index2label)}_class_{datum_type}_dataset"
114
+
115
+ @property
116
+ def metadata(self) -> DatasetMetadata:
117
+ return DatasetMetadata(id=self._id, index2label=self._index2label)
118
+
119
+ def __len__(self) -> int:
120
+ return len(self._images)
121
+
122
+
123
+ class CustomImageClassificationDataset(BaseAnnotatedDataset[Sequence[int]], ImageClassificationDataset):
124
+ def __init__(
125
+ self,
126
+ images: Array | Sequence[Array],
127
+ labels: Array | Sequence[int],
128
+ metadata: Sequence[dict[str, Any]] | None,
129
+ classes: Sequence[str] | None,
130
+ name: str | None = None,
131
+ ) -> None:
132
+ super().__init__(
133
+ "ic",
134
+ images,
135
+ np.asarray(labels).tolist() if isinstance(labels, Array) else labels,
136
+ metadata,
137
+ classes,
138
+ )
139
+ if name is not None:
140
+ self.__name__ = name
141
+ self.__class__.__name__ = name
142
+ self.__class__.__qualname__ = name
143
+
144
+ def __getitem__(self, idx: int, /) -> tuple[Array, Array, dict[str, Any]]:
145
+ one_hot = [0.0] * len(self._index2label)
146
+ one_hot[self._labels[idx]] = 1.0
147
+ return (
148
+ self._images[idx],
149
+ np.asarray(one_hot),
150
+ _ensure_id(idx, self._metadata[idx] if self._metadata is not None else {}),
151
+ )
152
+
153
+
154
+ class CustomObjectDetectionDataset(BaseAnnotatedDataset[Sequence[Sequence[int]]], ObjectDetectionDataset):
155
+ class ObjectDetectionTarget:
156
+ def __init__(
157
+ self,
158
+ labels: Sequence[int],
159
+ bboxes: Sequence[Sequence[float]],
160
+ class_count: int,
161
+ ) -> None:
162
+ self._labels = labels
163
+ self._bboxes = bboxes
164
+ one_hot = [[0.0] * class_count] * len(labels)
165
+ for i, label in enumerate(labels):
166
+ one_hot[i][label] = 1.0
167
+ self._scores = one_hot
168
+
169
+ @property
170
+ def labels(self) -> Sequence[int]:
171
+ return self._labels
172
+
173
+ @property
174
+ def boxes(self) -> Sequence[Sequence[float]]:
175
+ return self._bboxes
176
+
177
+ @property
178
+ def scores(self) -> Sequence[Sequence[float]]:
179
+ return self._scores
180
+
181
+ def __init__(
182
+ self,
183
+ images: Array | Sequence[Array],
184
+ labels: Array | Sequence[Array] | Sequence[Sequence[int]],
185
+ bboxes: Array | Sequence[Array] | Sequence[Sequence[Array]] | Sequence[Sequence[Sequence[float]]],
186
+ metadata: Sequence[dict[str, Any]] | None,
187
+ classes: Sequence[str] | None,
188
+ name: str | None = None,
189
+ ) -> None:
190
+ super().__init__(
191
+ "od",
192
+ images,
193
+ [np.asarray(label).tolist() if isinstance(label, Array) else label for label in labels],
194
+ metadata,
195
+ classes,
196
+ )
197
+ if name is not None:
198
+ self.__name__ = name
199
+ self.__class__.__name__ = name
200
+ self.__class__.__qualname__ = name
201
+ self._bboxes = [
202
+ [np.asarray(box).tolist() if isinstance(box, Array) else box for box in bbox] for bbox in bboxes
203
+ ]
204
+
205
+ @property
206
+ def metadata(self) -> DatasetMetadata:
207
+ return DatasetMetadata(id=self._id, index2label=self._index2label)
208
+
209
+ def __getitem__(self, idx: int, /) -> tuple[Array, ObjectDetectionTarget, dict[str, Any]]:
210
+ return (
211
+ self._images[idx],
212
+ self.ObjectDetectionTarget(self._labels[idx], self._bboxes[idx], len(self._classes)),
213
+ _ensure_id(idx, self._metadata[idx] if self._metadata is not None else {}),
214
+ )
215
+
216
+
217
+ def to_image_classification_dataset(
218
+ images: Array | Sequence[Array],
219
+ labels: Array | Sequence[int],
220
+ metadata: Sequence[dict[str, Any]] | dict[str, Sequence[Any]] | None,
221
+ classes: Sequence[str] | None,
222
+ name: str | None = None,
223
+ ) -> ImageClassificationDataset:
224
+ """
225
+ Helper function to create custom ImageClassificationDataset classes.
226
+
227
+ Parameters
228
+ ----------
229
+ images : Array | Sequence[Array]
230
+ The images to use in the dataset.
231
+ labels : Array | Sequence[int]
232
+ The labels to use in the dataset.
233
+ metadata : Sequence[dict[str, Any]] | dict[str, Sequence[Any]] | None
234
+ The metadata to use in the dataset.
235
+ classes : Sequence[str] | None
236
+ The classes to use in the dataset.
237
+
238
+ Returns
239
+ -------
240
+ ImageClassificationDataset
241
+ """
242
+ _validate_data("ic", images, labels, None, metadata)
243
+ return CustomImageClassificationDataset(images, labels, _listify_metadata(metadata), classes, name)
244
+
245
+
246
+ def to_object_detection_dataset(
247
+ images: Array | Sequence[Array],
248
+ labels: Array | Sequence[Array] | Sequence[Sequence[int]],
249
+ bboxes: Array | Sequence[Array] | Sequence[Sequence[Array]] | Sequence[Sequence[Sequence[float]]],
250
+ metadata: Sequence[dict[str, Any]] | dict[str, Sequence[Any]] | None,
251
+ classes: Sequence[str] | None,
252
+ name: str | None = None,
253
+ ) -> ObjectDetectionDataset:
254
+ """
255
+ Helper function to create custom ObjectDetectionDataset classes.
256
+
257
+ Parameters
258
+ ----------
259
+ images : Array | Sequence[Array]
260
+ The images to use in the dataset.
261
+ labels : Array | Sequence[Array] | Sequence[Sequence[int]]
262
+ The labels to use in the dataset.
263
+ bboxes : Array | Sequence[Array] | Sequence[Sequence[Array]] | Sequence[Sequence[Sequence[float]]]
264
+ The bounding boxes (x0,y0,x1,y0) to use in the dataset.
265
+ metadata : Sequence[dict[str, Any]] | dict[str, Sequence[Any]] | None
266
+ The metadata to use in the dataset.
267
+ classes : Sequence[str] | None
268
+ The classes to use in the dataset.
269
+
270
+ Returns
271
+ -------
272
+ ObjectDetectionDataset
273
+ """
274
+ _validate_data("od", images, labels, bboxes, metadata)
275
+ return CustomObjectDetectionDataset(images, labels, bboxes, _listify_metadata(metadata), classes, name)
@@ -0,0 +1,112 @@
1
+ """
2
+ Collate functions used with a PyTorch DataLoader to load data from MAITE compliant datasets.
3
+ """
4
+
5
+ from __future__ import annotations
6
+
7
+ __all__ = []
8
+
9
+ from collections.abc import Iterable, Sequence
10
+ from typing import Any, TypeVar, TYPE_CHECKING
11
+
12
+ import numpy as np
13
+ from numpy.typing import NDArray
14
+
15
+ if TYPE_CHECKING:
16
+ import torch
17
+
18
+ from maite_datasets._protocols import ArrayLike
19
+
20
+ T_in = TypeVar("T_in")
21
+ T_tgt = TypeVar("T_tgt")
22
+ T_md = TypeVar("T_md")
23
+
24
+
25
+ def collate_as_list(
26
+ batch_data_as_singles: Iterable[tuple[T_in, T_tgt, T_md]],
27
+ ) -> tuple[Sequence[T_in], Sequence[T_tgt], Sequence[T_md]]:
28
+ """
29
+ A collate function that takes a batch of individual data points in the format
30
+ (input, target, metadata) and returns three lists: the input batch, the target batch,
31
+ and the metadata batch. This is useful for loading data with torch.utils.data.DataLoader
32
+ when the target and metadata are not tensors.
33
+
34
+ Parameters
35
+ ----------
36
+ batch_data_as_singles : An iterable of (input, target, metadata) tuples.
37
+
38
+ Returns
39
+ -------
40
+ tuple[Sequence[T_in], Sequence[T_tgt], Sequence[T_md]]
41
+ A tuple of three lists: the input batch, the target batch, and the metadata batch.
42
+ """
43
+ input_batch: list[T_in] = []
44
+ target_batch: list[T_tgt] = []
45
+ metadata_batch: list[T_md] = []
46
+ for input_datum, target_datum, metadata_datum in batch_data_as_singles:
47
+ input_batch.append(input_datum)
48
+ target_batch.append(target_datum)
49
+ metadata_batch.append(metadata_datum)
50
+
51
+ return input_batch, target_batch, metadata_batch
52
+
53
+
54
+ def collate_as_numpy(
55
+ batch_data_as_singles: Iterable[tuple[ArrayLike, T_tgt, T_md]],
56
+ ) -> tuple[NDArray[Any], Sequence[T_tgt], Sequence[T_md]]:
57
+ """
58
+ A collate function that takes a batch of individual data points in the format
59
+ (input, target, metadata) and returns the batched input as a single NumPy array with two
60
+ lists: the target batch, and the metadata batch. The inputs must be homogeneous arrays.
61
+
62
+ Parameters
63
+ ----------
64
+ batch_data_as_singles : An iterable of (ArrayLike, target, metadata) tuples.
65
+
66
+ Returns
67
+ -------
68
+ tuple[NDArray[Any], Sequence[T_tgt], Sequence[T_md]]
69
+ A tuple of a NumPy array and two lists: the input batch, the target batch, and the metadata batch.
70
+ """
71
+ input_batch: list[NDArray[Any]] = []
72
+ target_batch: list[T_tgt] = []
73
+ metadata_batch: list[T_md] = []
74
+ for input_datum, target_datum, metadata_datum in batch_data_as_singles:
75
+ input_batch.append(np.asarray(input_datum))
76
+ target_batch.append(target_datum)
77
+ metadata_batch.append(metadata_datum)
78
+
79
+ return np.stack(input_batch) if input_batch else np.array([]), target_batch, metadata_batch
80
+
81
+
82
+ def collate_as_torch(
83
+ batch_data_as_singles: Iterable[tuple[ArrayLike, T_tgt, T_md]],
84
+ ) -> tuple[torch.Tensor, Sequence[T_tgt], Sequence[T_md]]:
85
+ """
86
+ A collate function that takes a batch of individual data points in the format
87
+ (input, target, metadata) and returns the batched input as a single torch Tensor with two
88
+ lists: the target batch, and the metadata batch. The inputs must be homogeneous arrays.
89
+
90
+ Parameters
91
+ ----------
92
+ batch_data_as_singles : An iterable of (ArrayLike, target, metadata) tuples.
93
+
94
+ Returns
95
+ -------
96
+ tuple[torch.Tensor, Sequence[T_tgt], Sequence[T_md]]
97
+ A tuple of a torch Tensor and two lists: the input batch, the target batch, and the metadata batch.
98
+ """
99
+ try:
100
+ import torch
101
+ except ImportError:
102
+ raise ImportError("PyTorch is not installed. Please install it to use this function.")
103
+
104
+ input_batch: list[torch.Tensor] = []
105
+ target_batch: list[T_tgt] = []
106
+ metadata_batch: list[T_md] = []
107
+ for input_datum, target_datum, metadata_datum in batch_data_as_singles:
108
+ input_batch.append(torch.as_tensor(input_datum))
109
+ target_batch.append(target_datum)
110
+ metadata_batch.append(metadata_datum)
111
+
112
+ return torch.stack(input_batch) if input_batch else torch.tensor([]), target_batch, metadata_batch
maite_datasets/_fileio.py CHANGED
@@ -23,9 +23,7 @@ def _print(text: str, verbose: bool) -> None:
23
23
  print(text)
24
24
 
25
25
 
26
- def _validate_file(
27
- fpath: Path | str, file_md5: str, md5: bool = False, chunk_size: int = 65535
28
- ) -> bool:
26
+ def _validate_file(fpath: Path | str, file_md5: str, md5: bool = False, chunk_size: int = 65535) -> bool:
29
27
  hasher = hashlib.md5(usedforsecurity=False) if md5 else hashlib.sha256()
30
28
  with open(fpath, "rb") as fpath_file:
31
29
  while chunk := fpath_file.read(chunk_size):
@@ -33,28 +31,20 @@ def _validate_file(
33
31
  return hasher.hexdigest() == file_md5
34
32
 
35
33
 
36
- def _download_dataset(
37
- url: str, file_path: Path, timeout: int = 60, verbose: bool = False
38
- ) -> None:
34
+ def _download_dataset(url: str, file_path: Path, timeout: int = 60, verbose: bool = False) -> None:
39
35
  """Download a single resource from its URL to the `data_folder`."""
40
36
  error_msg = "URL fetch failure on {}: {} -- {}"
41
37
  try:
42
38
  response = requests.get(url, stream=True, timeout=timeout)
43
39
  response.raise_for_status()
44
40
  except requests.exceptions.HTTPError as e:
45
- raise RuntimeError(
46
- f"{error_msg.format(url, e.response.status_code, e.response.reason)}"
47
- ) from e
41
+ raise RuntimeError(f"{error_msg.format(url, e.response.status_code, e.response.reason)}") from e
48
42
  except requests.exceptions.RequestException as e:
49
43
  raise ValueError(f"{error_msg.format(url, 'Unknown error', str(e))}") from e
50
44
 
51
45
  total_size = int(response.headers.get("content-length", 0))
52
46
  block_size = 8192 # 8 KB
53
- progress_bar = (
54
- None
55
- if tqdm is None
56
- else tqdm(total=total_size, unit="iB", unit_scale=True, disable=not verbose)
57
- )
47
+ progress_bar = None if tqdm is None else tqdm(total=total_size, unit="iB", unit_scale=True, disable=not verbose)
58
48
 
59
49
  with open(file_path, "wb") as f:
60
50
  for chunk in response.iter_content(block_size):
@@ -72,9 +62,7 @@ def _extract_zip_archive(file_path: Path, extract_to: Path) -> None:
72
62
  zip_ref.extractall(extract_to) # noqa: S202
73
63
  file_path.unlink()
74
64
  except zipfile.BadZipFile:
75
- raise FileNotFoundError(
76
- f"{file_path.name} is not a valid zip file, skipping extraction."
77
- )
65
+ raise FileNotFoundError(f"{file_path.name} is not a valid zip file, skipping extraction.")
78
66
 
79
67
 
80
68
  def _extract_tar_archive(file_path: Path, extract_to: Path) -> None:
@@ -84,9 +72,7 @@ def _extract_tar_archive(file_path: Path, extract_to: Path) -> None:
84
72
  tar_ref.extractall(extract_to) # noqa: S202
85
73
  file_path.unlink()
86
74
  except tarfile.TarError:
87
- raise FileNotFoundError(
88
- f"{file_path.name} is not a valid tar file, skipping extraction."
89
- )
75
+ raise FileNotFoundError(f"{file_path.name} is not a valid tar file, skipping extraction.")
90
76
 
91
77
 
92
78
  def _extract_archive(
@@ -135,11 +121,7 @@ def _ensure_exists(
135
121
  file_ext = file_path.suffixes[0]
136
122
  compression = True
137
123
 
138
- check_path = (
139
- alternate_path
140
- if alternate_path.exists() and not file_path.exists()
141
- else file_path
142
- )
124
+ check_path = alternate_path if alternate_path.exists() and not file_path.exists() else file_path
143
125
 
144
126
  # Download file if it doesn't exist.
145
127
  if not check_path.exists() and download:
@@ -147,9 +129,7 @@ def _ensure_exists(
147
129
  _download_dataset(url, check_path, verbose=verbose)
148
130
 
149
131
  if not _validate_file(check_path, checksum, md5):
150
- raise Exception(
151
- "File checksum mismatch. Remove current file and retry download."
152
- )
132
+ raise Exception("File checksum mismatch. Remove current file and retry download.")
153
133
 
154
134
  # If the file is a zip, tar or tgz extract it into the designated folder.
155
135
  if file_ext in ARCHIVE_ENDINGS:
@@ -164,9 +144,7 @@ def _ensure_exists(
164
144
  )
165
145
  else:
166
146
  if not _validate_file(check_path, checksum, md5):
167
- raise Exception(
168
- "File checksum mismatch. Remove current file and retry download."
169
- )
147
+ raise Exception("File checksum mismatch. Remove current file and retry download.")
170
148
  _print(f"{filename} already exists, skipping download.", verbose)
171
149
 
172
150
  if file_ext in ARCHIVE_ENDINGS:
@@ -174,9 +174,7 @@ class ObjectDetectionTarget(Protocol):
174
174
  def scores(self) -> ArrayLike: ...
175
175
 
176
176
 
177
- ObjectDetectionDatum: TypeAlias = tuple[
178
- ArrayLike, ObjectDetectionTarget, Mapping[str, Any]
179
- ]
177
+ ObjectDetectionDatum: TypeAlias = tuple[ArrayLike, ObjectDetectionTarget, Mapping[str, Any]]
180
178
  """
181
179
  Type alias for an object detection datum tuple.
182
180
 
maite_datasets/_types.py CHANGED
@@ -37,9 +37,7 @@ class AnnotatedDataset(Dataset[_TDatum]):
37
37
  def __len__(self) -> int: ...
38
38
 
39
39
 
40
- class ImageClassificationDataset(
41
- AnnotatedDataset[tuple[_TArray, _TArray, DatumMetadata]]
42
- ): ...
40
+ class ImageClassificationDataset(AnnotatedDataset[tuple[_TArray, _TArray, DatumMetadata]]): ...
43
41
 
44
42
 
45
43
  @dataclass
@@ -49,6 +47,4 @@ class ObjectDetectionTarget(Generic[_TArray]):
49
47
  scores: _TArray
50
48
 
51
49
 
52
- class ObjectDetectionDataset(
53
- AnnotatedDataset[tuple[_TArray, ObjectDetectionTarget[_TArray], DatumMetadata]]
54
- ): ...
50
+ class ObjectDetectionDataset(AnnotatedDataset[tuple[_TArray, ObjectDetectionTarget[_TArray], DatumMetadata]]): ...
@@ -0,0 +1,169 @@
1
+ from __future__ import annotations
2
+
3
+ __all__ = []
4
+
5
+ import numpy as np
6
+ from collections.abc import Sequence, Sized
7
+ from typing import Any, Literal
8
+
9
+ from maite_datasets._protocols import Array, ObjectDetectionTarget
10
+
11
+
12
+ class ValidationMessages:
13
+ DATASET_SIZED = "Dataset must be sized."
14
+ DATASET_INDEXABLE = "Dataset must be indexable."
15
+ DATASET_NONEMPTY = "Dataset must be non-empty."
16
+ DATASET_METADATA = "Dataset must have a 'metadata' attribute."
17
+ DATASET_METADATA_TYPE = "Dataset metadata must be a dictionary."
18
+ DATASET_METADATA_FORMAT = "Dataset metadata must contain an 'id' key."
19
+ DATUM_TYPE = "Dataset datum must be a tuple."
20
+ DATUM_FORMAT = "Dataset datum must contain 3 elements: image, target, metadata."
21
+ DATUM_IMAGE_TYPE = "Images must be 3-dimensional arrays."
22
+ DATUM_IMAGE_FORMAT = "Images must be in CHW format."
23
+ DATUM_TARGET_IC_TYPE = "ImageClassificationDataset targets must be one-dimensional arrays."
24
+ DATUM_TARGET_IC_FORMAT = "ImageClassificationDataset targets must be one-hot encoded or pseudo-probabilities."
25
+ DATUM_TARGET_OD_TYPE = "ObjectDetectionDataset targets must be have 'boxes', 'labels' and 'scores'."
26
+ DATUM_TARGET_OD_LABELS_TYPE = "ObjectDetectionTarget labels must be one-dimensional (N,) arrays."
27
+ DATUM_TARGET_OD_BOXES_TYPE = "ObjectDetectionTarget boxes must be two-dimensional (N, 4) arrays in xxyy format."
28
+ DATUM_TARGET_OD_SCORES_TYPE = "ObjectDetectionTarget scores must be one (N,) or two-dimensional (N, M) arrays."
29
+ DATUM_TARGET_TYPE = "Target is not a valid ImageClassification or ObjectDetection target type."
30
+ DATUM_METADATA_TYPE = "Datum metadata must be a dictionary."
31
+ DATUM_METADATA_FORMAT = "Datum metadata must contain an 'id' key."
32
+
33
+
34
+ def _validate_dataset_type(dataset: Any) -> list[str]:
35
+ issues = []
36
+ is_sized = isinstance(dataset, Sized)
37
+ is_indexable = hasattr(dataset, "__getitem__")
38
+ if not is_sized:
39
+ issues.append(ValidationMessages.DATASET_SIZED)
40
+ if not is_indexable:
41
+ issues.append(ValidationMessages.DATASET_INDEXABLE)
42
+ if is_sized and len(dataset) == 0:
43
+ issues.append(ValidationMessages.DATASET_NONEMPTY)
44
+ return issues
45
+
46
+
47
+ def _validate_dataset_metadata(dataset: Any) -> list[str]:
48
+ issues = []
49
+ if not hasattr(dataset, "metadata"):
50
+ issues.append(ValidationMessages.DATASET_METADATA)
51
+ metadata = getattr(dataset, "metadata", None)
52
+ if not isinstance(metadata, dict):
53
+ issues.append(ValidationMessages.DATASET_METADATA_TYPE)
54
+ if not isinstance(metadata, dict) or "id" not in metadata:
55
+ issues.append(ValidationMessages.DATASET_METADATA_FORMAT)
56
+ return issues
57
+
58
+
59
+ def _validate_datum_type(datum: Any) -> list[str]:
60
+ issues = []
61
+ if not isinstance(datum, tuple):
62
+ issues.append(ValidationMessages.DATUM_TYPE)
63
+ if datum is None or isinstance(datum, Sized) and len(datum) != 3:
64
+ issues.append(ValidationMessages.DATUM_FORMAT)
65
+ return issues
66
+
67
+
68
+ def _validate_datum_image(image: Any) -> list[str]:
69
+ issues = []
70
+ if not isinstance(image, Array) or len(image.shape) != 3:
71
+ issues.append(ValidationMessages.DATUM_IMAGE_TYPE)
72
+ if (
73
+ not isinstance(image, Array)
74
+ or len(image.shape) == 3
75
+ and (image.shape[0] > image.shape[1] or image.shape[0] > image.shape[2])
76
+ ):
77
+ issues.append(ValidationMessages.DATUM_IMAGE_FORMAT)
78
+ return issues
79
+
80
+
81
+ def _validate_datum_target_ic(target: Any) -> list[str]:
82
+ issues = []
83
+ if not isinstance(target, Array) or len(target.shape) != 1:
84
+ issues.append(ValidationMessages.DATUM_TARGET_IC_TYPE)
85
+ if target is None or sum(target) > 1 + 1e-6 or sum(target) < 1 - 1e-6:
86
+ issues.append(ValidationMessages.DATUM_TARGET_IC_FORMAT)
87
+ return issues
88
+
89
+
90
+ def _validate_datum_target_od(target: Any) -> list[str]:
91
+ issues = []
92
+ if not isinstance(target, ObjectDetectionTarget):
93
+ issues.append(ValidationMessages.DATUM_TARGET_OD_TYPE)
94
+ od_target: ObjectDetectionTarget | None = target if isinstance(target, ObjectDetectionTarget) else None
95
+ if od_target is None or len(np.asarray(od_target.labels).shape) != 1:
96
+ issues.append(ValidationMessages.DATUM_TARGET_OD_LABELS_TYPE)
97
+ if (
98
+ od_target is None
99
+ or len(np.asarray(od_target.boxes).shape) != 2
100
+ or (len(np.asarray(od_target.boxes).shape) == 2 and np.asarray(od_target.boxes).shape[1] != 4)
101
+ ):
102
+ issues.append(ValidationMessages.DATUM_TARGET_OD_BOXES_TYPE)
103
+ if od_target is None or len(np.asarray(od_target.scores).shape) not in (1, 2):
104
+ issues.append(ValidationMessages.DATUM_TARGET_OD_SCORES_TYPE)
105
+ return issues
106
+
107
+
108
+ def _detect_target_type(target: Any) -> Literal["ic", "od", "auto"]:
109
+ if isinstance(target, Array):
110
+ return "ic"
111
+ if isinstance(target, ObjectDetectionTarget):
112
+ return "od"
113
+ return "auto"
114
+
115
+
116
+ def _validate_datum_target(target: Any, target_type: Literal["ic", "od", "auto"]) -> list[str]:
117
+ issues = []
118
+ target_type = _detect_target_type(target) if target_type == "auto" else target_type
119
+ if target_type == "ic":
120
+ issues.extend(_validate_datum_target_ic(target))
121
+ elif target_type == "od":
122
+ issues.extend(_validate_datum_target_od(target))
123
+ else:
124
+ issues.append(ValidationMessages.DATUM_TARGET_TYPE)
125
+ return issues
126
+
127
+
128
+ def _validate_datum_metadata(metadata: Any) -> list[str]:
129
+ issues = []
130
+ if metadata is None or not isinstance(metadata, dict):
131
+ issues.append(ValidationMessages.DATUM_METADATA_TYPE)
132
+ if metadata is None or isinstance(metadata, dict) and "id" not in metadata:
133
+ issues.append(ValidationMessages.DATUM_METADATA_FORMAT)
134
+ return issues
135
+
136
+
137
+ def validate_dataset(dataset: Any, dataset_type: Literal["ic", "od", "auto"] = "auto") -> None:
138
+ """
139
+ Validate a dataset for compliance with MAITE protocol.
140
+
141
+ Parameters
142
+ ----------
143
+ dataset: Any
144
+ Dataset to validate.
145
+ dataset_type: "ic", "od", or "auto", default "auto"
146
+ Dataset type, if known.
147
+
148
+ Raises
149
+ ------
150
+ ValueError
151
+ Raises exception if dataset is invalid with a list of validation issues.
152
+ """
153
+ issues = []
154
+ issues.extend(_validate_dataset_type(dataset))
155
+ datum = None if issues else dataset[0] # type: ignore
156
+ issues.extend(_validate_dataset_metadata(dataset))
157
+ issues.extend(_validate_datum_type(datum))
158
+
159
+ is_seq = isinstance(datum, Sequence)
160
+ datum_len = len(datum) if is_seq else 0
161
+ image = datum[0] if is_seq and datum_len > 0 else None
162
+ target = datum[1] if is_seq and datum_len > 1 else None
163
+ metadata = datum[2] if is_seq and datum_len > 2 else None
164
+ issues.extend(_validate_datum_image(image))
165
+ issues.extend(_validate_datum_target(target, dataset_type))
166
+ issues.extend(_validate_datum_metadata(metadata))
167
+
168
+ if issues:
169
+ raise ValueError("Dataset validation issues found:\n - " + "\n - ".join(issues))
@@ -24,9 +24,7 @@ CIFARClassStringMap = Literal[
24
24
  "ship",
25
25
  "truck",
26
26
  ]
27
- TCIFARClassMap = TypeVar(
28
- "TCIFARClassMap", CIFARClassStringMap, int, list[CIFARClassStringMap], list[int]
29
- )
27
+ TCIFARClassMap = TypeVar("TCIFARClassMap", CIFARClassStringMap, int, list[CIFARClassStringMap], list[int])
30
28
 
31
29
 
32
30
  class CIFAR10(BaseICDataset[NDArray[np.number[Any]]], BaseDatasetNumpyMixin):
@@ -91,9 +89,7 @@ class CIFAR10(BaseICDataset[NDArray[np.number[Any]]], BaseDatasetNumpyMixin):
91
89
  self,
92
90
  root: str | Path,
93
91
  image_set: Literal["train", "test", "base"] = "train",
94
- transforms: Transform[NDArray[np.number[Any]]]
95
- | Sequence[Transform[NDArray[np.number[Any]]]]
96
- | None = None,
92
+ transforms: Transform[NDArray[np.number[Any]]] | Sequence[Transform[NDArray[np.number[Any]]]] | None = None,
97
93
  download: bool = False,
98
94
  verbose: bool = False,
99
95
  ) -> None:
@@ -105,9 +101,7 @@ class CIFAR10(BaseICDataset[NDArray[np.number[Any]]], BaseDatasetNumpyMixin):
105
101
  verbose,
106
102
  )
107
103
 
108
- def _load_bin_data(
109
- self, data_folder: list[Path]
110
- ) -> tuple[list[str], list[int], dict[str, Any]]:
104
+ def _load_bin_data(self, data_folder: list[Path]) -> tuple[list[str], list[int], dict[str, Any]]:
111
105
  batch_nums = np.zeros(60000, dtype=np.uint8)
112
106
  all_labels = np.zeros(60000, dtype=np.uint8)
113
107
  all_images = np.zeros((60000, 3, 32, 32), dtype=np.uint8)
@@ -115,9 +109,7 @@ class CIFAR10(BaseICDataset[NDArray[np.number[Any]]], BaseDatasetNumpyMixin):
115
109
  for batch_file in data_folder:
116
110
  # Get batch parameters
117
111
  batch_type = "test" if "test" in batch_file.stem else "train"
118
- batch_num = (
119
- 5 if batch_type == "test" else int(batch_file.stem.split("_")[-1]) - 1
120
- )
112
+ batch_num = 5 if batch_type == "test" else int(batch_file.stem.split("_")[-1]) - 1
121
113
 
122
114
  # Load data
123
115
  batch_images, batch_labels = self._unpack_batch_files(batch_file)
@@ -193,9 +185,7 @@ class CIFAR10(BaseICDataset[NDArray[np.number[Any]]], BaseDatasetNumpyMixin):
193
185
  {"batch_num": batch_nums.tolist()},
194
186
  )
195
187
 
196
- def _unpack_batch_files(
197
- self, file_path: Path
198
- ) -> tuple[NDArray[np.uint8], NDArray[np.uint8]]:
188
+ def _unpack_batch_files(self, file_path: Path) -> tuple[NDArray[np.uint8], NDArray[np.uint8]]:
199
189
  # Load pickle data with latin1 encoding
200
190
  with file_path.open("rb") as f:
201
191
  buffer = np.frombuffer(f.read(), dtype=np.uint8)
@@ -12,12 +12,8 @@ from maite_datasets._base import BaseICDataset, DataLocation
12
12
  from maite_datasets._mixin._numpy import BaseDatasetNumpyMixin
13
13
  from maite_datasets._protocols import Transform
14
14
 
15
- MNISTClassStringMap = Literal[
16
- "zero", "one", "two", "three", "four", "five", "six", "seven", "eight", "nine"
17
- ]
18
- TMNISTClassMap = TypeVar(
19
- "TMNISTClassMap", MNISTClassStringMap, int, list[MNISTClassStringMap], list[int]
20
- )
15
+ MNISTClassStringMap = Literal["zero", "one", "two", "three", "four", "five", "six", "seven", "eight", "nine"]
16
+ TMNISTClassMap = TypeVar("TMNISTClassMap", MNISTClassStringMap, int, list[MNISTClassStringMap], list[int])
21
17
  CorruptionStringMap = Literal[
22
18
  "identity",
23
19
  "shot_noise",
@@ -122,9 +118,7 @@ class MNIST(BaseICDataset[NDArray[np.number[Any]]], BaseDatasetNumpyMixin):
122
118
  root: str | Path,
123
119
  image_set: Literal["train", "test", "base"] = "train",
124
120
  corruption: CorruptionStringMap | None = None,
125
- transforms: Transform[NDArray[np.number[Any]]]
126
- | Sequence[Transform[NDArray[np.number[Any]]]]
127
- | None = None,
121
+ transforms: Transform[NDArray[np.number[Any]]] | Sequence[Transform[NDArray[np.number[Any]]]] | None = None,
128
122
  download: bool = False,
129
123
  verbose: bool = False,
130
124
  ) -> None:
@@ -182,18 +176,12 @@ class MNIST(BaseICDataset[NDArray[np.number[Any]]], BaseDatasetNumpyMixin):
182
176
 
183
177
  return data, labels
184
178
 
185
- def _grab_data(
186
- self, path: Path
187
- ) -> tuple[NDArray[np.number[Any]], NDArray[np.uintp]]:
179
+ def _grab_data(self, path: Path) -> tuple[NDArray[np.number[Any]], NDArray[np.uintp]]:
188
180
  """Function to load in the data numpy array"""
189
181
  with np.load(path, allow_pickle=True) as data_array:
190
182
  if self.image_set == "base":
191
- data = np.concatenate(
192
- [data_array["x_train"], data_array["x_test"]], axis=0
193
- )
194
- labels = np.concatenate(
195
- [data_array["y_train"], data_array["y_test"]], axis=0
196
- ).astype(np.uintp)
183
+ data = np.concatenate([data_array["x_train"], data_array["x_test"]], axis=0)
184
+ labels = np.concatenate([data_array["y_train"], data_array["y_test"]], axis=0).astype(np.uintp)
197
185
  else:
198
186
  data, labels = (
199
187
  data_array[f"x_{self.image_set}"],
@@ -76,9 +76,7 @@ class Ships(BaseICDataset[NDArray[np.number[Any]]], BaseDatasetNumpyMixin):
76
76
  def __init__(
77
77
  self,
78
78
  root: str | Path,
79
- transforms: Transform[NDArray[np.number[Any]]]
80
- | Sequence[Transform[NDArray[np.number[Any]]]]
81
- | None = None,
79
+ transforms: Transform[NDArray[np.number[Any]]] | Sequence[Transform[NDArray[np.number[Any]]]] | None = None,
82
80
  download: bool = False,
83
81
  verbose: bool = False,
84
82
  ) -> None:
@@ -14,9 +14,7 @@ from maite_datasets._mixin._numpy import BaseDatasetNumpyMixin
14
14
  from maite_datasets._protocols import Transform
15
15
 
16
16
 
17
- class AntiUAVDetection(
18
- BaseODDataset[NDArray[np.number[Any]], list[str], str], BaseDatasetNumpyMixin
19
- ):
17
+ class AntiUAVDetection(BaseODDataset[NDArray[np.number[Any]], list[str], str], BaseDatasetNumpyMixin):
20
18
  """
21
19
  A UAV detection dataset focused on detecting UAVs in natural images against large variation in backgrounds.
22
20
 
@@ -103,9 +101,7 @@ class AntiUAVDetection(
103
101
  self,
104
102
  root: str | Path,
105
103
  image_set: Literal["train", "val", "test", "base"] = "train",
106
- transforms: Transform[NDArray[np.number[Any]]]
107
- | Sequence[Transform[NDArray[np.number[Any]]]]
108
- | None = None,
104
+ transforms: Transform[NDArray[np.number[Any]]] | Sequence[Transform[NDArray[np.number[Any]]]] | None = None,
109
105
  download: bool = False,
110
106
  verbose: bool = False,
111
107
  ) -> None:
@@ -128,9 +124,7 @@ class AntiUAVDetection(
128
124
 
129
125
  for resource in self._resources:
130
126
  self._resource = resource
131
- resource_filepaths, resource_targets, resource_metadata = (
132
- super()._load_data()
133
- )
127
+ resource_filepaths, resource_targets, resource_metadata = super()._load_data()
134
128
  filepaths.extend(resource_filepaths)
135
129
  targets.extend(resource_targets)
136
130
  metadata_list.append(resource_metadata)
@@ -148,9 +142,7 @@ class AntiUAVDetection(
148
142
  for resource in self._resources:
149
143
  if self.image_set in resource.filename:
150
144
  self._resource = resource
151
- resource_filepaths, resource_targets, resource_metadata = (
152
- super()._load_data()
153
- )
145
+ resource_filepaths, resource_targets, resource_metadata = super()._load_data()
154
146
  filepaths.extend(resource_filepaths)
155
147
  targets.extend(resource_targets)
156
148
  datum_metadata.update(resource_metadata)
@@ -164,17 +156,13 @@ class AntiUAVDetection(
164
156
  if not data_folder:
165
157
  raise FileNotFoundError
166
158
 
167
- file_data = {
168
- "image_id": [f"{resource_name}_{entry.name}" for entry in data_folder]
169
- }
159
+ file_data = {"image_id": [f"{resource_name}_{entry.name}" for entry in data_folder]}
170
160
  data = [str(entry) for entry in data_folder]
171
161
  annotations = sorted(str(entry) for entry in (base_dir / "xml").glob("*.xml"))
172
162
 
173
163
  return data, annotations, file_data
174
164
 
175
- def _read_annotations(
176
- self, annotation: str
177
- ) -> tuple[list[list[float]], list[int], dict[str, Any]]:
165
+ def _read_annotations(self, annotation: str) -> tuple[list[list[float]], list[int], dict[str, Any]]:
178
166
  """Function for extracting the info for the label and boxes"""
179
167
  boxes: list[list[float]] = []
180
168
  labels = []
@@ -13,9 +13,7 @@ from maite_datasets._mixin._numpy import BaseDatasetNumpyMixin
13
13
  from maite_datasets._protocols import Transform
14
14
 
15
15
 
16
- class MILCO(
17
- BaseODDataset[NDArray[np.number[Any]], list[str], str], BaseDatasetNumpyMixin
18
- ):
16
+ class MILCO(BaseODDataset[NDArray[np.number[Any]], list[str], str], BaseDatasetNumpyMixin):
19
17
  """
20
18
  A side-scan sonar dataset focused on mine-like object detection.
21
19
 
@@ -118,9 +116,7 @@ class MILCO(
118
116
  self,
119
117
  root: str | Path,
120
118
  image_set: Literal["train", "operational", "base"] = "train",
121
- transforms: Transform[NDArray[np.number[Any]]]
122
- | Sequence[Transform[NDArray[np.number[Any]]]]
123
- | None = None,
119
+ transforms: Transform[NDArray[np.number[Any]]] | Sequence[Transform[NDArray[np.number[Any]]]] | None = None,
124
120
  download: bool = False,
125
121
  verbose: bool = False,
126
122
  ) -> None:
@@ -180,9 +176,7 @@ class MILCO(
180
176
 
181
177
  return data, annotations, file_data
182
178
 
183
- def _read_annotations(
184
- self, annotation: str
185
- ) -> tuple[list[list[float]], list[int], dict[str, Any]]:
179
+ def _read_annotations(self, annotation: str) -> tuple[list[list[float]], list[int], dict[str, Any]]:
186
180
  """Function for extracting the info out of the text files"""
187
181
  labels: list[int] = []
188
182
  boxes: list[list[float]] = []
@@ -313,9 +313,7 @@ class SeaDrone(
313
313
  self,
314
314
  root: str | Path,
315
315
  image_set: Literal["train", "val", "test", "base"] = "train",
316
- transforms: Transform[NDArray[np.number[Any]]]
317
- | Sequence[Transform[NDArray[np.number[Any]]]]
318
- | None = None,
316
+ transforms: Transform[NDArray[np.number[Any]]] | Sequence[Transform[NDArray[np.number[Any]]]] | None = None,
319
317
  download: bool = False,
320
318
  verbose: bool = False,
321
319
  ) -> None:
@@ -365,9 +363,7 @@ class SeaDrone(
365
363
 
366
364
  def _load_data(
367
365
  self,
368
- ) -> tuple[
369
- list[str], list[tuple[list[int], list[list[float]]]], dict[str, list[Any]]
370
- ]:
366
+ ) -> tuple[list[str], list[tuple[list[int], list[list[float]]]], dict[str, list[Any]]]:
371
367
  image_sets: dict[str, list[int]] = {
372
368
  "train": list(range(20)),
373
369
  "val": list(range(20, 24)),
@@ -390,9 +386,7 @@ class SeaDrone(
390
386
 
391
387
  return filepaths, list(targets), datum_metadata
392
388
 
393
- def _load_images(
394
- self, data_folder: Path, file_data: dict[int, dict[str, Any]]
395
- ) -> dict[int, dict[str, Any]]:
389
+ def _load_images(self, data_folder: Path, file_data: dict[int, dict[str, Any]]) -> dict[int, dict[str, Any]]:
396
390
  for entry in data_folder.iterdir():
397
391
  if entry.is_file() and entry.suffix == ".jpg":
398
392
  if int(entry.stem) not in file_data:
@@ -441,14 +435,10 @@ class SeaDrone(
441
435
  current_file["storage"] = source.get("folder_name", "")
442
436
 
443
437
  # Handle non-standard file metadata
444
- current_file["date_time"] = (
445
- file_meta.get("date_time") or meta.get("date_time") or ""
446
- )
438
+ current_file["date_time"] = file_meta.get("date_time") or meta.get("date_time") or ""
447
439
  if "frame" in file_meta:
448
440
  frame = file_meta["frame"][:-4]
449
- current_file["frame"] = (
450
- int(frame.split("_")[-1]) if "IMG_" in frame else int(frame[3:])
451
- )
441
+ current_file["frame"] = int(frame.split("_")[-1]) if "IMG_" in frame else int(frame[3:])
452
442
  elif "frame_no" in source:
453
443
  current_file["frame"] = source["frame_no"]
454
444
  else:
@@ -456,9 +446,7 @@ class SeaDrone(
456
446
 
457
447
  # Grab additional metadata if available
458
448
  for output_key, (possible_keys, default) in mappings.items():
459
- current_file[output_key] = next(
460
- (meta.get(key) for key in possible_keys if key in meta), default
461
- )
449
+ current_file[output_key] = next((meta.get(key) for key in possible_keys if key in meta), default)
462
450
 
463
451
  # Retrieve the label and bounding box
464
452
  for annotation in result["annotations"]:
@@ -482,9 +470,7 @@ class SeaDrone(
482
470
 
483
471
  return file_data
484
472
 
485
- def _restructure_file_data(
486
- self, file_data: dict[int, dict[str, Any]]
487
- ) -> dict[str, list[Any]]:
473
+ def _restructure_file_data(self, file_data: dict[int, dict[str, Any]]) -> dict[str, list[Any]]:
488
474
  """Restructure file_data from dictionary of dictionaries to a dictionary of lists"""
489
475
  # Get the keys from the dictionary
490
476
  all_keys = set()
@@ -501,9 +487,7 @@ class SeaDrone(
501
487
  # Create the lists
502
488
  for file_id, file_dict in file_data.items():
503
489
  restructured_data["image_id"].append(file_id)
504
- restructured_data["label_box"].append(
505
- (file_dict.get("label", []), file_dict.get("box", []))
506
- )
490
+ restructured_data["label_box"].append((file_dict.get("label", []), file_dict.get("box", [])))
507
491
  for key in all_keys:
508
492
  restructured_data[key].append(file_dict.get(key, None))
509
493
 
@@ -528,12 +512,8 @@ class SeaDrone(
528
512
  json_name = folder
529
513
  if json_name == "test":
530
514
  json_name += "_nogt"
531
- annotation_file = (
532
- self.path / "annotations" / f"instances_{json_name}.json"
533
- )
534
- file_data = self._create_per_image_annotations(
535
- annotation_file, file_data
536
- )
515
+ annotation_file = self.path / "annotations" / f"instances_{json_name}.json"
516
+ file_data = self._create_per_image_annotations(annotation_file, file_data)
537
517
 
538
518
  meta_data = self._restructure_file_data(file_data)
539
519
  data = meta_data.pop("data_path")
@@ -45,9 +45,7 @@ VOCClassStringMap = Literal[
45
45
  "train",
46
46
  "tvmonitor",
47
47
  ]
48
- TVOCClassMap = TypeVar(
49
- "TVOCClassMap", VOCClassStringMap, int, list[VOCClassStringMap], list[int]
50
- )
48
+ TVOCClassMap = TypeVar("TVOCClassMap", VOCClassStringMap, int, list[VOCClassStringMap], list[int])
51
49
 
52
50
 
53
51
  class BaseVOCDataset(BaseDataset[_TArray, _TTarget, list[str], str]):
@@ -170,13 +168,9 @@ class BaseVOCDataset(BaseDataset[_TArray, _TTarget, list[str], str]):
170
168
  base if base.stem == f"VOC{self.year}" else None,
171
169
  base / f"VOC{self.year}" if base.stem == "VOCdevkit" else None,
172
170
  base / "VOCdevkit" / f"VOC{self.year}",
173
- base / "TrainVal" / "VOCdevkit" / f"VOC{self.year}"
174
- if self.year == "2011"
175
- else None,
171
+ base / "TrainVal" / "VOCdevkit" / f"VOC{self.year}" if self.year == "2011" else None,
176
172
  dataset_dir / "VOCdevkit" / f"VOC{self.year}",
177
- dataset_dir / "TrainVal" / "VOCdevkit" / f"VOC{self.year}"
178
- if self.year == "2011"
179
- else None,
173
+ dataset_dir / "TrainVal" / "VOCdevkit" / f"VOC{self.year}" if self.year == "2011" else None,
180
174
  ]
181
175
 
182
176
  # Filter out None values and check each path
@@ -269,9 +263,7 @@ class BaseVOCDataset(BaseDataset[_TArray, _TTarget, list[str], str]):
269
263
 
270
264
  for img_set in ["test", "base"]:
271
265
  self.image_set = img_set
272
- resource_filepaths, resource_targets, resource_metadata = (
273
- self._load_data_inner()
274
- )
266
+ resource_filepaths, resource_targets, resource_metadata = self._load_data_inner()
275
267
  filepaths.extend(resource_filepaths)
276
268
  targets.extend(resource_targets)
277
269
  metadata_list.append(resource_metadata)
@@ -288,14 +280,10 @@ class BaseVOCDataset(BaseDataset[_TArray, _TTarget, list[str], str]):
288
280
  self._resource = self._resources[resource_idx[1]]
289
281
 
290
282
  if train_exists and not test_exists:
291
- _ensure_exists(
292
- *self._resource, tmp_path, self._root, self._download, self._verbose
293
- )
283
+ _ensure_exists(*self._resource, tmp_path, self._root, self._download, self._verbose)
294
284
  self._merge_voc_directories(tmp_path)
295
285
 
296
- resource_filepaths, resource_targets, resource_metadata = (
297
- self._load_try_and_update()
298
- )
286
+ resource_filepaths, resource_targets, resource_metadata = self._load_try_and_update()
299
287
  filepaths.extend(resource_filepaths)
300
288
  targets.extend(resource_targets)
301
289
  datum_metadata.update(resource_metadata)
@@ -341,9 +329,7 @@ class BaseVOCDataset(BaseDataset[_TArray, _TTarget, list[str], str]):
341
329
  if self._verbose:
342
330
  print("No download needed, loaded data successfully.")
343
331
  except FileNotFoundError:
344
- _ensure_exists(
345
- *self._resource, self.path, self._root, self._download, self._verbose
346
- )
332
+ _ensure_exists(*self._resource, self.path, self._root, self._download, self._verbose)
347
333
  self._update_path()
348
334
  result = self._load_data_inner()
349
335
  return result
@@ -364,9 +350,7 @@ class BaseVOCDataset(BaseDataset[_TArray, _TTarget, list[str], str]):
364
350
  def _get_image_sets(self) -> dict[str, list[str]]:
365
351
  """Function to create the list of images in each image set"""
366
352
  image_folder = self.path / "JPEGImages"
367
- image_set_list = (
368
- ["train", "val", "trainval"] if self.image_set != "test" else ["test"]
369
- )
353
+ image_set_list = ["train", "val", "trainval"] if self.image_set != "test" else ["test"]
370
354
  image_sets = {}
371
355
  for image_set in image_set_list:
372
356
  text_file = self.path / "ImageSets" / "Main" / (image_set + ".txt")
@@ -408,9 +392,7 @@ class BaseVOCDataset(BaseDataset[_TArray, _TTarget, list[str], str]):
408
392
 
409
393
  return data, annotations, file_meta
410
394
 
411
- def _read_annotations(
412
- self, annotation: str
413
- ) -> tuple[list[list[float]], list[int], dict[str, Any]]:
395
+ def _read_annotations(self, annotation: str) -> tuple[list[list[float]], list[int], dict[str, Any]]:
414
396
  boxes: list[list[float]] = []
415
397
  label_str = []
416
398
  if not Path(annotation).exists():
@@ -435,12 +417,8 @@ class BaseVOCDataset(BaseDataset[_TArray, _TTarget, list[str], str]):
435
417
  for obj in root.findall("object"):
436
418
  label_str.append(obj.findtext("name", default=""))
437
419
  additional_meta["pose"].append(obj.findtext("pose", default=""))
438
- additional_meta["truncated"].append(
439
- int(obj.findtext("truncated", default="-1"))
440
- )
441
- additional_meta["difficult"].append(
442
- int(obj.findtext("difficult", default="-1"))
443
- )
420
+ additional_meta["truncated"].append(int(obj.findtext("truncated", default="-1")))
421
+ additional_meta["difficult"].append(int(obj.findtext("difficult", default="-1")))
444
422
  boxes.append(
445
423
  [
446
424
  float(obj.findtext("bndbox/xmin", default="0")),
@@ -454,9 +432,7 @@ class BaseVOCDataset(BaseDataset[_TArray, _TTarget, list[str], str]):
454
432
 
455
433
 
456
434
  class VOCDetection(
457
- BaseVOCDataset[
458
- NDArray[np.number[Any]], ObjectDetectionTarget[NDArray[np.number[Any]]]
459
- ],
435
+ BaseVOCDataset[NDArray[np.number[Any]], ObjectDetectionTarget[NDArray[np.number[Any]]]],
460
436
  BaseODDataset[NDArray[np.number[Any]], list[str], str],
461
437
  BaseDatasetNumpyMixin,
462
438
  ):
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: maite-datasets
3
- Version: 0.0.1
3
+ Version: 0.0.3
4
4
  Summary: A collection of Image Classification and Object Detection task datasets conforming to the MAITE protocol.
5
5
  Author-email: Andrew Weng <andrew.weng@ariacoustics.com>, Ryan Wood <ryan.wood@ariacoustics.com>, Shaun Jullens <shaun.jullens@ariacoustics.com>
6
6
  License-Expression: MIT
@@ -0,0 +1,26 @@
1
+ maite_datasets/__init__.py,sha256=81LNxx03O7FzWNZQbIrSovDrdpO_x74WkLPKBJy91gU,483
2
+ maite_datasets/_base.py,sha256=BiWB_xvL4AtV0jxVjzpcZHuRTb52dTD0CQtu08DzoXA,8195
3
+ maite_datasets/_builder.py,sha256=URhRCedvuqsy88N4lzQrwI-uL1kS1_kavP9fS402sPw,10036
4
+ maite_datasets/_collate.py,sha256=-XuKeeMmOnSB0RgQbz8BjsoqQar9Tsf_qALZxijQ498,4063
5
+ maite_datasets/_fileio.py,sha256=7S-hF3xU60AdcsPsfYR7rjbeGZUlv3JjGEZhGJOxGYU,5622
6
+ maite_datasets/_protocols.py,sha256=uwnI2P-zJnpEHJ0eOJ7dO_7KehwHEtEqR4pYcJiEXNk,5312
7
+ maite_datasets/_types.py,sha256=S5DMyiUrkUjV9uM0ysKqxVoi7z5P7B3EPiLI4Fyq9Jc,1147
8
+ maite_datasets/_validate.py,sha256=sP-5lYXkmkiTadJcy_LtEMiZ0m82xR0yELoxWORrZDQ,6904
9
+ maite_datasets/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
10
+ maite_datasets/_mixin/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
11
+ maite_datasets/_mixin/_numpy.py,sha256=GEuRyeprH-STh-_zktAp0Tg6NNyMdh1ThyhjW558NOo,860
12
+ maite_datasets/_mixin/_torch.py,sha256=pkN2vMNsDk_h5wnD5899zIHsPtEADbGfmRyI5CdGonI,827
13
+ maite_datasets/image_classification/__init__.py,sha256=pcZojkdsiMoLgY4mKjoQY6WyEwiGYHxNrAGpnvn3zsY,308
14
+ maite_datasets/image_classification/_cifar10.py,sha256=w7BPGZzUV1gXFoYRgxa6VOqKn1EgQi3x1rrA4nEUbeI,8470
15
+ maite_datasets/image_classification/_mnist.py,sha256=6xDWY4qbY1hlcUZKvVZeQMvYbF0vLtaVzOuQUKJkcJU,8248
16
+ maite_datasets/image_classification/_ships.py,sha256=_fkm4iu6xuvfRuivgIS8S3CYnQOgghi9Kc0Riz1Dr8g,5187
17
+ maite_datasets/object_detection/__init__.py,sha256=NE8apy2C0kTg_Ng_M15U21ZW66WC_LWezmdG8vk2WHM,590
18
+ maite_datasets/object_detection/_antiuav.py,sha256=2xFOOCT2aujkD6T9LHJfUd02zyTsoNlLZ_rxqztUBP0,8333
19
+ maite_datasets/object_detection/_milco.py,sha256=KEU4JFvCxfyMAb4RFMnxTMk_MggdEAV8y4LU-kjN3lE,7997
20
+ maite_datasets/object_detection/_seadrone.py,sha256=w_pSojLzgwdKrUSxaz8r7dPJVKGND6JSYl0S_BKOLH0,271282
21
+ maite_datasets/object_detection/_voc.py,sha256=VuokKaOzI1wSfgG5DC7ufMbRDlG-b6Se3hg4eQzNQbE,19731
22
+ maite_datasets/object_detection/_voc_torch.py,sha256=bjeawnNit7Llcf_cZY_9lcJYoUoAU-Wen6MMT-7QX3k,2917
23
+ maite_datasets-0.0.3.dist-info/METADATA,sha256=hoOvbKjGriS10siM8HsRvepA3nfi-QgUcrpjGsHr1lM,3747
24
+ maite_datasets-0.0.3.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
25
+ maite_datasets-0.0.3.dist-info/licenses/LICENSE,sha256=6h3J3R-ajGHh_isDSftzS5_jJjB9HH4TaI0vU-VscaY,1082
26
+ maite_datasets-0.0.3.dist-info/RECORD,,
@@ -1,23 +0,0 @@
1
- maite_datasets/__init__.py,sha256=K-0CHtknkjv1JHlW0grduC3dZiPzGKqPxfaeWo8ymTw,59
2
- maite_datasets/_base.py,sha256=WhuyFJrfMLPnU1Yc-WUUTVqXPtRs6rnmiwUy-9P01eM,8399
3
- maite_datasets/_fileio.py,sha256=Nuzl1j8sUDpQxlqnRyfbIGAx8UHCxJFOQMyKuA9WTqk,5824
4
- maite_datasets/_protocols.py,sha256=JqtnXeRWwepWBolDFosAXZmJEXIjo4wPA0UMnjqmdOY,5318
5
- maite_datasets/_types.py,sha256=iOhN4UVlH_nVoWBMJVCT7bLz_3LKd6W9vl_zur1z4Aw,1159
6
- maite_datasets/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
7
- maite_datasets/_mixin/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
8
- maite_datasets/_mixin/_numpy.py,sha256=GEuRyeprH-STh-_zktAp0Tg6NNyMdh1ThyhjW558NOo,860
9
- maite_datasets/_mixin/_torch.py,sha256=pkN2vMNsDk_h5wnD5899zIHsPtEADbGfmRyI5CdGonI,827
10
- maite_datasets/image_classification/__init__.py,sha256=pcZojkdsiMoLgY4mKjoQY6WyEwiGYHxNrAGpnvn3zsY,308
11
- maite_datasets/image_classification/_cifar10.py,sha256=muy43KfqJS2M7sY4d20nrLmYdwXf8_nIeYBcvOYcfuk,8552
12
- maite_datasets/image_classification/_mnist.py,sha256=sUvJ2QuOGVd2OsGZTP5q-gYqVw5hEONBqOqH9V19oHk,8366
13
- maite_datasets/image_classification/_ships.py,sha256=kahX8T-P2Sd0ovXxcartFsUzfsohreEWA49qp18Xf44,5203
14
- maite_datasets/object_detection/__init__.py,sha256=NE8apy2C0kTg_Ng_M15U21ZW66WC_LWezmdG8vk2WHM,590
15
- maite_datasets/object_detection/_antiuav.py,sha256=SHE5FvUD8vguucZXjZTik02Zm6Xc79UlqFrRZc7EoLY,8479
16
- maite_datasets/object_detection/_milco.py,sha256=Pqicus9nDfA4qOTyYbI_Emo7YiT18bnQSMyU6QsX5Vk,8033
17
- maite_datasets/object_detection/_seadrone.py,sha256=aGqRyEsn6OCQyySZL3DtPi6TDPwG0svIGZmpXkzyCbc,271558
18
- maite_datasets/object_detection/_voc.py,sha256=73ZFQPKfKbU3yVYQkacJoLBOwVKe726nGJeeqFeWBbo,20037
19
- maite_datasets/object_detection/_voc_torch.py,sha256=bjeawnNit7Llcf_cZY_9lcJYoUoAU-Wen6MMT-7QX3k,2917
20
- maite_datasets-0.0.1.dist-info/METADATA,sha256=mkhBQv_bHXDYSZiNwNj-gDwqr876Iwd5UEt3LXc57LA,3747
21
- maite_datasets-0.0.1.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
22
- maite_datasets-0.0.1.dist-info/licenses/LICENSE,sha256=6h3J3R-ajGHh_isDSftzS5_jJjB9HH4TaI0vU-VscaY,1082
23
- maite_datasets-0.0.1.dist-info/RECORD,,