dataeval 0.76.1__py3-none-any.whl → 0.82.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dataeval/__init__.py +3 -3
- dataeval/config.py +77 -0
- dataeval/detectors/__init__.py +1 -1
- dataeval/detectors/drift/__init__.py +6 -6
- dataeval/detectors/drift/{base.py → _base.py} +40 -85
- dataeval/detectors/drift/{cvm.py → _cvm.py} +21 -28
- dataeval/detectors/drift/{ks.py → _ks.py} +20 -26
- dataeval/detectors/drift/{mmd.py → _mmd.py} +31 -43
- dataeval/detectors/drift/{torch.py → _torch.py} +2 -1
- dataeval/detectors/drift/{uncertainty.py → _uncertainty.py} +24 -7
- dataeval/detectors/drift/updates.py +20 -3
- dataeval/detectors/linters/__init__.py +3 -5
- dataeval/detectors/linters/duplicates.py +13 -36
- dataeval/detectors/linters/outliers.py +23 -148
- dataeval/detectors/ood/__init__.py +1 -1
- dataeval/detectors/ood/ae.py +30 -9
- dataeval/detectors/ood/base.py +5 -4
- dataeval/detectors/ood/mixin.py +21 -7
- dataeval/detectors/ood/vae.py +73 -0
- dataeval/metadata/__init__.py +6 -0
- dataeval/metadata/_distance.py +167 -0
- dataeval/metadata/_ood.py +217 -0
- dataeval/metadata/_utils.py +44 -0
- dataeval/metrics/__init__.py +1 -1
- dataeval/metrics/bias/__init__.py +6 -4
- dataeval/metrics/bias/{balance.py → _balance.py} +15 -101
- dataeval/metrics/bias/_coverage.py +98 -0
- dataeval/metrics/bias/{diversity.py → _diversity.py} +18 -111
- dataeval/metrics/bias/{parity.py → _parity.py} +39 -77
- dataeval/metrics/estimators/__init__.py +15 -4
- dataeval/metrics/estimators/{ber.py → _ber.py} +42 -29
- dataeval/metrics/estimators/_clusterer.py +44 -0
- dataeval/metrics/estimators/{divergence.py → _divergence.py} +18 -30
- dataeval/metrics/estimators/{uap.py → _uap.py} +4 -18
- dataeval/metrics/stats/__init__.py +16 -13
- dataeval/metrics/stats/{base.py → _base.py} +82 -133
- dataeval/metrics/stats/{boxratiostats.py → _boxratiostats.py} +15 -18
- dataeval/metrics/stats/_dimensionstats.py +75 -0
- dataeval/metrics/stats/{hashstats.py → _hashstats.py} +21 -37
- dataeval/metrics/stats/_imagestats.py +94 -0
- dataeval/metrics/stats/_labelstats.py +131 -0
- dataeval/metrics/stats/{pixelstats.py → _pixelstats.py} +19 -50
- dataeval/metrics/stats/{visualstats.py → _visualstats.py} +23 -54
- dataeval/outputs/__init__.py +53 -0
- dataeval/{output.py → outputs/_base.py} +55 -25
- dataeval/outputs/_bias.py +381 -0
- dataeval/outputs/_drift.py +83 -0
- dataeval/outputs/_estimators.py +114 -0
- dataeval/outputs/_linters.py +184 -0
- dataeval/{detectors/ood/output.py → outputs/_ood.py} +22 -22
- dataeval/outputs/_stats.py +387 -0
- dataeval/outputs/_utils.py +44 -0
- dataeval/outputs/_workflows.py +364 -0
- dataeval/typing.py +234 -0
- dataeval/utils/__init__.py +2 -2
- dataeval/utils/_array.py +169 -0
- dataeval/utils/_bin.py +199 -0
- dataeval/utils/_clusterer.py +144 -0
- dataeval/utils/_fast_mst.py +189 -0
- dataeval/utils/{image.py → _image.py} +6 -4
- dataeval/utils/_method.py +14 -0
- dataeval/utils/{shared.py → _mst.py} +3 -65
- dataeval/utils/{plot.py → _plot.py} +6 -6
- dataeval/utils/data/__init__.py +26 -0
- dataeval/utils/data/_dataset.py +217 -0
- dataeval/utils/data/_embeddings.py +104 -0
- dataeval/utils/data/_images.py +68 -0
- dataeval/utils/data/_metadata.py +360 -0
- dataeval/utils/data/_selection.py +126 -0
- dataeval/utils/{dataset/split.py → data/_split.py} +12 -38
- dataeval/utils/data/_targets.py +85 -0
- dataeval/utils/data/collate.py +103 -0
- dataeval/utils/data/datasets/__init__.py +17 -0
- dataeval/utils/data/datasets/_base.py +254 -0
- dataeval/utils/data/datasets/_cifar10.py +134 -0
- dataeval/utils/data/datasets/_fileio.py +168 -0
- dataeval/utils/data/datasets/_milco.py +153 -0
- dataeval/utils/data/datasets/_mixin.py +56 -0
- dataeval/utils/data/datasets/_mnist.py +183 -0
- dataeval/utils/data/datasets/_ships.py +123 -0
- dataeval/utils/data/datasets/_types.py +52 -0
- dataeval/utils/data/datasets/_voc.py +352 -0
- dataeval/utils/data/selections/__init__.py +15 -0
- dataeval/utils/data/selections/_classfilter.py +57 -0
- dataeval/utils/data/selections/_indices.py +26 -0
- dataeval/utils/data/selections/_limit.py +26 -0
- dataeval/utils/data/selections/_reverse.py +18 -0
- dataeval/utils/data/selections/_shuffle.py +29 -0
- dataeval/utils/metadata.py +51 -376
- dataeval/utils/torch/{gmm.py → _gmm.py} +4 -2
- dataeval/utils/torch/{internal.py → _internal.py} +21 -51
- dataeval/utils/torch/models.py +43 -2
- dataeval/workflows/__init__.py +2 -1
- dataeval/workflows/sufficiency.py +11 -346
- {dataeval-0.76.1.dist-info → dataeval-0.82.0.dist-info}/METADATA +5 -2
- dataeval-0.82.0.dist-info/RECORD +104 -0
- dataeval/detectors/linters/clusterer.py +0 -512
- dataeval/detectors/linters/merged_stats.py +0 -49
- dataeval/detectors/ood/metadata_ks_compare.py +0 -129
- dataeval/detectors/ood/metadata_least_likely.py +0 -119
- dataeval/interop.py +0 -69
- dataeval/metrics/bias/coverage.py +0 -194
- dataeval/metrics/stats/datasetstats.py +0 -202
- dataeval/metrics/stats/dimensionstats.py +0 -115
- dataeval/metrics/stats/labelstats.py +0 -210
- dataeval/utils/dataset/__init__.py +0 -7
- dataeval/utils/dataset/datasets.py +0 -412
- dataeval/utils/dataset/read.py +0 -63
- dataeval-0.76.1.dist-info/RECORD +0 -67
- /dataeval/{log.py → _log.py} +0 -0
- /dataeval/utils/torch/{blocks.py → _blocks.py} +0 -0
- {dataeval-0.76.1.dist-info → dataeval-0.82.0.dist-info}/LICENSE.txt +0 -0
- {dataeval-0.76.1.dist-info → dataeval-0.82.0.dist-info}/WHEEL +0 -0
@@ -0,0 +1,352 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
__all__ = []
|
4
|
+
|
5
|
+
from pathlib import Path
|
6
|
+
from typing import Any, Literal, Sequence, TypeVar
|
7
|
+
|
8
|
+
import torch
|
9
|
+
from defusedxml.ElementTree import parse
|
10
|
+
from numpy.typing import NDArray
|
11
|
+
|
12
|
+
from dataeval.utils.data.datasets._base import (
|
13
|
+
BaseDataset,
|
14
|
+
BaseODDataset,
|
15
|
+
BaseSegDataset,
|
16
|
+
DataLocation,
|
17
|
+
)
|
18
|
+
from dataeval.utils.data.datasets._mixin import BaseDatasetNumpyMixin, BaseDatasetTorchMixin
|
19
|
+
from dataeval.utils.data.datasets._types import ObjectDetectionTarget, SegmentationTarget, Transform
|
20
|
+
|
21
|
+
_TArray = TypeVar("_TArray")
|
22
|
+
_TTarget = TypeVar("_TTarget")
|
23
|
+
|
24
|
+
VOCClassStringMap = Literal[
|
25
|
+
"aeroplane",
|
26
|
+
"bicycle",
|
27
|
+
"bird",
|
28
|
+
"boat",
|
29
|
+
"bottle",
|
30
|
+
"bus",
|
31
|
+
"car",
|
32
|
+
"cat",
|
33
|
+
"chair",
|
34
|
+
"cow",
|
35
|
+
"diningtable",
|
36
|
+
"dog",
|
37
|
+
"horse",
|
38
|
+
"motorbike",
|
39
|
+
"person",
|
40
|
+
"pottedplant",
|
41
|
+
"sheep",
|
42
|
+
"sofa",
|
43
|
+
"train",
|
44
|
+
"tvmonitor",
|
45
|
+
]
|
46
|
+
TVOCClassMap = TypeVar("TVOCClassMap", VOCClassStringMap, int, list[VOCClassStringMap], list[int])
|
47
|
+
|
48
|
+
|
49
|
+
class BaseVOCDataset(BaseDataset[_TArray, _TTarget, list[str]]):
|
50
|
+
_resources = [
|
51
|
+
DataLocation(
|
52
|
+
url="http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar",
|
53
|
+
filename="VOCtrainval_11-May-2012.tar",
|
54
|
+
md5=True,
|
55
|
+
checksum="6cd6e144f989b92b3379bac3b3de84fd",
|
56
|
+
),
|
57
|
+
DataLocation(
|
58
|
+
url="http://host.robots.ox.ac.uk/pascal/VOC/voc2011/VOCtrainval_25-May-2011.tar",
|
59
|
+
filename="VOCtrainval_25-May-2011.tar",
|
60
|
+
md5=True,
|
61
|
+
checksum="6c3384ef61512963050cb5d687e5bf1e",
|
62
|
+
),
|
63
|
+
DataLocation(
|
64
|
+
url="http://host.robots.ox.ac.uk/pascal/VOC/voc2010/VOCtrainval_03-May-2010.tar",
|
65
|
+
filename="VOCtrainval_03-May-2010.tar",
|
66
|
+
md5=True,
|
67
|
+
checksum="da459979d0c395079b5c75ee67908abb",
|
68
|
+
),
|
69
|
+
DataLocation(
|
70
|
+
url="http://host.robots.ox.ac.uk/pascal/VOC/voc2009/VOCtrainval_11-May-2009.tar",
|
71
|
+
filename="VOCtrainval_11-May-2009.tar",
|
72
|
+
md5=True,
|
73
|
+
checksum="da459979d0c395079b5c75ee67908abb",
|
74
|
+
),
|
75
|
+
DataLocation(
|
76
|
+
url="http://host.robots.ox.ac.uk/pascal/VOC/voc2008/VOCtrainval_14-Jul-2008.tar",
|
77
|
+
filename="VOCtrainval_14-Jul-2008.tar",
|
78
|
+
md5=True,
|
79
|
+
checksum="2629fa636546599198acfcfbfcf1904a",
|
80
|
+
),
|
81
|
+
DataLocation(
|
82
|
+
url="http://host.robots.ox.ac.uk/pascal/VOC/voc2007/VOCtrainval_06-Nov-2007.tar",
|
83
|
+
filename="VOCtrainval_06-Nov-2007.tar",
|
84
|
+
md5=True,
|
85
|
+
checksum="c52e279531787c972589f7e41ab4ae64",
|
86
|
+
),
|
87
|
+
DataLocation(
|
88
|
+
url="http://host.robots.ox.ac.uk/pascal/VOC/voc2007/VOCtest_06-Nov-2007.tar",
|
89
|
+
filename="VOCtest_06-Nov-2007.tar",
|
90
|
+
md5=True,
|
91
|
+
checksum="b6e924de25625d8de591ea690078ad9f",
|
92
|
+
),
|
93
|
+
]
|
94
|
+
|
95
|
+
index2label: dict[int, str] = {
|
96
|
+
0: "aeroplane",
|
97
|
+
1: "bicycle",
|
98
|
+
2: "bird",
|
99
|
+
3: "boat",
|
100
|
+
4: "bottle",
|
101
|
+
5: "bus",
|
102
|
+
6: "car",
|
103
|
+
7: "cat",
|
104
|
+
8: "chair",
|
105
|
+
9: "cow",
|
106
|
+
10: "diningtable",
|
107
|
+
11: "dog",
|
108
|
+
12: "horse",
|
109
|
+
13: "motorbike",
|
110
|
+
14: "person",
|
111
|
+
15: "pottedplant",
|
112
|
+
16: "sheep",
|
113
|
+
17: "sofa",
|
114
|
+
18: "train",
|
115
|
+
19: "tvmonitor",
|
116
|
+
}
|
117
|
+
|
118
|
+
def __init__(
|
119
|
+
self,
|
120
|
+
root: str | Path,
|
121
|
+
year: Literal["2007", "2008", "2009", "2010", "2011", "2012"] = "2012",
|
122
|
+
image_set: Literal["train", "val", "test", "base"] = "train",
|
123
|
+
download: bool = False,
|
124
|
+
transforms: Transform[_TArray] | Sequence[Transform[_TArray]] | None = None,
|
125
|
+
verbose: bool = False,
|
126
|
+
) -> None:
|
127
|
+
self.year = year
|
128
|
+
self._resource_index = self._get_year_image_set_index(year, image_set)
|
129
|
+
super().__init__(
|
130
|
+
root,
|
131
|
+
download,
|
132
|
+
image_set,
|
133
|
+
transforms,
|
134
|
+
verbose,
|
135
|
+
)
|
136
|
+
|
137
|
+
def _get_dataset_dir(self) -> Path:
|
138
|
+
"""Function to reassign the dataset directory for common use with the VOC detection and segmentation classes"""
|
139
|
+
if self._root.stem == f"VOC{self.year}":
|
140
|
+
dataset_dir: Path = self._root
|
141
|
+
else:
|
142
|
+
dataset_dir: Path = self._root / f"VOC{self.year}"
|
143
|
+
if not dataset_dir.exists():
|
144
|
+
dataset_dir.mkdir(parents=True, exist_ok=True)
|
145
|
+
return dataset_dir
|
146
|
+
|
147
|
+
def _get_year_image_set_index(self, year, image_set) -> int:
|
148
|
+
"""Function to ensure that the correct resource file is accessed"""
|
149
|
+
if year == "2007" and image_set == "test":
|
150
|
+
return -1
|
151
|
+
elif year != "2007" and image_set == "test":
|
152
|
+
raise ValueError(
|
153
|
+
f"The only test set available is for the year 2007, not {year}. "
|
154
|
+
"Either select the year 2007 or use a different image_set."
|
155
|
+
)
|
156
|
+
else:
|
157
|
+
return 2012 - int(year)
|
158
|
+
|
159
|
+
def _get_image_sets(self) -> dict[str, list[str]]:
|
160
|
+
"""Function to create the list of images in each image set"""
|
161
|
+
image_folder = self.path / "JPEGImages"
|
162
|
+
image_set_list = ["train", "val", "trainval"] if self.image_set != "test" else ["test"]
|
163
|
+
image_sets = {}
|
164
|
+
for image_set in image_set_list:
|
165
|
+
text_file = self.path / "ImageSets" / "Main" / (image_set + ".txt")
|
166
|
+
selected_images: list[str] = []
|
167
|
+
with open(text_file) as f:
|
168
|
+
for line in f.readlines():
|
169
|
+
out = line.strip()
|
170
|
+
selected_images.append(str(image_folder / (out + ".jpg")))
|
171
|
+
|
172
|
+
name = "base" if image_set == "trainval" else image_set
|
173
|
+
image_sets[name] = selected_images
|
174
|
+
return image_sets
|
175
|
+
|
176
|
+
def _load_data_inner(self) -> tuple[list[str], list[str], dict[str, Any]]:
|
177
|
+
"""Function to load in the file paths for the data, annotations and segmentation masks"""
|
178
|
+
file_meta = {"year": [], "image_id": [], "mask_path": []}
|
179
|
+
ann_folder = self.path / "Annotations"
|
180
|
+
seg_folder = self.path / "SegmentationClass"
|
181
|
+
|
182
|
+
# Load in the image sets
|
183
|
+
image_sets = self._get_image_sets()
|
184
|
+
|
185
|
+
# Get the data, annotations and metadata
|
186
|
+
annotations = []
|
187
|
+
data = image_sets[self.image_set]
|
188
|
+
for entry in data:
|
189
|
+
file_name = Path(entry).name
|
190
|
+
file_stem = Path(entry).stem
|
191
|
+
# Remove file extension and split by "_"
|
192
|
+
parts = file_stem.split("_")
|
193
|
+
file_meta["year"].append(parts[0])
|
194
|
+
file_meta["image_id"].append(parts[1])
|
195
|
+
file_meta["mask_path"].append(str(seg_folder / file_name))
|
196
|
+
annotations.append(str(ann_folder / file_stem) + ".xml")
|
197
|
+
|
198
|
+
return data, annotations, file_meta
|
199
|
+
|
200
|
+
def _read_annotations(self, annotation: str) -> tuple[list[list[float]], list[int], dict[str, Any]]:
|
201
|
+
boxes: list[list[float]] = []
|
202
|
+
label_str = []
|
203
|
+
root = parse(annotation).getroot()
|
204
|
+
num_objects = len(root.findall("object"))
|
205
|
+
additional_meta: dict[str, Any] = {
|
206
|
+
"folder": [root.findtext("folder", default="") for _ in range(num_objects)],
|
207
|
+
"filename": [root.findtext("filename", default="") for _ in range(num_objects)],
|
208
|
+
"database": [root.findtext("source/database", default="") for _ in range(num_objects)],
|
209
|
+
"annotation_source": [root.findtext("source/annotation", default="") for _ in range(num_objects)],
|
210
|
+
"image_source": [root.findtext("source/image", default="") for _ in range(num_objects)],
|
211
|
+
"image_width": [int(root.findtext("size/width", default="-1")) for _ in range(num_objects)],
|
212
|
+
"image_height": [int(root.findtext("size/height", default="-1")) for _ in range(num_objects)],
|
213
|
+
"image_depth": [int(root.findtext("size/depth", default="-1")) for _ in range(num_objects)],
|
214
|
+
"segmented": [int(root.findtext("segmented", default="-1")) for _ in range(num_objects)],
|
215
|
+
"pose": [],
|
216
|
+
"truncated": [],
|
217
|
+
"difficult": [],
|
218
|
+
}
|
219
|
+
for obj in root.findall("object"):
|
220
|
+
label_str.append(obj.findtext("name", default=""))
|
221
|
+
additional_meta["pose"].append(obj.findtext("pose", default=""))
|
222
|
+
additional_meta["truncated"].append(int(obj.findtext("truncated", default="-1")))
|
223
|
+
additional_meta["difficult"].append(int(obj.findtext("difficult", default="-1")))
|
224
|
+
boxes.append(
|
225
|
+
[
|
226
|
+
float(obj.findtext("bndbox/xmin", default="0")),
|
227
|
+
float(obj.findtext("bndbox/ymin", default="0")),
|
228
|
+
float(obj.findtext("bndbox/xmax", default="0")),
|
229
|
+
float(obj.findtext("bndbox/ymax", default="0")),
|
230
|
+
]
|
231
|
+
)
|
232
|
+
labels = [self.label2index[label] for label in label_str]
|
233
|
+
return boxes, labels, additional_meta
|
234
|
+
|
235
|
+
|
236
|
+
class VOCDetection(
|
237
|
+
BaseVOCDataset[NDArray[Any], ObjectDetectionTarget[NDArray[Any]]],
|
238
|
+
BaseODDataset[NDArray[Any]],
|
239
|
+
BaseDatasetNumpyMixin,
|
240
|
+
):
|
241
|
+
"""
|
242
|
+
`Pascal VOC <http://host.robots.ox.ac.uk/pascal/VOC/>`_ Detection Dataset.
|
243
|
+
|
244
|
+
Parameters
|
245
|
+
----------
|
246
|
+
root : str or pathlib.Path
|
247
|
+
Root directory of dataset where the ``vocdataset`` folder exists.
|
248
|
+
download : bool, default False
|
249
|
+
If True, downloads the dataset from the internet and puts it in root directory.
|
250
|
+
Class checks to see if data is already downloaded to ensure it does not create a duplicate download.
|
251
|
+
image_set : "train", "val", "test", or "base", default "train"
|
252
|
+
If "test", then dataset year must be "2007".
|
253
|
+
If "base", then the combined dataset of "train" and "val" is returned.
|
254
|
+
year : "2007", "2008", "2009", "2010", "2011" or "2012", default "2012"
|
255
|
+
The dataset year.
|
256
|
+
transforms : Transform | Sequence[Transform] | None, default None
|
257
|
+
Transform(s) to apply to the data.
|
258
|
+
verbose : bool, default False
|
259
|
+
If True, outputs print statements.
|
260
|
+
|
261
|
+
Attributes
|
262
|
+
----------
|
263
|
+
index2label : dict
|
264
|
+
Dictionary which translates from class integers to the associated class strings.
|
265
|
+
label2index : dict
|
266
|
+
Dictionary which translates from class strings to the associated class integers.
|
267
|
+
path : Path
|
268
|
+
Location of the folder containing the data.
|
269
|
+
metadata : dict
|
270
|
+
Dictionary containing Dataset metadata, such as `id` which returns the dataset class name.
|
271
|
+
"""
|
272
|
+
|
273
|
+
|
274
|
+
class VOCDetectionTorch(
|
275
|
+
BaseVOCDataset[torch.Tensor, ObjectDetectionTarget[torch.Tensor]],
|
276
|
+
BaseODDataset[torch.Tensor],
|
277
|
+
BaseDatasetTorchMixin,
|
278
|
+
):
|
279
|
+
"""
|
280
|
+
`Pascal VOC <http://host.robots.ox.ac.uk/pascal/VOC/>`_ Detection Dataset.
|
281
|
+
|
282
|
+
Parameters
|
283
|
+
----------
|
284
|
+
root : str or pathlib.Path
|
285
|
+
Root directory of dataset where the ``vocdataset`` folder exists.
|
286
|
+
download : bool, default False
|
287
|
+
If True, downloads the dataset from the internet and puts it in root directory.
|
288
|
+
Class checks to see if data is already downloaded to ensure it does not create a duplicate download.
|
289
|
+
image_set : "train", "val", "test", or "base", default "train"
|
290
|
+
If "test", then dataset year must be "2007".
|
291
|
+
If "base", then the combined dataset of "train" and "val" is returned.
|
292
|
+
year : "2007", "2008", "2009", "2010", "2011" or "2012", default "2012"
|
293
|
+
The dataset year.
|
294
|
+
transforms : Transform | Sequence[Transform] | None, default None
|
295
|
+
Transform(s) to apply to the data.
|
296
|
+
verbose : bool, default False
|
297
|
+
If True, outputs print statements.
|
298
|
+
|
299
|
+
Attributes
|
300
|
+
----------
|
301
|
+
index2label : dict
|
302
|
+
Dictionary which translates from class integers to the associated class strings.
|
303
|
+
label2index : dict
|
304
|
+
Dictionary which translates from class strings to the associated class integers.
|
305
|
+
path : Path
|
306
|
+
Location of the folder containing the data.
|
307
|
+
metadata : dict
|
308
|
+
Dictionary containing Dataset metadata, such as `id` which returns the dataset class name.
|
309
|
+
"""
|
310
|
+
|
311
|
+
|
312
|
+
class VOCSegmentation(
|
313
|
+
BaseVOCDataset[NDArray[Any], SegmentationTarget[NDArray[Any]]],
|
314
|
+
BaseSegDataset[NDArray[Any]],
|
315
|
+
BaseDatasetNumpyMixin,
|
316
|
+
):
|
317
|
+
"""
|
318
|
+
`Pascal VOC <http://host.robots.ox.ac.uk/pascal/VOC/>`_ Segmentation Dataset.
|
319
|
+
|
320
|
+
Parameters
|
321
|
+
----------
|
322
|
+
root : str or pathlib.Path
|
323
|
+
Root directory of dataset where the ``vocdataset`` folder exists.
|
324
|
+
download : bool, default False
|
325
|
+
If True, downloads the dataset from the internet and puts it in root directory.
|
326
|
+
Class checks to see if data is already downloaded to ensure it does not create a duplicate download.
|
327
|
+
image_set : "train", "val", "test", or "base", default "train"
|
328
|
+
If "test", then dataset year must be "2007".
|
329
|
+
If "base", then the combined dataset of "train" and "val" is returned.
|
330
|
+
year : "2007", "2008", "2009", "2010", "2011" or "2012", default "2012"
|
331
|
+
The dataset year.
|
332
|
+
transforms : Transform | Sequence[Transform] | None, default None
|
333
|
+
Transform(s) to apply to the data.
|
334
|
+
verbose : bool, default False
|
335
|
+
If True, outputs print statements.
|
336
|
+
|
337
|
+
Attributes
|
338
|
+
----------
|
339
|
+
index2label : dict
|
340
|
+
Dictionary which translates from class integers to the associated class strings.
|
341
|
+
label2index : dict
|
342
|
+
Dictionary which translates from class strings to the associated class integers.
|
343
|
+
path : Path
|
344
|
+
Location of the folder containing the data.
|
345
|
+
metadata : dict
|
346
|
+
Dictionary containing Dataset metadata, such as `id` which returns the dataset class name.
|
347
|
+
"""
|
348
|
+
|
349
|
+
def _load_data(self) -> tuple[list[str], list[str], dict[str, list[Any]]]:
|
350
|
+
filepaths, targets, datum_metadata = super()._load_data()
|
351
|
+
self._masks = datum_metadata.pop("mask_path")
|
352
|
+
return filepaths, targets, datum_metadata
|
@@ -0,0 +1,15 @@
|
|
1
|
+
"""Provides selection classes for selecting subsets of Computer Vision datasets."""
|
2
|
+
|
3
|
+
__all__ = [
|
4
|
+
"ClassFilter",
|
5
|
+
"Indices",
|
6
|
+
"Limit",
|
7
|
+
"Reverse",
|
8
|
+
"Shuffle",
|
9
|
+
]
|
10
|
+
|
11
|
+
from dataeval.utils.data.selections._classfilter import ClassFilter
|
12
|
+
from dataeval.utils.data.selections._indices import Indices
|
13
|
+
from dataeval.utils.data.selections._limit import Limit
|
14
|
+
from dataeval.utils.data.selections._reverse import Reverse
|
15
|
+
from dataeval.utils.data.selections._shuffle import Shuffle
|
@@ -0,0 +1,57 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
__all__ = []
|
4
|
+
|
5
|
+
from typing import Sequence
|
6
|
+
|
7
|
+
import numpy as np
|
8
|
+
|
9
|
+
from dataeval.typing import Array, ImageClassificationDatum
|
10
|
+
from dataeval.utils._array import as_numpy
|
11
|
+
from dataeval.utils.data._selection import Select, Selection, SelectionStage
|
12
|
+
|
13
|
+
|
14
|
+
class ClassFilter(Selection[ImageClassificationDatum]):
|
15
|
+
"""
|
16
|
+
Filter and balance the dataset by class.
|
17
|
+
|
18
|
+
Parameters
|
19
|
+
----------
|
20
|
+
classes : Sequence[int] or None, default None
|
21
|
+
The classes to filter by. If None, all classes are included.
|
22
|
+
balance : bool, default False
|
23
|
+
Whether to balance the classes.
|
24
|
+
|
25
|
+
Note
|
26
|
+
----
|
27
|
+
If `balance` is True, the total number of instances of each class will
|
28
|
+
be equalized. This may result in a lower total number of instances.
|
29
|
+
"""
|
30
|
+
|
31
|
+
stage = SelectionStage.FILTER
|
32
|
+
|
33
|
+
def __init__(self, classes: Sequence[int] | None = None, balance: bool = False) -> None:
|
34
|
+
self.classes = classes
|
35
|
+
self.balance = balance
|
36
|
+
|
37
|
+
def __call__(self, dataset: Select[ImageClassificationDatum]) -> None:
|
38
|
+
if self.classes is None and not self.balance:
|
39
|
+
return
|
40
|
+
|
41
|
+
per_class_limit = dataset._size_limit // len(self.classes) if self.classes and self.balance else 0
|
42
|
+
class_indices: dict[int, list[int]] = {} if self.classes is None else {k: [] for k in self.classes}
|
43
|
+
for i, idx in enumerate(dataset._selection):
|
44
|
+
target = dataset._dataset[idx][1]
|
45
|
+
if isinstance(target, Array):
|
46
|
+
label = int(np.argmax(as_numpy(target)))
|
47
|
+
else:
|
48
|
+
# ObjectDetectionTarget and SegmentationTarget not supported yet
|
49
|
+
raise TypeError("ClassFilter only supports classification targets as an array of confidence scores.")
|
50
|
+
if not self.classes or label in self.classes:
|
51
|
+
class_indices.setdefault(label, []).append(i)
|
52
|
+
if per_class_limit and all(len(indices) >= per_class_limit for indices in class_indices.values()):
|
53
|
+
break
|
54
|
+
|
55
|
+
per_class_limit = min(len(c) for c in class_indices.values()) if self.balance else dataset._size_limit
|
56
|
+
subselection = sorted([i for v in class_indices.values() for i in v[:per_class_limit]])
|
57
|
+
dataset._selection = [dataset._selection[i] for i in subselection]
|
@@ -0,0 +1,26 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
__all__ = []
|
4
|
+
|
5
|
+
from typing import Any, Sequence
|
6
|
+
|
7
|
+
from dataeval.utils.data._selection import Select, Selection, SelectionStage
|
8
|
+
|
9
|
+
|
10
|
+
class Indices(Selection[Any]):
|
11
|
+
"""
|
12
|
+
Selects specific indices from the dataset.
|
13
|
+
|
14
|
+
Parameters
|
15
|
+
----------
|
16
|
+
indices : Sequence[int]
|
17
|
+
The indices to select from the dataset.
|
18
|
+
"""
|
19
|
+
|
20
|
+
stage = SelectionStage.FILTER
|
21
|
+
|
22
|
+
def __init__(self, indices: Sequence[int]) -> None:
|
23
|
+
self.indices = indices
|
24
|
+
|
25
|
+
def __call__(self, dataset: Select[Any]) -> None:
|
26
|
+
dataset._selection = [index for index in self.indices if index in dataset._selection]
|
@@ -0,0 +1,26 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
__all__ = []
|
4
|
+
|
5
|
+
from typing import Any
|
6
|
+
|
7
|
+
from dataeval.utils.data._selection import Select, Selection, SelectionStage
|
8
|
+
|
9
|
+
|
10
|
+
class Limit(Selection[Any]):
|
11
|
+
"""
|
12
|
+
Limit the size of the dataset.
|
13
|
+
|
14
|
+
Parameters
|
15
|
+
----------
|
16
|
+
size : int
|
17
|
+
The maximum size of the dataset.
|
18
|
+
"""
|
19
|
+
|
20
|
+
stage = SelectionStage.STATE
|
21
|
+
|
22
|
+
def __init__(self, size: int) -> None:
|
23
|
+
self.size = size
|
24
|
+
|
25
|
+
def __call__(self, dataset: Select[Any]) -> None:
|
26
|
+
dataset._size_limit = self.size
|
@@ -0,0 +1,18 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
__all__ = []
|
4
|
+
|
5
|
+
from typing import Any
|
6
|
+
|
7
|
+
from dataeval.utils.data._selection import Select, Selection, SelectionStage
|
8
|
+
|
9
|
+
|
10
|
+
class Reverse(Selection[Any]):
|
11
|
+
"""
|
12
|
+
Reverse the selection order of the dataset.
|
13
|
+
"""
|
14
|
+
|
15
|
+
stage = SelectionStage.ORDER
|
16
|
+
|
17
|
+
def __call__(self, dataset: Select[Any]) -> None:
|
18
|
+
dataset._selection.reverse()
|
@@ -0,0 +1,29 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
__all__ = []
|
4
|
+
|
5
|
+
from typing import Any
|
6
|
+
|
7
|
+
import numpy as np
|
8
|
+
|
9
|
+
from dataeval.utils.data._selection import Select, Selection, SelectionStage
|
10
|
+
|
11
|
+
|
12
|
+
class Shuffle(Selection[Any]):
|
13
|
+
"""
|
14
|
+
Shuffle the dataset using a seed.
|
15
|
+
|
16
|
+
Parameters
|
17
|
+
----------
|
18
|
+
seed
|
19
|
+
Seed for the random number generator.
|
20
|
+
"""
|
21
|
+
|
22
|
+
stage = SelectionStage.ORDER
|
23
|
+
|
24
|
+
def __init__(self, seed: int):
|
25
|
+
self.seed = seed
|
26
|
+
|
27
|
+
def __call__(self, dataset: Select[Any]) -> None:
|
28
|
+
rng = np.random.default_rng(self.seed)
|
29
|
+
rng.shuffle(dataset._selection)
|