maite-datasets 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1 @@
1
+ """Module for MAITE compliant Computer Vision datasets."""
@@ -0,0 +1,254 @@
1
+ from __future__ import annotations
2
+
3
+ __all__ = []
4
+
5
+ from abc import abstractmethod
6
+ from pathlib import Path
7
+ from typing import Any, Generic, Iterator, Literal, NamedTuple, Sequence, TypeVar, cast
8
+
9
+ import numpy as np
10
+
11
+ from maite_datasets._fileio import _ensure_exists
12
+ from maite_datasets._protocols import Array, Transform
13
+ from maite_datasets._types import (
14
+ AnnotatedDataset,
15
+ DatasetMetadata,
16
+ DatumMetadata,
17
+ ImageClassificationDataset,
18
+ ObjectDetectionDataset,
19
+ ObjectDetectionTarget,
20
+ )
21
+
22
+ _TArray = TypeVar("_TArray", bound=Array)
23
+ _TTarget = TypeVar("_TTarget")
24
+ _TRawTarget = TypeVar(
25
+ "_TRawTarget",
26
+ Sequence[int],
27
+ Sequence[str],
28
+ Sequence[tuple[list[int], list[list[float]]]],
29
+ )
30
+ _TAnnotation = TypeVar("_TAnnotation", int, str, tuple[list[int], list[list[float]]])
31
+
32
+
33
+ def _to_datum_metadata(index: int, metadata: dict[str, Any]) -> DatumMetadata:
34
+ _id = metadata.pop("id", index)
35
+ return DatumMetadata(id=_id, **metadata)
36
+
37
+
38
+ class DataLocation(NamedTuple):
39
+ url: str
40
+ filename: str
41
+ md5: bool
42
+ checksum: str
43
+
44
+
45
+ class BaseDatasetMixin(Generic[_TArray]):
46
+ index2label: dict[int, str]
47
+
48
+ def _as_array(self, raw: list[Any]) -> _TArray: ...
49
+ def _one_hot_encode(self, value: int | list[int]) -> _TArray: ...
50
+ def _read_file(self, path: str) -> _TArray: ...
51
+
52
+
53
+ class BaseDataset(
54
+ AnnotatedDataset[tuple[_TArray, _TTarget, DatumMetadata]],
55
+ Generic[_TArray, _TTarget, _TRawTarget, _TAnnotation],
56
+ ):
57
+ """
58
+ Base class for internet downloaded datasets.
59
+ """
60
+
61
+ # Each subclass should override the attributes below.
62
+ # Each resource tuple must contain:
63
+ # 'url': str, the URL to download from
64
+ # 'filename': str, the name of the file once downloaded
65
+ # 'md5': boolean, True if it's the checksum value is md5
66
+ # 'checksum': str, the associated checksum for the downloaded file
67
+ _resources: list[DataLocation]
68
+ _resource_index: int = 0
69
+ index2label: dict[int, str]
70
+
71
+ def __init__(
72
+ self,
73
+ root: str | Path,
74
+ image_set: Literal["train", "val", "test", "operational", "base"] = "train",
75
+ transforms: Transform[_TArray] | Sequence[Transform[_TArray]] | None = None,
76
+ download: bool = False,
77
+ verbose: bool = False,
78
+ ) -> None:
79
+ self._root: Path = (
80
+ root.absolute() if isinstance(root, Path) else Path(root).absolute()
81
+ )
82
+ transforms = transforms if transforms is not None else []
83
+ self.transforms: Sequence[Transform[_TArray]] = (
84
+ transforms if isinstance(transforms, Sequence) else [transforms]
85
+ )
86
+ self.image_set = image_set
87
+ self._verbose = verbose
88
+
89
+ # Internal Attributes
90
+ self._download = download
91
+ self._filepaths: list[str]
92
+ self._targets: _TRawTarget
93
+ self._datum_metadata: dict[str, list[Any]]
94
+ self._resource: DataLocation = self._resources[self._resource_index]
95
+ self._label2index = {v: k for k, v in self.index2label.items()}
96
+
97
+ self.metadata: DatasetMetadata = DatasetMetadata(
98
+ id=self._unique_id(),
99
+ index2label=self.index2label,
100
+ split=self.image_set,
101
+ )
102
+
103
+ # Load the data
104
+ self.path: Path = self._get_dataset_dir()
105
+ self._filepaths, self._targets, self._datum_metadata = self._load_data()
106
+ self.size: int = len(self._filepaths)
107
+
108
+ def __str__(self) -> str:
109
+ nt = "\n "
110
+ title = f"{self.__class__.__name__} Dataset"
111
+ sep = "-" * len(title)
112
+ attrs = [
113
+ f"{k.capitalize()}: {v}"
114
+ for k, v in self.__dict__.items()
115
+ if not k.startswith("_")
116
+ ]
117
+ return f"{title}\n{sep}{nt}{nt.join(attrs)}"
118
+
119
+ @property
120
+ def label2index(self) -> dict[str, int]:
121
+ return self._label2index
122
+
123
+ def __iter__(self) -> Iterator[tuple[_TArray, _TTarget, DatumMetadata]]:
124
+ for i in range(len(self)):
125
+ yield self[i]
126
+
127
+ def _get_dataset_dir(self) -> Path:
128
+ # Create a designated folder for this dataset (named after the class)
129
+ if self._root.stem.lower() == self.__class__.__name__.lower():
130
+ dataset_dir: Path = self._root
131
+ else:
132
+ dataset_dir: Path = self._root / self.__class__.__name__.lower()
133
+ if not dataset_dir.exists():
134
+ dataset_dir.mkdir(parents=True, exist_ok=True)
135
+ return dataset_dir
136
+
137
+ def _unique_id(self) -> str:
138
+ return f"{self.__class__.__name__}_{self.image_set}"
139
+
140
+ def _load_data(self) -> tuple[list[str], _TRawTarget, dict[str, Any]]:
141
+ """
142
+ Function to determine if data can be accessed or if it needs to be downloaded and/or extracted.
143
+ """
144
+ if self._verbose:
145
+ print(f"Determining if {self._resource.filename} needs to be downloaded.")
146
+
147
+ try:
148
+ result = self._load_data_inner()
149
+ if self._verbose:
150
+ print("No download needed, loaded data successfully.")
151
+ except FileNotFoundError:
152
+ _ensure_exists(
153
+ *self._resource, self.path, self._root, self._download, self._verbose
154
+ )
155
+ result = self._load_data_inner()
156
+ return result
157
+
158
+ @abstractmethod
159
+ def _load_data_inner(self) -> tuple[list[str], _TRawTarget, dict[str, Any]]: ...
160
+
161
+ def _transform(self, image: _TArray) -> _TArray:
162
+ """Function to transform the image prior to returning based on parameters passed in."""
163
+ for transform in self.transforms:
164
+ image = transform(image)
165
+ return image
166
+
167
+ def __len__(self) -> int:
168
+ return self.size
169
+
170
+
171
+ class BaseICDataset(
172
+ BaseDataset[_TArray, _TArray, list[int], int],
173
+ BaseDatasetMixin[_TArray],
174
+ ImageClassificationDataset[_TArray],
175
+ ):
176
+ """
177
+ Base class for image classification datasets.
178
+ """
179
+
180
+ def __getitem__(self, index: int) -> tuple[_TArray, _TArray, DatumMetadata]:
181
+ """
182
+ Args
183
+ ----
184
+ index : int
185
+ Value of the desired data point
186
+
187
+ Returns
188
+ -------
189
+ tuple[TArray, TArray, DatumMetadata]
190
+ Image, target, datum_metadata - where target is one-hot encoding of class.
191
+ """
192
+ # Get the associated label and score
193
+ label = self._targets[index]
194
+ score = self._one_hot_encode(label)
195
+ # Get the image
196
+ img = self._read_file(self._filepaths[index])
197
+ img = self._transform(img)
198
+
199
+ img_metadata = {key: val[index] for key, val in self._datum_metadata.items()}
200
+
201
+ return img, score, _to_datum_metadata(index, img_metadata)
202
+
203
+
204
+ class BaseODDataset(
205
+ BaseDataset[_TArray, ObjectDetectionTarget[_TArray], _TRawTarget, _TAnnotation],
206
+ BaseDatasetMixin[_TArray],
207
+ ObjectDetectionDataset[_TArray],
208
+ ):
209
+ """
210
+ Base class for object detection datasets.
211
+ """
212
+
213
+ _bboxes_per_size: bool = False
214
+
215
+ def __getitem__(
216
+ self, index: int
217
+ ) -> tuple[_TArray, ObjectDetectionTarget[_TArray], DatumMetadata]:
218
+ """
219
+ Args
220
+ ----
221
+ index : int
222
+ Value of the desired data point
223
+
224
+ Returns
225
+ -------
226
+ tuple[TArray, ObjectDetectionTarget[TArray], DatumMetadata]
227
+ Image, target, datum_metadata - target.boxes returns boxes in x0, y0, x1, y1 format
228
+ """
229
+ # Grab the bounding boxes and labels from the annotations
230
+ annotation = cast(_TAnnotation, self._targets[index])
231
+ boxes, labels, additional_metadata = self._read_annotations(annotation)
232
+ # Get the image
233
+ img = self._read_file(self._filepaths[index])
234
+ img_size = img.shape
235
+ img = self._transform(img)
236
+ # Adjust labels if necessary
237
+ if self._bboxes_per_size and boxes:
238
+ boxes = boxes * np.array(
239
+ [[img_size[1], img_size[2], img_size[1], img_size[2]]]
240
+ )
241
+ # Create the Object Detection Target
242
+ target = ObjectDetectionTarget(
243
+ self._as_array(boxes), self._as_array(labels), self._one_hot_encode(labels)
244
+ )
245
+
246
+ img_metadata = {key: val[index] for key, val in self._datum_metadata.items()}
247
+ img_metadata = img_metadata | additional_metadata
248
+
249
+ return img, target, _to_datum_metadata(index, img_metadata)
250
+
251
+ @abstractmethod
252
+ def _read_annotations(
253
+ self, annotation: _TAnnotation
254
+ ) -> tuple[list[list[float]], list[int], dict[str, Any]]: ...
@@ -0,0 +1,174 @@
1
+ from __future__ import annotations
2
+
3
+ __all__ = []
4
+
5
+ import hashlib
6
+ import tarfile
7
+ import zipfile
8
+ from pathlib import Path
9
+
10
+ import requests
11
+
12
+ try:
13
+ from tqdm.auto import tqdm
14
+ except ImportError:
15
+ tqdm = None
16
+
17
+ ARCHIVE_ENDINGS = [".zip", ".tar", ".tgz"]
18
+ COMPRESS_ENDINGS = [".gz", ".bz2"]
19
+
20
+
21
+ def _print(text: str, verbose: bool) -> None:
22
+ if verbose:
23
+ print(text)
24
+
25
+
26
+ def _validate_file(
27
+ fpath: Path | str, file_md5: str, md5: bool = False, chunk_size: int = 65535
28
+ ) -> bool:
29
+ hasher = hashlib.md5(usedforsecurity=False) if md5 else hashlib.sha256()
30
+ with open(fpath, "rb") as fpath_file:
31
+ while chunk := fpath_file.read(chunk_size):
32
+ hasher.update(chunk)
33
+ return hasher.hexdigest() == file_md5
34
+
35
+
36
+ def _download_dataset(
37
+ url: str, file_path: Path, timeout: int = 60, verbose: bool = False
38
+ ) -> None:
39
+ """Download a single resource from its URL to the `data_folder`."""
40
+ error_msg = "URL fetch failure on {}: {} -- {}"
41
+ try:
42
+ response = requests.get(url, stream=True, timeout=timeout)
43
+ response.raise_for_status()
44
+ except requests.exceptions.HTTPError as e:
45
+ raise RuntimeError(
46
+ f"{error_msg.format(url, e.response.status_code, e.response.reason)}"
47
+ ) from e
48
+ except requests.exceptions.RequestException as e:
49
+ raise ValueError(f"{error_msg.format(url, 'Unknown error', str(e))}") from e
50
+
51
+ total_size = int(response.headers.get("content-length", 0))
52
+ block_size = 8192 # 8 KB
53
+ progress_bar = (
54
+ None
55
+ if tqdm is None
56
+ else tqdm(total=total_size, unit="iB", unit_scale=True, disable=not verbose)
57
+ )
58
+
59
+ with open(file_path, "wb") as f:
60
+ for chunk in response.iter_content(block_size):
61
+ f.write(chunk)
62
+ if progress_bar is not None:
63
+ progress_bar.update(len(chunk))
64
+ if progress_bar is not None:
65
+ progress_bar.close()
66
+
67
+
68
+ def _extract_zip_archive(file_path: Path, extract_to: Path) -> None:
69
+ """Extracts the zip file to the given directory."""
70
+ try:
71
+ with zipfile.ZipFile(file_path, "r") as zip_ref:
72
+ zip_ref.extractall(extract_to) # noqa: S202
73
+ file_path.unlink()
74
+ except zipfile.BadZipFile:
75
+ raise FileNotFoundError(
76
+ f"{file_path.name} is not a valid zip file, skipping extraction."
77
+ )
78
+
79
+
80
+ def _extract_tar_archive(file_path: Path, extract_to: Path) -> None:
81
+ """Extracts a tar file (or compressed tar) to the specified directory."""
82
+ try:
83
+ with tarfile.open(file_path, "r:*") as tar_ref:
84
+ tar_ref.extractall(extract_to) # noqa: S202
85
+ file_path.unlink()
86
+ except tarfile.TarError:
87
+ raise FileNotFoundError(
88
+ f"{file_path.name} is not a valid tar file, skipping extraction."
89
+ )
90
+
91
+
92
+ def _extract_archive(
93
+ file_ext: str,
94
+ file_path: Path,
95
+ directory: Path,
96
+ compression: bool = False,
97
+ verbose: bool = False,
98
+ ) -> None:
99
+ """
100
+ Single function to extract and then flatten if necessary.
101
+ Recursively extracts nested zip files as well.
102
+ Extracts and flattens all folders to the base directory.
103
+ """
104
+ if file_ext != ".zip" or compression:
105
+ _extract_tar_archive(file_path, directory)
106
+ else:
107
+ _extract_zip_archive(file_path, directory)
108
+ # Look for nested zip files in the extraction directory and extract them recursively.
109
+ # Does NOT extract in place - extracts everything to directory
110
+ for child in directory.iterdir():
111
+ if child.suffix == ".zip":
112
+ _print(f"Extracting nested zip: {child} to {directory}", verbose)
113
+ _extract_zip_archive(child, directory)
114
+
115
+
116
+ def _ensure_exists(
117
+ url: str,
118
+ filename: str,
119
+ md5: bool,
120
+ checksum: str,
121
+ directory: Path,
122
+ root: Path,
123
+ download: bool = True,
124
+ verbose: bool = False,
125
+ ) -> None:
126
+ """
127
+ For each resource, download it if it doesn't exist in the dataset_dir.
128
+ If the resource is a zip file, extract it (including recursively extracting nested zips).
129
+ """
130
+ file_path = directory / str(filename)
131
+ alternate_path = root / str(filename)
132
+ _, file_ext = file_path.stem, file_path.suffix
133
+ compression = False
134
+ if file_ext in COMPRESS_ENDINGS:
135
+ file_ext = file_path.suffixes[0]
136
+ compression = True
137
+
138
+ check_path = (
139
+ alternate_path
140
+ if alternate_path.exists() and not file_path.exists()
141
+ else file_path
142
+ )
143
+
144
+ # Download file if it doesn't exist.
145
+ if not check_path.exists() and download:
146
+ _print(f"Downloading {filename} from {url}", verbose)
147
+ _download_dataset(url, check_path, verbose=verbose)
148
+
149
+ if not _validate_file(check_path, checksum, md5):
150
+ raise Exception(
151
+ "File checksum mismatch. Remove current file and retry download."
152
+ )
153
+
154
+ # If the file is a zip, tar or tgz extract it into the designated folder.
155
+ if file_ext in ARCHIVE_ENDINGS:
156
+ _print(f"Extracting {filename}...", verbose)
157
+ _extract_archive(file_ext, check_path, directory, compression, verbose)
158
+
159
+ elif not check_path.exists() and not download:
160
+ raise FileNotFoundError(
161
+ "Data could not be loaded with the provided root directory, "
162
+ f"the file path to the file {filename} does not exist, "
163
+ "and the download parameter is set to False."
164
+ )
165
+ else:
166
+ if not _validate_file(check_path, checksum, md5):
167
+ raise Exception(
168
+ "File checksum mismatch. Remove current file and retry download."
169
+ )
170
+ _print(f"{filename} already exists, skipping download.", verbose)
171
+
172
+ if file_ext in ARCHIVE_ENDINGS:
173
+ _print(f"Extracting {filename}...", verbose)
174
+ _extract_archive(file_ext, check_path, directory, compression, verbose)
File without changes
@@ -0,0 +1,28 @@
1
+ from __future__ import annotations
2
+
3
+ __all__ = []
4
+
5
+ from typing import Any
6
+
7
+ import numpy as np
8
+ from numpy.typing import NDArray
9
+ from PIL import Image
10
+
11
+ from maite_datasets._base import BaseDatasetMixin
12
+
13
+
14
+ class BaseDatasetNumpyMixin(BaseDatasetMixin[NDArray[np.number[Any]]]):
15
+ def _as_array(self, raw: list[Any]) -> NDArray[np.number[Any]]:
16
+ return np.asarray(raw)
17
+
18
+ def _one_hot_encode(self, value: int | list[int]) -> NDArray[np.number[Any]]:
19
+ if isinstance(value, int):
20
+ encoded = np.zeros(len(self.index2label))
21
+ encoded[value] = 1
22
+ else:
23
+ encoded = np.zeros((len(value), len(self.index2label)))
24
+ encoded[np.arange(len(value)), value] = 1
25
+ return encoded
26
+
27
+ def _read_file(self, path: str) -> NDArray[np.number[Any]]:
28
+ return np.array(Image.open(path)).transpose(2, 0, 1)
@@ -0,0 +1,28 @@
1
+ from __future__ import annotations
2
+
3
+ __all__ = []
4
+
5
+ from typing import Any
6
+
7
+ import numpy as np
8
+ import torch
9
+ from PIL import Image
10
+
11
+ from maite_datasets._base import BaseDatasetMixin
12
+
13
+
14
+ class BaseDatasetTorchMixin(BaseDatasetMixin[torch.Tensor]):
15
+ def _as_array(self, raw: list[Any]) -> torch.Tensor:
16
+ return torch.as_tensor(raw)
17
+
18
+ def _one_hot_encode(self, value: int | list[int]) -> torch.Tensor:
19
+ if isinstance(value, int):
20
+ encoded = torch.zeros(len(self.index2label))
21
+ encoded[value] = 1
22
+ else:
23
+ encoded = torch.zeros((len(value), len(self.index2label)))
24
+ encoded[torch.arange(len(value)), value] = 1
25
+ return encoded
26
+
27
+ def _read_file(self, path: str) -> torch.Tensor:
28
+ return torch.as_tensor(np.array(Image.open(path)).transpose(2, 0, 1))