dataeval 0.86.9__py3-none-any.whl → 0.87.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dataeval/__init__.py +1 -1
- dataeval/_version.py +2 -2
- dataeval/config.py +4 -19
- dataeval/data/_metadata.py +56 -27
- dataeval/data/_split.py +1 -1
- dataeval/data/selections/_classbalance.py +4 -3
- dataeval/data/selections/_classfilter.py +5 -5
- dataeval/data/selections/_indices.py +2 -2
- dataeval/data/selections/_prioritize.py +249 -29
- dataeval/data/selections/_reverse.py +1 -1
- dataeval/data/selections/_shuffle.py +2 -2
- dataeval/detectors/ood/__init__.py +2 -1
- dataeval/detectors/ood/base.py +38 -1
- dataeval/detectors/ood/knn.py +95 -0
- dataeval/metrics/bias/_balance.py +28 -21
- dataeval/metrics/bias/_diversity.py +4 -4
- dataeval/metrics/bias/_parity.py +2 -2
- dataeval/metrics/stats/_hashstats.py +19 -2
- dataeval/outputs/_workflows.py +20 -7
- dataeval/typing.py +14 -2
- dataeval/utils/__init__.py +2 -2
- dataeval/utils/_bin.py +7 -6
- dataeval/utils/data/__init__.py +2 -0
- dataeval/utils/data/_dataset.py +13 -6
- dataeval/utils/data/_validate.py +169 -0
- {dataeval-0.86.9.dist-info → dataeval-0.87.0.dist-info}/METADATA +5 -17
- {dataeval-0.86.9.dist-info → dataeval-0.87.0.dist-info}/RECORD +29 -39
- dataeval/utils/datasets/__init__.py +0 -21
- dataeval/utils/datasets/_antiuav.py +0 -189
- dataeval/utils/datasets/_base.py +0 -266
- dataeval/utils/datasets/_cifar10.py +0 -201
- dataeval/utils/datasets/_fileio.py +0 -142
- dataeval/utils/datasets/_milco.py +0 -197
- dataeval/utils/datasets/_mixin.py +0 -54
- dataeval/utils/datasets/_mnist.py +0 -202
- dataeval/utils/datasets/_seadrone.py +0 -512
- dataeval/utils/datasets/_ships.py +0 -144
- dataeval/utils/datasets/_types.py +0 -48
- dataeval/utils/datasets/_voc.py +0 -583
- {dataeval-0.86.9.dist-info → dataeval-0.87.0.dist-info}/WHEEL +0 -0
- /dataeval-0.86.9.dist-info/licenses/LICENSE.txt → /dataeval-0.87.0.dist-info/licenses/LICENSE +0 -0
@@ -1,201 +0,0 @@
|
|
1
|
-
from __future__ import annotations
|
2
|
-
|
3
|
-
__all__ = []
|
4
|
-
|
5
|
-
from pathlib import Path
|
6
|
-
from typing import TYPE_CHECKING, Any, Literal, Sequence, TypeVar
|
7
|
-
|
8
|
-
import numpy as np
|
9
|
-
from numpy.typing import NDArray
|
10
|
-
|
11
|
-
from dataeval.utils.datasets._base import BaseICDataset, DataLocation
|
12
|
-
from dataeval.utils.datasets._mixin import BaseDatasetNumpyMixin
|
13
|
-
|
14
|
-
if TYPE_CHECKING:
|
15
|
-
from dataeval.typing import Transform
|
16
|
-
|
17
|
-
CIFARClassStringMap = Literal["airplane", "automobile", "bird", "cat", "deer", "dog", "frog", "horse", "ship", "truck"]
|
18
|
-
TCIFARClassMap = TypeVar("TCIFARClassMap", CIFARClassStringMap, int, list[CIFARClassStringMap], list[int])
|
19
|
-
|
20
|
-
|
21
|
-
class CIFAR10(BaseICDataset[NDArray[Any]], BaseDatasetNumpyMixin):
|
22
|
-
"""
|
23
|
-
`CIFAR10 <https://www.cs.toronto.edu/~kriz/cifar.html>`_ Dataset as NumPy arrays.
|
24
|
-
|
25
|
-
Parameters
|
26
|
-
----------
|
27
|
-
root : str or pathlib.Path
|
28
|
-
Root directory where the data should be downloaded to or the ``cifar10`` folder of the already downloaded data.
|
29
|
-
image_set : "train", "test" or "base", default "train"
|
30
|
-
If "base", returns all of the data to allow the user to create their own splits.
|
31
|
-
transforms : Transform, Sequence[Transform] or None, default None
|
32
|
-
Transform(s) to apply to the data.
|
33
|
-
download : bool, default False
|
34
|
-
If True, downloads the dataset from the internet and puts it in root directory.
|
35
|
-
Class checks to see if data is already downloaded to ensure it does not create a duplicate download.
|
36
|
-
verbose : bool, default False
|
37
|
-
If True, outputs print statements.
|
38
|
-
|
39
|
-
Attributes
|
40
|
-
----------
|
41
|
-
path : pathlib.Path
|
42
|
-
Location of the folder containing the data.
|
43
|
-
image_set : "train", "test" or "base"
|
44
|
-
The selected image set from the dataset.
|
45
|
-
transforms : Sequence[Transform]
|
46
|
-
The transforms to be applied to the data.
|
47
|
-
size : int
|
48
|
-
The size of the dataset.
|
49
|
-
index2label : dict[int, str]
|
50
|
-
Dictionary which translates from class integers to the associated class strings.
|
51
|
-
label2index : dict[str, int]
|
52
|
-
Dictionary which translates from class strings to the associated class integers.
|
53
|
-
metadata : DatasetMetadata
|
54
|
-
Typed dictionary containing dataset metadata, such as `id` which returns the dataset class name.
|
55
|
-
"""
|
56
|
-
|
57
|
-
_resources = [
|
58
|
-
DataLocation(
|
59
|
-
url="https://www.cs.toronto.edu/~kriz/cifar-10-binary.tar.gz",
|
60
|
-
filename="cifar-10-binary.tar.gz",
|
61
|
-
md5=True,
|
62
|
-
checksum="c32a1d4ab5d03f1284b67883e8d87530",
|
63
|
-
),
|
64
|
-
]
|
65
|
-
|
66
|
-
index2label: dict[int, str] = {
|
67
|
-
0: "airplane",
|
68
|
-
1: "automobile",
|
69
|
-
2: "bird",
|
70
|
-
3: "cat",
|
71
|
-
4: "deer",
|
72
|
-
5: "dog",
|
73
|
-
6: "frog",
|
74
|
-
7: "horse",
|
75
|
-
8: "ship",
|
76
|
-
9: "truck",
|
77
|
-
}
|
78
|
-
|
79
|
-
def __init__(
|
80
|
-
self,
|
81
|
-
root: str | Path,
|
82
|
-
image_set: Literal["train", "test", "base"] = "train",
|
83
|
-
transforms: Transform[NDArray[Any]] | Sequence[Transform[NDArray[Any]]] | None = None,
|
84
|
-
download: bool = False,
|
85
|
-
verbose: bool = False,
|
86
|
-
) -> None:
|
87
|
-
super().__init__(
|
88
|
-
root,
|
89
|
-
image_set,
|
90
|
-
transforms,
|
91
|
-
download,
|
92
|
-
verbose,
|
93
|
-
)
|
94
|
-
|
95
|
-
def _load_bin_data(self, data_folder: list[Path]) -> tuple[list[str], list[int], dict[str, Any]]:
|
96
|
-
batch_nums = np.zeros(60000, dtype=np.uint8)
|
97
|
-
all_labels = np.zeros(60000, dtype=np.uint8)
|
98
|
-
all_images = np.zeros((60000, 3, 32, 32), dtype=np.uint8)
|
99
|
-
# Process each batch file, skipping .meta and .html files
|
100
|
-
for batch_file in data_folder:
|
101
|
-
# Get batch parameters
|
102
|
-
batch_type = "test" if "test" in batch_file.stem else "train"
|
103
|
-
batch_num = 5 if batch_type == "test" else int(batch_file.stem.split("_")[-1]) - 1
|
104
|
-
|
105
|
-
# Load data
|
106
|
-
batch_images, batch_labels = self._unpack_batch_files(batch_file)
|
107
|
-
|
108
|
-
# Stack data
|
109
|
-
num_images = batch_images.shape[0]
|
110
|
-
batch_start = batch_num * num_images
|
111
|
-
all_images[batch_start : batch_start + num_images] = batch_images
|
112
|
-
all_labels[batch_start : batch_start + num_images] = batch_labels
|
113
|
-
batch_nums[batch_start : batch_start + num_images] = batch_num
|
114
|
-
|
115
|
-
# Save data
|
116
|
-
self._loaded_data = all_images
|
117
|
-
np.savez(self.path / "cifar10", images=self._loaded_data, labels=all_labels, batches=batch_nums)
|
118
|
-
|
119
|
-
# Select data
|
120
|
-
image_list = np.arange(all_labels.shape[0]).astype(str)
|
121
|
-
if self.image_set == "train":
|
122
|
-
return (
|
123
|
-
image_list[np.nonzero(batch_nums != 5)[0]].tolist(),
|
124
|
-
all_labels[batch_nums != 5].tolist(),
|
125
|
-
{"batch_num": batch_nums[batch_nums != 5].tolist()},
|
126
|
-
)
|
127
|
-
if self.image_set == "test":
|
128
|
-
return (
|
129
|
-
image_list[np.nonzero(batch_nums == 5)[0]].tolist(),
|
130
|
-
all_labels[batch_nums == 5].tolist(),
|
131
|
-
{"batch_num": batch_nums[batch_nums == 5].tolist()},
|
132
|
-
)
|
133
|
-
return image_list.tolist(), all_labels.tolist(), {"batch_num": batch_nums.tolist()}
|
134
|
-
|
135
|
-
def _load_data_inner(self) -> tuple[list[str], list[int], dict[str, Any]]:
|
136
|
-
"""Function to load in the file paths for the data and labels and retrieve metadata"""
|
137
|
-
data_file = self.path / "cifar10.npz"
|
138
|
-
if not data_file.exists():
|
139
|
-
data_folder = sorted((self.path / "cifar-10-batches-bin").glob("*.bin"))
|
140
|
-
if not data_folder:
|
141
|
-
raise FileNotFoundError
|
142
|
-
return self._load_bin_data(data_folder)
|
143
|
-
|
144
|
-
# Load data
|
145
|
-
data = np.load(data_file)
|
146
|
-
self._loaded_data = data["images"]
|
147
|
-
all_labels = data["labels"]
|
148
|
-
batch_nums = data["batches"]
|
149
|
-
|
150
|
-
# Select data
|
151
|
-
image_list = np.arange(all_labels.shape[0]).astype(str)
|
152
|
-
if self.image_set == "train":
|
153
|
-
return (
|
154
|
-
image_list[np.nonzero(batch_nums != 5)[0]].tolist(),
|
155
|
-
all_labels[batch_nums != 5].tolist(),
|
156
|
-
{"batch_num": batch_nums[batch_nums != 5].tolist()},
|
157
|
-
)
|
158
|
-
if self.image_set == "test":
|
159
|
-
return (
|
160
|
-
image_list[np.nonzero(batch_nums == 5)[0]].tolist(),
|
161
|
-
all_labels[batch_nums == 5].tolist(),
|
162
|
-
{"batch_num": batch_nums[batch_nums == 5].tolist()},
|
163
|
-
)
|
164
|
-
return image_list.tolist(), all_labels.tolist(), {"batch_num": batch_nums.tolist()}
|
165
|
-
|
166
|
-
def _unpack_batch_files(self, file_path: Path) -> tuple[NDArray[np.uint8], NDArray[np.uint8]]:
|
167
|
-
# Load pickle data with latin1 encoding
|
168
|
-
with file_path.open("rb") as f:
|
169
|
-
buffer = np.frombuffer(f.read(), dtype=np.uint8)
|
170
|
-
# Each entry is 1 byte for label + 3072 bytes for image (3*32*32)
|
171
|
-
entry_size = 1 + 3072
|
172
|
-
num_entries = buffer.size // entry_size
|
173
|
-
# Extract labels (first byte of each entry)
|
174
|
-
labels = buffer[::entry_size]
|
175
|
-
|
176
|
-
# Extract image data and reshape to (N, 3, 32, 32)
|
177
|
-
images = np.zeros((num_entries, 3, 32, 32), dtype=np.uint8)
|
178
|
-
for i in range(num_entries):
|
179
|
-
# Skip the label byte and get image data for this entry
|
180
|
-
start_idx = i * entry_size + 1 # +1 to skip label
|
181
|
-
img_flat = buffer[start_idx : start_idx + 3072]
|
182
|
-
|
183
|
-
# The CIFAR format stores channels in blocks (all R, then all G, then all B)
|
184
|
-
# Each channel block is 1024 bytes (32x32)
|
185
|
-
red_channel = img_flat[0:1024].reshape(32, 32)
|
186
|
-
green_channel = img_flat[1024:2048].reshape(32, 32)
|
187
|
-
blue_channel = img_flat[2048:3072].reshape(32, 32)
|
188
|
-
|
189
|
-
# Stack the channels in the proper C×H×W format
|
190
|
-
images[i, 0] = red_channel # Red channel
|
191
|
-
images[i, 1] = green_channel # Green channel
|
192
|
-
images[i, 2] = blue_channel # Blue channel
|
193
|
-
return images, labels
|
194
|
-
|
195
|
-
def _read_file(self, path: str) -> NDArray[Any]:
|
196
|
-
"""
|
197
|
-
Function to grab the correct image from the loaded data.
|
198
|
-
Overwrite of the base `_read_file` because data is an all or nothing load.
|
199
|
-
"""
|
200
|
-
index = int(path)
|
201
|
-
return self._loaded_data[index]
|
@@ -1,142 +0,0 @@
|
|
1
|
-
from __future__ import annotations
|
2
|
-
|
3
|
-
__all__ = []
|
4
|
-
|
5
|
-
import hashlib
|
6
|
-
import tarfile
|
7
|
-
import zipfile
|
8
|
-
from pathlib import Path
|
9
|
-
|
10
|
-
import requests
|
11
|
-
from tqdm.auto import tqdm
|
12
|
-
|
13
|
-
ARCHIVE_ENDINGS = [".zip", ".tar", ".tgz"]
|
14
|
-
COMPRESS_ENDINGS = [".gz", ".bz2"]
|
15
|
-
|
16
|
-
|
17
|
-
def _print(text: str, verbose: bool) -> None:
|
18
|
-
if verbose:
|
19
|
-
print(text)
|
20
|
-
|
21
|
-
|
22
|
-
def _validate_file(fpath: Path | str, file_md5: str, md5: bool = False, chunk_size: int = 65535) -> bool:
|
23
|
-
hasher = hashlib.md5(usedforsecurity=False) if md5 else hashlib.sha256()
|
24
|
-
with open(fpath, "rb") as fpath_file:
|
25
|
-
while chunk := fpath_file.read(chunk_size):
|
26
|
-
hasher.update(chunk)
|
27
|
-
return hasher.hexdigest() == file_md5
|
28
|
-
|
29
|
-
|
30
|
-
def _download_dataset(url: str, file_path: Path, timeout: int = 60, verbose: bool = False) -> None:
|
31
|
-
"""Download a single resource from its URL to the `data_folder`."""
|
32
|
-
error_msg = "URL fetch failure on {}: {} -- {}"
|
33
|
-
try:
|
34
|
-
response = requests.get(url, stream=True, timeout=timeout)
|
35
|
-
response.raise_for_status()
|
36
|
-
except requests.exceptions.HTTPError as e:
|
37
|
-
raise RuntimeError(f"{error_msg.format(url, e.response.status_code, e.response.reason)}") from e
|
38
|
-
except requests.exceptions.RequestException as e:
|
39
|
-
raise ValueError(f"{error_msg.format(url, 'Unknown error', str(e))}") from e
|
40
|
-
|
41
|
-
total_size = int(response.headers.get("content-length", 0))
|
42
|
-
block_size = 8192 # 8 KB
|
43
|
-
progress_bar = tqdm(total=total_size, unit="iB", unit_scale=True, disable=not verbose)
|
44
|
-
|
45
|
-
with open(file_path, "wb") as f:
|
46
|
-
for chunk in response.iter_content(block_size):
|
47
|
-
f.write(chunk)
|
48
|
-
progress_bar.update(len(chunk))
|
49
|
-
progress_bar.close()
|
50
|
-
|
51
|
-
|
52
|
-
def _extract_zip_archive(file_path: Path, extract_to: Path) -> None:
|
53
|
-
"""Extracts the zip file to the given directory."""
|
54
|
-
try:
|
55
|
-
with zipfile.ZipFile(file_path, "r") as zip_ref:
|
56
|
-
zip_ref.extractall(extract_to) # noqa: S202
|
57
|
-
file_path.unlink()
|
58
|
-
except zipfile.BadZipFile:
|
59
|
-
raise FileNotFoundError(f"{file_path.name} is not a valid zip file, skipping extraction.")
|
60
|
-
|
61
|
-
|
62
|
-
def _extract_tar_archive(file_path: Path, extract_to: Path) -> None:
|
63
|
-
"""Extracts a tar file (or compressed tar) to the specified directory."""
|
64
|
-
try:
|
65
|
-
with tarfile.open(file_path, "r:*") as tar_ref:
|
66
|
-
tar_ref.extractall(extract_to) # noqa: S202
|
67
|
-
file_path.unlink()
|
68
|
-
except tarfile.TarError:
|
69
|
-
raise FileNotFoundError(f"{file_path.name} is not a valid tar file, skipping extraction.")
|
70
|
-
|
71
|
-
|
72
|
-
def _extract_archive(
|
73
|
-
file_ext: str, file_path: Path, directory: Path, compression: bool = False, verbose: bool = False
|
74
|
-
) -> None:
|
75
|
-
"""
|
76
|
-
Single function to extract and then flatten if necessary.
|
77
|
-
Recursively extracts nested zip files as well.
|
78
|
-
Extracts and flattens all folders to the base directory.
|
79
|
-
"""
|
80
|
-
if file_ext != ".zip" or compression:
|
81
|
-
_extract_tar_archive(file_path, directory)
|
82
|
-
else:
|
83
|
-
_extract_zip_archive(file_path, directory)
|
84
|
-
# Look for nested zip files in the extraction directory and extract them recursively.
|
85
|
-
# Does NOT extract in place - extracts everything to directory
|
86
|
-
for child in directory.iterdir():
|
87
|
-
if child.suffix == ".zip":
|
88
|
-
_print(f"Extracting nested zip: {child} to {directory}", verbose)
|
89
|
-
_extract_zip_archive(child, directory)
|
90
|
-
|
91
|
-
|
92
|
-
def _ensure_exists(
|
93
|
-
url: str,
|
94
|
-
filename: str,
|
95
|
-
md5: bool,
|
96
|
-
checksum: str,
|
97
|
-
directory: Path,
|
98
|
-
root: Path,
|
99
|
-
download: bool = True,
|
100
|
-
verbose: bool = False,
|
101
|
-
) -> None:
|
102
|
-
"""
|
103
|
-
For each resource, download it if it doesn't exist in the dataset_dir.
|
104
|
-
If the resource is a zip file, extract it (including recursively extracting nested zips).
|
105
|
-
"""
|
106
|
-
file_path = directory / str(filename)
|
107
|
-
alternate_path = root / str(filename)
|
108
|
-
_, file_ext = file_path.stem, file_path.suffix
|
109
|
-
compression = False
|
110
|
-
if file_ext in COMPRESS_ENDINGS:
|
111
|
-
file_ext = file_path.suffixes[0]
|
112
|
-
compression = True
|
113
|
-
|
114
|
-
check_path = alternate_path if alternate_path.exists() and not file_path.exists() else file_path
|
115
|
-
|
116
|
-
# Download file if it doesn't exist.
|
117
|
-
if not check_path.exists() and download:
|
118
|
-
_print(f"Downloading {filename} from {url}", verbose)
|
119
|
-
_download_dataset(url, check_path, verbose=verbose)
|
120
|
-
|
121
|
-
if not _validate_file(check_path, checksum, md5):
|
122
|
-
raise Exception("File checksum mismatch. Remove current file and retry download.")
|
123
|
-
|
124
|
-
# If the file is a zip, tar or tgz extract it into the designated folder.
|
125
|
-
if file_ext in ARCHIVE_ENDINGS:
|
126
|
-
_print(f"Extracting {filename}...", verbose)
|
127
|
-
_extract_archive(file_ext, check_path, directory, compression, verbose)
|
128
|
-
|
129
|
-
elif not check_path.exists() and not download:
|
130
|
-
raise FileNotFoundError(
|
131
|
-
"Data could not be loaded with the provided root directory, "
|
132
|
-
f"the file path to the file {filename} does not exist, "
|
133
|
-
"and the download parameter is set to False."
|
134
|
-
)
|
135
|
-
else:
|
136
|
-
if not _validate_file(check_path, checksum, md5):
|
137
|
-
raise Exception("File checksum mismatch. Remove current file and retry download.")
|
138
|
-
_print(f"{filename} already exists, skipping download.", verbose)
|
139
|
-
|
140
|
-
if file_ext in ARCHIVE_ENDINGS:
|
141
|
-
_print(f"Extracting {filename}...", verbose)
|
142
|
-
_extract_archive(file_ext, check_path, directory, compression, verbose)
|
@@ -1,197 +0,0 @@
|
|
1
|
-
from __future__ import annotations
|
2
|
-
|
3
|
-
__all__ = []
|
4
|
-
|
5
|
-
from pathlib import Path
|
6
|
-
from typing import TYPE_CHECKING, Any, Literal, Sequence
|
7
|
-
|
8
|
-
from numpy.typing import NDArray
|
9
|
-
|
10
|
-
from dataeval.utils.datasets._base import BaseODDataset, DataLocation
|
11
|
-
from dataeval.utils.datasets._mixin import BaseDatasetNumpyMixin
|
12
|
-
|
13
|
-
if TYPE_CHECKING:
|
14
|
-
from dataeval.typing import Transform
|
15
|
-
|
16
|
-
|
17
|
-
class MILCO(BaseODDataset[NDArray[Any], list[str], str], BaseDatasetNumpyMixin):
|
18
|
-
"""
|
19
|
-
A side-scan sonar dataset focused on mine-like object detection.
|
20
|
-
|
21
|
-
The dataset comes from the paper
|
22
|
-
`Side-scan sonar imaging data of underwater vehicles for mine detection <https://doi.org/10.1016/j.dib.2024.110132>`_
|
23
|
-
by N.P. Santos et. al. (2024).
|
24
|
-
|
25
|
-
The full dataset contains 1170 side-scan sonar images collected using a 900-1800 kHz Marine Sonic
|
26
|
-
dual frequency side-scan sonar of a Teledyne Marine Gavia Autonomous Underwater Vehicle.
|
27
|
-
All the images were carefully analyzed and annotated, including the image coordinates of the
|
28
|
-
Bounding Box (BB) of the detected objects divided into NOn-Mine-like BOttom Objects (NOMBO)
|
29
|
-
and MIne-Like COntacts (MILCO) classes.
|
30
|
-
|
31
|
-
This dataset is consists of 345 images from 2010, 120 images from 2015, 93 images from 2017, 564 images from 2018,
|
32
|
-
and 48 images from 2021). In these 1170 images, there are 432 MILCO objects, and 235 NOMBO objects.
|
33
|
-
The class “0” corresponds to a MILCO object and the class “1” corresponds to a NOMBO object.
|
34
|
-
The raw BB coordinates provided in the downloaded text files are (x, y, w, h),
|
35
|
-
given as percentages of the image (x_BB = x/img_width, y_BB = y/img_height, etc.).
|
36
|
-
The images come in 2 sizes, 416 x 416 or 1024 x 1024.
|
37
|
-
|
38
|
-
Parameters
|
39
|
-
----------
|
40
|
-
root : str or pathlib.Path
|
41
|
-
Root directory where the data should be downloaded to or the ``milco`` folder of the already downloaded data.
|
42
|
-
image_set: "train", "operational", or "base", default "train"
|
43
|
-
If "train", then the images from 2015, 2017 and 2021 are selected,
|
44
|
-
resulting in 315 MILCO objects and 177 NOMBO objects.
|
45
|
-
If "operational", then the images from 2010 and 2018 are selected,
|
46
|
-
resulting in 117 MILCO objects and 58 NOMBO objects.
|
47
|
-
If "base", then the full dataset is selected.
|
48
|
-
transforms : Transform, Sequence[Transform] or None, default None
|
49
|
-
Transform(s) to apply to the data.
|
50
|
-
download : bool, default False
|
51
|
-
If True, downloads the dataset from the internet and puts it in root directory.
|
52
|
-
Class checks to see if data is already downloaded to ensure it does not create a duplicate download.
|
53
|
-
verbose : bool, default False
|
54
|
-
If True, outputs print statements.
|
55
|
-
|
56
|
-
Attributes
|
57
|
-
----------
|
58
|
-
path : pathlib.Path
|
59
|
-
Location of the folder containing the data.
|
60
|
-
image_set : "train", "operational" or "base"
|
61
|
-
The selected image set from the dataset.
|
62
|
-
index2label : dict[int, str]
|
63
|
-
Dictionary which translates from class integers to the associated class strings.
|
64
|
-
label2index : dict[str, int]
|
65
|
-
Dictionary which translates from class strings to the associated class integers.
|
66
|
-
metadata : DatasetMetadata
|
67
|
-
Typed dictionary containing dataset metadata, such as `id` which returns the dataset class name.
|
68
|
-
transforms : Sequence[Transform]
|
69
|
-
The transforms to be applied to the data.
|
70
|
-
size : int
|
71
|
-
The size of the dataset.
|
72
|
-
|
73
|
-
Note
|
74
|
-
----
|
75
|
-
Data License: `CC BY 4.0 <https://creativecommons.org/licenses/by/4.0/>`_
|
76
|
-
"""
|
77
|
-
|
78
|
-
_resources = [
|
79
|
-
DataLocation(
|
80
|
-
url="https://figshare.com/ndownloader/files/43169002",
|
81
|
-
filename="2015.zip",
|
82
|
-
md5=True,
|
83
|
-
checksum="93dfbb4fb7987734152c372496b4884c",
|
84
|
-
),
|
85
|
-
DataLocation(
|
86
|
-
url="https://figshare.com/ndownloader/files/43169005",
|
87
|
-
filename="2017.zip",
|
88
|
-
md5=True,
|
89
|
-
checksum="9c2de230a2bbf654921416bea6fc0f42",
|
90
|
-
),
|
91
|
-
DataLocation(
|
92
|
-
url="https://figshare.com/ndownloader/files/43168999",
|
93
|
-
filename="2021.zip",
|
94
|
-
md5=True,
|
95
|
-
checksum="b84749b21fa95a4a4c7de3741db78bc7",
|
96
|
-
),
|
97
|
-
DataLocation(
|
98
|
-
url="https://figshare.com/ndownloader/files/43169008",
|
99
|
-
filename="2010.zip",
|
100
|
-
md5=True,
|
101
|
-
checksum="43347a0cc383c0d3dbe0d24ae56f328d",
|
102
|
-
),
|
103
|
-
DataLocation(
|
104
|
-
url="https://figshare.com/ndownloader/files/43169011",
|
105
|
-
filename="2018.zip",
|
106
|
-
md5=True,
|
107
|
-
checksum="25d091044a10c78674fedad655023e3b",
|
108
|
-
),
|
109
|
-
]
|
110
|
-
|
111
|
-
index2label: dict[int, str] = {
|
112
|
-
0: "MILCO",
|
113
|
-
1: "NOMBO",
|
114
|
-
}
|
115
|
-
|
116
|
-
def __init__(
|
117
|
-
self,
|
118
|
-
root: str | Path,
|
119
|
-
image_set: Literal["train", "operational", "base"] = "train",
|
120
|
-
transforms: Transform[NDArray[Any]] | Sequence[Transform[NDArray[Any]]] | None = None,
|
121
|
-
download: bool = False,
|
122
|
-
verbose: bool = False,
|
123
|
-
) -> None:
|
124
|
-
super().__init__(
|
125
|
-
root,
|
126
|
-
image_set,
|
127
|
-
transforms,
|
128
|
-
download,
|
129
|
-
verbose,
|
130
|
-
)
|
131
|
-
self._bboxes_per_size = True
|
132
|
-
|
133
|
-
def _load_data(self) -> tuple[list[str], list[str], dict[str, list[Any]]]:
|
134
|
-
filepaths: list[str] = []
|
135
|
-
targets: list[str] = []
|
136
|
-
datum_metadata: dict[str, list[Any]] = {}
|
137
|
-
metadata_list: list[dict[str, Any]] = []
|
138
|
-
image_sets: dict[str, list[int]] = {
|
139
|
-
"base": list(range(len(self._resources))),
|
140
|
-
"train": list(range(3)),
|
141
|
-
"operational": list(range(3, len(self._resources))),
|
142
|
-
}
|
143
|
-
|
144
|
-
# Load the data
|
145
|
-
resource_indices = image_sets[self.image_set]
|
146
|
-
for idx in resource_indices:
|
147
|
-
self._resource = self._resources[idx]
|
148
|
-
filepath, target, metadata = super()._load_data()
|
149
|
-
filepaths.extend(filepath)
|
150
|
-
targets.extend(target)
|
151
|
-
metadata_list.append(metadata)
|
152
|
-
|
153
|
-
# Adjust datum metadata to correct format
|
154
|
-
for data_dict in metadata_list:
|
155
|
-
for key, val in data_dict.items():
|
156
|
-
if key not in datum_metadata:
|
157
|
-
datum_metadata[str(key)] = []
|
158
|
-
datum_metadata[str(key)].extend(val)
|
159
|
-
|
160
|
-
return filepaths, targets, datum_metadata
|
161
|
-
|
162
|
-
def _load_data_inner(self) -> tuple[list[str], list[str], dict[str, Any]]:
|
163
|
-
file_data = {"year": [], "image_id": [], "data_path": [], "label_path": []}
|
164
|
-
data_folder = sorted((self.path / self._resource.filename[:-4]).glob("*.jpg"))
|
165
|
-
if not data_folder:
|
166
|
-
raise FileNotFoundError
|
167
|
-
|
168
|
-
for entry in data_folder:
|
169
|
-
# Remove file extension and split by "_"
|
170
|
-
parts = entry.stem.split("_")
|
171
|
-
file_data["image_id"].append(parts[0])
|
172
|
-
file_data["year"].append(parts[1])
|
173
|
-
file_data["data_path"].append(str(entry))
|
174
|
-
file_data["label_path"].append(str(entry.parent / entry.stem) + ".txt")
|
175
|
-
data = file_data.pop("data_path")
|
176
|
-
annotations = file_data.pop("label_path")
|
177
|
-
|
178
|
-
return data, annotations, file_data
|
179
|
-
|
180
|
-
def _read_annotations(self, annotation: str) -> tuple[list[list[float]], list[int], dict[str, Any]]:
|
181
|
-
"""Function for extracting the info out of the text files"""
|
182
|
-
labels: list[int] = []
|
183
|
-
boxes: list[list[float]] = []
|
184
|
-
with open(annotation) as f:
|
185
|
-
for line in f.readlines():
|
186
|
-
out = line.strip().split()
|
187
|
-
labels.append(int(out[0]))
|
188
|
-
|
189
|
-
xcenter, ycenter, width, height = [float(out[1]), float(out[2]), float(out[3]), float(out[4])]
|
190
|
-
|
191
|
-
x0 = xcenter - width / 2
|
192
|
-
x1 = x0 + width
|
193
|
-
y0 = ycenter - height / 2
|
194
|
-
y1 = y0 + height
|
195
|
-
boxes.append([x0, y0, x1, y1])
|
196
|
-
|
197
|
-
return boxes, labels, {}
|
@@ -1,54 +0,0 @@
|
|
1
|
-
from __future__ import annotations
|
2
|
-
|
3
|
-
__all__ = []
|
4
|
-
|
5
|
-
from typing import Any, Generic, TypeVar
|
6
|
-
|
7
|
-
import numpy as np
|
8
|
-
import torch
|
9
|
-
from numpy.typing import NDArray
|
10
|
-
from PIL import Image
|
11
|
-
|
12
|
-
_TArray = TypeVar("_TArray")
|
13
|
-
|
14
|
-
|
15
|
-
class BaseDatasetMixin(Generic[_TArray]):
|
16
|
-
index2label: dict[int, str]
|
17
|
-
|
18
|
-
def _as_array(self, raw: list[Any]) -> _TArray: ...
|
19
|
-
def _one_hot_encode(self, value: int | list[int]) -> _TArray: ...
|
20
|
-
def _read_file(self, path: str) -> _TArray: ...
|
21
|
-
|
22
|
-
|
23
|
-
class BaseDatasetNumpyMixin(BaseDatasetMixin[NDArray[Any]]):
|
24
|
-
def _as_array(self, raw: list[Any]) -> NDArray[Any]:
|
25
|
-
return np.asarray(raw)
|
26
|
-
|
27
|
-
def _one_hot_encode(self, value: int | list[int]) -> NDArray[Any]:
|
28
|
-
if isinstance(value, int):
|
29
|
-
encoded = np.zeros(len(self.index2label))
|
30
|
-
encoded[value] = 1
|
31
|
-
else:
|
32
|
-
encoded = np.zeros((len(value), len(self.index2label)))
|
33
|
-
encoded[np.arange(len(value)), value] = 1
|
34
|
-
return encoded
|
35
|
-
|
36
|
-
def _read_file(self, path: str) -> NDArray[Any]:
|
37
|
-
return np.array(Image.open(path)).transpose(2, 0, 1)
|
38
|
-
|
39
|
-
|
40
|
-
class BaseDatasetTorchMixin(BaseDatasetMixin[torch.Tensor]):
|
41
|
-
def _as_array(self, raw: list[Any]) -> torch.Tensor:
|
42
|
-
return torch.as_tensor(raw)
|
43
|
-
|
44
|
-
def _one_hot_encode(self, value: int | list[int]) -> torch.Tensor:
|
45
|
-
if isinstance(value, int):
|
46
|
-
encoded = torch.zeros(len(self.index2label))
|
47
|
-
encoded[value] = 1
|
48
|
-
else:
|
49
|
-
encoded = torch.zeros((len(value), len(self.index2label)))
|
50
|
-
encoded[torch.arange(len(value)), value] = 1
|
51
|
-
return encoded
|
52
|
-
|
53
|
-
def _read_file(self, path: str) -> torch.Tensor:
|
54
|
-
return torch.as_tensor(np.array(Image.open(path)).transpose(2, 0, 1))
|