maite-datasets 0.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- maite_datasets/__init__.py +1 -0
- maite_datasets/_base.py +254 -0
- maite_datasets/_fileio.py +174 -0
- maite_datasets/_mixin/__init__.py +0 -0
- maite_datasets/_mixin/_numpy.py +28 -0
- maite_datasets/_mixin/_torch.py +28 -0
- maite_datasets/_protocols.py +224 -0
- maite_datasets/_types.py +54 -0
- maite_datasets/image_classification/__init__.py +11 -0
- maite_datasets/image_classification/_cifar10.py +233 -0
- maite_datasets/image_classification/_mnist.py +215 -0
- maite_datasets/image_classification/_ships.py +150 -0
- maite_datasets/object_detection/__init__.py +20 -0
- maite_datasets/object_detection/_antiuav.py +200 -0
- maite_datasets/object_detection/_milco.py +207 -0
- maite_datasets/object_detection/_seadrone.py +551 -0
- maite_datasets/object_detection/_voc.py +510 -0
- maite_datasets/object_detection/_voc_torch.py +65 -0
- maite_datasets/py.typed +0 -0
- maite_datasets-0.0.1.dist-info/METADATA +91 -0
- maite_datasets-0.0.1.dist-info/RECORD +23 -0
- maite_datasets-0.0.1.dist-info/WHEEL +4 -0
- maite_datasets-0.0.1.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1 @@
|
|
1
|
+
"""Module for MAITE compliant Computer Vision datasets."""
|
maite_datasets/_base.py
ADDED
@@ -0,0 +1,254 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
__all__ = []
|
4
|
+
|
5
|
+
from abc import abstractmethod
|
6
|
+
from pathlib import Path
|
7
|
+
from typing import Any, Generic, Iterator, Literal, NamedTuple, Sequence, TypeVar, cast
|
8
|
+
|
9
|
+
import numpy as np
|
10
|
+
|
11
|
+
from maite_datasets._fileio import _ensure_exists
|
12
|
+
from maite_datasets._protocols import Array, Transform
|
13
|
+
from maite_datasets._types import (
|
14
|
+
AnnotatedDataset,
|
15
|
+
DatasetMetadata,
|
16
|
+
DatumMetadata,
|
17
|
+
ImageClassificationDataset,
|
18
|
+
ObjectDetectionDataset,
|
19
|
+
ObjectDetectionTarget,
|
20
|
+
)
|
21
|
+
|
22
|
+
_TArray = TypeVar("_TArray", bound=Array)
|
23
|
+
_TTarget = TypeVar("_TTarget")
|
24
|
+
_TRawTarget = TypeVar(
|
25
|
+
"_TRawTarget",
|
26
|
+
Sequence[int],
|
27
|
+
Sequence[str],
|
28
|
+
Sequence[tuple[list[int], list[list[float]]]],
|
29
|
+
)
|
30
|
+
_TAnnotation = TypeVar("_TAnnotation", int, str, tuple[list[int], list[list[float]]])
|
31
|
+
|
32
|
+
|
33
|
+
def _to_datum_metadata(index: int, metadata: dict[str, Any]) -> DatumMetadata:
|
34
|
+
_id = metadata.pop("id", index)
|
35
|
+
return DatumMetadata(id=_id, **metadata)
|
36
|
+
|
37
|
+
|
38
|
+
class DataLocation(NamedTuple):
|
39
|
+
url: str
|
40
|
+
filename: str
|
41
|
+
md5: bool
|
42
|
+
checksum: str
|
43
|
+
|
44
|
+
|
45
|
+
class BaseDatasetMixin(Generic[_TArray]):
|
46
|
+
index2label: dict[int, str]
|
47
|
+
|
48
|
+
def _as_array(self, raw: list[Any]) -> _TArray: ...
|
49
|
+
def _one_hot_encode(self, value: int | list[int]) -> _TArray: ...
|
50
|
+
def _read_file(self, path: str) -> _TArray: ...
|
51
|
+
|
52
|
+
|
53
|
+
class BaseDataset(
|
54
|
+
AnnotatedDataset[tuple[_TArray, _TTarget, DatumMetadata]],
|
55
|
+
Generic[_TArray, _TTarget, _TRawTarget, _TAnnotation],
|
56
|
+
):
|
57
|
+
"""
|
58
|
+
Base class for internet downloaded datasets.
|
59
|
+
"""
|
60
|
+
|
61
|
+
# Each subclass should override the attributes below.
|
62
|
+
# Each resource tuple must contain:
|
63
|
+
# 'url': str, the URL to download from
|
64
|
+
# 'filename': str, the name of the file once downloaded
|
65
|
+
# 'md5': boolean, True if it's the checksum value is md5
|
66
|
+
# 'checksum': str, the associated checksum for the downloaded file
|
67
|
+
_resources: list[DataLocation]
|
68
|
+
_resource_index: int = 0
|
69
|
+
index2label: dict[int, str]
|
70
|
+
|
71
|
+
def __init__(
|
72
|
+
self,
|
73
|
+
root: str | Path,
|
74
|
+
image_set: Literal["train", "val", "test", "operational", "base"] = "train",
|
75
|
+
transforms: Transform[_TArray] | Sequence[Transform[_TArray]] | None = None,
|
76
|
+
download: bool = False,
|
77
|
+
verbose: bool = False,
|
78
|
+
) -> None:
|
79
|
+
self._root: Path = (
|
80
|
+
root.absolute() if isinstance(root, Path) else Path(root).absolute()
|
81
|
+
)
|
82
|
+
transforms = transforms if transforms is not None else []
|
83
|
+
self.transforms: Sequence[Transform[_TArray]] = (
|
84
|
+
transforms if isinstance(transforms, Sequence) else [transforms]
|
85
|
+
)
|
86
|
+
self.image_set = image_set
|
87
|
+
self._verbose = verbose
|
88
|
+
|
89
|
+
# Internal Attributes
|
90
|
+
self._download = download
|
91
|
+
self._filepaths: list[str]
|
92
|
+
self._targets: _TRawTarget
|
93
|
+
self._datum_metadata: dict[str, list[Any]]
|
94
|
+
self._resource: DataLocation = self._resources[self._resource_index]
|
95
|
+
self._label2index = {v: k for k, v in self.index2label.items()}
|
96
|
+
|
97
|
+
self.metadata: DatasetMetadata = DatasetMetadata(
|
98
|
+
id=self._unique_id(),
|
99
|
+
index2label=self.index2label,
|
100
|
+
split=self.image_set,
|
101
|
+
)
|
102
|
+
|
103
|
+
# Load the data
|
104
|
+
self.path: Path = self._get_dataset_dir()
|
105
|
+
self._filepaths, self._targets, self._datum_metadata = self._load_data()
|
106
|
+
self.size: int = len(self._filepaths)
|
107
|
+
|
108
|
+
def __str__(self) -> str:
|
109
|
+
nt = "\n "
|
110
|
+
title = f"{self.__class__.__name__} Dataset"
|
111
|
+
sep = "-" * len(title)
|
112
|
+
attrs = [
|
113
|
+
f"{k.capitalize()}: {v}"
|
114
|
+
for k, v in self.__dict__.items()
|
115
|
+
if not k.startswith("_")
|
116
|
+
]
|
117
|
+
return f"{title}\n{sep}{nt}{nt.join(attrs)}"
|
118
|
+
|
119
|
+
@property
|
120
|
+
def label2index(self) -> dict[str, int]:
|
121
|
+
return self._label2index
|
122
|
+
|
123
|
+
def __iter__(self) -> Iterator[tuple[_TArray, _TTarget, DatumMetadata]]:
|
124
|
+
for i in range(len(self)):
|
125
|
+
yield self[i]
|
126
|
+
|
127
|
+
def _get_dataset_dir(self) -> Path:
|
128
|
+
# Create a designated folder for this dataset (named after the class)
|
129
|
+
if self._root.stem.lower() == self.__class__.__name__.lower():
|
130
|
+
dataset_dir: Path = self._root
|
131
|
+
else:
|
132
|
+
dataset_dir: Path = self._root / self.__class__.__name__.lower()
|
133
|
+
if not dataset_dir.exists():
|
134
|
+
dataset_dir.mkdir(parents=True, exist_ok=True)
|
135
|
+
return dataset_dir
|
136
|
+
|
137
|
+
def _unique_id(self) -> str:
|
138
|
+
return f"{self.__class__.__name__}_{self.image_set}"
|
139
|
+
|
140
|
+
def _load_data(self) -> tuple[list[str], _TRawTarget, dict[str, Any]]:
|
141
|
+
"""
|
142
|
+
Function to determine if data can be accessed or if it needs to be downloaded and/or extracted.
|
143
|
+
"""
|
144
|
+
if self._verbose:
|
145
|
+
print(f"Determining if {self._resource.filename} needs to be downloaded.")
|
146
|
+
|
147
|
+
try:
|
148
|
+
result = self._load_data_inner()
|
149
|
+
if self._verbose:
|
150
|
+
print("No download needed, loaded data successfully.")
|
151
|
+
except FileNotFoundError:
|
152
|
+
_ensure_exists(
|
153
|
+
*self._resource, self.path, self._root, self._download, self._verbose
|
154
|
+
)
|
155
|
+
result = self._load_data_inner()
|
156
|
+
return result
|
157
|
+
|
158
|
+
@abstractmethod
|
159
|
+
def _load_data_inner(self) -> tuple[list[str], _TRawTarget, dict[str, Any]]: ...
|
160
|
+
|
161
|
+
def _transform(self, image: _TArray) -> _TArray:
|
162
|
+
"""Function to transform the image prior to returning based on parameters passed in."""
|
163
|
+
for transform in self.transforms:
|
164
|
+
image = transform(image)
|
165
|
+
return image
|
166
|
+
|
167
|
+
def __len__(self) -> int:
|
168
|
+
return self.size
|
169
|
+
|
170
|
+
|
171
|
+
class BaseICDataset(
|
172
|
+
BaseDataset[_TArray, _TArray, list[int], int],
|
173
|
+
BaseDatasetMixin[_TArray],
|
174
|
+
ImageClassificationDataset[_TArray],
|
175
|
+
):
|
176
|
+
"""
|
177
|
+
Base class for image classification datasets.
|
178
|
+
"""
|
179
|
+
|
180
|
+
def __getitem__(self, index: int) -> tuple[_TArray, _TArray, DatumMetadata]:
|
181
|
+
"""
|
182
|
+
Args
|
183
|
+
----
|
184
|
+
index : int
|
185
|
+
Value of the desired data point
|
186
|
+
|
187
|
+
Returns
|
188
|
+
-------
|
189
|
+
tuple[TArray, TArray, DatumMetadata]
|
190
|
+
Image, target, datum_metadata - where target is one-hot encoding of class.
|
191
|
+
"""
|
192
|
+
# Get the associated label and score
|
193
|
+
label = self._targets[index]
|
194
|
+
score = self._one_hot_encode(label)
|
195
|
+
# Get the image
|
196
|
+
img = self._read_file(self._filepaths[index])
|
197
|
+
img = self._transform(img)
|
198
|
+
|
199
|
+
img_metadata = {key: val[index] for key, val in self._datum_metadata.items()}
|
200
|
+
|
201
|
+
return img, score, _to_datum_metadata(index, img_metadata)
|
202
|
+
|
203
|
+
|
204
|
+
class BaseODDataset(
|
205
|
+
BaseDataset[_TArray, ObjectDetectionTarget[_TArray], _TRawTarget, _TAnnotation],
|
206
|
+
BaseDatasetMixin[_TArray],
|
207
|
+
ObjectDetectionDataset[_TArray],
|
208
|
+
):
|
209
|
+
"""
|
210
|
+
Base class for object detection datasets.
|
211
|
+
"""
|
212
|
+
|
213
|
+
_bboxes_per_size: bool = False
|
214
|
+
|
215
|
+
def __getitem__(
|
216
|
+
self, index: int
|
217
|
+
) -> tuple[_TArray, ObjectDetectionTarget[_TArray], DatumMetadata]:
|
218
|
+
"""
|
219
|
+
Args
|
220
|
+
----
|
221
|
+
index : int
|
222
|
+
Value of the desired data point
|
223
|
+
|
224
|
+
Returns
|
225
|
+
-------
|
226
|
+
tuple[TArray, ObjectDetectionTarget[TArray], DatumMetadata]
|
227
|
+
Image, target, datum_metadata - target.boxes returns boxes in x0, y0, x1, y1 format
|
228
|
+
"""
|
229
|
+
# Grab the bounding boxes and labels from the annotations
|
230
|
+
annotation = cast(_TAnnotation, self._targets[index])
|
231
|
+
boxes, labels, additional_metadata = self._read_annotations(annotation)
|
232
|
+
# Get the image
|
233
|
+
img = self._read_file(self._filepaths[index])
|
234
|
+
img_size = img.shape
|
235
|
+
img = self._transform(img)
|
236
|
+
# Adjust labels if necessary
|
237
|
+
if self._bboxes_per_size and boxes:
|
238
|
+
boxes = boxes * np.array(
|
239
|
+
[[img_size[1], img_size[2], img_size[1], img_size[2]]]
|
240
|
+
)
|
241
|
+
# Create the Object Detection Target
|
242
|
+
target = ObjectDetectionTarget(
|
243
|
+
self._as_array(boxes), self._as_array(labels), self._one_hot_encode(labels)
|
244
|
+
)
|
245
|
+
|
246
|
+
img_metadata = {key: val[index] for key, val in self._datum_metadata.items()}
|
247
|
+
img_metadata = img_metadata | additional_metadata
|
248
|
+
|
249
|
+
return img, target, _to_datum_metadata(index, img_metadata)
|
250
|
+
|
251
|
+
@abstractmethod
|
252
|
+
def _read_annotations(
|
253
|
+
self, annotation: _TAnnotation
|
254
|
+
) -> tuple[list[list[float]], list[int], dict[str, Any]]: ...
|
@@ -0,0 +1,174 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
__all__ = []
|
4
|
+
|
5
|
+
import hashlib
|
6
|
+
import tarfile
|
7
|
+
import zipfile
|
8
|
+
from pathlib import Path
|
9
|
+
|
10
|
+
import requests
|
11
|
+
|
12
|
+
try:
|
13
|
+
from tqdm.auto import tqdm
|
14
|
+
except ImportError:
|
15
|
+
tqdm = None
|
16
|
+
|
17
|
+
ARCHIVE_ENDINGS = [".zip", ".tar", ".tgz"]
|
18
|
+
COMPRESS_ENDINGS = [".gz", ".bz2"]
|
19
|
+
|
20
|
+
|
21
|
+
def _print(text: str, verbose: bool) -> None:
|
22
|
+
if verbose:
|
23
|
+
print(text)
|
24
|
+
|
25
|
+
|
26
|
+
def _validate_file(
|
27
|
+
fpath: Path | str, file_md5: str, md5: bool = False, chunk_size: int = 65535
|
28
|
+
) -> bool:
|
29
|
+
hasher = hashlib.md5(usedforsecurity=False) if md5 else hashlib.sha256()
|
30
|
+
with open(fpath, "rb") as fpath_file:
|
31
|
+
while chunk := fpath_file.read(chunk_size):
|
32
|
+
hasher.update(chunk)
|
33
|
+
return hasher.hexdigest() == file_md5
|
34
|
+
|
35
|
+
|
36
|
+
def _download_dataset(
|
37
|
+
url: str, file_path: Path, timeout: int = 60, verbose: bool = False
|
38
|
+
) -> None:
|
39
|
+
"""Download a single resource from its URL to the `data_folder`."""
|
40
|
+
error_msg = "URL fetch failure on {}: {} -- {}"
|
41
|
+
try:
|
42
|
+
response = requests.get(url, stream=True, timeout=timeout)
|
43
|
+
response.raise_for_status()
|
44
|
+
except requests.exceptions.HTTPError as e:
|
45
|
+
raise RuntimeError(
|
46
|
+
f"{error_msg.format(url, e.response.status_code, e.response.reason)}"
|
47
|
+
) from e
|
48
|
+
except requests.exceptions.RequestException as e:
|
49
|
+
raise ValueError(f"{error_msg.format(url, 'Unknown error', str(e))}") from e
|
50
|
+
|
51
|
+
total_size = int(response.headers.get("content-length", 0))
|
52
|
+
block_size = 8192 # 8 KB
|
53
|
+
progress_bar = (
|
54
|
+
None
|
55
|
+
if tqdm is None
|
56
|
+
else tqdm(total=total_size, unit="iB", unit_scale=True, disable=not verbose)
|
57
|
+
)
|
58
|
+
|
59
|
+
with open(file_path, "wb") as f:
|
60
|
+
for chunk in response.iter_content(block_size):
|
61
|
+
f.write(chunk)
|
62
|
+
if progress_bar is not None:
|
63
|
+
progress_bar.update(len(chunk))
|
64
|
+
if progress_bar is not None:
|
65
|
+
progress_bar.close()
|
66
|
+
|
67
|
+
|
68
|
+
def _extract_zip_archive(file_path: Path, extract_to: Path) -> None:
|
69
|
+
"""Extracts the zip file to the given directory."""
|
70
|
+
try:
|
71
|
+
with zipfile.ZipFile(file_path, "r") as zip_ref:
|
72
|
+
zip_ref.extractall(extract_to) # noqa: S202
|
73
|
+
file_path.unlink()
|
74
|
+
except zipfile.BadZipFile:
|
75
|
+
raise FileNotFoundError(
|
76
|
+
f"{file_path.name} is not a valid zip file, skipping extraction."
|
77
|
+
)
|
78
|
+
|
79
|
+
|
80
|
+
def _extract_tar_archive(file_path: Path, extract_to: Path) -> None:
|
81
|
+
"""Extracts a tar file (or compressed tar) to the specified directory."""
|
82
|
+
try:
|
83
|
+
with tarfile.open(file_path, "r:*") as tar_ref:
|
84
|
+
tar_ref.extractall(extract_to) # noqa: S202
|
85
|
+
file_path.unlink()
|
86
|
+
except tarfile.TarError:
|
87
|
+
raise FileNotFoundError(
|
88
|
+
f"{file_path.name} is not a valid tar file, skipping extraction."
|
89
|
+
)
|
90
|
+
|
91
|
+
|
92
|
+
def _extract_archive(
|
93
|
+
file_ext: str,
|
94
|
+
file_path: Path,
|
95
|
+
directory: Path,
|
96
|
+
compression: bool = False,
|
97
|
+
verbose: bool = False,
|
98
|
+
) -> None:
|
99
|
+
"""
|
100
|
+
Single function to extract and then flatten if necessary.
|
101
|
+
Recursively extracts nested zip files as well.
|
102
|
+
Extracts and flattens all folders to the base directory.
|
103
|
+
"""
|
104
|
+
if file_ext != ".zip" or compression:
|
105
|
+
_extract_tar_archive(file_path, directory)
|
106
|
+
else:
|
107
|
+
_extract_zip_archive(file_path, directory)
|
108
|
+
# Look for nested zip files in the extraction directory and extract them recursively.
|
109
|
+
# Does NOT extract in place - extracts everything to directory
|
110
|
+
for child in directory.iterdir():
|
111
|
+
if child.suffix == ".zip":
|
112
|
+
_print(f"Extracting nested zip: {child} to {directory}", verbose)
|
113
|
+
_extract_zip_archive(child, directory)
|
114
|
+
|
115
|
+
|
116
|
+
def _ensure_exists(
|
117
|
+
url: str,
|
118
|
+
filename: str,
|
119
|
+
md5: bool,
|
120
|
+
checksum: str,
|
121
|
+
directory: Path,
|
122
|
+
root: Path,
|
123
|
+
download: bool = True,
|
124
|
+
verbose: bool = False,
|
125
|
+
) -> None:
|
126
|
+
"""
|
127
|
+
For each resource, download it if it doesn't exist in the dataset_dir.
|
128
|
+
If the resource is a zip file, extract it (including recursively extracting nested zips).
|
129
|
+
"""
|
130
|
+
file_path = directory / str(filename)
|
131
|
+
alternate_path = root / str(filename)
|
132
|
+
_, file_ext = file_path.stem, file_path.suffix
|
133
|
+
compression = False
|
134
|
+
if file_ext in COMPRESS_ENDINGS:
|
135
|
+
file_ext = file_path.suffixes[0]
|
136
|
+
compression = True
|
137
|
+
|
138
|
+
check_path = (
|
139
|
+
alternate_path
|
140
|
+
if alternate_path.exists() and not file_path.exists()
|
141
|
+
else file_path
|
142
|
+
)
|
143
|
+
|
144
|
+
# Download file if it doesn't exist.
|
145
|
+
if not check_path.exists() and download:
|
146
|
+
_print(f"Downloading {filename} from {url}", verbose)
|
147
|
+
_download_dataset(url, check_path, verbose=verbose)
|
148
|
+
|
149
|
+
if not _validate_file(check_path, checksum, md5):
|
150
|
+
raise Exception(
|
151
|
+
"File checksum mismatch. Remove current file and retry download."
|
152
|
+
)
|
153
|
+
|
154
|
+
# If the file is a zip, tar or tgz extract it into the designated folder.
|
155
|
+
if file_ext in ARCHIVE_ENDINGS:
|
156
|
+
_print(f"Extracting {filename}...", verbose)
|
157
|
+
_extract_archive(file_ext, check_path, directory, compression, verbose)
|
158
|
+
|
159
|
+
elif not check_path.exists() and not download:
|
160
|
+
raise FileNotFoundError(
|
161
|
+
"Data could not be loaded with the provided root directory, "
|
162
|
+
f"the file path to the file {filename} does not exist, "
|
163
|
+
"and the download parameter is set to False."
|
164
|
+
)
|
165
|
+
else:
|
166
|
+
if not _validate_file(check_path, checksum, md5):
|
167
|
+
raise Exception(
|
168
|
+
"File checksum mismatch. Remove current file and retry download."
|
169
|
+
)
|
170
|
+
_print(f"{filename} already exists, skipping download.", verbose)
|
171
|
+
|
172
|
+
if file_ext in ARCHIVE_ENDINGS:
|
173
|
+
_print(f"Extracting {filename}...", verbose)
|
174
|
+
_extract_archive(file_ext, check_path, directory, compression, verbose)
|
File without changes
|
@@ -0,0 +1,28 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
__all__ = []
|
4
|
+
|
5
|
+
from typing import Any
|
6
|
+
|
7
|
+
import numpy as np
|
8
|
+
from numpy.typing import NDArray
|
9
|
+
from PIL import Image
|
10
|
+
|
11
|
+
from maite_datasets._base import BaseDatasetMixin
|
12
|
+
|
13
|
+
|
14
|
+
class BaseDatasetNumpyMixin(BaseDatasetMixin[NDArray[np.number[Any]]]):
|
15
|
+
def _as_array(self, raw: list[Any]) -> NDArray[np.number[Any]]:
|
16
|
+
return np.asarray(raw)
|
17
|
+
|
18
|
+
def _one_hot_encode(self, value: int | list[int]) -> NDArray[np.number[Any]]:
|
19
|
+
if isinstance(value, int):
|
20
|
+
encoded = np.zeros(len(self.index2label))
|
21
|
+
encoded[value] = 1
|
22
|
+
else:
|
23
|
+
encoded = np.zeros((len(value), len(self.index2label)))
|
24
|
+
encoded[np.arange(len(value)), value] = 1
|
25
|
+
return encoded
|
26
|
+
|
27
|
+
def _read_file(self, path: str) -> NDArray[np.number[Any]]:
|
28
|
+
return np.array(Image.open(path)).transpose(2, 0, 1)
|
@@ -0,0 +1,28 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
__all__ = []
|
4
|
+
|
5
|
+
from typing import Any
|
6
|
+
|
7
|
+
import numpy as np
|
8
|
+
import torch
|
9
|
+
from PIL import Image
|
10
|
+
|
11
|
+
from maite_datasets._base import BaseDatasetMixin
|
12
|
+
|
13
|
+
|
14
|
+
class BaseDatasetTorchMixin(BaseDatasetMixin[torch.Tensor]):
|
15
|
+
def _as_array(self, raw: list[Any]) -> torch.Tensor:
|
16
|
+
return torch.as_tensor(raw)
|
17
|
+
|
18
|
+
def _one_hot_encode(self, value: int | list[int]) -> torch.Tensor:
|
19
|
+
if isinstance(value, int):
|
20
|
+
encoded = torch.zeros(len(self.index2label))
|
21
|
+
encoded[value] = 1
|
22
|
+
else:
|
23
|
+
encoded = torch.zeros((len(value), len(self.index2label)))
|
24
|
+
encoded[torch.arange(len(value)), value] = 1
|
25
|
+
return encoded
|
26
|
+
|
27
|
+
def _read_file(self, path: str) -> torch.Tensor:
|
28
|
+
return torch.as_tensor(np.array(Image.open(path)).transpose(2, 0, 1))
|