dataeval 0.76.0__py3-none-any.whl → 0.81.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dataeval/__init__.py +3 -3
- dataeval/{output.py → _output.py} +14 -0
- dataeval/config.py +77 -0
- dataeval/detectors/__init__.py +1 -1
- dataeval/detectors/drift/__init__.py +6 -6
- dataeval/detectors/drift/{base.py → _base.py} +41 -30
- dataeval/detectors/drift/{cvm.py → _cvm.py} +21 -28
- dataeval/detectors/drift/{ks.py → _ks.py} +20 -26
- dataeval/detectors/drift/{mmd.py → _mmd.py} +33 -19
- dataeval/detectors/drift/{torch.py → _torch.py} +2 -1
- dataeval/detectors/drift/{uncertainty.py → _uncertainty.py} +23 -7
- dataeval/detectors/drift/updates.py +1 -1
- dataeval/detectors/linters/__init__.py +0 -3
- dataeval/detectors/linters/duplicates.py +17 -8
- dataeval/detectors/linters/outliers.py +52 -43
- dataeval/detectors/ood/ae.py +29 -8
- dataeval/detectors/ood/base.py +5 -4
- dataeval/detectors/ood/metadata_ks_compare.py +1 -1
- dataeval/detectors/ood/mixin.py +20 -5
- dataeval/detectors/ood/output.py +1 -1
- dataeval/detectors/ood/vae.py +73 -0
- dataeval/metadata/__init__.py +5 -0
- dataeval/metadata/_ood.py +238 -0
- dataeval/metrics/__init__.py +1 -1
- dataeval/metrics/bias/__init__.py +5 -4
- dataeval/metrics/bias/{balance.py → _balance.py} +67 -17
- dataeval/metrics/bias/{coverage.py → _coverage.py} +41 -35
- dataeval/metrics/bias/{diversity.py → _diversity.py} +17 -12
- dataeval/metrics/bias/{parity.py → _parity.py} +89 -63
- dataeval/metrics/estimators/__init__.py +14 -4
- dataeval/metrics/estimators/{ber.py → _ber.py} +42 -11
- dataeval/metrics/estimators/_clusterer.py +104 -0
- dataeval/metrics/estimators/{divergence.py → _divergence.py} +18 -13
- dataeval/metrics/estimators/{uap.py → _uap.py} +4 -4
- dataeval/metrics/stats/__init__.py +7 -7
- dataeval/metrics/stats/{base.py → _base.py} +52 -16
- dataeval/metrics/stats/{boxratiostats.py → _boxratiostats.py} +6 -9
- dataeval/metrics/stats/{datasetstats.py → _datasetstats.py} +10 -14
- dataeval/metrics/stats/{dimensionstats.py → _dimensionstats.py} +6 -5
- dataeval/metrics/stats/{hashstats.py → _hashstats.py} +6 -6
- dataeval/metrics/stats/{labelstats.py → _labelstats.py} +25 -25
- dataeval/metrics/stats/{pixelstats.py → _pixelstats.py} +5 -4
- dataeval/metrics/stats/{visualstats.py → _visualstats.py} +9 -8
- dataeval/typing.py +54 -0
- dataeval/utils/__init__.py +2 -2
- dataeval/utils/_array.py +169 -0
- dataeval/utils/_bin.py +199 -0
- dataeval/utils/_clusterer.py +144 -0
- dataeval/utils/_fast_mst.py +189 -0
- dataeval/utils/{image.py → _image.py} +6 -4
- dataeval/utils/_method.py +18 -0
- dataeval/utils/{shared.py → _mst.py} +3 -65
- dataeval/utils/{plot.py → _plot.py} +4 -4
- dataeval/utils/data/__init__.py +22 -0
- dataeval/utils/data/_embeddings.py +105 -0
- dataeval/utils/data/_images.py +65 -0
- dataeval/utils/data/_metadata.py +352 -0
- dataeval/utils/data/_selection.py +119 -0
- dataeval/utils/{dataset/split.py → data/_split.py} +13 -14
- dataeval/utils/data/_targets.py +73 -0
- dataeval/utils/data/_types.py +58 -0
- dataeval/utils/data/collate.py +103 -0
- dataeval/utils/data/datasets/__init__.py +17 -0
- dataeval/utils/data/datasets/_base.py +254 -0
- dataeval/utils/data/datasets/_cifar10.py +134 -0
- dataeval/utils/data/datasets/_fileio.py +168 -0
- dataeval/utils/data/datasets/_milco.py +153 -0
- dataeval/utils/data/datasets/_mixin.py +56 -0
- dataeval/utils/data/datasets/_mnist.py +183 -0
- dataeval/utils/data/datasets/_ships.py +123 -0
- dataeval/utils/data/datasets/_voc.py +352 -0
- dataeval/utils/data/selections/__init__.py +15 -0
- dataeval/utils/data/selections/_classfilter.py +60 -0
- dataeval/utils/data/selections/_indices.py +26 -0
- dataeval/utils/data/selections/_limit.py +26 -0
- dataeval/utils/data/selections/_reverse.py +18 -0
- dataeval/utils/data/selections/_shuffle.py +29 -0
- dataeval/utils/metadata.py +198 -376
- dataeval/utils/torch/{gmm.py → _gmm.py} +4 -2
- dataeval/utils/torch/{internal.py → _internal.py} +21 -51
- dataeval/utils/torch/models.py +43 -2
- dataeval/workflows/sufficiency.py +10 -9
- {dataeval-0.76.0.dist-info → dataeval-0.81.0.dist-info}/METADATA +44 -15
- dataeval-0.81.0.dist-info/RECORD +94 -0
- dataeval/detectors/linters/clusterer.py +0 -512
- dataeval/detectors/linters/merged_stats.py +0 -49
- dataeval/detectors/ood/metadata_least_likely.py +0 -119
- dataeval/interop.py +0 -69
- dataeval/utils/dataset/__init__.py +0 -7
- dataeval/utils/dataset/datasets.py +0 -412
- dataeval/utils/dataset/read.py +0 -63
- dataeval-0.76.0.dist-info/RECORD +0 -67
- /dataeval/{log.py → _log.py} +0 -0
- /dataeval/utils/torch/{blocks.py → _blocks.py} +0 -0
- {dataeval-0.76.0.dist-info → dataeval-0.81.0.dist-info}/LICENSE.txt +0 -0
- {dataeval-0.76.0.dist-info → dataeval-0.81.0.dist-info}/WHEEL +0 -0
@@ -4,7 +4,7 @@ __all__ = []
|
|
4
4
|
|
5
5
|
import warnings
|
6
6
|
from dataclasses import dataclass
|
7
|
-
from typing import Any, Iterator,
|
7
|
+
from typing import Any, Iterator, Protocol
|
8
8
|
|
9
9
|
import numpy as np
|
10
10
|
from numpy.typing import NDArray
|
@@ -13,10 +13,11 @@ from sklearn.metrics import silhouette_score
|
|
13
13
|
from sklearn.model_selection import GroupKFold, KFold, StratifiedGroupKFold, StratifiedKFold
|
14
14
|
from sklearn.utils.multiclass import type_of_target
|
15
15
|
|
16
|
-
from dataeval.
|
16
|
+
from dataeval._output import Output, set_metadata
|
17
17
|
|
18
18
|
|
19
|
-
|
19
|
+
@dataclass
|
20
|
+
class TrainValSplit:
|
20
21
|
"""Tuple containing train and validation indices"""
|
21
22
|
|
22
23
|
train: NDArray[np.intp]
|
@@ -274,8 +275,7 @@ def get_group_ids(metadata: dict[str, Any], group_names: list[str], num_samples:
|
|
274
275
|
for name, feature in features2group.items():
|
275
276
|
if len(feature) != num_samples:
|
276
277
|
raise ValueError(
|
277
|
-
f"Feature length does not match number of labels. "
|
278
|
-
f"Got {len(feature)} features and {num_samples} samples"
|
278
|
+
f"Feature length does not match number of labels. Got {len(feature)} features and {num_samples} samples"
|
279
279
|
)
|
280
280
|
|
281
281
|
if type_of_target(feature) == "continuous":
|
@@ -505,23 +505,22 @@ def split_dataset(
|
|
505
505
|
if is_groupable(possible_groups, group_partitions):
|
506
506
|
groups = possible_groups
|
507
507
|
|
508
|
-
test_indices: NDArray[np.intp]
|
509
508
|
index = np.arange(label_length)
|
510
509
|
|
511
|
-
|
510
|
+
tvs = (
|
512
511
|
single_split(index=index, labels=labels, split_frac=test_frac, groups=groups, stratified=stratify)
|
513
512
|
if test_frac
|
514
|
-
else (index, np.array([], dtype=np.intp))
|
513
|
+
else TrainValSplit(index, np.array([], dtype=np.intp))
|
515
514
|
)
|
516
515
|
|
517
|
-
tv_labels = labels[
|
518
|
-
tv_groups = groups[
|
516
|
+
tv_labels = labels[tvs.train]
|
517
|
+
tv_groups = groups[tvs.train] if groups is not None else None
|
519
518
|
|
520
519
|
if num_folds == 1:
|
521
|
-
tv_splits = [single_split(
|
520
|
+
tv_splits = [single_split(tvs.train, tv_labels, val_frac, tv_groups, stratify)]
|
522
521
|
else:
|
523
|
-
tv_splits = make_splits(
|
522
|
+
tv_splits = make_splits(tvs.train, tv_labels, num_folds, tv_groups, stratify)
|
524
523
|
|
525
|
-
folds: list[TrainValSplit] = [TrainValSplit(
|
524
|
+
folds: list[TrainValSplit] = [TrainValSplit(tvs.train[split.train], tvs.train[split.val]) for split in tv_splits]
|
526
525
|
|
527
|
-
return SplitDatasetOutput(
|
526
|
+
return SplitDatasetOutput(tvs.val, folds)
|
@@ -0,0 +1,73 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
__all__ = []
|
4
|
+
|
5
|
+
from dataclasses import dataclass
|
6
|
+
|
7
|
+
import numpy as np
|
8
|
+
from numpy.typing import NDArray
|
9
|
+
|
10
|
+
|
11
|
+
def _len(arr: NDArray, dim: int) -> int:
|
12
|
+
return 0 if len(arr) == 0 else len(np.atleast_1d(arr) if dim == 1 else np.atleast_2d(arr))
|
13
|
+
|
14
|
+
|
15
|
+
@dataclass(frozen=True)
|
16
|
+
class Targets:
|
17
|
+
"""
|
18
|
+
Dataclass defining targets for image classification or object detection.
|
19
|
+
|
20
|
+
Attributes
|
21
|
+
----------
|
22
|
+
labels : NDArray[np.intp]
|
23
|
+
Labels (N,) for N images or objects
|
24
|
+
scores : NDArray[np.float32]
|
25
|
+
Probability scores (N,M) for N images of M classes or confidence score (N,) of objects
|
26
|
+
bboxes : NDArray[np.float32] | None
|
27
|
+
Bounding boxes (N,4) for N objects in (x0,y0,x1,y1) format
|
28
|
+
source : NDArray[np.intp] | None
|
29
|
+
Source image index (N,) for N objects
|
30
|
+
"""
|
31
|
+
|
32
|
+
labels: NDArray[np.intp]
|
33
|
+
scores: NDArray[np.float32]
|
34
|
+
bboxes: NDArray[np.float32] | None
|
35
|
+
source: NDArray[np.intp] | None
|
36
|
+
|
37
|
+
def __post_init__(self) -> None:
|
38
|
+
if (self.bboxes is None) != (self.source is None):
|
39
|
+
raise ValueError("Either both bboxes and source must be provided or neither.")
|
40
|
+
|
41
|
+
labels = _len(self.labels, 1)
|
42
|
+
scores = _len(self.scores, 2) if self.bboxes is None else _len(self.scores, 1)
|
43
|
+
bboxes = labels if self.bboxes is None else _len(self.bboxes, 2)
|
44
|
+
source = labels if self.source is None else _len(self.source, 1)
|
45
|
+
|
46
|
+
if labels != scores or labels != bboxes or labels != source:
|
47
|
+
raise ValueError(
|
48
|
+
"Labels, scores, bboxes and source must be the same length (if provided).\n"
|
49
|
+
+ f" labels: {self.labels.shape}\n"
|
50
|
+
+ f" scores: {self.scores.shape}\n"
|
51
|
+
+ f" bboxes: {None if self.bboxes is None else self.bboxes.shape}\n"
|
52
|
+
+ f" source: {None if self.source is None else self.source.shape}\n"
|
53
|
+
)
|
54
|
+
|
55
|
+
def __len__(self) -> int:
|
56
|
+
return len(self.labels)
|
57
|
+
|
58
|
+
def at(self, idx: int) -> Targets:
|
59
|
+
if self.source is None or self.bboxes is None:
|
60
|
+
return Targets(
|
61
|
+
np.atleast_1d(self.labels[idx]),
|
62
|
+
np.atleast_2d(self.scores[idx]),
|
63
|
+
None,
|
64
|
+
None,
|
65
|
+
)
|
66
|
+
else:
|
67
|
+
mask = np.where(self.source == idx, True, False)
|
68
|
+
return Targets(
|
69
|
+
np.atleast_1d(self.labels[mask]),
|
70
|
+
np.atleast_1d(self.scores[mask]),
|
71
|
+
np.atleast_2d(self.bboxes[mask]),
|
72
|
+
np.atleast_1d(self.source[mask]),
|
73
|
+
)
|
@@ -0,0 +1,58 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
__all__ = []
|
4
|
+
|
5
|
+
import sys
|
6
|
+
from dataclasses import dataclass
|
7
|
+
from typing import Any, Generic, Protocol, TypedDict, TypeVar
|
8
|
+
|
9
|
+
if sys.version_info >= (3, 11):
|
10
|
+
from typing import NotRequired, Required
|
11
|
+
else:
|
12
|
+
from typing_extensions import NotRequired, Required
|
13
|
+
|
14
|
+
from torch.utils.data import Dataset as _Dataset
|
15
|
+
|
16
|
+
_TArray = TypeVar("_TArray")
|
17
|
+
_TData = TypeVar("_TData", covariant=True)
|
18
|
+
_TTarget = TypeVar("_TTarget", covariant=True)
|
19
|
+
|
20
|
+
|
21
|
+
class DatasetMetadata(TypedDict):
|
22
|
+
id: Required[str]
|
23
|
+
index2label: NotRequired[dict[int, str]]
|
24
|
+
split: NotRequired[str]
|
25
|
+
|
26
|
+
|
27
|
+
class Dataset(_Dataset[tuple[_TData, _TTarget, dict[str, Any]]]):
|
28
|
+
metadata: DatasetMetadata
|
29
|
+
|
30
|
+
def __getitem__(self, index: Any) -> tuple[_TData, _TTarget, dict[str, Any]]: ...
|
31
|
+
def __len__(self) -> int: ...
|
32
|
+
|
33
|
+
|
34
|
+
class ImageClassificationDataset(Dataset[_TArray, _TArray]): ...
|
35
|
+
|
36
|
+
|
37
|
+
@dataclass
|
38
|
+
class ObjectDetectionTarget(Generic[_TArray]):
|
39
|
+
boxes: _TArray
|
40
|
+
labels: _TArray
|
41
|
+
scores: _TArray
|
42
|
+
|
43
|
+
|
44
|
+
class ObjectDetectionDataset(Dataset[_TArray, ObjectDetectionTarget[_TArray]]): ...
|
45
|
+
|
46
|
+
|
47
|
+
@dataclass
|
48
|
+
class SegmentationTarget(Generic[_TArray]):
|
49
|
+
mask: _TArray
|
50
|
+
labels: _TArray
|
51
|
+
scores: _TArray
|
52
|
+
|
53
|
+
|
54
|
+
class SegmentationDataset(Dataset[_TArray, SegmentationTarget[_TArray]]): ...
|
55
|
+
|
56
|
+
|
57
|
+
class Transform(Generic[_TArray], Protocol):
|
58
|
+
def __call__(self, data: _TArray, /) -> _TArray: ...
|
@@ -0,0 +1,103 @@
|
|
1
|
+
"""
|
2
|
+
Collate functions used with a PyTorch DataLoader to load data from MAITE compliant datasets.
|
3
|
+
"""
|
4
|
+
|
5
|
+
from __future__ import annotations
|
6
|
+
|
7
|
+
from typing import Any, Iterable, Sequence, TypeVar
|
8
|
+
|
9
|
+
import numpy as np
|
10
|
+
import torch
|
11
|
+
from numpy.typing import NDArray
|
12
|
+
|
13
|
+
from dataeval.typing import ArrayLike
|
14
|
+
from dataeval.utils._array import as_numpy
|
15
|
+
|
16
|
+
T_in = TypeVar("T_in")
|
17
|
+
T_tgt = TypeVar("T_tgt")
|
18
|
+
T_md = TypeVar("T_md")
|
19
|
+
|
20
|
+
|
21
|
+
def list_collate_fn(
|
22
|
+
batch_data_as_singles: Iterable[tuple[T_in, T_tgt, T_md]],
|
23
|
+
) -> tuple[Sequence[T_in], Sequence[T_tgt], Sequence[T_md]]:
|
24
|
+
"""
|
25
|
+
A collate function that takes a batch of individual data points in the format
|
26
|
+
(input, target, metadata) and returns three lists: the input batch, the target batch,
|
27
|
+
and the metadata batch. This is useful for loading data with torch.utils.data.DataLoader
|
28
|
+
when the target and metadata are not tensors.
|
29
|
+
|
30
|
+
Parameters
|
31
|
+
----------
|
32
|
+
batch_data_as_singles : An iterable of (input, target, metadata) tuples.
|
33
|
+
|
34
|
+
Returns
|
35
|
+
-------
|
36
|
+
tuple[Sequence[T_in], Sequence[T_tgt], Sequence[T_md]]
|
37
|
+
A tuple of three lists: the input batch, the target batch, and the metadata batch.
|
38
|
+
"""
|
39
|
+
input_batch: list[T_in] = []
|
40
|
+
target_batch: list[T_tgt] = []
|
41
|
+
metadata_batch: list[T_md] = []
|
42
|
+
for input_datum, target_datum, metadata_datum in batch_data_as_singles:
|
43
|
+
input_batch.append(input_datum)
|
44
|
+
target_batch.append(target_datum)
|
45
|
+
metadata_batch.append(metadata_datum)
|
46
|
+
|
47
|
+
return input_batch, target_batch, metadata_batch
|
48
|
+
|
49
|
+
|
50
|
+
def numpy_collate_fn(
|
51
|
+
batch_data_as_singles: Iterable[tuple[ArrayLike, T_tgt, T_md]],
|
52
|
+
) -> tuple[NDArray[Any], Sequence[T_tgt], Sequence[T_md]]:
|
53
|
+
"""
|
54
|
+
A collate function that takes a batch of individual data points in the format
|
55
|
+
(input, target, metadata) and returns the batched input as a single NumPy array with two
|
56
|
+
lists: the target batch, and the metadata batch. The inputs must be homogeneous arrays.
|
57
|
+
|
58
|
+
Parameters
|
59
|
+
----------
|
60
|
+
batch_data_as_singles : An iterable of (ArrayLike, target, metadata) tuples.
|
61
|
+
|
62
|
+
Returns
|
63
|
+
-------
|
64
|
+
tuple[NDArray[Any], Sequence[T_tgt], Sequence[T_md]]
|
65
|
+
A tuple of a NumPy array and two lists: the input batch, the target batch, and the metadata batch.
|
66
|
+
"""
|
67
|
+
input_batch: list[NDArray[Any]] = []
|
68
|
+
target_batch: list[T_tgt] = []
|
69
|
+
metadata_batch: list[T_md] = []
|
70
|
+
for input_datum, target_datum, metadata_datum in batch_data_as_singles:
|
71
|
+
input_batch.append(as_numpy(input_datum))
|
72
|
+
target_batch.append(target_datum)
|
73
|
+
metadata_batch.append(metadata_datum)
|
74
|
+
|
75
|
+
return np.stack(input_batch) if input_batch else np.array([]), target_batch, metadata_batch
|
76
|
+
|
77
|
+
|
78
|
+
def torch_collate_fn(
|
79
|
+
batch_data_as_singles: Iterable[tuple[ArrayLike, T_tgt, T_md]],
|
80
|
+
) -> tuple[torch.Tensor, Sequence[T_tgt], Sequence[T_md]]:
|
81
|
+
"""
|
82
|
+
A collate function that takes a batch of individual data points in the format
|
83
|
+
(input, target, metadata) and returns the batched input as a single torch Tensor with two
|
84
|
+
lists: the target batch, and the metadata batch. The inputs must be homogeneous arrays.
|
85
|
+
|
86
|
+
Parameters
|
87
|
+
----------
|
88
|
+
batch_data_as_singles : An iterable of (ArrayLike, target, metadata) tuples.
|
89
|
+
|
90
|
+
Returns
|
91
|
+
-------
|
92
|
+
tuple[torch.Tensor, Sequence[T_tgt], Sequence[T_md]]
|
93
|
+
A tuple of a torch Tensor and two lists: the input batch, the target batch, and the metadata batch.
|
94
|
+
"""
|
95
|
+
input_batch: list[torch.Tensor] = []
|
96
|
+
target_batch: list[T_tgt] = []
|
97
|
+
metadata_batch: list[T_md] = []
|
98
|
+
for input_datum, target_datum, metadata_datum in batch_data_as_singles:
|
99
|
+
input_batch.append(torch.as_tensor(input_datum))
|
100
|
+
target_batch.append(target_datum)
|
101
|
+
metadata_batch.append(metadata_datum)
|
102
|
+
|
103
|
+
return torch.stack(input_batch) if input_batch else torch.tensor([]), target_batch, metadata_batch
|
@@ -0,0 +1,17 @@
|
|
1
|
+
"""Provides access to common Computer Vision datasets."""
|
2
|
+
|
3
|
+
from dataeval.utils.data.datasets._cifar10 import CIFAR10
|
4
|
+
from dataeval.utils.data.datasets._milco import MILCO
|
5
|
+
from dataeval.utils.data.datasets._mnist import MNIST
|
6
|
+
from dataeval.utils.data.datasets._ships import Ships
|
7
|
+
from dataeval.utils.data.datasets._voc import VOCDetection, VOCDetectionTorch, VOCSegmentation
|
8
|
+
|
9
|
+
__all__ = [
|
10
|
+
"MNIST",
|
11
|
+
"Ships",
|
12
|
+
"CIFAR10",
|
13
|
+
"MILCO",
|
14
|
+
"VOCDetection",
|
15
|
+
"VOCDetectionTorch",
|
16
|
+
"VOCSegmentation",
|
17
|
+
]
|
@@ -0,0 +1,254 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
__all__ = []
|
4
|
+
|
5
|
+
from abc import abstractmethod
|
6
|
+
from pathlib import Path
|
7
|
+
from typing import Any, Generic, Iterator, Literal, NamedTuple, Sequence, TypeVar
|
8
|
+
|
9
|
+
from dataeval.utils.data._types import (
|
10
|
+
Dataset,
|
11
|
+
DatasetMetadata,
|
12
|
+
ImageClassificationDataset,
|
13
|
+
ObjectDetectionDataset,
|
14
|
+
ObjectDetectionTarget,
|
15
|
+
SegmentationDataset,
|
16
|
+
SegmentationTarget,
|
17
|
+
Transform,
|
18
|
+
)
|
19
|
+
from dataeval.utils.data.datasets._fileio import _ensure_exists
|
20
|
+
from dataeval.utils.data.datasets._mixin import BaseDatasetMixin
|
21
|
+
|
22
|
+
_TArray = TypeVar("_TArray")
|
23
|
+
_TTarget = TypeVar("_TTarget")
|
24
|
+
_TRawTarget = TypeVar("_TRawTarget", list[int], list[str])
|
25
|
+
|
26
|
+
|
27
|
+
class DataLocation(NamedTuple):
|
28
|
+
url: str
|
29
|
+
filename: str
|
30
|
+
md5: bool
|
31
|
+
checksum: str
|
32
|
+
|
33
|
+
|
34
|
+
class BaseDataset(Dataset[_TArray, _TTarget], Generic[_TArray, _TTarget, _TRawTarget]):
|
35
|
+
"""
|
36
|
+
Base class for internet downloaded datasets.
|
37
|
+
"""
|
38
|
+
|
39
|
+
# Each subclass should override the attributes below.
|
40
|
+
# Each resource tuple must contain:
|
41
|
+
# 'url': str, the URL to download from
|
42
|
+
# 'filename': str, the name of the file once downloaded
|
43
|
+
# 'md5': boolean, True if it's the checksum value is md5
|
44
|
+
# 'checksum': str, the associated checksum for the downloaded file
|
45
|
+
_resources: list[DataLocation]
|
46
|
+
_resource_index: int = 0
|
47
|
+
index2label: dict[int, str]
|
48
|
+
|
49
|
+
def __init__(
|
50
|
+
self,
|
51
|
+
root: str | Path,
|
52
|
+
download: bool = False,
|
53
|
+
image_set: Literal["train", "val", "test", "base"] = "train",
|
54
|
+
transforms: Transform[_TArray] | Sequence[Transform[_TArray]] | None = None,
|
55
|
+
verbose: bool = False,
|
56
|
+
) -> None:
|
57
|
+
self._root: Path = root.absolute() if isinstance(root, Path) else Path(root).absolute()
|
58
|
+
transforms = transforms if transforms is not None else []
|
59
|
+
self.transforms: Sequence[Transform[_TArray]] = transforms if isinstance(transforms, Sequence) else [transforms]
|
60
|
+
self.image_set = image_set
|
61
|
+
self._verbose = verbose
|
62
|
+
|
63
|
+
# Internal Attributes
|
64
|
+
self._download = download
|
65
|
+
self._filepaths: list[str]
|
66
|
+
self._targets: _TRawTarget
|
67
|
+
self._datum_metadata: dict[str, list[Any]]
|
68
|
+
self._resource: DataLocation = self._resources[self._resource_index]
|
69
|
+
self._label2index = {v: k for k, v in self.index2label.items()}
|
70
|
+
|
71
|
+
self.metadata: DatasetMetadata = DatasetMetadata(
|
72
|
+
id=self._unique_id(),
|
73
|
+
index2label=self.index2label,
|
74
|
+
split=self.image_set,
|
75
|
+
)
|
76
|
+
|
77
|
+
# Load the data
|
78
|
+
self.path: Path = self._get_dataset_dir()
|
79
|
+
self._filepaths, self._targets, self._datum_metadata = self._load_data()
|
80
|
+
self.size: int = len(self._filepaths)
|
81
|
+
|
82
|
+
def __str__(self) -> str:
|
83
|
+
nt = "\n "
|
84
|
+
title = f"{self.__class__.__name__} Dataset"
|
85
|
+
sep = "-" * len(title)
|
86
|
+
attrs = [f"{k.capitalize()}: {v}" for k, v in self.__dict__.items() if not k.startswith("_")]
|
87
|
+
return f"{title}\n{sep}{nt}{nt.join(attrs)}"
|
88
|
+
|
89
|
+
@property
|
90
|
+
def label2index(self) -> dict[str, int]:
|
91
|
+
return self._label2index
|
92
|
+
|
93
|
+
def __iter__(self) -> Iterator[tuple[_TArray, _TTarget, dict[str, Any]]]:
|
94
|
+
for i in range(len(self)):
|
95
|
+
yield self[i]
|
96
|
+
|
97
|
+
def _get_dataset_dir(self) -> Path:
|
98
|
+
# Create a designated folder for this dataset (named after the class)
|
99
|
+
if self._root.stem in [
|
100
|
+
self.__class__.__name__.lower(),
|
101
|
+
self.__class__.__name__.upper(),
|
102
|
+
self.__class__.__name__,
|
103
|
+
]:
|
104
|
+
dataset_dir: Path = self._root
|
105
|
+
else:
|
106
|
+
dataset_dir: Path = self._root / self.__class__.__name__.lower()
|
107
|
+
if not dataset_dir.exists():
|
108
|
+
dataset_dir.mkdir(parents=True, exist_ok=True)
|
109
|
+
return dataset_dir
|
110
|
+
|
111
|
+
def _unique_id(self) -> str:
|
112
|
+
unique_id = f"{self.__class__.__name__}_{self.image_set}"
|
113
|
+
return unique_id
|
114
|
+
|
115
|
+
def _load_data(self) -> tuple[list[str], _TRawTarget, dict[str, Any]]:
|
116
|
+
"""
|
117
|
+
Function to determine if data can be accessed or if it needs to be downloaded and/or extracted.
|
118
|
+
"""
|
119
|
+
if self._verbose:
|
120
|
+
print(f"Determining if {self._resource.filename} needs to be downloaded.")
|
121
|
+
|
122
|
+
try:
|
123
|
+
result = self._load_data_inner()
|
124
|
+
if self._verbose:
|
125
|
+
print("No download needed, loaded data successfully.")
|
126
|
+
except FileNotFoundError:
|
127
|
+
_ensure_exists(*self._resource, self.path, self._root, self._download, self._verbose)
|
128
|
+
result = self._load_data_inner()
|
129
|
+
return result
|
130
|
+
|
131
|
+
@abstractmethod
|
132
|
+
def _load_data_inner(self) -> tuple[list[str], _TRawTarget, dict[str, Any]]: ...
|
133
|
+
|
134
|
+
def _transform(self, image: _TArray) -> _TArray:
|
135
|
+
"""Function to transform the image prior to returning based on parameters passed in."""
|
136
|
+
for transform in self.transforms:
|
137
|
+
image = transform(image)
|
138
|
+
return image
|
139
|
+
|
140
|
+
def __len__(self) -> int:
|
141
|
+
return self.size
|
142
|
+
|
143
|
+
|
144
|
+
class BaseICDataset(
|
145
|
+
BaseDataset[_TArray, _TArray, list[int]],
|
146
|
+
BaseDatasetMixin[_TArray],
|
147
|
+
ImageClassificationDataset[_TArray],
|
148
|
+
):
|
149
|
+
"""
|
150
|
+
Base class for image classification datasets.
|
151
|
+
"""
|
152
|
+
|
153
|
+
def __getitem__(self, index: int) -> tuple[_TArray, _TArray, dict[str, Any]]:
|
154
|
+
"""
|
155
|
+
Args
|
156
|
+
----
|
157
|
+
index : int
|
158
|
+
Value of the desired data point
|
159
|
+
|
160
|
+
Returns
|
161
|
+
-------
|
162
|
+
tuple[TArray, TArray, dict[str, Any]]
|
163
|
+
Image, target, datum_metadata - where target is one-hot encoding of class.
|
164
|
+
"""
|
165
|
+
# Get the associated label and score
|
166
|
+
label = self._targets[index]
|
167
|
+
score = self._one_hot_encode(label)
|
168
|
+
# Get the image
|
169
|
+
img = self._read_file(self._filepaths[index])
|
170
|
+
img = self._transform(img)
|
171
|
+
|
172
|
+
img_metadata = {key: val[index] for key, val in self._datum_metadata.items()}
|
173
|
+
|
174
|
+
return img, score, img_metadata
|
175
|
+
|
176
|
+
|
177
|
+
class BaseODDataset(
|
178
|
+
BaseDataset[_TArray, ObjectDetectionTarget[_TArray], list[str]],
|
179
|
+
BaseDatasetMixin[_TArray],
|
180
|
+
ObjectDetectionDataset[_TArray],
|
181
|
+
):
|
182
|
+
"""
|
183
|
+
Base class for object detection datasets.
|
184
|
+
"""
|
185
|
+
|
186
|
+
def __getitem__(self, index: int) -> tuple[_TArray, ObjectDetectionTarget[_TArray], dict[str, Any]]:
|
187
|
+
"""
|
188
|
+
Args
|
189
|
+
----
|
190
|
+
index : int
|
191
|
+
Value of the desired data point
|
192
|
+
|
193
|
+
Returns
|
194
|
+
-------
|
195
|
+
tuple[TArray, ObjectDetectionTarget[TArray], dict[str, Any]]
|
196
|
+
Image, target, datum_metadata - target.boxes returns boxes in x0, y0, x1, y1 format
|
197
|
+
"""
|
198
|
+
# Grab the bounding boxes and labels from the annotations
|
199
|
+
boxes, labels, additional_metadata = self._read_annotations(self._targets[index])
|
200
|
+
# Get the image
|
201
|
+
img = self._read_file(self._filepaths[index])
|
202
|
+
img = self._transform(img)
|
203
|
+
|
204
|
+
target = ObjectDetectionTarget(self._as_array(boxes), self._as_array(labels), self._one_hot_encode(labels))
|
205
|
+
|
206
|
+
img_metadata = {key: val[index] for key, val in self._datum_metadata.items()}
|
207
|
+
img_metadata = img_metadata | additional_metadata
|
208
|
+
|
209
|
+
return img, target, img_metadata
|
210
|
+
|
211
|
+
@abstractmethod
|
212
|
+
def _read_annotations(self, annotation: str) -> tuple[list[list[float]], list[int], dict[str, Any]]: ...
|
213
|
+
|
214
|
+
|
215
|
+
class BaseSegDataset(
|
216
|
+
BaseDataset[_TArray, SegmentationTarget[_TArray], list[str]],
|
217
|
+
BaseDatasetMixin[_TArray],
|
218
|
+
SegmentationDataset[_TArray],
|
219
|
+
):
|
220
|
+
"""
|
221
|
+
Base class for segmentation datasets.
|
222
|
+
"""
|
223
|
+
|
224
|
+
_masks: Sequence[str]
|
225
|
+
|
226
|
+
def __getitem__(self, index: int) -> tuple[_TArray, SegmentationTarget[_TArray], dict[str, Any]]:
|
227
|
+
"""
|
228
|
+
Args
|
229
|
+
----
|
230
|
+
index : int
|
231
|
+
Value of the desired data point
|
232
|
+
|
233
|
+
Returns
|
234
|
+
-------
|
235
|
+
tuple[TArray, SegmentationTarget[TArray], dict[str, Any]]
|
236
|
+
Image, target, datum_metadata - target.mask returns the ground truth mask
|
237
|
+
"""
|
238
|
+
# Grab the labels from the annotations
|
239
|
+
_, labels, additional_metadata = self._read_annotations(self._targets[index])
|
240
|
+
# Grab the ground truth masks
|
241
|
+
mask = self._read_file(self._masks[index])
|
242
|
+
# Get the image
|
243
|
+
img = self._read_file(self._filepaths[index])
|
244
|
+
img = self._transform(img)
|
245
|
+
|
246
|
+
target = SegmentationTarget(mask, self._as_array(labels), self._one_hot_encode(labels))
|
247
|
+
|
248
|
+
img_metadata = {key: val[index] for key, val in self._datum_metadata.items()}
|
249
|
+
img_metadata = img_metadata | additional_metadata
|
250
|
+
|
251
|
+
return img, target, img_metadata
|
252
|
+
|
253
|
+
@abstractmethod
|
254
|
+
def _read_annotations(self, annotation: str) -> tuple[list[list[float]], list[int], dict[str, Any]]: ...
|