dataeval 0.81.0__py3-none-any.whl → 0.82.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dataeval/__init__.py +1 -1
- dataeval/config.py +68 -11
- dataeval/detectors/drift/__init__.py +2 -2
- dataeval/detectors/drift/_base.py +8 -64
- dataeval/detectors/drift/_mmd.py +12 -38
- dataeval/detectors/drift/_torch.py +7 -7
- dataeval/detectors/drift/_uncertainty.py +6 -5
- dataeval/detectors/drift/updates.py +20 -3
- dataeval/detectors/linters/__init__.py +3 -2
- dataeval/detectors/linters/duplicates.py +14 -46
- dataeval/detectors/linters/outliers.py +25 -159
- dataeval/detectors/ood/__init__.py +1 -1
- dataeval/detectors/ood/ae.py +6 -5
- dataeval/detectors/ood/base.py +2 -2
- dataeval/detectors/ood/metadata_ood_mi.py +4 -6
- dataeval/detectors/ood/mixin.py +3 -4
- dataeval/detectors/ood/vae.py +3 -2
- dataeval/metadata/__init__.py +2 -1
- dataeval/metadata/_distance.py +134 -0
- dataeval/metadata/_ood.py +30 -49
- dataeval/metadata/_utils.py +44 -0
- dataeval/metrics/bias/__init__.py +5 -4
- dataeval/metrics/bias/_balance.py +17 -149
- dataeval/metrics/bias/_coverage.py +4 -106
- dataeval/metrics/bias/_diversity.py +12 -107
- dataeval/metrics/bias/_parity.py +7 -71
- dataeval/metrics/estimators/__init__.py +5 -4
- dataeval/metrics/estimators/_ber.py +2 -20
- dataeval/metrics/estimators/_clusterer.py +1 -61
- dataeval/metrics/estimators/_divergence.py +2 -19
- dataeval/metrics/estimators/_uap.py +2 -16
- dataeval/metrics/stats/__init__.py +15 -12
- dataeval/metrics/stats/_base.py +41 -128
- dataeval/metrics/stats/_boxratiostats.py +13 -13
- dataeval/metrics/stats/_dimensionstats.py +17 -58
- dataeval/metrics/stats/_hashstats.py +19 -35
- dataeval/metrics/stats/_imagestats.py +94 -0
- dataeval/metrics/stats/_labelstats.py +42 -121
- dataeval/metrics/stats/_pixelstats.py +19 -51
- dataeval/metrics/stats/_visualstats.py +19 -51
- dataeval/outputs/__init__.py +57 -0
- dataeval/outputs/_base.py +182 -0
- dataeval/outputs/_bias.py +381 -0
- dataeval/outputs/_drift.py +83 -0
- dataeval/outputs/_estimators.py +114 -0
- dataeval/outputs/_linters.py +186 -0
- dataeval/outputs/_metadata.py +54 -0
- dataeval/{detectors/ood/output.py → outputs/_ood.py} +22 -22
- dataeval/outputs/_stats.py +393 -0
- dataeval/outputs/_utils.py +44 -0
- dataeval/outputs/_workflows.py +364 -0
- dataeval/typing.py +187 -7
- dataeval/utils/_method.py +1 -5
- dataeval/utils/_plot.py +2 -2
- dataeval/utils/data/__init__.py +5 -1
- dataeval/utils/data/_dataset.py +217 -0
- dataeval/utils/data/_embeddings.py +12 -14
- dataeval/utils/data/_images.py +30 -27
- dataeval/utils/data/_metadata.py +28 -11
- dataeval/utils/data/_selection.py +25 -22
- dataeval/utils/data/_split.py +5 -29
- dataeval/utils/data/_targets.py +14 -2
- dataeval/utils/data/datasets/_base.py +5 -5
- dataeval/utils/data/datasets/_cifar10.py +1 -1
- dataeval/utils/data/datasets/_milco.py +1 -1
- dataeval/utils/data/datasets/_mnist.py +1 -1
- dataeval/utils/data/datasets/_ships.py +1 -1
- dataeval/utils/data/{_types.py → datasets/_types.py} +10 -16
- dataeval/utils/data/datasets/_voc.py +1 -1
- dataeval/utils/data/selections/_classfilter.py +4 -5
- dataeval/utils/data/selections/_indices.py +2 -2
- dataeval/utils/data/selections/_limit.py +2 -2
- dataeval/utils/data/selections/_reverse.py +2 -2
- dataeval/utils/data/selections/_shuffle.py +2 -2
- dataeval/utils/torch/_internal.py +5 -5
- dataeval/utils/torch/trainer.py +8 -8
- dataeval/workflows/__init__.py +2 -1
- dataeval/workflows/sufficiency.py +6 -342
- {dataeval-0.81.0.dist-info → dataeval-0.82.1.dist-info}/METADATA +2 -2
- dataeval-0.82.1.dist-info/RECORD +105 -0
- dataeval/_output.py +0 -137
- dataeval/detectors/ood/metadata_ks_compare.py +0 -129
- dataeval/metrics/stats/_datasetstats.py +0 -198
- dataeval-0.81.0.dist-info/RECORD +0 -94
- {dataeval-0.81.0.dist-info → dataeval-0.82.1.dist-info}/LICENSE.txt +0 -0
- {dataeval-0.81.0.dist-info → dataeval-0.82.1.dist-info}/WHEEL +0 -0
@@ -0,0 +1,217 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
__all__ = []
|
4
|
+
|
5
|
+
from typing import Any, Generic, Iterable, Literal, Sequence, TypeVar
|
6
|
+
|
7
|
+
from dataeval.typing import (
|
8
|
+
Array,
|
9
|
+
ArrayLike,
|
10
|
+
DatasetMetadata,
|
11
|
+
ImageClassificationDataset,
|
12
|
+
ObjectDetectionDataset,
|
13
|
+
)
|
14
|
+
from dataeval.utils._array import as_numpy
|
15
|
+
|
16
|
+
|
17
|
+
def _validate_data(
|
18
|
+
datum_type: Literal["ic", "od"],
|
19
|
+
images: Array | Sequence[Array],
|
20
|
+
labels: Sequence[int] | Sequence[Sequence[int]],
|
21
|
+
bboxes: Sequence[Sequence[Sequence[float]]] | None,
|
22
|
+
metadata: Sequence[dict[str, Any]] | None,
|
23
|
+
) -> None:
|
24
|
+
# Validate inputs
|
25
|
+
dataset_len = len(images)
|
26
|
+
|
27
|
+
if not isinstance(images, (Sequence, Array)) or len(images[0].shape) != 3:
|
28
|
+
raise ValueError("Images must be a sequence or array of 3 dimensional arrays (H, W, C).")
|
29
|
+
if len(labels) != dataset_len:
|
30
|
+
raise ValueError(f"Number of labels ({len(labels)}) does not match number of images ({dataset_len}).")
|
31
|
+
if bboxes is not None and len(bboxes) != dataset_len:
|
32
|
+
raise ValueError(f"Number of bboxes ({len(bboxes)}) does not match number of images ({dataset_len}).")
|
33
|
+
if metadata is not None and len(metadata) != dataset_len:
|
34
|
+
raise ValueError(f"Number of metadata ({len(metadata)}) does not match number of images ({dataset_len}).")
|
35
|
+
|
36
|
+
if datum_type == "ic":
|
37
|
+
if not isinstance(labels, Sequence) or not isinstance(labels[0], int):
|
38
|
+
raise TypeError("Labels must be a sequence of integers for image classification.")
|
39
|
+
elif datum_type == "od":
|
40
|
+
if not isinstance(labels, Sequence) or not isinstance(labels[0], Sequence) or not isinstance(labels[0][0], int):
|
41
|
+
raise TypeError("Labels must be a sequence of sequences of integers for object detection.")
|
42
|
+
if (
|
43
|
+
bboxes is None
|
44
|
+
or not isinstance(bboxes, (Sequence, Array))
|
45
|
+
or not isinstance(bboxes[0], (Sequence, Array))
|
46
|
+
or not isinstance(bboxes[0][0], (Sequence, Array))
|
47
|
+
or not len(bboxes[0][0]) == 4
|
48
|
+
):
|
49
|
+
raise TypeError("Boxes must be a sequence of sequences of (x0, y0, x1, y1) for object detection.")
|
50
|
+
|
51
|
+
|
52
|
+
def _find_max(arr: ArrayLike) -> Any:
|
53
|
+
if isinstance(arr[0], (Iterable, Sequence, Array)):
|
54
|
+
return max([_find_max(x) for x in arr]) # type: ignore
|
55
|
+
else:
|
56
|
+
return max(arr)
|
57
|
+
|
58
|
+
|
59
|
+
_TLabels = TypeVar("_TLabels", Sequence[int], Sequence[Sequence[int]])
|
60
|
+
|
61
|
+
|
62
|
+
class BaseAnnotatedDataset(Generic[_TLabels]):
|
63
|
+
def __init__(
|
64
|
+
self,
|
65
|
+
datum_type: Literal["ic", "od"],
|
66
|
+
images: Array | Sequence[Array],
|
67
|
+
labels: _TLabels,
|
68
|
+
metadata: Sequence[dict[str, Any]] | None,
|
69
|
+
classes: Sequence[str] | None,
|
70
|
+
name: str | None = None,
|
71
|
+
) -> None:
|
72
|
+
self._classes = classes if classes is not None else [str(i) for i in range(_find_max(labels) + 1)]
|
73
|
+
self._index2label = dict(enumerate(self._classes))
|
74
|
+
self._images = images
|
75
|
+
self._labels = labels
|
76
|
+
self._metadata = metadata
|
77
|
+
self._id = name or f"{len(self._images)}_image_{len(self._index2label)}_class_{datum_type}_dataset"
|
78
|
+
|
79
|
+
@property
|
80
|
+
def metadata(self) -> DatasetMetadata:
|
81
|
+
return DatasetMetadata(id=self._id, index2label=self._index2label)
|
82
|
+
|
83
|
+
def __len__(self) -> int:
|
84
|
+
return len(self._images)
|
85
|
+
|
86
|
+
|
87
|
+
class CustomImageClassificationDataset(BaseAnnotatedDataset[Sequence[int]], ImageClassificationDataset):
|
88
|
+
def __init__(
|
89
|
+
self,
|
90
|
+
images: Array | Sequence[Array],
|
91
|
+
labels: Sequence[int],
|
92
|
+
metadata: Sequence[dict[str, Any]] | None,
|
93
|
+
classes: Sequence[str] | None,
|
94
|
+
name: str | None = None,
|
95
|
+
) -> None:
|
96
|
+
super().__init__("ic", images, labels, metadata, classes)
|
97
|
+
if name is not None:
|
98
|
+
self.__name__ = name
|
99
|
+
self.__class__.__name__ = name
|
100
|
+
self.__class__.__qualname__ = name
|
101
|
+
|
102
|
+
def __getitem__(self, idx: int, /) -> tuple[Array, Array, dict[str, Any]]:
|
103
|
+
one_hot = [0.0] * len(self._index2label)
|
104
|
+
one_hot[self._labels[idx]] = 1.0
|
105
|
+
return (
|
106
|
+
self._images[idx],
|
107
|
+
as_numpy(one_hot),
|
108
|
+
self._metadata[idx] if self._metadata is not None else {},
|
109
|
+
)
|
110
|
+
|
111
|
+
|
112
|
+
class CustomObjectDetectionDataset(BaseAnnotatedDataset[Sequence[Sequence[int]]], ObjectDetectionDataset):
|
113
|
+
class ObjectDetectionTarget:
|
114
|
+
def __init__(self, labels: Sequence[int], bboxes: Sequence[Sequence[float]]) -> None:
|
115
|
+
self._labels = labels
|
116
|
+
self._bboxes = bboxes
|
117
|
+
self._scores = [1.0] * len(labels)
|
118
|
+
|
119
|
+
@property
|
120
|
+
def labels(self) -> Sequence[int]:
|
121
|
+
return self._labels
|
122
|
+
|
123
|
+
@property
|
124
|
+
def boxes(self) -> Sequence[Sequence[float]]:
|
125
|
+
return self._bboxes
|
126
|
+
|
127
|
+
@property
|
128
|
+
def scores(self) -> Sequence[float]:
|
129
|
+
return self._scores
|
130
|
+
|
131
|
+
def __init__(
|
132
|
+
self,
|
133
|
+
images: Array | Sequence[Array],
|
134
|
+
labels: Sequence[Sequence[int]],
|
135
|
+
bboxes: Sequence[Sequence[Sequence[float]]],
|
136
|
+
metadata: Sequence[dict[str, Any]] | None,
|
137
|
+
classes: Sequence[str] | None,
|
138
|
+
name: str | None = None,
|
139
|
+
) -> None:
|
140
|
+
super().__init__("od", images, labels, metadata, classes)
|
141
|
+
if name is not None:
|
142
|
+
self.__name__ = name
|
143
|
+
self.__class__.__name__ = name
|
144
|
+
self.__class__.__qualname__ = name
|
145
|
+
self._bboxes = bboxes
|
146
|
+
|
147
|
+
@property
|
148
|
+
def metadata(self) -> DatasetMetadata:
|
149
|
+
return DatasetMetadata(id=self._id, index2label=self._index2label)
|
150
|
+
|
151
|
+
def __getitem__(self, idx: int, /) -> tuple[Array, ObjectDetectionTarget, dict[str, Any]]:
|
152
|
+
return (
|
153
|
+
self._images[idx],
|
154
|
+
self.ObjectDetectionTarget(self._labels[idx], self._bboxes[idx]),
|
155
|
+
self._metadata[idx] if self._metadata is not None else {},
|
156
|
+
)
|
157
|
+
|
158
|
+
|
159
|
+
def to_image_classification_dataset(
|
160
|
+
images: Array | Sequence[Array],
|
161
|
+
labels: Sequence[int],
|
162
|
+
metadata: Sequence[dict[str, Any]] | None,
|
163
|
+
classes: Sequence[str] | None,
|
164
|
+
name: str | None = None,
|
165
|
+
) -> ImageClassificationDataset:
|
166
|
+
"""
|
167
|
+
Helper function to create custom ImageClassificationDataset classes.
|
168
|
+
|
169
|
+
Parameters
|
170
|
+
----------
|
171
|
+
images : Array | Sequence[Array]
|
172
|
+
The images to use in the dataset.
|
173
|
+
labels : Sequence[int]
|
174
|
+
The labels to use in the dataset.
|
175
|
+
metadata : Sequence[dict[str, Any]] | None
|
176
|
+
The metadata to use in the dataset.
|
177
|
+
classes : Sequence[str] | None
|
178
|
+
The classes to use in the dataset.
|
179
|
+
|
180
|
+
Returns
|
181
|
+
-------
|
182
|
+
ImageClassificationDataset
|
183
|
+
"""
|
184
|
+
_validate_data("ic", images, labels, None, metadata)
|
185
|
+
return CustomImageClassificationDataset(images, labels, metadata, classes, name)
|
186
|
+
|
187
|
+
|
188
|
+
def to_object_detection_dataset(
|
189
|
+
images: Array | Sequence[Array],
|
190
|
+
labels: Sequence[Sequence[int]],
|
191
|
+
bboxes: Sequence[Sequence[Sequence[float]]],
|
192
|
+
metadata: Sequence[dict[str, Any]] | None,
|
193
|
+
classes: Sequence[str] | None,
|
194
|
+
name: str | None = None,
|
195
|
+
) -> ObjectDetectionDataset:
|
196
|
+
"""
|
197
|
+
Helper function to create custom ObjectDetectionDataset classes.
|
198
|
+
|
199
|
+
Parameters
|
200
|
+
----------
|
201
|
+
images : Array | Sequence[Array]
|
202
|
+
The images to use in the dataset.
|
203
|
+
labels : Sequence[Sequence[int]]
|
204
|
+
The labels to use in the dataset.
|
205
|
+
bboxes : Sequence[Sequence[Sequence[float]]]
|
206
|
+
The bounding boxes (x0,y0,x1,y0) to use in the dataset.
|
207
|
+
metadata : Sequence[dict[str, Any]] | None
|
208
|
+
The metadata to use in the dataset.
|
209
|
+
classes : Sequence[str] | None
|
210
|
+
The classes to use in the dataset.
|
211
|
+
|
212
|
+
Returns
|
213
|
+
-------
|
214
|
+
ObjectDetectionDataset
|
215
|
+
"""
|
216
|
+
_validate_data("od", images, labels, bboxes, metadata)
|
217
|
+
return CustomObjectDetectionDataset(images, labels, bboxes, metadata, classes, name)
|
@@ -9,9 +9,8 @@ import torch
|
|
9
9
|
from torch.utils.data import DataLoader, Subset
|
10
10
|
from tqdm import tqdm
|
11
11
|
|
12
|
-
from dataeval.config import get_device
|
13
|
-
from dataeval.typing import
|
14
|
-
from dataeval.utils.data._types import Dataset
|
12
|
+
from dataeval.config import DeviceLike, get_device
|
13
|
+
from dataeval.typing import Array, Dataset
|
15
14
|
from dataeval.utils.torch.models import SupportsEncode
|
16
15
|
|
17
16
|
|
@@ -25,13 +24,14 @@ class Embeddings:
|
|
25
24
|
----------
|
26
25
|
dataset : ImageClassificationDataset or ObjectDetectionDataset
|
27
26
|
Dataset to access original images from.
|
28
|
-
batch_size : int
|
27
|
+
batch_size : int
|
29
28
|
Batch size to use when encoding images.
|
30
|
-
model : torch.nn.Module,
|
29
|
+
model : torch.nn.Module or None, default None
|
31
30
|
Model to use for encoding images.
|
32
|
-
device :
|
33
|
-
|
34
|
-
|
31
|
+
device : DeviceLike or None, default None
|
32
|
+
The hardware device to use if specified, otherwise uses the DataEval
|
33
|
+
default or torch default.
|
34
|
+
verbose : bool, default False
|
35
35
|
Whether to print progress bar when encoding images.
|
36
36
|
"""
|
37
37
|
|
@@ -41,11 +41,10 @@ class Embeddings:
|
|
41
41
|
|
42
42
|
def __init__(
|
43
43
|
self,
|
44
|
-
dataset: Dataset[
|
44
|
+
dataset: Dataset[tuple[Array, Any, Any]],
|
45
45
|
batch_size: int,
|
46
|
-
indices: Sequence[int] | None = None,
|
47
46
|
model: torch.nn.Module | None = None,
|
48
|
-
device:
|
47
|
+
device: DeviceLike | None = None,
|
49
48
|
verbose: bool = False,
|
50
49
|
) -> None:
|
51
50
|
self.device = get_device(device)
|
@@ -53,7 +52,6 @@ class Embeddings:
|
|
53
52
|
self.verbose = verbose
|
54
53
|
|
55
54
|
self._dataset = dataset
|
56
|
-
self._indices = indices if indices is not None else range(len(dataset))
|
57
55
|
model = torch.nn.Flatten() if model is None else model
|
58
56
|
self._model = model.to(self.device).eval()
|
59
57
|
self._encoder = model.encode if isinstance(model, SupportsEncode) else model
|
@@ -78,7 +76,7 @@ class Embeddings:
|
|
78
76
|
@torch.no_grad
|
79
77
|
def _batch(self, indices: Sequence[int]) -> Iterator[torch.Tensor]:
|
80
78
|
# manual batching
|
81
|
-
dataloader = DataLoader(Subset(self._dataset, indices), batch_size=self.batch_size, collate_fn=self._collate_fn)
|
79
|
+
dataloader = DataLoader(Subset(self._dataset, indices), batch_size=self.batch_size, collate_fn=self._collate_fn) # type: ignore
|
82
80
|
for i, images in (
|
83
81
|
tqdm(enumerate(dataloader), total=math.ceil(len(indices) / self.batch_size), desc="Batch processing")
|
84
82
|
if self.verbose
|
@@ -87,7 +85,7 @@ class Embeddings:
|
|
87
85
|
embeddings = self._encoder(torch.stack(images).to(self.device))
|
88
86
|
yield embeddings
|
89
87
|
|
90
|
-
def __getitem__(self, key: int | slice | list[int]) -> torch.Tensor:
|
88
|
+
def __getitem__(self, key: int | slice | list[int], /) -> torch.Tensor:
|
91
89
|
if isinstance(key, list):
|
92
90
|
return torch.vstack(list(self._batch(key))).to(self.device)
|
93
91
|
if isinstance(key, slice):
|
dataeval/utils/data/_images.py
CHANGED
@@ -2,13 +2,14 @@ from __future__ import annotations
|
|
2
2
|
|
3
3
|
__all__ = []
|
4
4
|
|
5
|
-
from typing import Any, Generic, Iterator, Sequence, overload
|
5
|
+
from typing import Any, Generic, Iterator, Sequence, TypeVar, cast, overload
|
6
6
|
|
7
|
-
from dataeval.typing import
|
8
|
-
from dataeval.utils.data._types import Dataset
|
7
|
+
from dataeval.typing import Dataset
|
9
8
|
|
9
|
+
T = TypeVar("T")
|
10
10
|
|
11
|
-
|
11
|
+
|
12
|
+
class Images(Generic[T]):
|
12
13
|
"""
|
13
14
|
Collection of image data from a dataset.
|
14
15
|
|
@@ -16,17 +17,15 @@ class Images(Generic[TArray]):
|
|
16
17
|
|
17
18
|
Parameters
|
18
19
|
----------
|
19
|
-
dataset :
|
20
|
+
dataset : Dataset[tuple[T, ...]] or Dataset[T]
|
20
21
|
Dataset to access images from.
|
21
22
|
"""
|
22
23
|
|
23
|
-
def __init__(
|
24
|
-
self,
|
25
|
-
dataset: Dataset[TArray, Any],
|
26
|
-
) -> None:
|
24
|
+
def __init__(self, dataset: Dataset[tuple[T, Any, Any] | T]) -> None:
|
25
|
+
self._is_tuple_datum = isinstance(dataset[0], tuple)
|
27
26
|
self._dataset = dataset
|
28
27
|
|
29
|
-
def to_list(self) -> Sequence[
|
28
|
+
def to_list(self) -> Sequence[T]:
|
30
29
|
"""
|
31
30
|
Converts entire dataset to a sequence of images.
|
32
31
|
|
@@ -37,29 +36,33 @@ class Images(Generic[TArray]):
|
|
37
36
|
|
38
37
|
Returns
|
39
38
|
-------
|
40
|
-
list[
|
39
|
+
list[T]
|
41
40
|
"""
|
42
41
|
return self[:]
|
43
42
|
|
44
43
|
@overload
|
45
|
-
def __getitem__(self, key:
|
46
|
-
|
44
|
+
def __getitem__(self, key: int, /) -> T: ...
|
47
45
|
@overload
|
48
|
-
def __getitem__(self, key:
|
49
|
-
|
50
|
-
def __getitem__(self, key: int | slice
|
51
|
-
if
|
52
|
-
|
53
|
-
|
54
|
-
|
55
|
-
|
56
|
-
|
57
|
-
|
58
|
-
|
59
|
-
|
60
|
-
|
46
|
+
def __getitem__(self, key: slice, /) -> Sequence[T]: ...
|
47
|
+
|
48
|
+
def __getitem__(self, key: int | slice, /) -> Sequence[T] | T:
|
49
|
+
if self._is_tuple_datum:
|
50
|
+
dataset = cast(Dataset[tuple[T, Any, Any]], self._dataset)
|
51
|
+
if isinstance(key, slice):
|
52
|
+
return [dataset[k][0] for k in range(len(self._dataset))[key]]
|
53
|
+
elif isinstance(key, int):
|
54
|
+
return dataset[key][0]
|
55
|
+
else:
|
56
|
+
dataset = cast(Dataset[T], self._dataset)
|
57
|
+
if isinstance(key, slice):
|
58
|
+
return [dataset[k] for k in range(len(self._dataset))[key]]
|
59
|
+
elif isinstance(key, int):
|
60
|
+
return dataset[key]
|
61
|
+
raise TypeError(f"Key must be integers or slices, not {type(key)}")
|
62
|
+
|
63
|
+
def __iter__(self) -> Iterator[T]:
|
61
64
|
for i in range(len(self._dataset)):
|
62
|
-
yield self
|
65
|
+
yield self[i]
|
63
66
|
|
64
67
|
def __len__(self) -> int:
|
65
68
|
return len(self._dataset)
|
dataeval/utils/data/_metadata.py
CHANGED
@@ -3,18 +3,19 @@ from __future__ import annotations
|
|
3
3
|
__all__ = []
|
4
4
|
|
5
5
|
import warnings
|
6
|
-
from typing import TYPE_CHECKING, Any, Literal, Mapping, Sequence
|
6
|
+
from typing import TYPE_CHECKING, Any, Literal, Mapping, Sequence, cast
|
7
7
|
|
8
8
|
import numpy as np
|
9
9
|
from numpy.typing import NDArray
|
10
10
|
|
11
|
-
from dataeval.typing import
|
12
|
-
|
13
|
-
|
14
|
-
|
15
|
-
Dataset,
|
11
|
+
from dataeval.typing import (
|
12
|
+
AnnotatedDataset,
|
13
|
+
Array,
|
14
|
+
ArrayLike,
|
16
15
|
ObjectDetectionTarget,
|
17
16
|
)
|
17
|
+
from dataeval.utils._array import as_numpy, to_numpy
|
18
|
+
from dataeval.utils._bin import bin_data, digitize_data, is_continuous
|
18
19
|
from dataeval.utils.metadata import merge
|
19
20
|
|
20
21
|
if TYPE_CHECKING:
|
@@ -65,7 +66,7 @@ class Metadata:
|
|
65
66
|
|
66
67
|
def __init__(
|
67
68
|
self,
|
68
|
-
dataset:
|
69
|
+
dataset: AnnotatedDataset[tuple[Any, Any, dict[str, Any]]],
|
69
70
|
*,
|
70
71
|
continuous_factor_bins: Mapping[str, int | Sequence[float]] | None = None,
|
71
72
|
auto_bin_method: Literal["uniform_width", "uniform_count", "clusters"] = "uniform_width",
|
@@ -276,12 +277,12 @@ class Metadata:
|
|
276
277
|
if self._processed and not force:
|
277
278
|
return
|
278
279
|
|
279
|
-
# Validate the metadata dimensions
|
280
|
-
self._validate()
|
281
|
-
|
282
280
|
# Create image indices from targets
|
283
281
|
self._image_indices = np.arange(len(self.raw)) if self.targets.source is None else self.targets.source
|
284
282
|
|
283
|
+
# Validate the metadata dimensions
|
284
|
+
self._validate()
|
285
|
+
|
285
286
|
# Include specified metadata keys
|
286
287
|
if self.include:
|
287
288
|
metadata = {i: self.merged[i] for i in self.include if i in self.merged}
|
@@ -341,7 +342,11 @@ class Metadata:
|
|
341
342
|
|
342
343
|
# Split out the dictionaries into the keys and values
|
343
344
|
self._discrete_factor_names = list(discrete_metadata.keys())
|
344
|
-
self._discrete_data =
|
345
|
+
self._discrete_data = (
|
346
|
+
np.stack(list(discrete_metadata.values()), axis=-1, dtype=np.int64)
|
347
|
+
if discrete_metadata
|
348
|
+
else np.array([], dtype=np.int64)
|
349
|
+
)
|
345
350
|
self._continuous_factor_names = list(continuous_metadata.keys())
|
346
351
|
self._continuous_data = (
|
347
352
|
np.stack(list(continuous_metadata.values()), axis=-1, dtype=np.float64)
|
@@ -350,3 +355,15 @@ class Metadata:
|
|
350
355
|
)
|
351
356
|
self._total_num_factors = len(self._discrete_factor_names + self._continuous_factor_names) + 1
|
352
357
|
self._processed = True
|
358
|
+
|
359
|
+
def add_factors(self, factors: Mapping[str, ArrayLike]) -> None:
|
360
|
+
self._merge()
|
361
|
+
self._processed = False
|
362
|
+
target_len = len(self.targets.source) if self.targets.source is not None else len(self.targets)
|
363
|
+
if any(len(v) != target_len for v in factors.values()):
|
364
|
+
raise ValueError(
|
365
|
+
"The lists/arrays in the provided factors have a different length than the current metadata factors."
|
366
|
+
)
|
367
|
+
merged = cast(tuple[dict[str, ArrayLike], dict[str, list[str]]], self._merged)[0]
|
368
|
+
for k, v in factors.items():
|
369
|
+
merged[k] = v
|
@@ -3,12 +3,11 @@ from __future__ import annotations
|
|
3
3
|
__all__ = []
|
4
4
|
|
5
5
|
from enum import IntEnum
|
6
|
-
from typing import
|
6
|
+
from typing import Generic, Iterator, Sequence, TypeVar
|
7
7
|
|
8
|
-
from dataeval.
|
8
|
+
from dataeval.typing import AnnotatedDataset, DatasetMetadata
|
9
9
|
|
10
|
-
|
11
|
-
_TTarget = TypeVar("_TTarget")
|
10
|
+
_TDatum = TypeVar("_TDatum", covariant=True)
|
12
11
|
|
13
12
|
|
14
13
|
class SelectionStage(IntEnum):
|
@@ -17,16 +16,16 @@ class SelectionStage(IntEnum):
|
|
17
16
|
ORDER = 2
|
18
17
|
|
19
18
|
|
20
|
-
class Selection(Generic[
|
19
|
+
class Selection(Generic[_TDatum]):
|
21
20
|
stage: SelectionStage
|
22
21
|
|
23
|
-
def __call__(self, dataset: Select[
|
22
|
+
def __call__(self, dataset: Select[_TDatum]) -> None: ...
|
24
23
|
|
25
24
|
def __str__(self) -> str:
|
26
25
|
return f"{self.__class__.__name__}({', '.join([f'{k}={v}' for k, v in self.__dict__.items()])})"
|
27
26
|
|
28
27
|
|
29
|
-
class Select(
|
28
|
+
class Select(AnnotatedDataset[_TDatum]):
|
30
29
|
"""
|
31
30
|
Wraps a dataset and applies selection criteria to it.
|
32
31
|
|
@@ -60,35 +59,43 @@ class Select(Generic[_TData, _TTarget], Dataset[_TData, _TTarget]):
|
|
60
59
|
(data_20, 0, {'id': 20})
|
61
60
|
"""
|
62
61
|
|
63
|
-
_dataset:
|
62
|
+
_dataset: AnnotatedDataset[_TDatum]
|
64
63
|
_selection: list[int]
|
65
|
-
_selections: Sequence[Selection[
|
64
|
+
_selections: Sequence[Selection[_TDatum]]
|
66
65
|
_size_limit: int
|
67
66
|
|
68
67
|
def __init__(
|
69
68
|
self,
|
70
|
-
dataset:
|
71
|
-
selections: Selection[
|
69
|
+
dataset: AnnotatedDataset[_TDatum],
|
70
|
+
selections: Selection[_TDatum] | list[Selection[_TDatum]] | None = None,
|
72
71
|
) -> None:
|
72
|
+
self.__dict__.update(dataset.__dict__)
|
73
73
|
self._dataset = dataset
|
74
74
|
self._size_limit = len(dataset)
|
75
75
|
self._selection = list(range(self._size_limit))
|
76
76
|
self._selections = self._sort_selections(selections)
|
77
|
-
|
77
|
+
|
78
|
+
# Ensure metadata is populated correctly as DatasetMetadata TypedDict
|
79
|
+
_metadata = getattr(dataset, "metadata", {})
|
80
|
+
if "id" not in _metadata:
|
81
|
+
_metadata["id"] = dataset.__class__.__name__
|
82
|
+
self._metadata = DatasetMetadata(**_metadata)
|
78
83
|
|
79
84
|
if self._selections:
|
80
85
|
self._apply_selections()
|
81
86
|
|
87
|
+
@property
|
88
|
+
def metadata(self) -> DatasetMetadata:
|
89
|
+
return self._metadata
|
90
|
+
|
82
91
|
def __str__(self) -> str:
|
83
92
|
nt = "\n "
|
84
93
|
title = f"{self.__class__.__name__} Dataset"
|
85
94
|
sep = "-" * len(title)
|
86
95
|
selections = f"Selections: [{', '.join([str(s) for s in self._sort_selections(self._selections)])}]"
|
87
|
-
return f"{title}\n{sep}{nt}{selections}\n\n{self._dataset}"
|
96
|
+
return f"{title}\n{sep}{nt}{selections}{nt}Selected Size: {len(self)}\n\n{self._dataset}"
|
88
97
|
|
89
|
-
def _sort_selections(
|
90
|
-
self, selections: Selection[_TData, _TTarget] | Sequence[Selection[_TData, _TTarget]] | None
|
91
|
-
) -> list[Selection]:
|
98
|
+
def _sort_selections(self, selections: Selection[_TDatum] | Sequence[Selection[_TDatum]] | None) -> list[Selection]:
|
92
99
|
if not selections:
|
93
100
|
return []
|
94
101
|
|
@@ -104,14 +111,10 @@ class Select(Generic[_TData, _TTarget], Dataset[_TData, _TTarget]):
|
|
104
111
|
selection(self)
|
105
112
|
self._selection = self._selection[: self._size_limit]
|
106
113
|
|
107
|
-
def
|
108
|
-
selfattr = getattr(self._dataset, name, None)
|
109
|
-
return selfattr if selfattr is not None else getattr(self._dataset, name)
|
110
|
-
|
111
|
-
def __getitem__(self, index: int) -> tuple[_TData, _TTarget, dict[str, Any]]:
|
114
|
+
def __getitem__(self, index: int) -> _TDatum:
|
112
115
|
return self._dataset[self._selection[index]]
|
113
116
|
|
114
|
-
def __iter__(self) -> Iterator[
|
117
|
+
def __iter__(self) -> Iterator[_TDatum]:
|
115
118
|
for i in range(len(self)):
|
116
119
|
yield self[i]
|
117
120
|
|
dataeval/utils/data/_split.py
CHANGED
@@ -3,7 +3,6 @@ from __future__ import annotations
|
|
3
3
|
__all__ = []
|
4
4
|
|
5
5
|
import warnings
|
6
|
-
from dataclasses import dataclass
|
7
6
|
from typing import Any, Iterator, Protocol
|
8
7
|
|
9
8
|
import numpy as np
|
@@ -13,32 +12,9 @@ from sklearn.metrics import silhouette_score
|
|
13
12
|
from sklearn.model_selection import GroupKFold, KFold, StratifiedGroupKFold, StratifiedKFold
|
14
13
|
from sklearn.utils.multiclass import type_of_target
|
15
14
|
|
16
|
-
from dataeval.
|
17
|
-
|
18
|
-
|
19
|
-
@dataclass
|
20
|
-
class TrainValSplit:
|
21
|
-
"""Tuple containing train and validation indices"""
|
22
|
-
|
23
|
-
train: NDArray[np.intp]
|
24
|
-
val: NDArray[np.intp]
|
25
|
-
|
26
|
-
|
27
|
-
@dataclass(frozen=True)
|
28
|
-
class SplitDatasetOutput(Output):
|
29
|
-
"""
|
30
|
-
Output class containing test indices and a list of TrainValSplits.
|
31
|
-
|
32
|
-
Attributes
|
33
|
-
----------
|
34
|
-
test: NDArray[np.intp]
|
35
|
-
Indices for the test set
|
36
|
-
folds: list[TrainValSplit]
|
37
|
-
List where each index contains the indices for the train and validation splits
|
38
|
-
"""
|
39
|
-
|
40
|
-
test: NDArray[np.intp]
|
41
|
-
folds: list[TrainValSplit]
|
15
|
+
from dataeval.config import get_seed
|
16
|
+
from dataeval.outputs._base import set_metadata
|
17
|
+
from dataeval.outputs._utils import SplitDatasetOutput, TrainValSplit
|
42
18
|
|
43
19
|
|
44
20
|
class KFoldSplitter(Protocol):
|
@@ -237,9 +213,9 @@ def bin_kmeans(array: NDArray[Any]) -> NDArray[np.intp]:
|
|
237
213
|
best_score = 0.50
|
238
214
|
bin_index = np.zeros(len(array), dtype=np.intp)
|
239
215
|
for k in range(2, 20):
|
240
|
-
clusterer = KMeans(n_clusters=k)
|
216
|
+
clusterer = KMeans(n_clusters=k, random_state=get_seed())
|
241
217
|
cluster_labels = clusterer.fit_predict(array)
|
242
|
-
score = silhouette_score(array, cluster_labels, sample_size=25_000)
|
218
|
+
score = silhouette_score(array, cluster_labels, sample_size=25_000, random_state=get_seed())
|
243
219
|
if score > best_score:
|
244
220
|
best_score = score
|
245
221
|
bin_index = cluster_labels.astype(np.intp)
|
dataeval/utils/data/_targets.py
CHANGED
@@ -1,5 +1,7 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
|
+
from typing import Iterator
|
4
|
+
|
3
5
|
__all__ = []
|
4
6
|
|
5
7
|
from dataclasses import dataclass
|
@@ -52,10 +54,16 @@ class Targets:
|
|
52
54
|
+ f" source: {None if self.source is None else self.source.shape}\n"
|
53
55
|
)
|
54
56
|
|
57
|
+
if self.bboxes is not None and len(self.bboxes) > 0 and self.bboxes.shape[-1] != 4:
|
58
|
+
raise ValueError("Bounding boxes must be in (x0,y0,x1,y1) format.")
|
59
|
+
|
55
60
|
def __len__(self) -> int:
|
56
|
-
|
61
|
+
if self.source is None:
|
62
|
+
return len(self.labels)
|
63
|
+
else:
|
64
|
+
return len(np.unique(self.source))
|
57
65
|
|
58
|
-
def
|
66
|
+
def __getitem__(self, idx: int, /) -> Targets:
|
59
67
|
if self.source is None or self.bboxes is None:
|
60
68
|
return Targets(
|
61
69
|
np.atleast_1d(self.labels[idx]),
|
@@ -71,3 +79,7 @@ class Targets:
|
|
71
79
|
np.atleast_2d(self.bboxes[mask]),
|
72
80
|
np.atleast_1d(self.source[mask]),
|
73
81
|
)
|
82
|
+
|
83
|
+
def __iter__(self) -> Iterator[Targets]:
|
84
|
+
for i in range(len(self.labels)) if self.source is None else np.unique(self.source):
|
85
|
+
yield self[i]
|
@@ -6,8 +6,10 @@ from abc import abstractmethod
|
|
6
6
|
from pathlib import Path
|
7
7
|
from typing import Any, Generic, Iterator, Literal, NamedTuple, Sequence, TypeVar
|
8
8
|
|
9
|
-
from dataeval.utils.data.
|
10
|
-
|
9
|
+
from dataeval.utils.data.datasets._fileio import _ensure_exists
|
10
|
+
from dataeval.utils.data.datasets._mixin import BaseDatasetMixin
|
11
|
+
from dataeval.utils.data.datasets._types import (
|
12
|
+
AnnotatedDataset,
|
11
13
|
DatasetMetadata,
|
12
14
|
ImageClassificationDataset,
|
13
15
|
ObjectDetectionDataset,
|
@@ -16,8 +18,6 @@ from dataeval.utils.data._types import (
|
|
16
18
|
SegmentationTarget,
|
17
19
|
Transform,
|
18
20
|
)
|
19
|
-
from dataeval.utils.data.datasets._fileio import _ensure_exists
|
20
|
-
from dataeval.utils.data.datasets._mixin import BaseDatasetMixin
|
21
21
|
|
22
22
|
_TArray = TypeVar("_TArray")
|
23
23
|
_TTarget = TypeVar("_TTarget")
|
@@ -31,7 +31,7 @@ class DataLocation(NamedTuple):
|
|
31
31
|
checksum: str
|
32
32
|
|
33
33
|
|
34
|
-
class BaseDataset(
|
34
|
+
class BaseDataset(AnnotatedDataset[tuple[_TArray, _TTarget, dict[str, Any]]], Generic[_TArray, _TTarget, _TRawTarget]):
|
35
35
|
"""
|
36
36
|
Base class for internet downloaded datasets.
|
37
37
|
"""
|