dataeval 0.84.0__py3-none-any.whl → 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dataeval/__init__.py +1 -1
- dataeval/data/__init__.py +19 -0
- dataeval/data/_embeddings.py +345 -0
- dataeval/{utils/data → data}/_images.py +2 -2
- dataeval/{utils/data → data}/_metadata.py +8 -7
- dataeval/{utils/data → data}/_selection.py +22 -9
- dataeval/{utils/data → data}/_split.py +1 -1
- dataeval/data/selections/__init__.py +19 -0
- dataeval/data/selections/_classbalance.py +37 -0
- dataeval/data/selections/_classfilter.py +109 -0
- dataeval/{utils/data → data}/selections/_indices.py +1 -1
- dataeval/{utils/data → data}/selections/_limit.py +1 -1
- dataeval/{utils/data → data}/selections/_prioritize.py +3 -3
- dataeval/{utils/data → data}/selections/_reverse.py +1 -1
- dataeval/{utils/data → data}/selections/_shuffle.py +3 -3
- dataeval/detectors/drift/__init__.py +2 -2
- dataeval/detectors/drift/_base.py +55 -203
- dataeval/detectors/drift/_cvm.py +19 -30
- dataeval/detectors/drift/_ks.py +18 -30
- dataeval/detectors/drift/_mmd.py +189 -53
- dataeval/detectors/drift/_uncertainty.py +52 -56
- dataeval/detectors/drift/updates.py +13 -12
- dataeval/detectors/linters/duplicates.py +6 -4
- dataeval/detectors/linters/outliers.py +3 -3
- dataeval/detectors/ood/ae.py +1 -1
- dataeval/metadata/_distance.py +1 -1
- dataeval/metadata/_ood.py +4 -4
- dataeval/metrics/bias/_balance.py +1 -1
- dataeval/metrics/bias/_diversity.py +1 -1
- dataeval/metrics/bias/_parity.py +1 -1
- dataeval/metrics/stats/_base.py +7 -7
- dataeval/metrics/stats/_dimensionstats.py +2 -2
- dataeval/metrics/stats/_hashstats.py +2 -2
- dataeval/metrics/stats/_imagestats.py +4 -4
- dataeval/metrics/stats/_labelstats.py +2 -2
- dataeval/metrics/stats/_pixelstats.py +2 -2
- dataeval/metrics/stats/_visualstats.py +2 -2
- dataeval/outputs/_bias.py +1 -1
- dataeval/typing.py +53 -19
- dataeval/utils/__init__.py +2 -2
- dataeval/utils/_array.py +18 -7
- dataeval/utils/data/__init__.py +5 -20
- dataeval/utils/data/_dataset.py +6 -4
- dataeval/utils/data/collate.py +2 -0
- dataeval/utils/datasets/__init__.py +17 -0
- dataeval/utils/{data/datasets → datasets}/_base.py +10 -7
- dataeval/utils/{data/datasets → datasets}/_cifar10.py +11 -11
- dataeval/utils/{data/datasets → datasets}/_milco.py +44 -16
- dataeval/utils/{data/datasets → datasets}/_mnist.py +11 -7
- dataeval/utils/{data/datasets → datasets}/_ships.py +10 -6
- dataeval/utils/{data/datasets → datasets}/_voc.py +43 -22
- dataeval/utils/torch/_internal.py +12 -35
- {dataeval-0.84.0.dist-info → dataeval-1.0.0.dist-info}/METADATA +2 -3
- dataeval-1.0.0.dist-info/RECORD +107 -0
- dataeval/detectors/drift/_torch.py +0 -222
- dataeval/utils/data/_embeddings.py +0 -186
- dataeval/utils/data/datasets/__init__.py +0 -17
- dataeval/utils/data/selections/__init__.py +0 -17
- dataeval/utils/data/selections/_classfilter.py +0 -59
- dataeval-0.84.0.dist-info/RECORD +0 -106
- /dataeval/{utils/data → data}/_targets.py +0 -0
- /dataeval/utils/{metadata.py → data/metadata.py} +0 -0
- /dataeval/utils/{data/datasets → datasets}/_fileio.py +0 -0
- /dataeval/utils/{data/datasets → datasets}/_mixin.py +0 -0
- /dataeval/utils/{data/datasets → datasets}/_types.py +0 -0
- {dataeval-0.84.0.dist-info → dataeval-1.0.0.dist-info}/LICENSE.txt +0 -0
- {dataeval-0.84.0.dist-info → dataeval-1.0.0.dist-info}/WHEEL +0 -0
dataeval/__init__.py
CHANGED
@@ -0,0 +1,19 @@
|
|
1
|
+
"""Provides utility functions for interacting with Computer Vision datasets."""
|
2
|
+
|
3
|
+
__all__ = [
|
4
|
+
"Embeddings",
|
5
|
+
"Images",
|
6
|
+
"Metadata",
|
7
|
+
"Select",
|
8
|
+
"SplitDatasetOutput",
|
9
|
+
"Targets",
|
10
|
+
"split_dataset",
|
11
|
+
]
|
12
|
+
|
13
|
+
from dataeval.data._embeddings import Embeddings
|
14
|
+
from dataeval.data._images import Images
|
15
|
+
from dataeval.data._metadata import Metadata
|
16
|
+
from dataeval.data._selection import Select
|
17
|
+
from dataeval.data._split import split_dataset
|
18
|
+
from dataeval.data._targets import Targets
|
19
|
+
from dataeval.outputs._utils import SplitDatasetOutput
|
@@ -0,0 +1,345 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
__all__ = []
|
4
|
+
|
5
|
+
import logging
|
6
|
+
import math
|
7
|
+
import os
|
8
|
+
from pathlib import Path
|
9
|
+
from typing import Any, Iterator, Sequence, cast
|
10
|
+
|
11
|
+
import torch
|
12
|
+
import xxhash as xxh
|
13
|
+
from numpy.typing import NDArray
|
14
|
+
from torch.utils.data import DataLoader, Subset
|
15
|
+
from tqdm import tqdm
|
16
|
+
|
17
|
+
from dataeval.config import DeviceLike, get_device
|
18
|
+
from dataeval.typing import AnnotatedDataset, AnnotatedModel, Array, ArrayLike, Dataset, Transform
|
19
|
+
from dataeval.utils._array import as_numpy
|
20
|
+
from dataeval.utils.torch.models import SupportsEncode
|
21
|
+
|
22
|
+
_logger = logging.getLogger(__name__)
|
23
|
+
|
24
|
+
|
25
|
+
class Embeddings:
|
26
|
+
"""
|
27
|
+
Collection of image embeddings from a dataset.
|
28
|
+
|
29
|
+
Embeddings are accessed by index or slice and are only loaded on-demand.
|
30
|
+
|
31
|
+
Parameters
|
32
|
+
----------
|
33
|
+
dataset : ImageClassificationDataset or ObjectDetectionDataset
|
34
|
+
Dataset to access original images from.
|
35
|
+
batch_size : int
|
36
|
+
Batch size to use when encoding images.
|
37
|
+
transforms : Transform or Sequence[Transform] or None, default None
|
38
|
+
Transforms to apply to images before encoding.
|
39
|
+
model : torch.nn.Module or None, default None
|
40
|
+
Model to use for encoding images.
|
41
|
+
device : DeviceLike or None, default None
|
42
|
+
The hardware device to use if specified, otherwise uses the DataEval
|
43
|
+
default or torch default.
|
44
|
+
cache : Path, str, or bool, default False
|
45
|
+
Whether to cache the embeddings to a file or in memory.
|
46
|
+
When a Path or string is provided, embeddings will be cached to disk.
|
47
|
+
verbose : bool, default False
|
48
|
+
Whether to print progress bar when encoding images.
|
49
|
+
|
50
|
+
Attributes
|
51
|
+
----------
|
52
|
+
batch_size : int
|
53
|
+
Batch size to use when encoding images.
|
54
|
+
cache : Path or bool
|
55
|
+
The path to cache embeddings to file, or True if caching to memory.
|
56
|
+
device : torch.device
|
57
|
+
The hardware device to use if specified, otherwise uses the DataEval
|
58
|
+
default or torch default.
|
59
|
+
verbose : bool
|
60
|
+
Whether to print progress bar when encoding images.
|
61
|
+
"""
|
62
|
+
|
63
|
+
device: torch.device
|
64
|
+
batch_size: int
|
65
|
+
verbose: bool
|
66
|
+
|
67
|
+
def __init__(
|
68
|
+
self,
|
69
|
+
dataset: Dataset[tuple[ArrayLike, Any, Any]] | Dataset[ArrayLike],
|
70
|
+
batch_size: int,
|
71
|
+
transforms: Transform[torch.Tensor] | Sequence[Transform[torch.Tensor]] | None = None,
|
72
|
+
model: torch.nn.Module | None = None,
|
73
|
+
device: DeviceLike | None = None,
|
74
|
+
cache: Path | str | bool = False,
|
75
|
+
verbose: bool = False,
|
76
|
+
) -> None:
|
77
|
+
self.device = get_device(device)
|
78
|
+
self.batch_size = batch_size if batch_size > 0 else 1
|
79
|
+
self.verbose = verbose
|
80
|
+
|
81
|
+
self._embeddings_only: bool = False
|
82
|
+
self._dataset = dataset
|
83
|
+
model = torch.nn.Flatten() if model is None else model
|
84
|
+
self._transforms = [transforms] if isinstance(transforms, Transform) else transforms
|
85
|
+
self._model = model.to(self.device).eval() if isinstance(model, torch.nn.Module) else model
|
86
|
+
self._encoder = model.encode if isinstance(model, SupportsEncode) else model
|
87
|
+
self._collate_fn = lambda datum: [torch.as_tensor(d[0] if isinstance(d, tuple) else d) for d in datum]
|
88
|
+
self._cached_idx: set[int] = set()
|
89
|
+
self._embeddings: torch.Tensor = torch.empty(())
|
90
|
+
|
91
|
+
self._cache = cache if isinstance(cache, bool) else self._resolve_path(cache)
|
92
|
+
|
93
|
+
def __hash__(self) -> int:
|
94
|
+
if self._embeddings_only:
|
95
|
+
bid = as_numpy(self._embeddings).ravel().tobytes()
|
96
|
+
else:
|
97
|
+
did = self._dataset.metadata["id"] if isinstance(self._dataset, AnnotatedDataset) else str(self._dataset)
|
98
|
+
mid = self._model.metadata["id"] if isinstance(self._model, AnnotatedModel) else str(self._model)
|
99
|
+
tid = str.join("|", [str(t) for t in self._transforms or []])
|
100
|
+
bid = f"{did}{mid}{tid}".encode()
|
101
|
+
|
102
|
+
return int(xxh.xxh3_64_hexdigest(bid), 16)
|
103
|
+
|
104
|
+
@property
|
105
|
+
def cache(self) -> Path | bool:
|
106
|
+
return self._cache
|
107
|
+
|
108
|
+
@cache.setter
|
109
|
+
def cache(self, value: Path | str | bool) -> None:
|
110
|
+
if isinstance(value, bool) and not value:
|
111
|
+
self._cached_idx = set()
|
112
|
+
self._embeddings = torch.empty(())
|
113
|
+
elif isinstance(value, (Path, str)):
|
114
|
+
value = self._resolve_path(value)
|
115
|
+
|
116
|
+
if isinstance(value, Path) and value != getattr(self, "_cache", None):
|
117
|
+
self._save(value)
|
118
|
+
|
119
|
+
self._cache = value
|
120
|
+
|
121
|
+
def _resolve_path(self, path: Path | str) -> Path:
|
122
|
+
if isinstance(path, str):
|
123
|
+
path = Path(os.path.abspath(path))
|
124
|
+
if isinstance(path, Path) and (path.is_dir() or not path.suffix):
|
125
|
+
path = path / f"emb-{hash(self)}.pt"
|
126
|
+
return path
|
127
|
+
|
128
|
+
def to_tensor(self, indices: Sequence[int] | None = None) -> torch.Tensor:
|
129
|
+
"""
|
130
|
+
Converts dataset to embeddings.
|
131
|
+
|
132
|
+
Parameters
|
133
|
+
----------
|
134
|
+
indices : Sequence[int] or None, default None
|
135
|
+
The indices to convert to embeddings
|
136
|
+
|
137
|
+
Returns
|
138
|
+
-------
|
139
|
+
torch.Tensor
|
140
|
+
|
141
|
+
Warning
|
142
|
+
-------
|
143
|
+
Processing large quantities of data can be resource intensive.
|
144
|
+
"""
|
145
|
+
if indices is not None:
|
146
|
+
return torch.vstack(list(self._batch(indices))).to(self.device)
|
147
|
+
else:
|
148
|
+
return self[:]
|
149
|
+
|
150
|
+
def to_numpy(self, indices: Sequence[int] | None = None) -> NDArray[Any]:
|
151
|
+
"""
|
152
|
+
Converts dataset to embeddings as numpy array.
|
153
|
+
|
154
|
+
Parameters
|
155
|
+
----------
|
156
|
+
indices : Sequence[int] or None, default None
|
157
|
+
The indices to convert to embeddings
|
158
|
+
|
159
|
+
Returns
|
160
|
+
-------
|
161
|
+
NDArray[Any]
|
162
|
+
|
163
|
+
Warning
|
164
|
+
-------
|
165
|
+
Processing large quantities of data can be resource intensive.
|
166
|
+
"""
|
167
|
+
return self.to_tensor(indices).cpu().numpy()
|
168
|
+
|
169
|
+
def new(self, dataset: Dataset[tuple[ArrayLike, Any, Any]] | Dataset[ArrayLike]) -> Embeddings:
|
170
|
+
"""
|
171
|
+
Creates a new Embeddings object with the same parameters but a different dataset.
|
172
|
+
|
173
|
+
Parameters
|
174
|
+
----------
|
175
|
+
dataset : ImageClassificationDataset or ObjectDetectionDataset
|
176
|
+
Dataset to access original images from.
|
177
|
+
|
178
|
+
Returns
|
179
|
+
-------
|
180
|
+
Embeddings
|
181
|
+
"""
|
182
|
+
if self._embeddings_only:
|
183
|
+
raise ValueError("Embeddings object does not have a model.")
|
184
|
+
return Embeddings(
|
185
|
+
dataset, self.batch_size, self._transforms, self._model, self.device, bool(self.cache), self.verbose
|
186
|
+
)
|
187
|
+
|
188
|
+
@classmethod
|
189
|
+
def from_array(cls, array: ArrayLike, device: DeviceLike | None = None) -> Embeddings:
|
190
|
+
"""
|
191
|
+
Instantiates a shallow Embeddings object using an array.
|
192
|
+
|
193
|
+
Parameters
|
194
|
+
----------
|
195
|
+
array : ArrayLike
|
196
|
+
The array to convert to embeddings.
|
197
|
+
device : DeviceLike or None, default None
|
198
|
+
The hardware device to use if specified, otherwise uses the DataEval
|
199
|
+
default or torch default.
|
200
|
+
|
201
|
+
Returns
|
202
|
+
-------
|
203
|
+
Embeddings
|
204
|
+
|
205
|
+
Example
|
206
|
+
-------
|
207
|
+
>>> import numpy as np
|
208
|
+
>>> from dataeval.data import Embeddings
|
209
|
+
>>> array = np.random.randn(100, 3, 224, 224)
|
210
|
+
>>> embeddings = Embeddings.from_array(array)
|
211
|
+
>>> print(embeddings.to_tensor().shape)
|
212
|
+
torch.Size([100, 3, 224, 224])
|
213
|
+
"""
|
214
|
+
embeddings = Embeddings([], 0, None, None, device, True, False)
|
215
|
+
array = array if isinstance(array, Array) else as_numpy(array)
|
216
|
+
embeddings._cached_idx = set(range(len(array)))
|
217
|
+
embeddings._embeddings = torch.as_tensor(array).to(get_device(device))
|
218
|
+
embeddings._embeddings_only = True
|
219
|
+
return embeddings
|
220
|
+
|
221
|
+
def save(self, path: Path | str) -> None:
|
222
|
+
"""
|
223
|
+
Saves the embeddings to disk.
|
224
|
+
|
225
|
+
Parameters
|
226
|
+
----------
|
227
|
+
path : Path or str
|
228
|
+
The file path to save the embeddings to.
|
229
|
+
"""
|
230
|
+
self._save(self._resolve_path(path), True)
|
231
|
+
|
232
|
+
def _save(self, path: Path, force: bool = False) -> None:
|
233
|
+
path.parent.mkdir(parents=True, exist_ok=True)
|
234
|
+
|
235
|
+
if self._embeddings_only or self.cache and not force:
|
236
|
+
embeddings = self._embeddings
|
237
|
+
cached_idx = self._cached_idx
|
238
|
+
else:
|
239
|
+
embeddings = self.to_tensor()
|
240
|
+
cached_idx = list(range(len(self)))
|
241
|
+
try:
|
242
|
+
cache_data = {
|
243
|
+
"embeddings": embeddings,
|
244
|
+
"cached_indices": cached_idx,
|
245
|
+
"device": self.device,
|
246
|
+
}
|
247
|
+
torch.save(cache_data, path)
|
248
|
+
_logger.log(logging.DEBUG, f"Saved embeddings cache from {path}")
|
249
|
+
except Exception as e:
|
250
|
+
_logger.log(logging.ERROR, f"Failed to save embeddings cache: {e}")
|
251
|
+
|
252
|
+
@classmethod
|
253
|
+
def load(cls, path: Path | str) -> Embeddings:
|
254
|
+
"""
|
255
|
+
Loads the embeddings from disk.
|
256
|
+
|
257
|
+
Parameters
|
258
|
+
----------
|
259
|
+
path : Path or str
|
260
|
+
The file path to load the embeddings from.
|
261
|
+
"""
|
262
|
+
emb = Embeddings([], 0)
|
263
|
+
path = Path(os.path.abspath(path)) if isinstance(path, str) else path
|
264
|
+
if path.exists() and path.is_file():
|
265
|
+
try:
|
266
|
+
cache_data = torch.load(path, weights_only=False)
|
267
|
+
emb._embeddings_only = True
|
268
|
+
emb._embeddings = cache_data["embeddings"]
|
269
|
+
emb._cached_idx = cache_data["cached_indices"]
|
270
|
+
emb.device = cache_data["device"]
|
271
|
+
_logger.log(logging.DEBUG, f"Loaded embeddings cache from {path}")
|
272
|
+
except Exception as e:
|
273
|
+
_logger.log(logging.ERROR, f"Failed to load embeddings cache: {e}")
|
274
|
+
raise e
|
275
|
+
else:
|
276
|
+
raise FileNotFoundError(f"Specified cache file {path} was not found.")
|
277
|
+
|
278
|
+
return emb
|
279
|
+
|
280
|
+
def _encode(self, images: list[torch.Tensor]) -> torch.Tensor:
|
281
|
+
if self._transforms:
|
282
|
+
images = [transform(image) for transform in self._transforms for image in images]
|
283
|
+
return self._encoder(torch.stack(images).to(self.device))
|
284
|
+
|
285
|
+
@torch.no_grad() # Reduce overhead cost by not tracking tensor gradients
|
286
|
+
def _batch(self, indices: Sequence[int]) -> Iterator[torch.Tensor]:
|
287
|
+
dataset = cast(torch.utils.data.Dataset, self._dataset)
|
288
|
+
total_batches = math.ceil(len(indices) / self.batch_size)
|
289
|
+
|
290
|
+
# If not caching, process all indices normally
|
291
|
+
if not self.cache:
|
292
|
+
for images in tqdm(
|
293
|
+
DataLoader(Subset(dataset, indices), self.batch_size, collate_fn=self._collate_fn),
|
294
|
+
total=total_batches,
|
295
|
+
desc="Batch embedding",
|
296
|
+
disable=not self.verbose,
|
297
|
+
):
|
298
|
+
yield self._encode(images)
|
299
|
+
return
|
300
|
+
|
301
|
+
# If caching, process each batch of indices at a time, preserving original order
|
302
|
+
for i in tqdm(range(0, len(indices), self.batch_size), desc="Batch embedding", disable=not self.verbose):
|
303
|
+
batch = indices[i : i + self.batch_size]
|
304
|
+
uncached = [idx for idx in batch if idx not in self._cached_idx]
|
305
|
+
|
306
|
+
if uncached:
|
307
|
+
# Process uncached indices as as single batch
|
308
|
+
for images in DataLoader(Subset(dataset, uncached), len(uncached), collate_fn=self._collate_fn):
|
309
|
+
embeddings = self._encode(images)
|
310
|
+
|
311
|
+
if not self._embeddings.shape:
|
312
|
+
full_shape = (len(self), *embeddings.shape[1:])
|
313
|
+
self._embeddings = torch.empty(full_shape, dtype=embeddings.dtype, device=self.device)
|
314
|
+
|
315
|
+
self._embeddings[uncached] = embeddings
|
316
|
+
self._cached_idx.update(uncached)
|
317
|
+
|
318
|
+
if isinstance(self.cache, Path):
|
319
|
+
self._save(self.cache)
|
320
|
+
|
321
|
+
yield self._embeddings[batch]
|
322
|
+
|
323
|
+
def __getitem__(self, key: int | slice, /) -> torch.Tensor:
|
324
|
+
if not isinstance(key, slice) and not hasattr(key, "__int__"):
|
325
|
+
raise TypeError("Invalid argument type.")
|
326
|
+
|
327
|
+
indices = list(range(len(self))[key]) if isinstance(key, slice) else [int(key)]
|
328
|
+
|
329
|
+
if self._embeddings_only:
|
330
|
+
if not self._embeddings.shape:
|
331
|
+
raise ValueError("Embeddings not initialized.")
|
332
|
+
if not set(indices).issubset(self._cached_idx):
|
333
|
+
raise ValueError("Unable to generate new embeddings from a shallow instance.")
|
334
|
+
return self._embeddings[key]
|
335
|
+
|
336
|
+
result = torch.vstack(list(self._batch(indices))).to(self.device)
|
337
|
+
return result.squeeze(0) if len(indices) == 1 else result
|
338
|
+
|
339
|
+
def __iter__(self) -> Iterator[torch.Tensor]:
|
340
|
+
# process in batches while yielding individual embeddings
|
341
|
+
for batch in self._batch(range(len(self))):
|
342
|
+
yield from batch
|
343
|
+
|
344
|
+
def __len__(self) -> int:
|
345
|
+
return len(self._embeddings) if self._embeddings_only else len(self._dataset)
|
@@ -4,13 +4,13 @@ __all__ = []
|
|
4
4
|
|
5
5
|
from typing import TYPE_CHECKING, Any, Generic, Iterator, Sequence, TypeVar, cast, overload
|
6
6
|
|
7
|
-
from dataeval.typing import Array, Dataset
|
7
|
+
from dataeval.typing import Array, ArrayLike, Dataset
|
8
8
|
from dataeval.utils._array import as_numpy, channels_first_to_last
|
9
9
|
|
10
10
|
if TYPE_CHECKING:
|
11
11
|
from matplotlib.figure import Figure
|
12
12
|
|
13
|
-
T = TypeVar("T",
|
13
|
+
T = TypeVar("T", Array, ArrayLike)
|
14
14
|
|
15
15
|
|
16
16
|
class Images(Generic[T]):
|
@@ -3,7 +3,7 @@ from __future__ import annotations
|
|
3
3
|
__all__ = []
|
4
4
|
|
5
5
|
import warnings
|
6
|
-
from typing import TYPE_CHECKING, Any, Literal, Mapping, Sequence, cast
|
6
|
+
from typing import TYPE_CHECKING, Any, Literal, Mapping, Sequence, Sized, cast
|
7
7
|
|
8
8
|
import numpy as np
|
9
9
|
from numpy.typing import NDArray
|
@@ -16,12 +16,12 @@ from dataeval.typing import (
|
|
16
16
|
)
|
17
17
|
from dataeval.utils._array import as_numpy, to_numpy
|
18
18
|
from dataeval.utils._bin import bin_data, digitize_data, is_continuous
|
19
|
-
from dataeval.utils.metadata import merge
|
19
|
+
from dataeval.utils.data.metadata import merge
|
20
20
|
|
21
21
|
if TYPE_CHECKING:
|
22
|
-
from dataeval.
|
22
|
+
from dataeval.data import Targets
|
23
23
|
else:
|
24
|
-
from dataeval.
|
24
|
+
from dataeval.data._targets import Targets
|
25
25
|
|
26
26
|
|
27
27
|
class Metadata:
|
@@ -208,8 +208,9 @@ class Metadata:
|
|
208
208
|
raw.append(metadata)
|
209
209
|
|
210
210
|
if is_od_target := isinstance(target, ObjectDetectionTarget):
|
211
|
-
|
212
|
-
|
211
|
+
target_labels = as_numpy(target.labels)
|
212
|
+
target_len = len(target_labels)
|
213
|
+
labels.extend(target_labels.tolist())
|
213
214
|
bboxes.extend(as_numpy(target.boxes).tolist())
|
214
215
|
scores.extend(as_numpy(target.scores).tolist())
|
215
216
|
srcidx.extend([i] * target_len)
|
@@ -360,7 +361,7 @@ class Metadata:
|
|
360
361
|
self._merge()
|
361
362
|
self._processed = False
|
362
363
|
target_len = len(self.targets.source) if self.targets.source is not None else len(self.targets)
|
363
|
-
if any(len(v) != target_len for v in factors.values()):
|
364
|
+
if any(len(v if isinstance(v, Sized) else as_numpy(v)) != target_len for v in factors.values()):
|
364
365
|
raise ValueError(
|
365
366
|
"The lists/arrays in the provided factors have a different length than the current metadata factors."
|
366
367
|
)
|
@@ -25,6 +25,10 @@ class Selection(Generic[_TDatum]):
|
|
25
25
|
return f"{self.__class__.__name__}({', '.join([f'{k}={v}' for k, v in self.__dict__.items()])})"
|
26
26
|
|
27
27
|
|
28
|
+
class Subselection(Generic[_TDatum]):
|
29
|
+
def __call__(self, original: _TDatum) -> _TDatum: ...
|
30
|
+
|
31
|
+
|
28
32
|
class Select(AnnotatedDataset[_TDatum]):
|
29
33
|
"""
|
30
34
|
Wraps a dataset and applies selection criteria to it.
|
@@ -38,7 +42,7 @@ class Select(AnnotatedDataset[_TDatum]):
|
|
38
42
|
|
39
43
|
Examples
|
40
44
|
--------
|
41
|
-
>>> from dataeval.
|
45
|
+
>>> from dataeval.data.selections import ClassFilter, Limit
|
42
46
|
|
43
47
|
>>> # Construct a sample dataset with size of 100 and class count of 10
|
44
48
|
>>> # Elements at index `idx` are returned as tuples:
|
@@ -63,6 +67,7 @@ class Select(AnnotatedDataset[_TDatum]):
|
|
63
67
|
_selection: list[int]
|
64
68
|
_selections: Sequence[Selection[_TDatum]]
|
65
69
|
_size_limit: int
|
70
|
+
_subselections: list[tuple[Subselection[_TDatum], set[int]]]
|
66
71
|
|
67
72
|
def __init__(
|
68
73
|
self,
|
@@ -73,7 +78,8 @@ class Select(AnnotatedDataset[_TDatum]):
|
|
73
78
|
self._dataset = dataset
|
74
79
|
self._size_limit = len(dataset)
|
75
80
|
self._selection = list(range(self._size_limit))
|
76
|
-
self._selections = self.
|
81
|
+
self._selections = self._sort_selections(selections)
|
82
|
+
self._subselections = []
|
77
83
|
|
78
84
|
# Ensure metadata is populated correctly as DatasetMetadata TypedDict
|
79
85
|
_metadata = getattr(dataset, "metadata", {})
|
@@ -81,7 +87,7 @@ class Select(AnnotatedDataset[_TDatum]):
|
|
81
87
|
_metadata["id"] = dataset.__class__.__name__
|
82
88
|
self._metadata = DatasetMetadata(**_metadata)
|
83
89
|
|
84
|
-
self.
|
90
|
+
self._apply_selections()
|
85
91
|
|
86
92
|
@property
|
87
93
|
def metadata(self) -> DatasetMetadata:
|
@@ -94,24 +100,31 @@ class Select(AnnotatedDataset[_TDatum]):
|
|
94
100
|
selections = f"Selections: [{', '.join([str(s) for s in self._selections])}]"
|
95
101
|
return f"{title}\n{sep}{nt}{selections}{nt}Selected Size: {len(self)}\n\n{self._dataset}"
|
96
102
|
|
97
|
-
def
|
103
|
+
def _sort_selections(
|
104
|
+
self, selections: Selection[_TDatum] | Sequence[Selection[_TDatum]] | None
|
105
|
+
) -> list[Selection[_TDatum]]:
|
98
106
|
if not selections:
|
99
107
|
return []
|
100
108
|
|
101
|
-
|
102
|
-
grouped: dict[int, list[Selection]] = {}
|
103
|
-
for selection in
|
109
|
+
selections_list = [selections] if isinstance(selections, Selection) else list(selections)
|
110
|
+
grouped: dict[int, list[Selection[_TDatum]]] = {}
|
111
|
+
for selection in selections_list:
|
104
112
|
grouped.setdefault(selection.stage, []).append(selection)
|
105
113
|
selection_list = [selection for category in sorted(grouped) for selection in grouped[category]]
|
106
114
|
return selection_list
|
107
115
|
|
108
|
-
def
|
116
|
+
def _apply_selections(self) -> None:
|
109
117
|
for selection in self._selections:
|
110
118
|
selection(self)
|
111
119
|
self._selection = self._selection[: self._size_limit]
|
112
120
|
|
121
|
+
def _apply_subselection(self, datum: _TDatum, index: int) -> _TDatum:
|
122
|
+
for subselection, indices in self._subselections:
|
123
|
+
datum = subselection(datum) if index in indices else datum
|
124
|
+
return datum
|
125
|
+
|
113
126
|
def __getitem__(self, index: int) -> _TDatum:
|
114
|
-
return self._dataset[self._selection[index]]
|
127
|
+
return self._apply_subselection(self._dataset[self._selection[index]], index)
|
115
128
|
|
116
129
|
def __iter__(self) -> Iterator[_TDatum]:
|
117
130
|
for i in range(len(self)):
|
@@ -12,10 +12,10 @@ from sklearn.model_selection import GroupKFold, KFold, StratifiedGroupKFold, Str
|
|
12
12
|
from sklearn.utils.multiclass import type_of_target
|
13
13
|
|
14
14
|
from dataeval.config import EPSILON
|
15
|
+
from dataeval.data._metadata import Metadata
|
15
16
|
from dataeval.outputs._base import set_metadata
|
16
17
|
from dataeval.outputs._utils import SplitDatasetOutput, TrainValSplit
|
17
18
|
from dataeval.typing import AnnotatedDataset
|
18
|
-
from dataeval.utils.data._metadata import Metadata
|
19
19
|
|
20
20
|
_logger = logging.getLogger(__name__)
|
21
21
|
|
@@ -0,0 +1,19 @@
|
|
1
|
+
"""Provides selection classes for selecting subsets of Computer Vision datasets."""
|
2
|
+
|
3
|
+
__all__ = [
|
4
|
+
"ClassBalance",
|
5
|
+
"ClassFilter",
|
6
|
+
"Indices",
|
7
|
+
"Limit",
|
8
|
+
"Prioritize",
|
9
|
+
"Reverse",
|
10
|
+
"Shuffle",
|
11
|
+
]
|
12
|
+
|
13
|
+
from dataeval.data.selections._classbalance import ClassBalance
|
14
|
+
from dataeval.data.selections._classfilter import ClassFilter
|
15
|
+
from dataeval.data.selections._indices import Indices
|
16
|
+
from dataeval.data.selections._limit import Limit
|
17
|
+
from dataeval.data.selections._prioritize import Prioritize
|
18
|
+
from dataeval.data.selections._reverse import Reverse
|
19
|
+
from dataeval.data.selections._shuffle import Shuffle
|
@@ -0,0 +1,37 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
__all__ = []
|
4
|
+
|
5
|
+
import numpy as np
|
6
|
+
|
7
|
+
from dataeval.data._selection import Select, Selection, SelectionStage
|
8
|
+
from dataeval.typing import Array, ImageClassificationDatum
|
9
|
+
from dataeval.utils._array import as_numpy
|
10
|
+
|
11
|
+
|
12
|
+
class ClassBalance(Selection[ImageClassificationDatum]):
|
13
|
+
"""
|
14
|
+
Balance the dataset by class.
|
15
|
+
|
16
|
+
Note
|
17
|
+
----
|
18
|
+
The total number of instances of each class will be equalized which may result
|
19
|
+
in a lower total number of instances than specified by the selection limit.
|
20
|
+
"""
|
21
|
+
|
22
|
+
stage = SelectionStage.FILTER
|
23
|
+
|
24
|
+
def __call__(self, dataset: Select[ImageClassificationDatum]) -> None:
|
25
|
+
class_indices: dict[int, list[int]] = {}
|
26
|
+
for i, idx in enumerate(dataset._selection):
|
27
|
+
target = dataset._dataset[idx][1]
|
28
|
+
if isinstance(target, Array):
|
29
|
+
label = int(np.argmax(as_numpy(target)))
|
30
|
+
else:
|
31
|
+
# ObjectDetectionTarget and SegmentationTarget not supported yet
|
32
|
+
raise TypeError("ClassFilter only supports classification targets as an array of confidence scores.")
|
33
|
+
class_indices.setdefault(label, []).append(i)
|
34
|
+
|
35
|
+
per_class_limit = min(min(len(c) for c in class_indices.values()), dataset._size_limit // len(class_indices))
|
36
|
+
subselection = sorted([i for v in class_indices.values() for i in v[:per_class_limit]])
|
37
|
+
dataset._selection = [dataset._selection[i] for i in subselection]
|