dataeval 0.86.0__py3-none-any.whl → 0.86.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dataeval/__init__.py +1 -1
- dataeval/_log.py +1 -1
- dataeval/config.py +21 -4
- dataeval/data/_embeddings.py +2 -2
- dataeval/data/_images.py +2 -3
- dataeval/data/_metadata.py +188 -178
- dataeval/data/_selection.py +1 -2
- dataeval/data/_split.py +4 -5
- dataeval/data/_targets.py +17 -13
- dataeval/data/selections/_classfilter.py +2 -5
- dataeval/data/selections/_prioritize.py +6 -9
- dataeval/data/selections/_shuffle.py +3 -1
- dataeval/detectors/drift/_base.py +4 -5
- dataeval/detectors/drift/_mmd.py +3 -6
- dataeval/detectors/drift/_nml/_base.py +4 -2
- dataeval/detectors/drift/_nml/_chunk.py +11 -19
- dataeval/detectors/drift/_nml/_domainclassifier.py +8 -19
- dataeval/detectors/drift/_nml/_result.py +8 -9
- dataeval/detectors/drift/_nml/_thresholds.py +66 -77
- dataeval/detectors/linters/outliers.py +7 -7
- dataeval/metadata/_distance.py +10 -7
- dataeval/metadata/_ood.py +11 -103
- dataeval/metrics/bias/_balance.py +23 -33
- dataeval/metrics/bias/_diversity.py +16 -14
- dataeval/metrics/bias/_parity.py +18 -18
- dataeval/metrics/estimators/_divergence.py +2 -4
- dataeval/metrics/stats/_base.py +103 -42
- dataeval/metrics/stats/_boxratiostats.py +21 -19
- dataeval/metrics/stats/_dimensionstats.py +14 -10
- dataeval/metrics/stats/_hashstats.py +1 -1
- dataeval/metrics/stats/_pixelstats.py +6 -6
- dataeval/metrics/stats/_visualstats.py +3 -3
- dataeval/outputs/_base.py +22 -7
- dataeval/outputs/_bias.py +24 -70
- dataeval/outputs/_drift.py +1 -9
- dataeval/outputs/_linters.py +11 -11
- dataeval/outputs/_stats.py +82 -23
- dataeval/outputs/_workflows.py +2 -2
- dataeval/utils/_array.py +6 -9
- dataeval/utils/_bin.py +1 -2
- dataeval/utils/_clusterer.py +7 -4
- dataeval/utils/_fast_mst.py +27 -13
- dataeval/utils/_image.py +65 -11
- dataeval/utils/_mst.py +1 -3
- dataeval/utils/_plot.py +15 -10
- dataeval/utils/data/_dataset.py +54 -28
- dataeval/utils/data/metadata.py +104 -82
- dataeval/utils/datasets/__init__.py +2 -0
- dataeval/utils/datasets/_antiuav.py +189 -0
- dataeval/utils/datasets/_base.py +11 -8
- dataeval/utils/datasets/_cifar10.py +104 -45
- dataeval/utils/datasets/_fileio.py +21 -47
- dataeval/utils/datasets/_milco.py +22 -12
- dataeval/utils/datasets/_mixin.py +2 -4
- dataeval/utils/datasets/_mnist.py +3 -4
- dataeval/utils/datasets/_ships.py +14 -7
- dataeval/utils/datasets/_voc.py +229 -42
- dataeval/utils/torch/models.py +5 -10
- dataeval/utils/torch/trainer.py +3 -3
- dataeval/workflows/sufficiency.py +2 -2
- {dataeval-0.86.0.dist-info → dataeval-0.86.2.dist-info}/METADATA +2 -1
- dataeval-0.86.2.dist-info/RECORD +114 -0
- dataeval/detectors/ood/vae.py +0 -74
- dataeval-0.86.0.dist-info/RECORD +0 -114
- {dataeval-0.86.0.dist-info → dataeval-0.86.2.dist-info}/LICENSE.txt +0 -0
- {dataeval-0.86.0.dist-info → dataeval-0.86.2.dist-info}/WHEEL +0 -0
@@ -0,0 +1,189 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
__all__ = []
|
4
|
+
|
5
|
+
from pathlib import Path
|
6
|
+
from typing import TYPE_CHECKING, Any, Literal, Sequence
|
7
|
+
|
8
|
+
from defusedxml.ElementTree import parse
|
9
|
+
from numpy.typing import NDArray
|
10
|
+
|
11
|
+
from dataeval.utils.datasets._base import BaseODDataset, DataLocation
|
12
|
+
from dataeval.utils.datasets._mixin import BaseDatasetNumpyMixin
|
13
|
+
|
14
|
+
if TYPE_CHECKING:
|
15
|
+
from dataeval.typing import Transform
|
16
|
+
|
17
|
+
|
18
|
+
class AntiUAVDetection(BaseODDataset[NDArray[Any]], BaseDatasetNumpyMixin):
|
19
|
+
"""
|
20
|
+
A UAV detection dataset focused on detecting UAVs in natural images against large variation in backgrounds.
|
21
|
+
|
22
|
+
The dataset comes from the paper
|
23
|
+
`Vision-based Anti-UAV Detection and Tracking <https://ieeexplore.ieee.org/document/9785379>`_
|
24
|
+
by Jie Zhao et. al. (2022).
|
25
|
+
|
26
|
+
The dataset is approximately 1.3 GB and can be found `here <https://github.com/wangdongdut/DUT-Anti-UAV>`_.
|
27
|
+
Images are collected against a variety of different backgrounds with a variety in the number and type of UAV.
|
28
|
+
Ground truth labels are provided for the train, validation and test set.
|
29
|
+
There are 35 different types of drones along with a variety in lighting conditions and weather conditions.
|
30
|
+
|
31
|
+
There are 10,000 images: 5200 images in the training set, 2200 images in the validation set,
|
32
|
+
and 2600 images in the test set.
|
33
|
+
The dataset only has a single UAV class with the focus being on identifying object location in the image.
|
34
|
+
Ground-truth bounding boxes are provided in (x0, y0, x1, y1) format.
|
35
|
+
The images come in a variety of sizes from 3744 x 5616 to 160 x 240.
|
36
|
+
|
37
|
+
Parameters
|
38
|
+
----------
|
39
|
+
root : str or pathlib.Path
|
40
|
+
Root directory where the data should be downloaded to or
|
41
|
+
the ``antiuavdetection`` folder of the already downloaded data.
|
42
|
+
image_set: "train", "val", "test", or "base", default "train"
|
43
|
+
If "base", then the full dataset is selected (train, val and test).
|
44
|
+
transforms : Transform, Sequence[Transform] or None, default None
|
45
|
+
Transform(s) to apply to the data.
|
46
|
+
download : bool, default False
|
47
|
+
If True, downloads the dataset from the internet and puts it in root directory.
|
48
|
+
Class checks to see if data is already downloaded to ensure it does not create a duplicate download.
|
49
|
+
verbose : bool, default False
|
50
|
+
If True, outputs print statements.
|
51
|
+
|
52
|
+
Attributes
|
53
|
+
----------
|
54
|
+
path : pathlib.Path
|
55
|
+
Location of the folder containing the data.
|
56
|
+
image_set : "train", "val", "test", or "base"
|
57
|
+
The selected image set from the dataset.
|
58
|
+
index2label : dict[int, str]
|
59
|
+
Dictionary which translates from class integers to the associated class strings.
|
60
|
+
label2index : dict[str, int]
|
61
|
+
Dictionary which translates from class strings to the associated class integers.
|
62
|
+
metadata : DatasetMetadata
|
63
|
+
Typed dictionary containing dataset metadata, such as `id` which returns the dataset class name.
|
64
|
+
transforms : Sequence[Transform]
|
65
|
+
The transforms to be applied to the data.
|
66
|
+
size : int
|
67
|
+
The size of the dataset.
|
68
|
+
|
69
|
+
Note
|
70
|
+
----
|
71
|
+
Data License: `Apache 2.0 <https://www.apache.org/licenses/LICENSE-2.0.txt>`_
|
72
|
+
"""
|
73
|
+
|
74
|
+
# Need to run the sha256 on the files and then store that
|
75
|
+
_resources = [
|
76
|
+
DataLocation(
|
77
|
+
url="https://drive.usercontent.google.com/download?id=1RVsSGPUKTdmoyoPTBTWwroyulLek1eTj&export=download&authuser=0&confirm=t&uuid=6bca4f94-a242-4bc2-9663-fb03cd94ef2c&at=APcmpox0--NroQ_3bqeTFaJxP7Pw%3A1746552902927",
|
78
|
+
filename="train.zip",
|
79
|
+
md5=False,
|
80
|
+
checksum="14f927290556df60e23cedfa80dffc10dc21e4a3b6843e150cfc49644376eece",
|
81
|
+
),
|
82
|
+
DataLocation(
|
83
|
+
url="https://drive.usercontent.google.com/download?id=1333uEQfGuqTKslRkkeLSCxylh6AQ0X6n&export=download&authuser=0&confirm=t&uuid=c2ad2f01-aca8-4a85-96bb-b8ef6e40feea&at=APcmpozY-8bhk3nZSFaYbE8rq1Fi%3A1746551543297",
|
84
|
+
filename="val.zip",
|
85
|
+
md5=False,
|
86
|
+
checksum="238be0ceb3e7c5be6711ee3247e49df2750d52f91f54f5366c68bebac112ebf8",
|
87
|
+
),
|
88
|
+
DataLocation(
|
89
|
+
url="https://drive.usercontent.google.com/download?id=1L1zeW1EMDLlXHClSDcCjl3rs_A6sVai0&export=download&authuser=0&confirm=t&uuid=5a1d7650-d8cd-4461-8354-7daf7292f06c&at=APcmpozLQC1CuP-n5_UX2JnP53Zo%3A1746551676177",
|
90
|
+
filename="test.zip",
|
91
|
+
md5=False,
|
92
|
+
checksum="a671989a01cff98c684aeb084e59b86f4152c50499d86152eb970a9fc7fb1cbe",
|
93
|
+
),
|
94
|
+
]
|
95
|
+
|
96
|
+
index2label: dict[int, str] = {
|
97
|
+
0: "unknown",
|
98
|
+
1: "UAV",
|
99
|
+
}
|
100
|
+
|
101
|
+
def __init__(
|
102
|
+
self,
|
103
|
+
root: str | Path,
|
104
|
+
image_set: Literal["train", "val", "test", "base"] = "train",
|
105
|
+
transforms: Transform[NDArray[Any]] | Sequence[Transform[NDArray[Any]]] | None = None,
|
106
|
+
download: bool = False,
|
107
|
+
verbose: bool = False,
|
108
|
+
) -> None:
|
109
|
+
super().__init__(
|
110
|
+
root,
|
111
|
+
image_set,
|
112
|
+
transforms,
|
113
|
+
download,
|
114
|
+
verbose,
|
115
|
+
)
|
116
|
+
|
117
|
+
def _load_data(self) -> tuple[list[str], list[str], dict[str, list[Any]]]:
|
118
|
+
filepaths: list[str] = []
|
119
|
+
targets: list[str] = []
|
120
|
+
datum_metadata: dict[str, list[Any]] = {}
|
121
|
+
|
122
|
+
# If base, load all resources
|
123
|
+
if self.image_set == "base":
|
124
|
+
metadata_list: list[dict[str, Any]] = []
|
125
|
+
|
126
|
+
for resource in self._resources:
|
127
|
+
self._resource = resource
|
128
|
+
resource_filepaths, resource_targets, resource_metadata = super()._load_data()
|
129
|
+
filepaths.extend(resource_filepaths)
|
130
|
+
targets.extend(resource_targets)
|
131
|
+
metadata_list.append(resource_metadata)
|
132
|
+
|
133
|
+
# Combine metadata
|
134
|
+
for data_dict in metadata_list:
|
135
|
+
for key, val in data_dict.items():
|
136
|
+
str_key = str(key) # Ensure key is string
|
137
|
+
if str_key not in datum_metadata:
|
138
|
+
datum_metadata[str_key] = []
|
139
|
+
datum_metadata[str_key].extend(val)
|
140
|
+
|
141
|
+
else:
|
142
|
+
# Grab only the desired data
|
143
|
+
for resource in self._resources:
|
144
|
+
if self.image_set in resource.filename:
|
145
|
+
self._resource = resource
|
146
|
+
resource_filepaths, resource_targets, resource_metadata = super()._load_data()
|
147
|
+
filepaths.extend(resource_filepaths)
|
148
|
+
targets.extend(resource_targets)
|
149
|
+
datum_metadata.update(resource_metadata)
|
150
|
+
|
151
|
+
return filepaths, targets, datum_metadata
|
152
|
+
|
153
|
+
def _load_data_inner(self) -> tuple[list[str], list[str], dict[str, Any]]:
|
154
|
+
resource_name = self._resource.filename[:-4]
|
155
|
+
base_dir = self.path / resource_name
|
156
|
+
data_folder = sorted((base_dir / "img").glob("*.jpg"))
|
157
|
+
if not data_folder:
|
158
|
+
raise FileNotFoundError
|
159
|
+
|
160
|
+
file_data = {"image_id": [f"{resource_name}_{entry.name}" for entry in data_folder]}
|
161
|
+
data = [str(entry) for entry in data_folder]
|
162
|
+
annotations = sorted(str(entry) for entry in (base_dir / "xml").glob("*.xml"))
|
163
|
+
|
164
|
+
return data, annotations, file_data
|
165
|
+
|
166
|
+
def _read_annotations(self, annotation: str) -> tuple[list[list[float]], list[int], dict[str, Any]]:
|
167
|
+
"""Function for extracting the info for the label and boxes"""
|
168
|
+
boxes: list[list[float]] = []
|
169
|
+
labels = []
|
170
|
+
root = parse(annotation).getroot()
|
171
|
+
if root is None:
|
172
|
+
raise ValueError(f"Unable to parse {annotation}")
|
173
|
+
additional_meta: dict[str, Any] = {
|
174
|
+
"image_width": int(root.findtext("size/width", default="-1")),
|
175
|
+
"image_height": int(root.findtext("size/height", default="-1")),
|
176
|
+
"image_depth": int(root.findtext("size/depth", default="-1")),
|
177
|
+
}
|
178
|
+
for obj in root.findall("object"):
|
179
|
+
labels.append(1 if obj.findtext("name", default="") == "UAV" else 0)
|
180
|
+
boxes.append(
|
181
|
+
[
|
182
|
+
float(obj.findtext("bndbox/xmin", default="0")),
|
183
|
+
float(obj.findtext("bndbox/ymin", default="0")),
|
184
|
+
float(obj.findtext("bndbox/xmax", default="0")),
|
185
|
+
float(obj.findtext("bndbox/ymax", default="0")),
|
186
|
+
]
|
187
|
+
)
|
188
|
+
|
189
|
+
return boxes, labels, additional_meta
|
dataeval/utils/datasets/_base.py
CHANGED
@@ -6,6 +6,8 @@ from abc import abstractmethod
|
|
6
6
|
from pathlib import Path
|
7
7
|
from typing import TYPE_CHECKING, Any, Generic, Iterator, Literal, NamedTuple, Sequence, TypeVar
|
8
8
|
|
9
|
+
import numpy as np
|
10
|
+
|
9
11
|
from dataeval.utils.datasets._fileio import _ensure_exists
|
10
12
|
from dataeval.utils.datasets._mixin import BaseDatasetMixin
|
11
13
|
from dataeval.utils.datasets._types import (
|
@@ -101,11 +103,7 @@ class BaseDataset(AnnotatedDataset[tuple[_TArray, _TTarget, dict[str, Any]]], Ge
|
|
101
103
|
|
102
104
|
def _get_dataset_dir(self) -> Path:
|
103
105
|
# Create a designated folder for this dataset (named after the class)
|
104
|
-
if self._root.stem
|
105
|
-
self.__class__.__name__.lower(),
|
106
|
-
self.__class__.__name__.upper(),
|
107
|
-
self.__class__.__name__,
|
108
|
-
]:
|
106
|
+
if self._root.stem.lower() == self.__class__.__name__.lower():
|
109
107
|
dataset_dir: Path = self._root
|
110
108
|
else:
|
111
109
|
dataset_dir: Path = self._root / self.__class__.__name__.lower()
|
@@ -114,8 +112,7 @@ class BaseDataset(AnnotatedDataset[tuple[_TArray, _TTarget, dict[str, Any]]], Ge
|
|
114
112
|
return dataset_dir
|
115
113
|
|
116
114
|
def _unique_id(self) -> str:
|
117
|
-
|
118
|
-
return unique_id
|
115
|
+
return f"{self.__class__.__name__}_{self.image_set}"
|
119
116
|
|
120
117
|
def _load_data(self) -> tuple[list[str], _TRawTarget, dict[str, Any]]:
|
121
118
|
"""
|
@@ -188,6 +185,8 @@ class BaseODDataset(
|
|
188
185
|
Base class for object detection datasets.
|
189
186
|
"""
|
190
187
|
|
188
|
+
_bboxes_per_size: bool = False
|
189
|
+
|
191
190
|
def __getitem__(self, index: int) -> tuple[_TArray, ObjectDetectionTarget[_TArray], dict[str, Any]]:
|
192
191
|
"""
|
193
192
|
Args
|
@@ -204,8 +203,12 @@ class BaseODDataset(
|
|
204
203
|
boxes, labels, additional_metadata = self._read_annotations(self._targets[index])
|
205
204
|
# Get the image
|
206
205
|
img = self._read_file(self._filepaths[index])
|
206
|
+
img_size = img.shape
|
207
207
|
img = self._transform(img)
|
208
|
-
|
208
|
+
# Adjust labels if necessary
|
209
|
+
if self._bboxes_per_size and boxes:
|
210
|
+
boxes = boxes * np.array([[img_size[1], img_size[2], img_size[1], img_size[2]]])
|
211
|
+
# Create the Object Detection Target
|
209
212
|
target = ObjectDetectionTarget(self._as_array(boxes), self._as_array(labels), self._one_hot_encode(labels))
|
210
213
|
|
211
214
|
img_metadata = {key: val[index] for key, val in self._datum_metadata.items()}
|
@@ -7,7 +7,6 @@ from typing import TYPE_CHECKING, Any, Literal, Sequence, TypeVar
|
|
7
7
|
|
8
8
|
import numpy as np
|
9
9
|
from numpy.typing import NDArray
|
10
|
-
from PIL import Image
|
11
10
|
|
12
11
|
from dataeval.utils.datasets._base import BaseICDataset, DataLocation
|
13
12
|
from dataeval.utils.datasets._mixin import BaseDatasetNumpyMixin
|
@@ -26,7 +25,7 @@ class CIFAR10(BaseICDataset[NDArray[Any]], BaseDatasetNumpyMixin):
|
|
26
25
|
Parameters
|
27
26
|
----------
|
28
27
|
root : str or pathlib.Path
|
29
|
-
Root directory
|
28
|
+
Root directory where the data should be downloaded to or the ``cifar10`` folder of the already downloaded data.
|
30
29
|
image_set : "train", "test" or "base", default "train"
|
31
30
|
If "base", returns all of the data to allow the user to create their own splits.
|
32
31
|
transforms : Transform, Sequence[Transform] or None, default None
|
@@ -93,50 +92,110 @@ class CIFAR10(BaseICDataset[NDArray[Any]], BaseDatasetNumpyMixin):
|
|
93
92
|
verbose,
|
94
93
|
)
|
95
94
|
|
95
|
+
def _load_bin_data(self, data_folder: list[Path]) -> tuple[list[str], list[int], dict[str, Any]]:
|
96
|
+
batch_nums = np.zeros(60000, dtype=np.uint8)
|
97
|
+
all_labels = np.zeros(60000, dtype=np.uint8)
|
98
|
+
all_images = np.zeros((60000, 3, 32, 32), dtype=np.uint8)
|
99
|
+
# Process each batch file, skipping .meta and .html files
|
100
|
+
for batch_file in data_folder:
|
101
|
+
# Get batch parameters
|
102
|
+
batch_type = "test" if "test" in batch_file.stem else "train"
|
103
|
+
batch_num = 5 if batch_type == "test" else int(batch_file.stem.split("_")[-1]) - 1
|
104
|
+
|
105
|
+
# Load data
|
106
|
+
batch_images, batch_labels = self._unpack_batch_files(batch_file)
|
107
|
+
|
108
|
+
# Stack data
|
109
|
+
num_images = batch_images.shape[0]
|
110
|
+
batch_start = batch_num * num_images
|
111
|
+
all_images[batch_start : batch_start + num_images] = batch_images
|
112
|
+
all_labels[batch_start : batch_start + num_images] = batch_labels
|
113
|
+
batch_nums[batch_start : batch_start + num_images] = batch_num
|
114
|
+
|
115
|
+
# Save data
|
116
|
+
self._loaded_data = all_images
|
117
|
+
np.savez(self.path / "cifar10", images=self._loaded_data, labels=all_labels, batches=batch_nums)
|
118
|
+
|
119
|
+
# Select data
|
120
|
+
image_list = np.arange(all_labels.shape[0]).astype(str)
|
121
|
+
if self.image_set == "train":
|
122
|
+
return (
|
123
|
+
image_list[np.nonzero(batch_nums != 5)[0]].tolist(),
|
124
|
+
all_labels[batch_nums != 5].tolist(),
|
125
|
+
{"batch_num": batch_nums[batch_nums != 5].tolist()},
|
126
|
+
)
|
127
|
+
if self.image_set == "test":
|
128
|
+
return (
|
129
|
+
image_list[np.nonzero(batch_nums == 5)[0]].tolist(),
|
130
|
+
all_labels[batch_nums == 5].tolist(),
|
131
|
+
{"batch_num": batch_nums[batch_nums == 5].tolist()},
|
132
|
+
)
|
133
|
+
return image_list.tolist(), all_labels.tolist(), {"batch_num": batch_nums.tolist()}
|
134
|
+
|
96
135
|
def _load_data_inner(self) -> tuple[list[str], list[int], dict[str, Any]]:
|
97
136
|
"""Function to load in the file paths for the data and labels and retrieve metadata"""
|
98
|
-
|
99
|
-
|
100
|
-
|
101
|
-
|
102
|
-
|
103
|
-
|
104
|
-
|
105
|
-
#
|
106
|
-
|
107
|
-
|
108
|
-
|
109
|
-
|
110
|
-
|
111
|
-
|
112
|
-
|
113
|
-
|
114
|
-
|
115
|
-
|
116
|
-
]
|
117
|
-
|
118
|
-
|
119
|
-
|
120
|
-
|
121
|
-
|
122
|
-
|
123
|
-
|
124
|
-
|
125
|
-
|
126
|
-
|
127
|
-
|
128
|
-
if not save_folder.exists():
|
129
|
-
save_folder.mkdir(exist_ok=True)
|
130
|
-
for i, file in enumerate(image_sets["base"]):
|
131
|
-
Image.fromarray(images[i].transpose(1, 2, 0).astype(np.uint8)).save(file)
|
132
|
-
|
133
|
-
return image_sets[self.image_set], labels, file_meta
|
134
|
-
|
135
|
-
def _unpack_batch_files(self, file_path: Path) -> tuple[NDArray[Any], list[int]]:
|
137
|
+
data_file = self.path / "cifar10.npz"
|
138
|
+
if not data_file.exists():
|
139
|
+
data_folder = sorted((self.path / "cifar-10-batches-bin").glob("*.bin"))
|
140
|
+
if not data_folder:
|
141
|
+
raise FileNotFoundError
|
142
|
+
return self._load_bin_data(data_folder)
|
143
|
+
|
144
|
+
# Load data
|
145
|
+
data = np.load(data_file)
|
146
|
+
self._loaded_data = data["images"]
|
147
|
+
all_labels = data["labels"]
|
148
|
+
batch_nums = data["batches"]
|
149
|
+
|
150
|
+
# Select data
|
151
|
+
image_list = np.arange(all_labels.shape[0]).astype(str)
|
152
|
+
if self.image_set == "train":
|
153
|
+
return (
|
154
|
+
image_list[np.nonzero(batch_nums != 5)[0]].tolist(),
|
155
|
+
all_labels[batch_nums != 5].tolist(),
|
156
|
+
{"batch_num": batch_nums[batch_nums != 5].tolist()},
|
157
|
+
)
|
158
|
+
if self.image_set == "test":
|
159
|
+
return (
|
160
|
+
image_list[np.nonzero(batch_nums == 5)[0]].tolist(),
|
161
|
+
all_labels[batch_nums == 5].tolist(),
|
162
|
+
{"batch_num": batch_nums[batch_nums == 5].tolist()},
|
163
|
+
)
|
164
|
+
return image_list.tolist(), all_labels.tolist(), {"batch_num": batch_nums.tolist()}
|
165
|
+
|
166
|
+
def _unpack_batch_files(self, file_path: Path) -> tuple[NDArray[np.uint8], NDArray[np.uint8]]:
|
136
167
|
# Load pickle data with latin1 encoding
|
137
168
|
with file_path.open("rb") as f:
|
138
|
-
buffer = np.frombuffer(f.read(),
|
139
|
-
|
140
|
-
|
141
|
-
|
142
|
-
|
169
|
+
buffer = np.frombuffer(f.read(), dtype=np.uint8)
|
170
|
+
# Each entry is 1 byte for label + 3072 bytes for image (3*32*32)
|
171
|
+
entry_size = 1 + 3072
|
172
|
+
num_entries = buffer.size // entry_size
|
173
|
+
# Extract labels (first byte of each entry)
|
174
|
+
labels = buffer[::entry_size]
|
175
|
+
|
176
|
+
# Extract image data and reshape to (N, 3, 32, 32)
|
177
|
+
images = np.zeros((num_entries, 3, 32, 32), dtype=np.uint8)
|
178
|
+
for i in range(num_entries):
|
179
|
+
# Skip the label byte and get image data for this entry
|
180
|
+
start_idx = i * entry_size + 1 # +1 to skip label
|
181
|
+
img_flat = buffer[start_idx : start_idx + 3072]
|
182
|
+
|
183
|
+
# The CIFAR format stores channels in blocks (all R, then all G, then all B)
|
184
|
+
# Each channel block is 1024 bytes (32x32)
|
185
|
+
red_channel = img_flat[0:1024].reshape(32, 32)
|
186
|
+
green_channel = img_flat[1024:2048].reshape(32, 32)
|
187
|
+
blue_channel = img_flat[2048:3072].reshape(32, 32)
|
188
|
+
|
189
|
+
# Stack the channels in the proper C×H×W format
|
190
|
+
images[i, 0] = red_channel # Red channel
|
191
|
+
images[i, 1] = green_channel # Green channel
|
192
|
+
images[i, 2] = blue_channel # Blue channel
|
193
|
+
return images, labels
|
194
|
+
|
195
|
+
def _read_file(self, path: str) -> NDArray[Any]:
|
196
|
+
"""
|
197
|
+
Function to grab the correct image from the loaded data.
|
198
|
+
Overwrite of the base `_read_file` because data is an all or nothing load.
|
199
|
+
"""
|
200
|
+
index = int(path)
|
201
|
+
return self._loaded_data[index]
|
@@ -3,7 +3,6 @@ from __future__ import annotations
|
|
3
3
|
__all__ = []
|
4
4
|
|
5
5
|
import hashlib
|
6
|
-
import shutil
|
7
6
|
import tarfile
|
8
7
|
import zipfile
|
9
8
|
from pathlib import Path
|
@@ -15,7 +14,12 @@ ARCHIVE_ENDINGS = [".zip", ".tar", ".tgz"]
|
|
15
14
|
COMPRESS_ENDINGS = [".gz", ".bz2"]
|
16
15
|
|
17
16
|
|
18
|
-
def
|
17
|
+
def _print(text: str, verbose: bool) -> None:
|
18
|
+
if verbose:
|
19
|
+
print(text)
|
20
|
+
|
21
|
+
|
22
|
+
def _validate_file(fpath: Path | str, file_md5: str, md5: bool = False, chunk_size: int = 65535) -> bool:
|
19
23
|
hasher = hashlib.md5(usedforsecurity=False) if md5 else hashlib.sha256()
|
20
24
|
with open(fpath, "rb") as fpath_file:
|
21
25
|
while chunk := fpath_file.read(chunk_size):
|
@@ -23,7 +27,7 @@ def _validate_file(fpath, file_md5, md5: bool = False, chunk_size=65535) -> bool
|
|
23
27
|
return hasher.hexdigest() == file_md5
|
24
28
|
|
25
29
|
|
26
|
-
def _download_dataset(url: str, file_path: Path, timeout: int = 60) -> None:
|
30
|
+
def _download_dataset(url: str, file_path: Path, timeout: int = 60, verbose: bool = False) -> None:
|
27
31
|
"""Download a single resource from its URL to the `data_folder`."""
|
28
32
|
error_msg = "URL fetch failure on {}: {} -- {}"
|
29
33
|
try:
|
@@ -36,7 +40,7 @@ def _download_dataset(url: str, file_path: Path, timeout: int = 60) -> None:
|
|
36
40
|
|
37
41
|
total_size = int(response.headers.get("content-length", 0))
|
38
42
|
block_size = 8192 # 8 KB
|
39
|
-
progress_bar = tqdm(total=total_size, unit="iB", unit_scale=True)
|
43
|
+
progress_bar = tqdm(total=total_size, unit="iB", unit_scale=True, disable=not verbose)
|
40
44
|
|
41
45
|
with open(file_path, "wb") as f:
|
42
46
|
for chunk in response.iter_content(block_size):
|
@@ -49,7 +53,7 @@ def _extract_zip_archive(file_path: Path, extract_to: Path) -> None:
|
|
49
53
|
"""Extracts the zip file to the given directory."""
|
50
54
|
try:
|
51
55
|
with zipfile.ZipFile(file_path, "r") as zip_ref:
|
52
|
-
zip_ref.extractall(extract_to)
|
56
|
+
zip_ref.extractall(extract_to) # noqa: S202
|
53
57
|
file_path.unlink()
|
54
58
|
except zipfile.BadZipFile:
|
55
59
|
raise FileNotFoundError(f"{file_path.name} is not a valid zip file, skipping extraction.")
|
@@ -59,36 +63,15 @@ def _extract_tar_archive(file_path: Path, extract_to: Path) -> None:
|
|
59
63
|
"""Extracts a tar file (or compressed tar) to the specified directory."""
|
60
64
|
try:
|
61
65
|
with tarfile.open(file_path, "r:*") as tar_ref:
|
62
|
-
tar_ref.extractall(extract_to)
|
66
|
+
tar_ref.extractall(extract_to) # noqa: S202
|
63
67
|
file_path.unlink()
|
64
68
|
except tarfile.TarError:
|
65
69
|
raise FileNotFoundError(f"{file_path.name} is not a valid tar file, skipping extraction.")
|
66
70
|
|
67
71
|
|
68
|
-
def
|
69
|
-
|
70
|
-
|
71
|
-
move all its subfolders to the dataset_dir and remove the now-empty folder.
|
72
|
-
"""
|
73
|
-
for child in base_directory.iterdir():
|
74
|
-
if child.is_dir():
|
75
|
-
inner_list = list(child.iterdir())
|
76
|
-
if all(subchild.is_dir() for subchild in inner_list):
|
77
|
-
for subchild in child.iterdir():
|
78
|
-
if verbose:
|
79
|
-
print(f"Moving {subchild.stem} to {base_directory}")
|
80
|
-
shutil.move(subchild, base_directory)
|
81
|
-
|
82
|
-
if verbose:
|
83
|
-
print(f"Removing empty folder {child.stem}")
|
84
|
-
child.rmdir()
|
85
|
-
|
86
|
-
# Checking for additional placeholder folders
|
87
|
-
if len(inner_list) == 1:
|
88
|
-
_flatten_extraction(base_directory, verbose)
|
89
|
-
|
90
|
-
|
91
|
-
def _archive_extraction(file_ext, file_path, directory, compression: bool = False, verbose: bool = False):
|
72
|
+
def _extract_archive(
|
73
|
+
file_ext: str, file_path: Path, directory: Path, compression: bool = False, verbose: bool = False
|
74
|
+
) -> None:
|
92
75
|
"""
|
93
76
|
Single function to extract and then flatten if necessary.
|
94
77
|
Recursively extracts nested zip files as well.
|
@@ -102,14 +85,9 @@ def _archive_extraction(file_ext, file_path, directory, compression: bool = Fals
|
|
102
85
|
# Does NOT extract in place - extracts everything to directory
|
103
86
|
for child in directory.iterdir():
|
104
87
|
if child.suffix == ".zip":
|
105
|
-
|
106
|
-
print(f"Extracting nested zip: {child} to {directory}")
|
88
|
+
_print(f"Extracting nested zip: {child} to {directory}", verbose)
|
107
89
|
_extract_zip_archive(child, directory)
|
108
90
|
|
109
|
-
# Determine if there are nested folders and remove them
|
110
|
-
# Helps ensure there that data is at most one folder below main directory
|
111
|
-
_flatten_extraction(directory, verbose)
|
112
|
-
|
113
91
|
|
114
92
|
def _ensure_exists(
|
115
93
|
url: str,
|
@@ -137,18 +115,16 @@ def _ensure_exists(
|
|
137
115
|
|
138
116
|
# Download file if it doesn't exist.
|
139
117
|
if not check_path.exists() and download:
|
140
|
-
|
141
|
-
|
142
|
-
_download_dataset(url, check_path)
|
118
|
+
_print(f"Downloading {filename} from {url}", verbose)
|
119
|
+
_download_dataset(url, check_path, verbose=verbose)
|
143
120
|
|
144
121
|
if not _validate_file(check_path, checksum, md5):
|
145
122
|
raise Exception("File checksum mismatch. Remove current file and retry download.")
|
146
123
|
|
147
124
|
# If the file is a zip, tar or tgz extract it into the designated folder.
|
148
125
|
if file_ext in ARCHIVE_ENDINGS:
|
149
|
-
|
150
|
-
|
151
|
-
_archive_extraction(file_ext, check_path, directory, compression, verbose)
|
126
|
+
_print(f"Extracting {filename}...", verbose)
|
127
|
+
_extract_archive(file_ext, check_path, directory, compression, verbose)
|
152
128
|
|
153
129
|
elif not check_path.exists() and not download:
|
154
130
|
raise FileNotFoundError(
|
@@ -159,10 +135,8 @@ def _ensure_exists(
|
|
159
135
|
else:
|
160
136
|
if not _validate_file(check_path, checksum, md5):
|
161
137
|
raise Exception("File checksum mismatch. Remove current file and retry download.")
|
162
|
-
|
163
|
-
print(f"{filename} already exists, skipping download.")
|
138
|
+
_print(f"{filename} already exists, skipping download.", verbose)
|
164
139
|
|
165
140
|
if file_ext in ARCHIVE_ENDINGS:
|
166
|
-
|
167
|
-
|
168
|
-
_archive_extraction(file_ext, check_path, directory, compression, verbose)
|
141
|
+
_print(f"Extracting {filename}...", verbose)
|
142
|
+
_extract_archive(file_ext, check_path, directory, compression, verbose)
|
@@ -38,7 +38,7 @@ class MILCO(BaseODDataset[NDArray[Any]], BaseDatasetNumpyMixin):
|
|
38
38
|
Parameters
|
39
39
|
----------
|
40
40
|
root : str or pathlib.Path
|
41
|
-
Root directory
|
41
|
+
Root directory where the data should be downloaded to or the ``milco`` folder of the already downloaded data.
|
42
42
|
image_set: "train", "operational", or "base", default "train"
|
43
43
|
If "train", then the images from 2015, 2017 and 2021 are selected,
|
44
44
|
resulting in 315 MILCO objects and 177 NOMBO objects.
|
@@ -128,6 +128,7 @@ class MILCO(BaseODDataset[NDArray[Any]], BaseDatasetNumpyMixin):
|
|
128
128
|
download,
|
129
129
|
verbose,
|
130
130
|
)
|
131
|
+
self._bboxes_per_size = True
|
131
132
|
|
132
133
|
def _load_data(self) -> tuple[list[str], list[str], dict[str, list[Any]]]:
|
133
134
|
filepaths: list[str] = []
|
@@ -160,15 +161,17 @@ class MILCO(BaseODDataset[NDArray[Any]], BaseDatasetNumpyMixin):
|
|
160
161
|
|
161
162
|
def _load_data_inner(self) -> tuple[list[str], list[str], dict[str, Any]]:
|
162
163
|
file_data = {"year": [], "image_id": [], "data_path": [], "label_path": []}
|
163
|
-
data_folder = self.path / self._resource.filename[:-4]
|
164
|
-
|
165
|
-
|
166
|
-
|
167
|
-
|
168
|
-
|
169
|
-
|
170
|
-
|
171
|
-
|
164
|
+
data_folder = sorted((self.path / self._resource.filename[:-4]).glob("*.jpg"))
|
165
|
+
if not data_folder:
|
166
|
+
raise FileNotFoundError
|
167
|
+
|
168
|
+
for entry in data_folder:
|
169
|
+
# Remove file extension and split by "_"
|
170
|
+
parts = entry.stem.split("_")
|
171
|
+
file_data["image_id"].append(parts[0])
|
172
|
+
file_data["year"].append(parts[1])
|
173
|
+
file_data["data_path"].append(str(entry))
|
174
|
+
file_data["label_path"].append(str(entry.parent / entry.stem) + ".txt")
|
172
175
|
data = file_data.pop("data_path")
|
173
176
|
annotations = file_data.pop("label_path")
|
174
177
|
|
@@ -180,8 +183,15 @@ class MILCO(BaseODDataset[NDArray[Any]], BaseDatasetNumpyMixin):
|
|
180
183
|
boxes: list[list[float]] = []
|
181
184
|
with open(annotation) as f:
|
182
185
|
for line in f.readlines():
|
183
|
-
out = line.strip().split(
|
186
|
+
out = line.strip().split()
|
184
187
|
labels.append(int(out[0]))
|
185
|
-
|
188
|
+
|
189
|
+
xcenter, ycenter, width, height = [float(out[1]), float(out[2]), float(out[3]), float(out[4])]
|
190
|
+
|
191
|
+
x0 = xcenter - width / 2
|
192
|
+
x1 = x0 + width
|
193
|
+
y0 = ycenter - height / 2
|
194
|
+
y1 = y0 + height
|
195
|
+
boxes.append([x0, y0, x1, y1])
|
186
196
|
|
187
197
|
return boxes, labels, {}
|
@@ -34,8 +34,7 @@ class BaseDatasetNumpyMixin(BaseDatasetMixin[NDArray[Any]]):
|
|
34
34
|
return encoded
|
35
35
|
|
36
36
|
def _read_file(self, path: str) -> NDArray[Any]:
|
37
|
-
|
38
|
-
return x
|
37
|
+
return np.array(Image.open(path)).transpose(2, 0, 1)
|
39
38
|
|
40
39
|
|
41
40
|
class BaseDatasetTorchMixin(BaseDatasetMixin[torch.Tensor]):
|
@@ -52,5 +51,4 @@ class BaseDatasetTorchMixin(BaseDatasetMixin[torch.Tensor]):
|
|
52
51
|
return encoded
|
53
52
|
|
54
53
|
def _read_file(self, path: str) -> torch.Tensor:
|
55
|
-
|
56
|
-
return x
|
54
|
+
return torch.as_tensor(np.array(Image.open(path)).transpose(2, 0, 1))
|
@@ -48,7 +48,7 @@ class MNIST(BaseICDataset[NDArray[Any]], BaseDatasetNumpyMixin):
|
|
48
48
|
Parameters
|
49
49
|
----------
|
50
50
|
root : str or pathlib.Path
|
51
|
-
Root directory
|
51
|
+
Root directory where the data should be downloaded to or the ``minst`` folder of the already downloaded data.
|
52
52
|
image_set : "train", "test" or "base", default "train"
|
53
53
|
If "base", returns all of the data to allow the user to create their own splits.
|
54
54
|
corruption : "identity", "shot_noise", "impulse_noise", "glass_blur", "motion_blur", \
|
@@ -154,7 +154,7 @@ class MNIST(BaseICDataset[NDArray[Any]], BaseDatasetNumpyMixin):
|
|
154
154
|
def _load_corruption(self) -> tuple[NDArray[Any], NDArray[np.uintp]]:
|
155
155
|
"""Function to load in the file paths for the data and labels for the different corrupt data formats"""
|
156
156
|
corruption = self.corruption if self.corruption is not None else "identity"
|
157
|
-
base_path = self.path / corruption
|
157
|
+
base_path = self.path / "mnist_c" / corruption
|
158
158
|
if self.image_set == "base":
|
159
159
|
raw_data = []
|
160
160
|
raw_labels = []
|
@@ -191,8 +191,7 @@ class MNIST(BaseICDataset[NDArray[Any]], BaseDatasetNumpyMixin):
|
|
191
191
|
|
192
192
|
def _grab_corruption_data(self, path: Path) -> NDArray[Any]:
|
193
193
|
"""Function to load in the data numpy array for the previously chosen corrupt format"""
|
194
|
-
|
195
|
-
return x
|
194
|
+
return np.load(path, allow_pickle=False)
|
196
195
|
|
197
196
|
def _read_file(self, path: str) -> NDArray[Any]:
|
198
197
|
"""
|