dataeval 0.86.9__py3-none-any.whl → 0.88.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dataeval/__init__.py +1 -1
- dataeval/_log.py +1 -1
- dataeval/_version.py +2 -2
- dataeval/config.py +4 -19
- dataeval/data/_embeddings.py +78 -35
- dataeval/data/_images.py +41 -8
- dataeval/data/_metadata.py +348 -66
- dataeval/data/_selection.py +22 -7
- dataeval/data/_split.py +3 -2
- dataeval/data/selections/_classbalance.py +4 -3
- dataeval/data/selections/_classfilter.py +9 -8
- dataeval/data/selections/_indices.py +4 -3
- dataeval/data/selections/_prioritize.py +249 -29
- dataeval/data/selections/_reverse.py +1 -1
- dataeval/data/selections/_shuffle.py +5 -4
- dataeval/detectors/drift/_base.py +2 -1
- dataeval/detectors/drift/_mmd.py +2 -1
- dataeval/detectors/drift/_nml/_base.py +1 -1
- dataeval/detectors/drift/_nml/_chunk.py +2 -1
- dataeval/detectors/drift/_nml/_result.py +3 -2
- dataeval/detectors/drift/_nml/_thresholds.py +6 -5
- dataeval/detectors/drift/_uncertainty.py +2 -1
- dataeval/detectors/linters/duplicates.py +2 -1
- dataeval/detectors/linters/outliers.py +4 -3
- dataeval/detectors/ood/__init__.py +2 -1
- dataeval/detectors/ood/ae.py +1 -1
- dataeval/detectors/ood/base.py +39 -1
- dataeval/detectors/ood/knn.py +95 -0
- dataeval/detectors/ood/mixin.py +2 -1
- dataeval/metadata/_utils.py +1 -1
- dataeval/metrics/bias/_balance.py +29 -22
- dataeval/metrics/bias/_diversity.py +4 -4
- dataeval/metrics/bias/_parity.py +2 -2
- dataeval/metrics/stats/_base.py +3 -29
- dataeval/metrics/stats/_boxratiostats.py +2 -1
- dataeval/metrics/stats/_dimensionstats.py +2 -1
- dataeval/metrics/stats/_hashstats.py +21 -3
- dataeval/metrics/stats/_pixelstats.py +2 -1
- dataeval/metrics/stats/_visualstats.py +2 -1
- dataeval/outputs/_base.py +2 -3
- dataeval/outputs/_bias.py +2 -1
- dataeval/outputs/_estimators.py +1 -1
- dataeval/outputs/_linters.py +3 -3
- dataeval/outputs/_stats.py +3 -3
- dataeval/outputs/_utils.py +1 -1
- dataeval/outputs/_workflows.py +49 -31
- dataeval/typing.py +23 -9
- dataeval/utils/__init__.py +2 -2
- dataeval/utils/_array.py +3 -2
- dataeval/utils/_bin.py +9 -7
- dataeval/utils/_method.py +2 -3
- dataeval/utils/_multiprocessing.py +34 -0
- dataeval/utils/_plot.py +2 -1
- dataeval/utils/data/__init__.py +6 -5
- dataeval/utils/data/{metadata.py → _merge.py} +3 -2
- dataeval/utils/data/_validate.py +170 -0
- dataeval/utils/data/collate.py +2 -1
- dataeval/utils/torch/_internal.py +2 -1
- dataeval/utils/torch/trainer.py +1 -1
- dataeval/workflows/sufficiency.py +13 -9
- {dataeval-0.86.9.dist-info → dataeval-0.88.0.dist-info}/METADATA +8 -21
- dataeval-0.88.0.dist-info/RECORD +105 -0
- dataeval/utils/data/_dataset.py +0 -246
- dataeval/utils/datasets/__init__.py +0 -21
- dataeval/utils/datasets/_antiuav.py +0 -189
- dataeval/utils/datasets/_base.py +0 -266
- dataeval/utils/datasets/_cifar10.py +0 -201
- dataeval/utils/datasets/_fileio.py +0 -142
- dataeval/utils/datasets/_milco.py +0 -197
- dataeval/utils/datasets/_mixin.py +0 -54
- dataeval/utils/datasets/_mnist.py +0 -202
- dataeval/utils/datasets/_seadrone.py +0 -512
- dataeval/utils/datasets/_ships.py +0 -144
- dataeval/utils/datasets/_types.py +0 -48
- dataeval/utils/datasets/_voc.py +0 -583
- dataeval-0.86.9.dist-info/RECORD +0 -115
- {dataeval-0.86.9.dist-info → dataeval-0.88.0.dist-info}/WHEEL +0 -0
- /dataeval-0.86.9.dist-info/licenses/LICENSE.txt → /dataeval-0.88.0.dist-info/licenses/LICENSE +0 -0
@@ -1,189 +0,0 @@
|
|
1
|
-
from __future__ import annotations
|
2
|
-
|
3
|
-
__all__ = []
|
4
|
-
|
5
|
-
from pathlib import Path
|
6
|
-
from typing import TYPE_CHECKING, Any, Literal, Sequence
|
7
|
-
|
8
|
-
from defusedxml.ElementTree import parse
|
9
|
-
from numpy.typing import NDArray
|
10
|
-
|
11
|
-
from dataeval.utils.datasets._base import BaseODDataset, DataLocation
|
12
|
-
from dataeval.utils.datasets._mixin import BaseDatasetNumpyMixin
|
13
|
-
|
14
|
-
if TYPE_CHECKING:
|
15
|
-
from dataeval.typing import Transform
|
16
|
-
|
17
|
-
|
18
|
-
class AntiUAVDetection(BaseODDataset[NDArray[Any], list[str], str], BaseDatasetNumpyMixin):
|
19
|
-
"""
|
20
|
-
A UAV detection dataset focused on detecting UAVs in natural images against large variation in backgrounds.
|
21
|
-
|
22
|
-
The dataset comes from the paper
|
23
|
-
`Vision-based Anti-UAV Detection and Tracking <https://ieeexplore.ieee.org/document/9785379>`_
|
24
|
-
by Jie Zhao et. al. (2022).
|
25
|
-
|
26
|
-
The dataset is approximately 1.3 GB and can be found `here <https://github.com/wangdongdut/DUT-Anti-UAV>`_.
|
27
|
-
Images are collected against a variety of different backgrounds with a variety in the number and type of UAV.
|
28
|
-
Ground truth labels are provided for the train, validation and test set.
|
29
|
-
There are 35 different types of drones along with a variety in lighting conditions and weather conditions.
|
30
|
-
|
31
|
-
There are 10,000 images: 5200 images in the training set, 2200 images in the validation set,
|
32
|
-
and 2600 images in the test set.
|
33
|
-
The dataset only has a single UAV class with the focus being on identifying object location in the image.
|
34
|
-
Ground-truth bounding boxes are provided in (x0, y0, x1, y1) format.
|
35
|
-
The images come in a variety of sizes from 3744 x 5616 to 160 x 240.
|
36
|
-
|
37
|
-
Parameters
|
38
|
-
----------
|
39
|
-
root : str or pathlib.Path
|
40
|
-
Root directory where the data should be downloaded to or
|
41
|
-
the ``antiuavdetection`` folder of the already downloaded data.
|
42
|
-
image_set: "train", "val", "test", or "base", default "train"
|
43
|
-
If "base", then the full dataset is selected (train, val and test).
|
44
|
-
transforms : Transform, Sequence[Transform] or None, default None
|
45
|
-
Transform(s) to apply to the data.
|
46
|
-
download : bool, default False
|
47
|
-
If True, downloads the dataset from the internet and puts it in root directory.
|
48
|
-
Class checks to see if data is already downloaded to ensure it does not create a duplicate download.
|
49
|
-
verbose : bool, default False
|
50
|
-
If True, outputs print statements.
|
51
|
-
|
52
|
-
Attributes
|
53
|
-
----------
|
54
|
-
path : pathlib.Path
|
55
|
-
Location of the folder containing the data.
|
56
|
-
image_set : "train", "val", "test", or "base"
|
57
|
-
The selected image set from the dataset.
|
58
|
-
index2label : dict[int, str]
|
59
|
-
Dictionary which translates from class integers to the associated class strings.
|
60
|
-
label2index : dict[str, int]
|
61
|
-
Dictionary which translates from class strings to the associated class integers.
|
62
|
-
metadata : DatasetMetadata
|
63
|
-
Typed dictionary containing dataset metadata, such as `id` which returns the dataset class name.
|
64
|
-
transforms : Sequence[Transform]
|
65
|
-
The transforms to be applied to the data.
|
66
|
-
size : int
|
67
|
-
The size of the dataset.
|
68
|
-
|
69
|
-
Note
|
70
|
-
----
|
71
|
-
Data License: `Apache 2.0 <https://www.apache.org/licenses/LICENSE-2.0.txt>`_
|
72
|
-
"""
|
73
|
-
|
74
|
-
# Need to run the sha256 on the files and then store that
|
75
|
-
_resources = [
|
76
|
-
DataLocation(
|
77
|
-
url="https://drive.usercontent.google.com/download?id=1RVsSGPUKTdmoyoPTBTWwroyulLek1eTj&export=download&authuser=0&confirm=t&uuid=6bca4f94-a242-4bc2-9663-fb03cd94ef2c&at=APcmpox0--NroQ_3bqeTFaJxP7Pw%3A1746552902927",
|
78
|
-
filename="train.zip",
|
79
|
-
md5=False,
|
80
|
-
checksum="14f927290556df60e23cedfa80dffc10dc21e4a3b6843e150cfc49644376eece",
|
81
|
-
),
|
82
|
-
DataLocation(
|
83
|
-
url="https://drive.usercontent.google.com/download?id=1333uEQfGuqTKslRkkeLSCxylh6AQ0X6n&export=download&authuser=0&confirm=t&uuid=c2ad2f01-aca8-4a85-96bb-b8ef6e40feea&at=APcmpozY-8bhk3nZSFaYbE8rq1Fi%3A1746551543297",
|
84
|
-
filename="val.zip",
|
85
|
-
md5=False,
|
86
|
-
checksum="238be0ceb3e7c5be6711ee3247e49df2750d52f91f54f5366c68bebac112ebf8",
|
87
|
-
),
|
88
|
-
DataLocation(
|
89
|
-
url="https://drive.usercontent.google.com/download?id=1L1zeW1EMDLlXHClSDcCjl3rs_A6sVai0&export=download&authuser=0&confirm=t&uuid=5a1d7650-d8cd-4461-8354-7daf7292f06c&at=APcmpozLQC1CuP-n5_UX2JnP53Zo%3A1746551676177",
|
90
|
-
filename="test.zip",
|
91
|
-
md5=False,
|
92
|
-
checksum="a671989a01cff98c684aeb084e59b86f4152c50499d86152eb970a9fc7fb1cbe",
|
93
|
-
),
|
94
|
-
]
|
95
|
-
|
96
|
-
index2label: dict[int, str] = {
|
97
|
-
0: "unknown",
|
98
|
-
1: "UAV",
|
99
|
-
}
|
100
|
-
|
101
|
-
def __init__(
|
102
|
-
self,
|
103
|
-
root: str | Path,
|
104
|
-
image_set: Literal["train", "val", "test", "base"] = "train",
|
105
|
-
transforms: Transform[NDArray[Any]] | Sequence[Transform[NDArray[Any]]] | None = None,
|
106
|
-
download: bool = False,
|
107
|
-
verbose: bool = False,
|
108
|
-
) -> None:
|
109
|
-
super().__init__(
|
110
|
-
root,
|
111
|
-
image_set,
|
112
|
-
transforms,
|
113
|
-
download,
|
114
|
-
verbose,
|
115
|
-
)
|
116
|
-
|
117
|
-
def _load_data(self) -> tuple[list[str], list[str], dict[str, list[Any]]]:
|
118
|
-
filepaths: list[str] = []
|
119
|
-
targets: list[str] = []
|
120
|
-
datum_metadata: dict[str, list[Any]] = {}
|
121
|
-
|
122
|
-
# If base, load all resources
|
123
|
-
if self.image_set == "base":
|
124
|
-
metadata_list: list[dict[str, Any]] = []
|
125
|
-
|
126
|
-
for resource in self._resources:
|
127
|
-
self._resource = resource
|
128
|
-
resource_filepaths, resource_targets, resource_metadata = super()._load_data()
|
129
|
-
filepaths.extend(resource_filepaths)
|
130
|
-
targets.extend(resource_targets)
|
131
|
-
metadata_list.append(resource_metadata)
|
132
|
-
|
133
|
-
# Combine metadata
|
134
|
-
for data_dict in metadata_list:
|
135
|
-
for key, val in data_dict.items():
|
136
|
-
str_key = str(key) # Ensure key is string
|
137
|
-
if str_key not in datum_metadata:
|
138
|
-
datum_metadata[str_key] = []
|
139
|
-
datum_metadata[str_key].extend(val)
|
140
|
-
|
141
|
-
else:
|
142
|
-
# Grab only the desired data
|
143
|
-
for resource in self._resources:
|
144
|
-
if self.image_set in resource.filename:
|
145
|
-
self._resource = resource
|
146
|
-
resource_filepaths, resource_targets, resource_metadata = super()._load_data()
|
147
|
-
filepaths.extend(resource_filepaths)
|
148
|
-
targets.extend(resource_targets)
|
149
|
-
datum_metadata.update(resource_metadata)
|
150
|
-
|
151
|
-
return filepaths, targets, datum_metadata
|
152
|
-
|
153
|
-
def _load_data_inner(self) -> tuple[list[str], list[str], dict[str, Any]]:
|
154
|
-
resource_name = self._resource.filename[:-4]
|
155
|
-
base_dir = self.path / resource_name
|
156
|
-
data_folder = sorted((base_dir / "img").glob("*.jpg"))
|
157
|
-
if not data_folder:
|
158
|
-
raise FileNotFoundError
|
159
|
-
|
160
|
-
file_data = {"image_id": [f"{resource_name}_{entry.name}" for entry in data_folder]}
|
161
|
-
data = [str(entry) for entry in data_folder]
|
162
|
-
annotations = sorted(str(entry) for entry in (base_dir / "xml").glob("*.xml"))
|
163
|
-
|
164
|
-
return data, annotations, file_data
|
165
|
-
|
166
|
-
def _read_annotations(self, annotation: str) -> tuple[list[list[float]], list[int], dict[str, Any]]:
|
167
|
-
"""Function for extracting the info for the label and boxes"""
|
168
|
-
boxes: list[list[float]] = []
|
169
|
-
labels = []
|
170
|
-
root = parse(annotation).getroot()
|
171
|
-
if root is None:
|
172
|
-
raise ValueError(f"Unable to parse {annotation}")
|
173
|
-
additional_meta: dict[str, Any] = {
|
174
|
-
"image_width": int(root.findtext("size/width", default="-1")),
|
175
|
-
"image_height": int(root.findtext("size/height", default="-1")),
|
176
|
-
"image_depth": int(root.findtext("size/depth", default="-1")),
|
177
|
-
}
|
178
|
-
for obj in root.findall("object"):
|
179
|
-
labels.append(1 if obj.findtext("name", default="") == "UAV" else 0)
|
180
|
-
boxes.append(
|
181
|
-
[
|
182
|
-
float(obj.findtext("bndbox/xmin", default="0")),
|
183
|
-
float(obj.findtext("bndbox/ymin", default="0")),
|
184
|
-
float(obj.findtext("bndbox/xmax", default="0")),
|
185
|
-
float(obj.findtext("bndbox/ymax", default="0")),
|
186
|
-
]
|
187
|
-
)
|
188
|
-
|
189
|
-
return boxes, labels, additional_meta
|
dataeval/utils/datasets/_base.py
DELETED
@@ -1,266 +0,0 @@
|
|
1
|
-
from __future__ import annotations
|
2
|
-
|
3
|
-
__all__ = []
|
4
|
-
|
5
|
-
from abc import abstractmethod
|
6
|
-
from pathlib import Path
|
7
|
-
from typing import TYPE_CHECKING, Any, Generic, Iterator, Literal, NamedTuple, Sequence, TypeVar, cast
|
8
|
-
|
9
|
-
import numpy as np
|
10
|
-
|
11
|
-
from dataeval.utils.datasets._fileio import _ensure_exists
|
12
|
-
from dataeval.utils.datasets._mixin import BaseDatasetMixin
|
13
|
-
from dataeval.utils.datasets._types import (
|
14
|
-
AnnotatedDataset,
|
15
|
-
DatasetMetadata,
|
16
|
-
ImageClassificationDataset,
|
17
|
-
ObjectDetectionDataset,
|
18
|
-
ObjectDetectionTarget,
|
19
|
-
SegmentationDataset,
|
20
|
-
SegmentationTarget,
|
21
|
-
)
|
22
|
-
|
23
|
-
if TYPE_CHECKING:
|
24
|
-
from dataeval.typing import Array, Transform
|
25
|
-
|
26
|
-
_TArray = TypeVar("_TArray", bound=Array)
|
27
|
-
else:
|
28
|
-
_TArray = TypeVar("_TArray")
|
29
|
-
|
30
|
-
_TTarget = TypeVar("_TTarget")
|
31
|
-
_TRawTarget = TypeVar("_TRawTarget", Sequence[int], Sequence[str], Sequence[tuple[list[int], list[list[float]]]])
|
32
|
-
_TAnnotation = TypeVar("_TAnnotation", int, str, tuple[list[int], list[list[float]]])
|
33
|
-
|
34
|
-
|
35
|
-
class DataLocation(NamedTuple):
|
36
|
-
url: str
|
37
|
-
filename: str
|
38
|
-
md5: bool
|
39
|
-
checksum: str
|
40
|
-
|
41
|
-
|
42
|
-
class BaseDataset(
|
43
|
-
AnnotatedDataset[tuple[_TArray, _TTarget, dict[str, Any]]], Generic[_TArray, _TTarget, _TRawTarget, _TAnnotation]
|
44
|
-
):
|
45
|
-
"""
|
46
|
-
Base class for internet downloaded datasets.
|
47
|
-
"""
|
48
|
-
|
49
|
-
# Each subclass should override the attributes below.
|
50
|
-
# Each resource tuple must contain:
|
51
|
-
# 'url': str, the URL to download from
|
52
|
-
# 'filename': str, the name of the file once downloaded
|
53
|
-
# 'md5': boolean, True if it's the checksum value is md5
|
54
|
-
# 'checksum': str, the associated checksum for the downloaded file
|
55
|
-
_resources: list[DataLocation]
|
56
|
-
_resource_index: int = 0
|
57
|
-
index2label: dict[int, str]
|
58
|
-
|
59
|
-
def __init__(
|
60
|
-
self,
|
61
|
-
root: str | Path,
|
62
|
-
image_set: Literal["train", "val", "test", "operational", "base"] = "train",
|
63
|
-
transforms: Transform[_TArray] | Sequence[Transform[_TArray]] | None = None,
|
64
|
-
download: bool = False,
|
65
|
-
verbose: bool = False,
|
66
|
-
) -> None:
|
67
|
-
self._root: Path = root.absolute() if isinstance(root, Path) else Path(root).absolute()
|
68
|
-
transforms = transforms if transforms is not None else []
|
69
|
-
self.transforms: Sequence[Transform[_TArray]] = transforms if isinstance(transforms, Sequence) else [transforms]
|
70
|
-
self.image_set = image_set
|
71
|
-
self._verbose = verbose
|
72
|
-
|
73
|
-
# Internal Attributes
|
74
|
-
self._download = download
|
75
|
-
self._filepaths: list[str]
|
76
|
-
self._targets: _TRawTarget
|
77
|
-
self._datum_metadata: dict[str, list[Any]]
|
78
|
-
self._resource: DataLocation = self._resources[self._resource_index]
|
79
|
-
self._label2index = {v: k for k, v in self.index2label.items()}
|
80
|
-
|
81
|
-
self.metadata: DatasetMetadata = DatasetMetadata(
|
82
|
-
id=self._unique_id(),
|
83
|
-
index2label=self.index2label,
|
84
|
-
split=self.image_set,
|
85
|
-
)
|
86
|
-
|
87
|
-
# Load the data
|
88
|
-
self.path: Path = self._get_dataset_dir()
|
89
|
-
self._filepaths, self._targets, self._datum_metadata = self._load_data()
|
90
|
-
self.size: int = len(self._filepaths)
|
91
|
-
|
92
|
-
def __str__(self) -> str:
|
93
|
-
nt = "\n "
|
94
|
-
title = f"{self.__class__.__name__} Dataset"
|
95
|
-
sep = "-" * len(title)
|
96
|
-
attrs = [f"{k.capitalize()}: {v}" for k, v in self.__dict__.items() if not k.startswith("_")]
|
97
|
-
return f"{title}\n{sep}{nt}{nt.join(attrs)}"
|
98
|
-
|
99
|
-
@property
|
100
|
-
def label2index(self) -> dict[str, int]:
|
101
|
-
return self._label2index
|
102
|
-
|
103
|
-
def __iter__(self) -> Iterator[tuple[_TArray, _TTarget, dict[str, Any]]]:
|
104
|
-
for i in range(len(self)):
|
105
|
-
yield self[i]
|
106
|
-
|
107
|
-
def _get_dataset_dir(self) -> Path:
|
108
|
-
# Create a designated folder for this dataset (named after the class)
|
109
|
-
if self._root.stem.lower() == self.__class__.__name__.lower():
|
110
|
-
dataset_dir: Path = self._root
|
111
|
-
else:
|
112
|
-
dataset_dir: Path = self._root / self.__class__.__name__.lower()
|
113
|
-
if not dataset_dir.exists():
|
114
|
-
dataset_dir.mkdir(parents=True, exist_ok=True)
|
115
|
-
return dataset_dir
|
116
|
-
|
117
|
-
def _unique_id(self) -> str:
|
118
|
-
return f"{self.__class__.__name__}_{self.image_set}"
|
119
|
-
|
120
|
-
def _load_data(self) -> tuple[list[str], _TRawTarget, dict[str, Any]]:
|
121
|
-
"""
|
122
|
-
Function to determine if data can be accessed or if it needs to be downloaded and/or extracted.
|
123
|
-
"""
|
124
|
-
if self._verbose:
|
125
|
-
print(f"Determining if {self._resource.filename} needs to be downloaded.")
|
126
|
-
|
127
|
-
try:
|
128
|
-
result = self._load_data_inner()
|
129
|
-
if self._verbose:
|
130
|
-
print("No download needed, loaded data successfully.")
|
131
|
-
except FileNotFoundError:
|
132
|
-
_ensure_exists(*self._resource, self.path, self._root, self._download, self._verbose)
|
133
|
-
result = self._load_data_inner()
|
134
|
-
return result
|
135
|
-
|
136
|
-
@abstractmethod
|
137
|
-
def _load_data_inner(self) -> tuple[list[str], _TRawTarget, dict[str, Any]]: ...
|
138
|
-
|
139
|
-
def _transform(self, image: _TArray) -> _TArray:
|
140
|
-
"""Function to transform the image prior to returning based on parameters passed in."""
|
141
|
-
for transform in self.transforms:
|
142
|
-
image = transform(image)
|
143
|
-
return image
|
144
|
-
|
145
|
-
def __len__(self) -> int:
|
146
|
-
return self.size
|
147
|
-
|
148
|
-
|
149
|
-
class BaseICDataset(
|
150
|
-
BaseDataset[_TArray, _TArray, list[int], int],
|
151
|
-
BaseDatasetMixin[_TArray],
|
152
|
-
ImageClassificationDataset[_TArray],
|
153
|
-
):
|
154
|
-
"""
|
155
|
-
Base class for image classification datasets.
|
156
|
-
"""
|
157
|
-
|
158
|
-
def __getitem__(self, index: int) -> tuple[_TArray, _TArray, dict[str, Any]]:
|
159
|
-
"""
|
160
|
-
Args
|
161
|
-
----
|
162
|
-
index : int
|
163
|
-
Value of the desired data point
|
164
|
-
|
165
|
-
Returns
|
166
|
-
-------
|
167
|
-
tuple[TArray, TArray, dict[str, Any]]
|
168
|
-
Image, target, datum_metadata - where target is one-hot encoding of class.
|
169
|
-
"""
|
170
|
-
# Get the associated label and score
|
171
|
-
label = self._targets[index]
|
172
|
-
score = self._one_hot_encode(label)
|
173
|
-
# Get the image
|
174
|
-
img = self._read_file(self._filepaths[index])
|
175
|
-
img = self._transform(img)
|
176
|
-
|
177
|
-
img_metadata = {key: val[index] for key, val in self._datum_metadata.items()}
|
178
|
-
|
179
|
-
return img, score, img_metadata
|
180
|
-
|
181
|
-
|
182
|
-
class BaseODDataset(
|
183
|
-
BaseDataset[_TArray, ObjectDetectionTarget[_TArray], _TRawTarget, _TAnnotation],
|
184
|
-
BaseDatasetMixin[_TArray],
|
185
|
-
ObjectDetectionDataset[_TArray],
|
186
|
-
):
|
187
|
-
"""
|
188
|
-
Base class for object detection datasets.
|
189
|
-
"""
|
190
|
-
|
191
|
-
_bboxes_per_size: bool = False
|
192
|
-
|
193
|
-
def __getitem__(self, index: int) -> tuple[_TArray, ObjectDetectionTarget[_TArray], dict[str, Any]]:
|
194
|
-
"""
|
195
|
-
Args
|
196
|
-
----
|
197
|
-
index : int
|
198
|
-
Value of the desired data point
|
199
|
-
|
200
|
-
Returns
|
201
|
-
-------
|
202
|
-
tuple[TArray, ObjectDetectionTarget[TArray], dict[str, Any]]
|
203
|
-
Image, target, datum_metadata - target.boxes returns boxes in x0, y0, x1, y1 format
|
204
|
-
"""
|
205
|
-
# Grab the bounding boxes and labels from the annotations
|
206
|
-
annotation = cast(_TAnnotation, self._targets[index])
|
207
|
-
boxes, labels, additional_metadata = self._read_annotations(annotation)
|
208
|
-
# Get the image
|
209
|
-
img = self._read_file(self._filepaths[index])
|
210
|
-
img_size = img.shape
|
211
|
-
img = self._transform(img)
|
212
|
-
# Adjust labels if necessary
|
213
|
-
if self._bboxes_per_size and boxes:
|
214
|
-
boxes = boxes * np.array([[img_size[1], img_size[2], img_size[1], img_size[2]]])
|
215
|
-
# Create the Object Detection Target
|
216
|
-
target = ObjectDetectionTarget(self._as_array(boxes), self._as_array(labels), self._one_hot_encode(labels))
|
217
|
-
|
218
|
-
img_metadata = {key: val[index] for key, val in self._datum_metadata.items()}
|
219
|
-
img_metadata = img_metadata | additional_metadata
|
220
|
-
|
221
|
-
return img, target, img_metadata
|
222
|
-
|
223
|
-
@abstractmethod
|
224
|
-
def _read_annotations(self, annotation: _TAnnotation) -> tuple[list[list[float]], list[int], dict[str, Any]]: ...
|
225
|
-
|
226
|
-
|
227
|
-
class BaseSegDataset(
|
228
|
-
BaseDataset[_TArray, SegmentationTarget[_TArray], list[str], str],
|
229
|
-
BaseDatasetMixin[_TArray],
|
230
|
-
SegmentationDataset[_TArray],
|
231
|
-
):
|
232
|
-
"""
|
233
|
-
Base class for segmentation datasets.
|
234
|
-
"""
|
235
|
-
|
236
|
-
_masks: Sequence[str]
|
237
|
-
|
238
|
-
def __getitem__(self, index: int) -> tuple[_TArray, SegmentationTarget[_TArray], dict[str, Any]]:
|
239
|
-
"""
|
240
|
-
Args
|
241
|
-
----
|
242
|
-
index : int
|
243
|
-
Value of the desired data point
|
244
|
-
|
245
|
-
Returns
|
246
|
-
-------
|
247
|
-
tuple[TArray, SegmentationTarget[TArray], dict[str, Any]]
|
248
|
-
Image, target, datum_metadata - target.mask returns the ground truth mask
|
249
|
-
"""
|
250
|
-
# Grab the labels from the annotations
|
251
|
-
_, labels, additional_metadata = self._read_annotations(self._targets[index])
|
252
|
-
# Grab the ground truth masks
|
253
|
-
mask = self._read_file(self._masks[index])
|
254
|
-
# Get the image
|
255
|
-
img = self._read_file(self._filepaths[index])
|
256
|
-
img = self._transform(img)
|
257
|
-
|
258
|
-
target = SegmentationTarget(mask, self._as_array(labels), self._one_hot_encode(labels))
|
259
|
-
|
260
|
-
img_metadata = {key: val[index] for key, val in self._datum_metadata.items()}
|
261
|
-
img_metadata = img_metadata | additional_metadata
|
262
|
-
|
263
|
-
return img, target, img_metadata
|
264
|
-
|
265
|
-
@abstractmethod
|
266
|
-
def _read_annotations(self, annotation: str) -> tuple[list[list[float]], list[int], dict[str, Any]]: ...
|
@@ -1,201 +0,0 @@
|
|
1
|
-
from __future__ import annotations
|
2
|
-
|
3
|
-
__all__ = []
|
4
|
-
|
5
|
-
from pathlib import Path
|
6
|
-
from typing import TYPE_CHECKING, Any, Literal, Sequence, TypeVar
|
7
|
-
|
8
|
-
import numpy as np
|
9
|
-
from numpy.typing import NDArray
|
10
|
-
|
11
|
-
from dataeval.utils.datasets._base import BaseICDataset, DataLocation
|
12
|
-
from dataeval.utils.datasets._mixin import BaseDatasetNumpyMixin
|
13
|
-
|
14
|
-
if TYPE_CHECKING:
|
15
|
-
from dataeval.typing import Transform
|
16
|
-
|
17
|
-
CIFARClassStringMap = Literal["airplane", "automobile", "bird", "cat", "deer", "dog", "frog", "horse", "ship", "truck"]
|
18
|
-
TCIFARClassMap = TypeVar("TCIFARClassMap", CIFARClassStringMap, int, list[CIFARClassStringMap], list[int])
|
19
|
-
|
20
|
-
|
21
|
-
class CIFAR10(BaseICDataset[NDArray[Any]], BaseDatasetNumpyMixin):
|
22
|
-
"""
|
23
|
-
`CIFAR10 <https://www.cs.toronto.edu/~kriz/cifar.html>`_ Dataset as NumPy arrays.
|
24
|
-
|
25
|
-
Parameters
|
26
|
-
----------
|
27
|
-
root : str or pathlib.Path
|
28
|
-
Root directory where the data should be downloaded to or the ``cifar10`` folder of the already downloaded data.
|
29
|
-
image_set : "train", "test" or "base", default "train"
|
30
|
-
If "base", returns all of the data to allow the user to create their own splits.
|
31
|
-
transforms : Transform, Sequence[Transform] or None, default None
|
32
|
-
Transform(s) to apply to the data.
|
33
|
-
download : bool, default False
|
34
|
-
If True, downloads the dataset from the internet and puts it in root directory.
|
35
|
-
Class checks to see if data is already downloaded to ensure it does not create a duplicate download.
|
36
|
-
verbose : bool, default False
|
37
|
-
If True, outputs print statements.
|
38
|
-
|
39
|
-
Attributes
|
40
|
-
----------
|
41
|
-
path : pathlib.Path
|
42
|
-
Location of the folder containing the data.
|
43
|
-
image_set : "train", "test" or "base"
|
44
|
-
The selected image set from the dataset.
|
45
|
-
transforms : Sequence[Transform]
|
46
|
-
The transforms to be applied to the data.
|
47
|
-
size : int
|
48
|
-
The size of the dataset.
|
49
|
-
index2label : dict[int, str]
|
50
|
-
Dictionary which translates from class integers to the associated class strings.
|
51
|
-
label2index : dict[str, int]
|
52
|
-
Dictionary which translates from class strings to the associated class integers.
|
53
|
-
metadata : DatasetMetadata
|
54
|
-
Typed dictionary containing dataset metadata, such as `id` which returns the dataset class name.
|
55
|
-
"""
|
56
|
-
|
57
|
-
_resources = [
|
58
|
-
DataLocation(
|
59
|
-
url="https://www.cs.toronto.edu/~kriz/cifar-10-binary.tar.gz",
|
60
|
-
filename="cifar-10-binary.tar.gz",
|
61
|
-
md5=True,
|
62
|
-
checksum="c32a1d4ab5d03f1284b67883e8d87530",
|
63
|
-
),
|
64
|
-
]
|
65
|
-
|
66
|
-
index2label: dict[int, str] = {
|
67
|
-
0: "airplane",
|
68
|
-
1: "automobile",
|
69
|
-
2: "bird",
|
70
|
-
3: "cat",
|
71
|
-
4: "deer",
|
72
|
-
5: "dog",
|
73
|
-
6: "frog",
|
74
|
-
7: "horse",
|
75
|
-
8: "ship",
|
76
|
-
9: "truck",
|
77
|
-
}
|
78
|
-
|
79
|
-
def __init__(
|
80
|
-
self,
|
81
|
-
root: str | Path,
|
82
|
-
image_set: Literal["train", "test", "base"] = "train",
|
83
|
-
transforms: Transform[NDArray[Any]] | Sequence[Transform[NDArray[Any]]] | None = None,
|
84
|
-
download: bool = False,
|
85
|
-
verbose: bool = False,
|
86
|
-
) -> None:
|
87
|
-
super().__init__(
|
88
|
-
root,
|
89
|
-
image_set,
|
90
|
-
transforms,
|
91
|
-
download,
|
92
|
-
verbose,
|
93
|
-
)
|
94
|
-
|
95
|
-
def _load_bin_data(self, data_folder: list[Path]) -> tuple[list[str], list[int], dict[str, Any]]:
|
96
|
-
batch_nums = np.zeros(60000, dtype=np.uint8)
|
97
|
-
all_labels = np.zeros(60000, dtype=np.uint8)
|
98
|
-
all_images = np.zeros((60000, 3, 32, 32), dtype=np.uint8)
|
99
|
-
# Process each batch file, skipping .meta and .html files
|
100
|
-
for batch_file in data_folder:
|
101
|
-
# Get batch parameters
|
102
|
-
batch_type = "test" if "test" in batch_file.stem else "train"
|
103
|
-
batch_num = 5 if batch_type == "test" else int(batch_file.stem.split("_")[-1]) - 1
|
104
|
-
|
105
|
-
# Load data
|
106
|
-
batch_images, batch_labels = self._unpack_batch_files(batch_file)
|
107
|
-
|
108
|
-
# Stack data
|
109
|
-
num_images = batch_images.shape[0]
|
110
|
-
batch_start = batch_num * num_images
|
111
|
-
all_images[batch_start : batch_start + num_images] = batch_images
|
112
|
-
all_labels[batch_start : batch_start + num_images] = batch_labels
|
113
|
-
batch_nums[batch_start : batch_start + num_images] = batch_num
|
114
|
-
|
115
|
-
# Save data
|
116
|
-
self._loaded_data = all_images
|
117
|
-
np.savez(self.path / "cifar10", images=self._loaded_data, labels=all_labels, batches=batch_nums)
|
118
|
-
|
119
|
-
# Select data
|
120
|
-
image_list = np.arange(all_labels.shape[0]).astype(str)
|
121
|
-
if self.image_set == "train":
|
122
|
-
return (
|
123
|
-
image_list[np.nonzero(batch_nums != 5)[0]].tolist(),
|
124
|
-
all_labels[batch_nums != 5].tolist(),
|
125
|
-
{"batch_num": batch_nums[batch_nums != 5].tolist()},
|
126
|
-
)
|
127
|
-
if self.image_set == "test":
|
128
|
-
return (
|
129
|
-
image_list[np.nonzero(batch_nums == 5)[0]].tolist(),
|
130
|
-
all_labels[batch_nums == 5].tolist(),
|
131
|
-
{"batch_num": batch_nums[batch_nums == 5].tolist()},
|
132
|
-
)
|
133
|
-
return image_list.tolist(), all_labels.tolist(), {"batch_num": batch_nums.tolist()}
|
134
|
-
|
135
|
-
def _load_data_inner(self) -> tuple[list[str], list[int], dict[str, Any]]:
|
136
|
-
"""Function to load in the file paths for the data and labels and retrieve metadata"""
|
137
|
-
data_file = self.path / "cifar10.npz"
|
138
|
-
if not data_file.exists():
|
139
|
-
data_folder = sorted((self.path / "cifar-10-batches-bin").glob("*.bin"))
|
140
|
-
if not data_folder:
|
141
|
-
raise FileNotFoundError
|
142
|
-
return self._load_bin_data(data_folder)
|
143
|
-
|
144
|
-
# Load data
|
145
|
-
data = np.load(data_file)
|
146
|
-
self._loaded_data = data["images"]
|
147
|
-
all_labels = data["labels"]
|
148
|
-
batch_nums = data["batches"]
|
149
|
-
|
150
|
-
# Select data
|
151
|
-
image_list = np.arange(all_labels.shape[0]).astype(str)
|
152
|
-
if self.image_set == "train":
|
153
|
-
return (
|
154
|
-
image_list[np.nonzero(batch_nums != 5)[0]].tolist(),
|
155
|
-
all_labels[batch_nums != 5].tolist(),
|
156
|
-
{"batch_num": batch_nums[batch_nums != 5].tolist()},
|
157
|
-
)
|
158
|
-
if self.image_set == "test":
|
159
|
-
return (
|
160
|
-
image_list[np.nonzero(batch_nums == 5)[0]].tolist(),
|
161
|
-
all_labels[batch_nums == 5].tolist(),
|
162
|
-
{"batch_num": batch_nums[batch_nums == 5].tolist()},
|
163
|
-
)
|
164
|
-
return image_list.tolist(), all_labels.tolist(), {"batch_num": batch_nums.tolist()}
|
165
|
-
|
166
|
-
def _unpack_batch_files(self, file_path: Path) -> tuple[NDArray[np.uint8], NDArray[np.uint8]]:
|
167
|
-
# Load pickle data with latin1 encoding
|
168
|
-
with file_path.open("rb") as f:
|
169
|
-
buffer = np.frombuffer(f.read(), dtype=np.uint8)
|
170
|
-
# Each entry is 1 byte for label + 3072 bytes for image (3*32*32)
|
171
|
-
entry_size = 1 + 3072
|
172
|
-
num_entries = buffer.size // entry_size
|
173
|
-
# Extract labels (first byte of each entry)
|
174
|
-
labels = buffer[::entry_size]
|
175
|
-
|
176
|
-
# Extract image data and reshape to (N, 3, 32, 32)
|
177
|
-
images = np.zeros((num_entries, 3, 32, 32), dtype=np.uint8)
|
178
|
-
for i in range(num_entries):
|
179
|
-
# Skip the label byte and get image data for this entry
|
180
|
-
start_idx = i * entry_size + 1 # +1 to skip label
|
181
|
-
img_flat = buffer[start_idx : start_idx + 3072]
|
182
|
-
|
183
|
-
# The CIFAR format stores channels in blocks (all R, then all G, then all B)
|
184
|
-
# Each channel block is 1024 bytes (32x32)
|
185
|
-
red_channel = img_flat[0:1024].reshape(32, 32)
|
186
|
-
green_channel = img_flat[1024:2048].reshape(32, 32)
|
187
|
-
blue_channel = img_flat[2048:3072].reshape(32, 32)
|
188
|
-
|
189
|
-
# Stack the channels in the proper C×H×W format
|
190
|
-
images[i, 0] = red_channel # Red channel
|
191
|
-
images[i, 1] = green_channel # Green channel
|
192
|
-
images[i, 2] = blue_channel # Blue channel
|
193
|
-
return images, labels
|
194
|
-
|
195
|
-
def _read_file(self, path: str) -> NDArray[Any]:
|
196
|
-
"""
|
197
|
-
Function to grab the correct image from the loaded data.
|
198
|
-
Overwrite of the base `_read_file` because data is an all or nothing load.
|
199
|
-
"""
|
200
|
-
index = int(path)
|
201
|
-
return self._loaded_data[index]
|