dataeval 0.76.1__py3-none-any.whl → 0.82.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dataeval/__init__.py +3 -3
- dataeval/config.py +77 -0
- dataeval/detectors/__init__.py +1 -1
- dataeval/detectors/drift/__init__.py +6 -6
- dataeval/detectors/drift/{base.py → _base.py} +40 -85
- dataeval/detectors/drift/{cvm.py → _cvm.py} +21 -28
- dataeval/detectors/drift/{ks.py → _ks.py} +20 -26
- dataeval/detectors/drift/{mmd.py → _mmd.py} +31 -43
- dataeval/detectors/drift/{torch.py → _torch.py} +2 -1
- dataeval/detectors/drift/{uncertainty.py → _uncertainty.py} +24 -7
- dataeval/detectors/drift/updates.py +20 -3
- dataeval/detectors/linters/__init__.py +3 -5
- dataeval/detectors/linters/duplicates.py +13 -36
- dataeval/detectors/linters/outliers.py +23 -148
- dataeval/detectors/ood/__init__.py +1 -1
- dataeval/detectors/ood/ae.py +30 -9
- dataeval/detectors/ood/base.py +5 -4
- dataeval/detectors/ood/mixin.py +21 -7
- dataeval/detectors/ood/vae.py +73 -0
- dataeval/metadata/__init__.py +6 -0
- dataeval/metadata/_distance.py +167 -0
- dataeval/metadata/_ood.py +217 -0
- dataeval/metadata/_utils.py +44 -0
- dataeval/metrics/__init__.py +1 -1
- dataeval/metrics/bias/__init__.py +6 -4
- dataeval/metrics/bias/{balance.py → _balance.py} +15 -101
- dataeval/metrics/bias/_coverage.py +98 -0
- dataeval/metrics/bias/{diversity.py → _diversity.py} +18 -111
- dataeval/metrics/bias/{parity.py → _parity.py} +39 -77
- dataeval/metrics/estimators/__init__.py +15 -4
- dataeval/metrics/estimators/{ber.py → _ber.py} +42 -29
- dataeval/metrics/estimators/_clusterer.py +44 -0
- dataeval/metrics/estimators/{divergence.py → _divergence.py} +18 -30
- dataeval/metrics/estimators/{uap.py → _uap.py} +4 -18
- dataeval/metrics/stats/__init__.py +16 -13
- dataeval/metrics/stats/{base.py → _base.py} +82 -133
- dataeval/metrics/stats/{boxratiostats.py → _boxratiostats.py} +15 -18
- dataeval/metrics/stats/_dimensionstats.py +75 -0
- dataeval/metrics/stats/{hashstats.py → _hashstats.py} +21 -37
- dataeval/metrics/stats/_imagestats.py +94 -0
- dataeval/metrics/stats/_labelstats.py +131 -0
- dataeval/metrics/stats/{pixelstats.py → _pixelstats.py} +19 -50
- dataeval/metrics/stats/{visualstats.py → _visualstats.py} +23 -54
- dataeval/outputs/__init__.py +53 -0
- dataeval/{output.py → outputs/_base.py} +55 -25
- dataeval/outputs/_bias.py +381 -0
- dataeval/outputs/_drift.py +83 -0
- dataeval/outputs/_estimators.py +114 -0
- dataeval/outputs/_linters.py +184 -0
- dataeval/{detectors/ood/output.py → outputs/_ood.py} +22 -22
- dataeval/outputs/_stats.py +387 -0
- dataeval/outputs/_utils.py +44 -0
- dataeval/outputs/_workflows.py +364 -0
- dataeval/typing.py +234 -0
- dataeval/utils/__init__.py +2 -2
- dataeval/utils/_array.py +169 -0
- dataeval/utils/_bin.py +199 -0
- dataeval/utils/_clusterer.py +144 -0
- dataeval/utils/_fast_mst.py +189 -0
- dataeval/utils/{image.py → _image.py} +6 -4
- dataeval/utils/_method.py +14 -0
- dataeval/utils/{shared.py → _mst.py} +3 -65
- dataeval/utils/{plot.py → _plot.py} +6 -6
- dataeval/utils/data/__init__.py +26 -0
- dataeval/utils/data/_dataset.py +217 -0
- dataeval/utils/data/_embeddings.py +104 -0
- dataeval/utils/data/_images.py +68 -0
- dataeval/utils/data/_metadata.py +360 -0
- dataeval/utils/data/_selection.py +126 -0
- dataeval/utils/{dataset/split.py → data/_split.py} +12 -38
- dataeval/utils/data/_targets.py +85 -0
- dataeval/utils/data/collate.py +103 -0
- dataeval/utils/data/datasets/__init__.py +17 -0
- dataeval/utils/data/datasets/_base.py +254 -0
- dataeval/utils/data/datasets/_cifar10.py +134 -0
- dataeval/utils/data/datasets/_fileio.py +168 -0
- dataeval/utils/data/datasets/_milco.py +153 -0
- dataeval/utils/data/datasets/_mixin.py +56 -0
- dataeval/utils/data/datasets/_mnist.py +183 -0
- dataeval/utils/data/datasets/_ships.py +123 -0
- dataeval/utils/data/datasets/_types.py +52 -0
- dataeval/utils/data/datasets/_voc.py +352 -0
- dataeval/utils/data/selections/__init__.py +15 -0
- dataeval/utils/data/selections/_classfilter.py +57 -0
- dataeval/utils/data/selections/_indices.py +26 -0
- dataeval/utils/data/selections/_limit.py +26 -0
- dataeval/utils/data/selections/_reverse.py +18 -0
- dataeval/utils/data/selections/_shuffle.py +29 -0
- dataeval/utils/metadata.py +51 -376
- dataeval/utils/torch/{gmm.py → _gmm.py} +4 -2
- dataeval/utils/torch/{internal.py → _internal.py} +21 -51
- dataeval/utils/torch/models.py +43 -2
- dataeval/workflows/__init__.py +2 -1
- dataeval/workflows/sufficiency.py +11 -346
- {dataeval-0.76.1.dist-info → dataeval-0.82.0.dist-info}/METADATA +5 -2
- dataeval-0.82.0.dist-info/RECORD +104 -0
- dataeval/detectors/linters/clusterer.py +0 -512
- dataeval/detectors/linters/merged_stats.py +0 -49
- dataeval/detectors/ood/metadata_ks_compare.py +0 -129
- dataeval/detectors/ood/metadata_least_likely.py +0 -119
- dataeval/interop.py +0 -69
- dataeval/metrics/bias/coverage.py +0 -194
- dataeval/metrics/stats/datasetstats.py +0 -202
- dataeval/metrics/stats/dimensionstats.py +0 -115
- dataeval/metrics/stats/labelstats.py +0 -210
- dataeval/utils/dataset/__init__.py +0 -7
- dataeval/utils/dataset/datasets.py +0 -412
- dataeval/utils/dataset/read.py +0 -63
- dataeval-0.76.1.dist-info/RECORD +0 -67
- /dataeval/{log.py → _log.py} +0 -0
- /dataeval/utils/torch/{blocks.py → _blocks.py} +0 -0
- {dataeval-0.76.1.dist-info → dataeval-0.82.0.dist-info}/LICENSE.txt +0 -0
- {dataeval-0.76.1.dist-info → dataeval-0.82.0.dist-info}/WHEEL +0 -0
@@ -0,0 +1,103 @@
|
|
1
|
+
"""
|
2
|
+
Collate functions used with a PyTorch DataLoader to load data from MAITE compliant datasets.
|
3
|
+
"""
|
4
|
+
|
5
|
+
from __future__ import annotations
|
6
|
+
|
7
|
+
from typing import Any, Iterable, Sequence, TypeVar
|
8
|
+
|
9
|
+
import numpy as np
|
10
|
+
import torch
|
11
|
+
from numpy.typing import NDArray
|
12
|
+
|
13
|
+
from dataeval.typing import ArrayLike
|
14
|
+
from dataeval.utils._array import as_numpy
|
15
|
+
|
16
|
+
T_in = TypeVar("T_in")
|
17
|
+
T_tgt = TypeVar("T_tgt")
|
18
|
+
T_md = TypeVar("T_md")
|
19
|
+
|
20
|
+
|
21
|
+
def list_collate_fn(
|
22
|
+
batch_data_as_singles: Iterable[tuple[T_in, T_tgt, T_md]],
|
23
|
+
) -> tuple[Sequence[T_in], Sequence[T_tgt], Sequence[T_md]]:
|
24
|
+
"""
|
25
|
+
A collate function that takes a batch of individual data points in the format
|
26
|
+
(input, target, metadata) and returns three lists: the input batch, the target batch,
|
27
|
+
and the metadata batch. This is useful for loading data with torch.utils.data.DataLoader
|
28
|
+
when the target and metadata are not tensors.
|
29
|
+
|
30
|
+
Parameters
|
31
|
+
----------
|
32
|
+
batch_data_as_singles : An iterable of (input, target, metadata) tuples.
|
33
|
+
|
34
|
+
Returns
|
35
|
+
-------
|
36
|
+
tuple[Sequence[T_in], Sequence[T_tgt], Sequence[T_md]]
|
37
|
+
A tuple of three lists: the input batch, the target batch, and the metadata batch.
|
38
|
+
"""
|
39
|
+
input_batch: list[T_in] = []
|
40
|
+
target_batch: list[T_tgt] = []
|
41
|
+
metadata_batch: list[T_md] = []
|
42
|
+
for input_datum, target_datum, metadata_datum in batch_data_as_singles:
|
43
|
+
input_batch.append(input_datum)
|
44
|
+
target_batch.append(target_datum)
|
45
|
+
metadata_batch.append(metadata_datum)
|
46
|
+
|
47
|
+
return input_batch, target_batch, metadata_batch
|
48
|
+
|
49
|
+
|
50
|
+
def numpy_collate_fn(
|
51
|
+
batch_data_as_singles: Iterable[tuple[ArrayLike, T_tgt, T_md]],
|
52
|
+
) -> tuple[NDArray[Any], Sequence[T_tgt], Sequence[T_md]]:
|
53
|
+
"""
|
54
|
+
A collate function that takes a batch of individual data points in the format
|
55
|
+
(input, target, metadata) and returns the batched input as a single NumPy array with two
|
56
|
+
lists: the target batch, and the metadata batch. The inputs must be homogeneous arrays.
|
57
|
+
|
58
|
+
Parameters
|
59
|
+
----------
|
60
|
+
batch_data_as_singles : An iterable of (ArrayLike, target, metadata) tuples.
|
61
|
+
|
62
|
+
Returns
|
63
|
+
-------
|
64
|
+
tuple[NDArray[Any], Sequence[T_tgt], Sequence[T_md]]
|
65
|
+
A tuple of a NumPy array and two lists: the input batch, the target batch, and the metadata batch.
|
66
|
+
"""
|
67
|
+
input_batch: list[NDArray[Any]] = []
|
68
|
+
target_batch: list[T_tgt] = []
|
69
|
+
metadata_batch: list[T_md] = []
|
70
|
+
for input_datum, target_datum, metadata_datum in batch_data_as_singles:
|
71
|
+
input_batch.append(as_numpy(input_datum))
|
72
|
+
target_batch.append(target_datum)
|
73
|
+
metadata_batch.append(metadata_datum)
|
74
|
+
|
75
|
+
return np.stack(input_batch) if input_batch else np.array([]), target_batch, metadata_batch
|
76
|
+
|
77
|
+
|
78
|
+
def torch_collate_fn(
|
79
|
+
batch_data_as_singles: Iterable[tuple[ArrayLike, T_tgt, T_md]],
|
80
|
+
) -> tuple[torch.Tensor, Sequence[T_tgt], Sequence[T_md]]:
|
81
|
+
"""
|
82
|
+
A collate function that takes a batch of individual data points in the format
|
83
|
+
(input, target, metadata) and returns the batched input as a single torch Tensor with two
|
84
|
+
lists: the target batch, and the metadata batch. The inputs must be homogeneous arrays.
|
85
|
+
|
86
|
+
Parameters
|
87
|
+
----------
|
88
|
+
batch_data_as_singles : An iterable of (ArrayLike, target, metadata) tuples.
|
89
|
+
|
90
|
+
Returns
|
91
|
+
-------
|
92
|
+
tuple[torch.Tensor, Sequence[T_tgt], Sequence[T_md]]
|
93
|
+
A tuple of a torch Tensor and two lists: the input batch, the target batch, and the metadata batch.
|
94
|
+
"""
|
95
|
+
input_batch: list[torch.Tensor] = []
|
96
|
+
target_batch: list[T_tgt] = []
|
97
|
+
metadata_batch: list[T_md] = []
|
98
|
+
for input_datum, target_datum, metadata_datum in batch_data_as_singles:
|
99
|
+
input_batch.append(torch.as_tensor(input_datum))
|
100
|
+
target_batch.append(target_datum)
|
101
|
+
metadata_batch.append(metadata_datum)
|
102
|
+
|
103
|
+
return torch.stack(input_batch) if input_batch else torch.tensor([]), target_batch, metadata_batch
|
@@ -0,0 +1,17 @@
|
|
1
|
+
"""Provides access to common Computer Vision datasets."""
|
2
|
+
|
3
|
+
from dataeval.utils.data.datasets._cifar10 import CIFAR10
|
4
|
+
from dataeval.utils.data.datasets._milco import MILCO
|
5
|
+
from dataeval.utils.data.datasets._mnist import MNIST
|
6
|
+
from dataeval.utils.data.datasets._ships import Ships
|
7
|
+
from dataeval.utils.data.datasets._voc import VOCDetection, VOCDetectionTorch, VOCSegmentation
|
8
|
+
|
9
|
+
__all__ = [
|
10
|
+
"MNIST",
|
11
|
+
"Ships",
|
12
|
+
"CIFAR10",
|
13
|
+
"MILCO",
|
14
|
+
"VOCDetection",
|
15
|
+
"VOCDetectionTorch",
|
16
|
+
"VOCSegmentation",
|
17
|
+
]
|
@@ -0,0 +1,254 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
__all__ = []
|
4
|
+
|
5
|
+
from abc import abstractmethod
|
6
|
+
from pathlib import Path
|
7
|
+
from typing import Any, Generic, Iterator, Literal, NamedTuple, Sequence, TypeVar
|
8
|
+
|
9
|
+
from dataeval.utils.data.datasets._fileio import _ensure_exists
|
10
|
+
from dataeval.utils.data.datasets._mixin import BaseDatasetMixin
|
11
|
+
from dataeval.utils.data.datasets._types import (
|
12
|
+
AnnotatedDataset,
|
13
|
+
DatasetMetadata,
|
14
|
+
ImageClassificationDataset,
|
15
|
+
ObjectDetectionDataset,
|
16
|
+
ObjectDetectionTarget,
|
17
|
+
SegmentationDataset,
|
18
|
+
SegmentationTarget,
|
19
|
+
Transform,
|
20
|
+
)
|
21
|
+
|
22
|
+
_TArray = TypeVar("_TArray")
|
23
|
+
_TTarget = TypeVar("_TTarget")
|
24
|
+
_TRawTarget = TypeVar("_TRawTarget", list[int], list[str])
|
25
|
+
|
26
|
+
|
27
|
+
class DataLocation(NamedTuple):
|
28
|
+
url: str
|
29
|
+
filename: str
|
30
|
+
md5: bool
|
31
|
+
checksum: str
|
32
|
+
|
33
|
+
|
34
|
+
class BaseDataset(AnnotatedDataset[tuple[_TArray, _TTarget, dict[str, Any]]], Generic[_TArray, _TTarget, _TRawTarget]):
|
35
|
+
"""
|
36
|
+
Base class for internet downloaded datasets.
|
37
|
+
"""
|
38
|
+
|
39
|
+
# Each subclass should override the attributes below.
|
40
|
+
# Each resource tuple must contain:
|
41
|
+
# 'url': str, the URL to download from
|
42
|
+
# 'filename': str, the name of the file once downloaded
|
43
|
+
# 'md5': boolean, True if it's the checksum value is md5
|
44
|
+
# 'checksum': str, the associated checksum for the downloaded file
|
45
|
+
_resources: list[DataLocation]
|
46
|
+
_resource_index: int = 0
|
47
|
+
index2label: dict[int, str]
|
48
|
+
|
49
|
+
def __init__(
|
50
|
+
self,
|
51
|
+
root: str | Path,
|
52
|
+
download: bool = False,
|
53
|
+
image_set: Literal["train", "val", "test", "base"] = "train",
|
54
|
+
transforms: Transform[_TArray] | Sequence[Transform[_TArray]] | None = None,
|
55
|
+
verbose: bool = False,
|
56
|
+
) -> None:
|
57
|
+
self._root: Path = root.absolute() if isinstance(root, Path) else Path(root).absolute()
|
58
|
+
transforms = transforms if transforms is not None else []
|
59
|
+
self.transforms: Sequence[Transform[_TArray]] = transforms if isinstance(transforms, Sequence) else [transforms]
|
60
|
+
self.image_set = image_set
|
61
|
+
self._verbose = verbose
|
62
|
+
|
63
|
+
# Internal Attributes
|
64
|
+
self._download = download
|
65
|
+
self._filepaths: list[str]
|
66
|
+
self._targets: _TRawTarget
|
67
|
+
self._datum_metadata: dict[str, list[Any]]
|
68
|
+
self._resource: DataLocation = self._resources[self._resource_index]
|
69
|
+
self._label2index = {v: k for k, v in self.index2label.items()}
|
70
|
+
|
71
|
+
self.metadata: DatasetMetadata = DatasetMetadata(
|
72
|
+
id=self._unique_id(),
|
73
|
+
index2label=self.index2label,
|
74
|
+
split=self.image_set,
|
75
|
+
)
|
76
|
+
|
77
|
+
# Load the data
|
78
|
+
self.path: Path = self._get_dataset_dir()
|
79
|
+
self._filepaths, self._targets, self._datum_metadata = self._load_data()
|
80
|
+
self.size: int = len(self._filepaths)
|
81
|
+
|
82
|
+
def __str__(self) -> str:
|
83
|
+
nt = "\n "
|
84
|
+
title = f"{self.__class__.__name__} Dataset"
|
85
|
+
sep = "-" * len(title)
|
86
|
+
attrs = [f"{k.capitalize()}: {v}" for k, v in self.__dict__.items() if not k.startswith("_")]
|
87
|
+
return f"{title}\n{sep}{nt}{nt.join(attrs)}"
|
88
|
+
|
89
|
+
@property
|
90
|
+
def label2index(self) -> dict[str, int]:
|
91
|
+
return self._label2index
|
92
|
+
|
93
|
+
def __iter__(self) -> Iterator[tuple[_TArray, _TTarget, dict[str, Any]]]:
|
94
|
+
for i in range(len(self)):
|
95
|
+
yield self[i]
|
96
|
+
|
97
|
+
def _get_dataset_dir(self) -> Path:
|
98
|
+
# Create a designated folder for this dataset (named after the class)
|
99
|
+
if self._root.stem in [
|
100
|
+
self.__class__.__name__.lower(),
|
101
|
+
self.__class__.__name__.upper(),
|
102
|
+
self.__class__.__name__,
|
103
|
+
]:
|
104
|
+
dataset_dir: Path = self._root
|
105
|
+
else:
|
106
|
+
dataset_dir: Path = self._root / self.__class__.__name__.lower()
|
107
|
+
if not dataset_dir.exists():
|
108
|
+
dataset_dir.mkdir(parents=True, exist_ok=True)
|
109
|
+
return dataset_dir
|
110
|
+
|
111
|
+
def _unique_id(self) -> str:
|
112
|
+
unique_id = f"{self.__class__.__name__}_{self.image_set}"
|
113
|
+
return unique_id
|
114
|
+
|
115
|
+
def _load_data(self) -> tuple[list[str], _TRawTarget, dict[str, Any]]:
|
116
|
+
"""
|
117
|
+
Function to determine if data can be accessed or if it needs to be downloaded and/or extracted.
|
118
|
+
"""
|
119
|
+
if self._verbose:
|
120
|
+
print(f"Determining if {self._resource.filename} needs to be downloaded.")
|
121
|
+
|
122
|
+
try:
|
123
|
+
result = self._load_data_inner()
|
124
|
+
if self._verbose:
|
125
|
+
print("No download needed, loaded data successfully.")
|
126
|
+
except FileNotFoundError:
|
127
|
+
_ensure_exists(*self._resource, self.path, self._root, self._download, self._verbose)
|
128
|
+
result = self._load_data_inner()
|
129
|
+
return result
|
130
|
+
|
131
|
+
@abstractmethod
|
132
|
+
def _load_data_inner(self) -> tuple[list[str], _TRawTarget, dict[str, Any]]: ...
|
133
|
+
|
134
|
+
def _transform(self, image: _TArray) -> _TArray:
|
135
|
+
"""Function to transform the image prior to returning based on parameters passed in."""
|
136
|
+
for transform in self.transforms:
|
137
|
+
image = transform(image)
|
138
|
+
return image
|
139
|
+
|
140
|
+
def __len__(self) -> int:
|
141
|
+
return self.size
|
142
|
+
|
143
|
+
|
144
|
+
class BaseICDataset(
|
145
|
+
BaseDataset[_TArray, _TArray, list[int]],
|
146
|
+
BaseDatasetMixin[_TArray],
|
147
|
+
ImageClassificationDataset[_TArray],
|
148
|
+
):
|
149
|
+
"""
|
150
|
+
Base class for image classification datasets.
|
151
|
+
"""
|
152
|
+
|
153
|
+
def __getitem__(self, index: int) -> tuple[_TArray, _TArray, dict[str, Any]]:
|
154
|
+
"""
|
155
|
+
Args
|
156
|
+
----
|
157
|
+
index : int
|
158
|
+
Value of the desired data point
|
159
|
+
|
160
|
+
Returns
|
161
|
+
-------
|
162
|
+
tuple[TArray, TArray, dict[str, Any]]
|
163
|
+
Image, target, datum_metadata - where target is one-hot encoding of class.
|
164
|
+
"""
|
165
|
+
# Get the associated label and score
|
166
|
+
label = self._targets[index]
|
167
|
+
score = self._one_hot_encode(label)
|
168
|
+
# Get the image
|
169
|
+
img = self._read_file(self._filepaths[index])
|
170
|
+
img = self._transform(img)
|
171
|
+
|
172
|
+
img_metadata = {key: val[index] for key, val in self._datum_metadata.items()}
|
173
|
+
|
174
|
+
return img, score, img_metadata
|
175
|
+
|
176
|
+
|
177
|
+
class BaseODDataset(
|
178
|
+
BaseDataset[_TArray, ObjectDetectionTarget[_TArray], list[str]],
|
179
|
+
BaseDatasetMixin[_TArray],
|
180
|
+
ObjectDetectionDataset[_TArray],
|
181
|
+
):
|
182
|
+
"""
|
183
|
+
Base class for object detection datasets.
|
184
|
+
"""
|
185
|
+
|
186
|
+
def __getitem__(self, index: int) -> tuple[_TArray, ObjectDetectionTarget[_TArray], dict[str, Any]]:
|
187
|
+
"""
|
188
|
+
Args
|
189
|
+
----
|
190
|
+
index : int
|
191
|
+
Value of the desired data point
|
192
|
+
|
193
|
+
Returns
|
194
|
+
-------
|
195
|
+
tuple[TArray, ObjectDetectionTarget[TArray], dict[str, Any]]
|
196
|
+
Image, target, datum_metadata - target.boxes returns boxes in x0, y0, x1, y1 format
|
197
|
+
"""
|
198
|
+
# Grab the bounding boxes and labels from the annotations
|
199
|
+
boxes, labels, additional_metadata = self._read_annotations(self._targets[index])
|
200
|
+
# Get the image
|
201
|
+
img = self._read_file(self._filepaths[index])
|
202
|
+
img = self._transform(img)
|
203
|
+
|
204
|
+
target = ObjectDetectionTarget(self._as_array(boxes), self._as_array(labels), self._one_hot_encode(labels))
|
205
|
+
|
206
|
+
img_metadata = {key: val[index] for key, val in self._datum_metadata.items()}
|
207
|
+
img_metadata = img_metadata | additional_metadata
|
208
|
+
|
209
|
+
return img, target, img_metadata
|
210
|
+
|
211
|
+
@abstractmethod
|
212
|
+
def _read_annotations(self, annotation: str) -> tuple[list[list[float]], list[int], dict[str, Any]]: ...
|
213
|
+
|
214
|
+
|
215
|
+
class BaseSegDataset(
|
216
|
+
BaseDataset[_TArray, SegmentationTarget[_TArray], list[str]],
|
217
|
+
BaseDatasetMixin[_TArray],
|
218
|
+
SegmentationDataset[_TArray],
|
219
|
+
):
|
220
|
+
"""
|
221
|
+
Base class for segmentation datasets.
|
222
|
+
"""
|
223
|
+
|
224
|
+
_masks: Sequence[str]
|
225
|
+
|
226
|
+
def __getitem__(self, index: int) -> tuple[_TArray, SegmentationTarget[_TArray], dict[str, Any]]:
|
227
|
+
"""
|
228
|
+
Args
|
229
|
+
----
|
230
|
+
index : int
|
231
|
+
Value of the desired data point
|
232
|
+
|
233
|
+
Returns
|
234
|
+
-------
|
235
|
+
tuple[TArray, SegmentationTarget[TArray], dict[str, Any]]
|
236
|
+
Image, target, datum_metadata - target.mask returns the ground truth mask
|
237
|
+
"""
|
238
|
+
# Grab the labels from the annotations
|
239
|
+
_, labels, additional_metadata = self._read_annotations(self._targets[index])
|
240
|
+
# Grab the ground truth masks
|
241
|
+
mask = self._read_file(self._masks[index])
|
242
|
+
# Get the image
|
243
|
+
img = self._read_file(self._filepaths[index])
|
244
|
+
img = self._transform(img)
|
245
|
+
|
246
|
+
target = SegmentationTarget(mask, self._as_array(labels), self._one_hot_encode(labels))
|
247
|
+
|
248
|
+
img_metadata = {key: val[index] for key, val in self._datum_metadata.items()}
|
249
|
+
img_metadata = img_metadata | additional_metadata
|
250
|
+
|
251
|
+
return img, target, img_metadata
|
252
|
+
|
253
|
+
@abstractmethod
|
254
|
+
def _read_annotations(self, annotation: str) -> tuple[list[list[float]], list[int], dict[str, Any]]: ...
|
@@ -0,0 +1,134 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
__all__ = []
|
4
|
+
|
5
|
+
from pathlib import Path
|
6
|
+
from typing import Any, Literal, Sequence, TypeVar
|
7
|
+
|
8
|
+
import numpy as np
|
9
|
+
from numpy.typing import NDArray
|
10
|
+
from PIL import Image
|
11
|
+
|
12
|
+
from dataeval.utils.data.datasets._base import BaseICDataset, DataLocation
|
13
|
+
from dataeval.utils.data.datasets._mixin import BaseDatasetNumpyMixin
|
14
|
+
from dataeval.utils.data.datasets._types import Transform
|
15
|
+
|
16
|
+
CIFARClassStringMap = Literal["airplane", "automobile", "bird", "cat", "deer", "dog", "frog", "horse", "ship", "truck"]
|
17
|
+
TCIFARClassMap = TypeVar("TCIFARClassMap", CIFARClassStringMap, int, list[CIFARClassStringMap], list[int])
|
18
|
+
|
19
|
+
|
20
|
+
class CIFAR10(BaseICDataset[NDArray[Any]], BaseDatasetNumpyMixin):
|
21
|
+
"""
|
22
|
+
`CIFAR10 <https://www.cs.toronto.edu/~kriz/cifar.html>`_ Dataset as NumPy arrays.
|
23
|
+
|
24
|
+
Parameters
|
25
|
+
----------
|
26
|
+
root : str or pathlib.Path
|
27
|
+
Root directory of dataset where the ``mnist`` folder exists.
|
28
|
+
download : bool, default False
|
29
|
+
If True, downloads the dataset from the internet and puts it in root directory.
|
30
|
+
Class checks to see if data is already downloaded to ensure it does not create a duplicate download.
|
31
|
+
image_set : "train", "test" or "base", default "train"
|
32
|
+
If "base", returns all of the data to allow the user to create their own splits.
|
33
|
+
transforms : Transform | Sequence[Transform] | None, default None
|
34
|
+
Transform(s) to apply to the data.
|
35
|
+
verbose : bool, default False
|
36
|
+
If True, outputs print statements.
|
37
|
+
|
38
|
+
Attributes
|
39
|
+
----------
|
40
|
+
index2label : dict
|
41
|
+
Dictionary which translates from class integers to the associated class strings.
|
42
|
+
label2index : dict
|
43
|
+
Dictionary which translates from class strings to the associated class integers.
|
44
|
+
path : Path
|
45
|
+
Location of the folder containing the data.
|
46
|
+
metadata : dict
|
47
|
+
Dictionary containing Dataset metadata, such as `id` which returns the dataset class name.
|
48
|
+
"""
|
49
|
+
|
50
|
+
_resources = [
|
51
|
+
DataLocation(
|
52
|
+
url="https://www.cs.toronto.edu/~kriz/cifar-10-binary.tar.gz",
|
53
|
+
filename="cifar-10-binary.tar.gz",
|
54
|
+
md5=True,
|
55
|
+
checksum="c32a1d4ab5d03f1284b67883e8d87530",
|
56
|
+
),
|
57
|
+
]
|
58
|
+
|
59
|
+
index2label: dict[int, str] = {
|
60
|
+
0: "airplane",
|
61
|
+
1: "automobile",
|
62
|
+
2: "bird",
|
63
|
+
3: "cat",
|
64
|
+
4: "deer",
|
65
|
+
5: "dog",
|
66
|
+
6: "frog",
|
67
|
+
7: "horse",
|
68
|
+
8: "ship",
|
69
|
+
9: "truck",
|
70
|
+
}
|
71
|
+
|
72
|
+
def __init__(
|
73
|
+
self,
|
74
|
+
root: str | Path,
|
75
|
+
download: bool = False,
|
76
|
+
image_set: Literal["train", "test", "base"] = "train",
|
77
|
+
transforms: Transform[NDArray[Any]] | Sequence[Transform[NDArray[Any]]] | None = None,
|
78
|
+
verbose: bool = False,
|
79
|
+
) -> None:
|
80
|
+
super().__init__(
|
81
|
+
root,
|
82
|
+
download,
|
83
|
+
image_set,
|
84
|
+
transforms,
|
85
|
+
verbose,
|
86
|
+
)
|
87
|
+
|
88
|
+
def _load_data_inner(self) -> tuple[list[str], list[int], dict[str, Any]]:
|
89
|
+
"""Function to load in the file paths for the data and labels and retrieve metadata"""
|
90
|
+
file_meta = {"batch_num": []}
|
91
|
+
raw_data = []
|
92
|
+
labels = []
|
93
|
+
data_folder = self.path / "cifar-10-batches-bin"
|
94
|
+
save_folder = self.path / "images"
|
95
|
+
image_sets: dict[str, list[str]] = {"base": [], "train": [], "test": []}
|
96
|
+
|
97
|
+
# Process each batch file, skipping .meta and .html files
|
98
|
+
for entry in data_folder.iterdir():
|
99
|
+
if entry.suffix == ".bin":
|
100
|
+
batch_data, batch_labels = self._unpack_batch_files(entry)
|
101
|
+
raw_data.append(batch_data)
|
102
|
+
group = "train" if "test" not in entry.stem else "test"
|
103
|
+
name_split = entry.stem.split("_")
|
104
|
+
batch_num = int(name_split[-1]) - 1 if group == "train" else 5
|
105
|
+
file_names = [
|
106
|
+
str(save_folder / f"{i + 10000 * batch_num:05d}_{self.index2label[label]}.png")
|
107
|
+
for i, label in enumerate(batch_labels)
|
108
|
+
]
|
109
|
+
image_sets["base"].extend(file_names)
|
110
|
+
image_sets[group].extend(file_names)
|
111
|
+
|
112
|
+
if self.image_set in (group, "base"):
|
113
|
+
labels.extend(batch_labels)
|
114
|
+
file_meta["batch_num"].extend([batch_num] * len(labels))
|
115
|
+
|
116
|
+
# Stack and reshape images
|
117
|
+
images = np.vstack(raw_data).reshape(-1, 3, 32, 32)
|
118
|
+
|
119
|
+
# Save the raw data into images if not already there
|
120
|
+
if not save_folder.exists():
|
121
|
+
save_folder.mkdir(exist_ok=True)
|
122
|
+
for i, file in enumerate(image_sets["base"]):
|
123
|
+
Image.fromarray(images[i].transpose(1, 2, 0).astype(np.uint8)).save(file)
|
124
|
+
|
125
|
+
return image_sets[self.image_set], labels, file_meta
|
126
|
+
|
127
|
+
def _unpack_batch_files(self, file_path: Path) -> tuple[NDArray[Any], list[int]]:
|
128
|
+
# Load pickle data with latin1 encoding
|
129
|
+
with file_path.open("rb") as f:
|
130
|
+
buffer = np.frombuffer(f.read(), "B")
|
131
|
+
labels = buffer[::3073]
|
132
|
+
pixels = np.delete(buffer, np.arange(0, buffer.size, 3073))
|
133
|
+
images = pixels.reshape(-1, 3072)
|
134
|
+
return images, labels.tolist()
|
@@ -0,0 +1,168 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
__all__ = []
|
4
|
+
|
5
|
+
import hashlib
|
6
|
+
import shutil
|
7
|
+
import tarfile
|
8
|
+
import zipfile
|
9
|
+
from pathlib import Path
|
10
|
+
|
11
|
+
import requests
|
12
|
+
from tqdm import tqdm
|
13
|
+
|
14
|
+
ARCHIVE_ENDINGS = [".zip", ".tar", ".tgz"]
|
15
|
+
COMPRESS_ENDINGS = [".gz", ".bz2"]
|
16
|
+
|
17
|
+
|
18
|
+
def _validate_file(fpath, file_md5, md5: bool = False, chunk_size=65535) -> bool:
|
19
|
+
hasher = hashlib.md5(usedforsecurity=False) if md5 else hashlib.sha256()
|
20
|
+
with open(fpath, "rb") as fpath_file:
|
21
|
+
while chunk := fpath_file.read(chunk_size):
|
22
|
+
hasher.update(chunk)
|
23
|
+
return hasher.hexdigest() == file_md5
|
24
|
+
|
25
|
+
|
26
|
+
def _download_dataset(url: str, file_path: Path, timeout: int = 60) -> None:
|
27
|
+
"""Download a single resource from its URL to the `data_folder`."""
|
28
|
+
error_msg = "URL fetch failure on {}: {} -- {}"
|
29
|
+
try:
|
30
|
+
response = requests.get(url, stream=True, timeout=timeout)
|
31
|
+
response.raise_for_status()
|
32
|
+
except requests.exceptions.HTTPError as e:
|
33
|
+
raise RuntimeError(f"{error_msg.format(url, e.response.status_code, e.response.reason)}") from e
|
34
|
+
except requests.exceptions.RequestException as e:
|
35
|
+
raise ValueError(f"{error_msg.format(url, 'Unknown error', str(e))}") from e
|
36
|
+
|
37
|
+
total_size = int(response.headers.get("content-length", 0))
|
38
|
+
block_size = 8192 # 8 KB
|
39
|
+
progress_bar = tqdm(total=total_size, unit="iB", unit_scale=True)
|
40
|
+
|
41
|
+
with open(file_path, "wb") as f:
|
42
|
+
for chunk in response.iter_content(block_size):
|
43
|
+
f.write(chunk)
|
44
|
+
progress_bar.update(len(chunk))
|
45
|
+
progress_bar.close()
|
46
|
+
|
47
|
+
|
48
|
+
def _extract_zip_archive(file_path: Path, extract_to: Path) -> None:
|
49
|
+
"""Extracts the zip file to the given directory."""
|
50
|
+
try:
|
51
|
+
with zipfile.ZipFile(file_path, "r") as zip_ref:
|
52
|
+
zip_ref.extractall(extract_to)
|
53
|
+
file_path.unlink()
|
54
|
+
except zipfile.BadZipFile:
|
55
|
+
raise FileNotFoundError(f"{file_path.name} is not a valid zip file, skipping extraction.")
|
56
|
+
|
57
|
+
|
58
|
+
def _extract_tar_archive(file_path: Path, extract_to: Path) -> None:
|
59
|
+
"""Extracts a tar file (or compressed tar) to the specified directory."""
|
60
|
+
try:
|
61
|
+
with tarfile.open(file_path, "r:*") as tar_ref:
|
62
|
+
tar_ref.extractall(extract_to)
|
63
|
+
file_path.unlink()
|
64
|
+
except tarfile.TarError:
|
65
|
+
raise FileNotFoundError(f"{file_path.name} is not a valid tar file, skipping extraction.")
|
66
|
+
|
67
|
+
|
68
|
+
def _flatten_extraction(base_directory: Path, verbose: bool = False) -> None:
|
69
|
+
"""
|
70
|
+
If the extracted folder contains only directories (and no files),
|
71
|
+
move all its subfolders to the dataset_dir and remove the now-empty folder.
|
72
|
+
"""
|
73
|
+
for child in base_directory.iterdir():
|
74
|
+
if child.is_dir():
|
75
|
+
inner_list = list(child.iterdir())
|
76
|
+
if all(subchild.is_dir() for subchild in inner_list):
|
77
|
+
for subchild in child.iterdir():
|
78
|
+
if verbose:
|
79
|
+
print(f"Moving {subchild.stem} to {base_directory}")
|
80
|
+
shutil.move(subchild, base_directory)
|
81
|
+
|
82
|
+
if verbose:
|
83
|
+
print(f"Removing empty folder {child.stem}")
|
84
|
+
child.rmdir()
|
85
|
+
|
86
|
+
# Checking for additional placeholder folders
|
87
|
+
if len(inner_list) == 1:
|
88
|
+
_flatten_extraction(base_directory, verbose)
|
89
|
+
|
90
|
+
|
91
|
+
def _archive_extraction(file_ext, file_path, directory, compression: bool = False, verbose: bool = False):
|
92
|
+
"""
|
93
|
+
Single function to extract and then flatten if necessary.
|
94
|
+
Recursively extracts nested zip files as well.
|
95
|
+
Extracts and flattens all folders to the base directory.
|
96
|
+
"""
|
97
|
+
if file_ext != ".zip" or compression:
|
98
|
+
_extract_tar_archive(file_path, directory)
|
99
|
+
else:
|
100
|
+
_extract_zip_archive(file_path, directory)
|
101
|
+
# Look for nested zip files in the extraction directory and extract them recursively.
|
102
|
+
# Does NOT extract in place - extracts everything to directory
|
103
|
+
for child in directory.iterdir():
|
104
|
+
if child.suffix == ".zip":
|
105
|
+
if verbose:
|
106
|
+
print(f"Extracting nested zip: {child} to {directory}")
|
107
|
+
_extract_zip_archive(child, directory)
|
108
|
+
|
109
|
+
# Determine if there are nested folders and remove them
|
110
|
+
# Helps ensure there that data is at most one folder below main directory
|
111
|
+
_flatten_extraction(directory, verbose)
|
112
|
+
|
113
|
+
|
114
|
+
def _ensure_exists(
|
115
|
+
url: str,
|
116
|
+
filename: str,
|
117
|
+
md5: bool,
|
118
|
+
checksum: str,
|
119
|
+
directory: Path,
|
120
|
+
root: Path,
|
121
|
+
download: bool = True,
|
122
|
+
verbose: bool = False,
|
123
|
+
) -> None:
|
124
|
+
"""
|
125
|
+
For each resource, download it if it doesn't exist in the dataset_dir.
|
126
|
+
If the resource is a zip file, extract it (including recursively extracting nested zips).
|
127
|
+
"""
|
128
|
+
file_path = directory / str(filename)
|
129
|
+
alternate_path = root / str(filename)
|
130
|
+
_, file_ext = file_path.stem, file_path.suffix
|
131
|
+
compression = False
|
132
|
+
if file_ext in COMPRESS_ENDINGS:
|
133
|
+
file_ext = file_path.suffixes[0]
|
134
|
+
compression = True
|
135
|
+
|
136
|
+
check_path = alternate_path if alternate_path.exists() and not file_path.exists() else file_path
|
137
|
+
|
138
|
+
# Download file if it doesn't exist.
|
139
|
+
if not check_path.exists() and download:
|
140
|
+
if verbose:
|
141
|
+
print(f"Downloading {filename} from {url}")
|
142
|
+
_download_dataset(url, check_path)
|
143
|
+
|
144
|
+
if not _validate_file(check_path, checksum, md5):
|
145
|
+
raise Exception("File checksum mismatch. Remove current file and retry download.")
|
146
|
+
|
147
|
+
# If the file is a zip, tar or tgz extract it into the designated folder.
|
148
|
+
if file_ext in ARCHIVE_ENDINGS:
|
149
|
+
if verbose:
|
150
|
+
print(f"Extracting {filename}...")
|
151
|
+
_archive_extraction(file_ext, check_path, directory, compression, verbose)
|
152
|
+
|
153
|
+
elif not check_path.exists() and not download:
|
154
|
+
raise FileNotFoundError(
|
155
|
+
"Data could not be loaded with the provided root directory, ",
|
156
|
+
f"the file path to the file {filename} does not exist, ",
|
157
|
+
"and the download parameter is set to False.",
|
158
|
+
)
|
159
|
+
else:
|
160
|
+
if not _validate_file(check_path, checksum, md5):
|
161
|
+
raise Exception("File checksum mismatch. Remove current file and retry download.")
|
162
|
+
if verbose:
|
163
|
+
print(f"{filename} already exists, skipping download.")
|
164
|
+
|
165
|
+
if file_ext in ARCHIVE_ENDINGS:
|
166
|
+
if verbose:
|
167
|
+
print(f"Extracting {filename}...")
|
168
|
+
_archive_extraction(file_ext, check_path, directory, compression, verbose)
|