dataeval 0.84.1__py3-none-any.whl → 0.85.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dataeval/__init__.py +1 -1
- dataeval/data/__init__.py +19 -0
- dataeval/{utils/data → data}/_embeddings.py +137 -17
- dataeval/{utils/data → data}/_metadata.py +3 -3
- dataeval/{utils/data → data}/_selection.py +22 -9
- dataeval/{utils/data → data}/_split.py +1 -1
- dataeval/data/selections/__init__.py +19 -0
- dataeval/{utils/data → data}/selections/_classbalance.py +1 -2
- dataeval/data/selections/_classfilter.py +109 -0
- dataeval/{utils/data → data}/selections/_indices.py +1 -1
- dataeval/{utils/data → data}/selections/_limit.py +1 -1
- dataeval/{utils/data → data}/selections/_prioritize.py +2 -2
- dataeval/{utils/data → data}/selections/_reverse.py +1 -1
- dataeval/{utils/data → data}/selections/_shuffle.py +1 -1
- dataeval/detectors/drift/_base.py +1 -1
- dataeval/detectors/drift/_cvm.py +2 -2
- dataeval/detectors/drift/_ks.py +2 -2
- dataeval/detectors/drift/_mmd.py +2 -2
- dataeval/detectors/linters/duplicates.py +1 -1
- dataeval/detectors/linters/outliers.py +1 -1
- dataeval/metadata/_distance.py +1 -1
- dataeval/metadata/_ood.py +4 -4
- dataeval/metrics/bias/_balance.py +1 -1
- dataeval/metrics/bias/_diversity.py +1 -1
- dataeval/metrics/bias/_parity.py +1 -1
- dataeval/metrics/stats/_labelstats.py +2 -2
- dataeval/outputs/_bias.py +1 -1
- dataeval/typing.py +31 -0
- dataeval/utils/__init__.py +2 -2
- dataeval/utils/data/__init__.py +5 -20
- dataeval/utils/data/collate.py +2 -0
- dataeval/utils/datasets/__init__.py +17 -0
- dataeval/utils/{data/datasets → datasets}/_base.py +3 -3
- dataeval/utils/{data/datasets → datasets}/_cifar10.py +2 -2
- dataeval/utils/{data/datasets → datasets}/_milco.py +2 -2
- dataeval/utils/{data/datasets → datasets}/_mnist.py +2 -2
- dataeval/utils/{data/datasets → datasets}/_ships.py +2 -2
- dataeval/utils/{data/datasets → datasets}/_voc.py +3 -3
- {dataeval-0.84.1.dist-info → dataeval-0.85.0.dist-info}/METADATA +1 -1
- {dataeval-0.84.1.dist-info → dataeval-0.85.0.dist-info}/RECORD +48 -47
- dataeval/utils/data/datasets/__init__.py +0 -17
- dataeval/utils/data/selections/__init__.py +0 -19
- dataeval/utils/data/selections/_classfilter.py +0 -44
- /dataeval/{utils/data → data}/_images.py +0 -0
- /dataeval/{utils/data → data}/_targets.py +0 -0
- /dataeval/utils/{metadata.py → data/metadata.py} +0 -0
- /dataeval/utils/{data/datasets → datasets}/_fileio.py +0 -0
- /dataeval/utils/{data/datasets → datasets}/_mixin.py +0 -0
- /dataeval/utils/{data/datasets → datasets}/_types.py +0 -0
- {dataeval-0.84.1.dist-info → dataeval-0.85.0.dist-info}/LICENSE.txt +0 -0
- {dataeval-0.84.1.dist-info → dataeval-0.85.0.dist-info}/WHEEL +0 -0
dataeval/__init__.py
CHANGED
@@ -0,0 +1,19 @@
|
|
1
|
+
"""Provides utility functions for interacting with Computer Vision datasets."""
|
2
|
+
|
3
|
+
__all__ = [
|
4
|
+
"Embeddings",
|
5
|
+
"Images",
|
6
|
+
"Metadata",
|
7
|
+
"Select",
|
8
|
+
"SplitDatasetOutput",
|
9
|
+
"Targets",
|
10
|
+
"split_dataset",
|
11
|
+
]
|
12
|
+
|
13
|
+
from dataeval.data._embeddings import Embeddings
|
14
|
+
from dataeval.data._images import Images
|
15
|
+
from dataeval.data._metadata import Metadata
|
16
|
+
from dataeval.data._selection import Select
|
17
|
+
from dataeval.data._split import split_dataset
|
18
|
+
from dataeval.data._targets import Targets
|
19
|
+
from dataeval.outputs._utils import SplitDatasetOutput
|
@@ -2,19 +2,25 @@ from __future__ import annotations
|
|
2
2
|
|
3
3
|
__all__ = []
|
4
4
|
|
5
|
+
import logging
|
5
6
|
import math
|
7
|
+
import os
|
8
|
+
from pathlib import Path
|
6
9
|
from typing import Any, Iterator, Sequence, cast
|
7
10
|
|
8
11
|
import torch
|
12
|
+
import xxhash as xxh
|
9
13
|
from numpy.typing import NDArray
|
10
14
|
from torch.utils.data import DataLoader, Subset
|
11
15
|
from tqdm import tqdm
|
12
16
|
|
13
17
|
from dataeval.config import DeviceLike, get_device
|
14
|
-
from dataeval.typing import Array, ArrayLike, Dataset, Transform
|
18
|
+
from dataeval.typing import AnnotatedDataset, AnnotatedModel, Array, ArrayLike, Dataset, Transform
|
15
19
|
from dataeval.utils._array import as_numpy
|
16
20
|
from dataeval.utils.torch.models import SupportsEncode
|
17
21
|
|
22
|
+
_logger = logging.getLogger(__name__)
|
23
|
+
|
18
24
|
|
19
25
|
class Embeddings:
|
20
26
|
"""
|
@@ -35,10 +41,23 @@ class Embeddings:
|
|
35
41
|
device : DeviceLike or None, default None
|
36
42
|
The hardware device to use if specified, otherwise uses the DataEval
|
37
43
|
default or torch default.
|
38
|
-
cache : bool, default False
|
39
|
-
Whether to cache the embeddings in memory.
|
44
|
+
cache : Path, str, or bool, default False
|
45
|
+
Whether to cache the embeddings to a file or in memory.
|
46
|
+
When a Path or string is provided, embeddings will be cached to disk.
|
40
47
|
verbose : bool, default False
|
41
48
|
Whether to print progress bar when encoding images.
|
49
|
+
|
50
|
+
Attributes
|
51
|
+
----------
|
52
|
+
batch_size : int
|
53
|
+
Batch size to use when encoding images.
|
54
|
+
cache : Path or bool
|
55
|
+
The path to cache embeddings to file, or True if caching to memory.
|
56
|
+
device : torch.device
|
57
|
+
The hardware device to use if specified, otherwise uses the DataEval
|
58
|
+
default or torch default.
|
59
|
+
verbose : bool
|
60
|
+
Whether to print progress bar when encoding images.
|
42
61
|
"""
|
43
62
|
|
44
63
|
device: torch.device
|
@@ -52,24 +71,59 @@ class Embeddings:
|
|
52
71
|
transforms: Transform[torch.Tensor] | Sequence[Transform[torch.Tensor]] | None = None,
|
53
72
|
model: torch.nn.Module | None = None,
|
54
73
|
device: DeviceLike | None = None,
|
55
|
-
cache: bool = False,
|
74
|
+
cache: Path | str | bool = False,
|
56
75
|
verbose: bool = False,
|
57
76
|
) -> None:
|
58
77
|
self.device = get_device(device)
|
59
|
-
self.cache = cache
|
60
78
|
self.batch_size = batch_size if batch_size > 0 else 1
|
61
79
|
self.verbose = verbose
|
62
80
|
|
81
|
+
self._embeddings_only: bool = False
|
63
82
|
self._dataset = dataset
|
64
|
-
self._length = len(dataset)
|
65
83
|
model = torch.nn.Flatten() if model is None else model
|
66
84
|
self._transforms = [transforms] if isinstance(transforms, Transform) else transforms
|
67
85
|
self._model = model.to(self.device).eval() if isinstance(model, torch.nn.Module) else model
|
68
86
|
self._encoder = model.encode if isinstance(model, SupportsEncode) else model
|
69
87
|
self._collate_fn = lambda datum: [torch.as_tensor(d[0] if isinstance(d, tuple) else d) for d in datum]
|
70
|
-
self._cached_idx = set()
|
88
|
+
self._cached_idx: set[int] = set()
|
71
89
|
self._embeddings: torch.Tensor = torch.empty(())
|
72
|
-
|
90
|
+
|
91
|
+
self._cache = cache if isinstance(cache, bool) else self._resolve_path(cache)
|
92
|
+
|
93
|
+
def __hash__(self) -> int:
|
94
|
+
if self._embeddings_only:
|
95
|
+
bid = as_numpy(self._embeddings).ravel().tobytes()
|
96
|
+
else:
|
97
|
+
did = self._dataset.metadata["id"] if isinstance(self._dataset, AnnotatedDataset) else str(self._dataset)
|
98
|
+
mid = self._model.metadata["id"] if isinstance(self._model, AnnotatedModel) else str(self._model)
|
99
|
+
tid = str.join("|", [str(t) for t in self._transforms or []])
|
100
|
+
bid = f"{did}{mid}{tid}".encode()
|
101
|
+
|
102
|
+
return int(xxh.xxh3_64_hexdigest(bid), 16)
|
103
|
+
|
104
|
+
@property
|
105
|
+
def cache(self) -> Path | bool:
|
106
|
+
return self._cache
|
107
|
+
|
108
|
+
@cache.setter
|
109
|
+
def cache(self, value: Path | str | bool) -> None:
|
110
|
+
if isinstance(value, bool) and not value:
|
111
|
+
self._cached_idx = set()
|
112
|
+
self._embeddings = torch.empty(())
|
113
|
+
elif isinstance(value, (Path, str)):
|
114
|
+
value = self._resolve_path(value)
|
115
|
+
|
116
|
+
if isinstance(value, Path) and value != getattr(self, "_cache", None):
|
117
|
+
self._save(value)
|
118
|
+
|
119
|
+
self._cache = value
|
120
|
+
|
121
|
+
def _resolve_path(self, path: Path | str) -> Path:
|
122
|
+
if isinstance(path, str):
|
123
|
+
path = Path(os.path.abspath(path))
|
124
|
+
if isinstance(path, Path) and (path.is_dir() or not path.suffix):
|
125
|
+
path = path / f"emb-{hash(self)}.pt"
|
126
|
+
return path
|
73
127
|
|
74
128
|
def to_tensor(self, indices: Sequence[int] | None = None) -> torch.Tensor:
|
75
129
|
"""
|
@@ -125,8 +179,10 @@ class Embeddings:
|
|
125
179
|
-------
|
126
180
|
Embeddings
|
127
181
|
"""
|
182
|
+
if self._embeddings_only:
|
183
|
+
raise ValueError("Embeddings object does not have a model.")
|
128
184
|
return Embeddings(
|
129
|
-
dataset, self.batch_size, self._transforms, self._model, self.device, self.cache, self.verbose
|
185
|
+
dataset, self.batch_size, self._transforms, self._model, self.device, bool(self.cache), self.verbose
|
130
186
|
)
|
131
187
|
|
132
188
|
@classmethod
|
@@ -149,7 +205,7 @@ class Embeddings:
|
|
149
205
|
Example
|
150
206
|
-------
|
151
207
|
>>> import numpy as np
|
152
|
-
>>> from dataeval.
|
208
|
+
>>> from dataeval.data import Embeddings
|
153
209
|
>>> array = np.random.randn(100, 3, 224, 224)
|
154
210
|
>>> embeddings = Embeddings.from_array(array)
|
155
211
|
>>> print(embeddings.to_tensor().shape)
|
@@ -157,12 +213,70 @@ class Embeddings:
|
|
157
213
|
"""
|
158
214
|
embeddings = Embeddings([], 0, None, None, device, True, False)
|
159
215
|
array = array if isinstance(array, Array) else as_numpy(array)
|
160
|
-
embeddings._length = len(array)
|
161
216
|
embeddings._cached_idx = set(range(len(array)))
|
162
217
|
embeddings._embeddings = torch.as_tensor(array).to(get_device(device))
|
163
|
-
embeddings.
|
218
|
+
embeddings._embeddings_only = True
|
164
219
|
return embeddings
|
165
220
|
|
221
|
+
def save(self, path: Path | str) -> None:
|
222
|
+
"""
|
223
|
+
Saves the embeddings to disk.
|
224
|
+
|
225
|
+
Parameters
|
226
|
+
----------
|
227
|
+
path : Path or str
|
228
|
+
The file path to save the embeddings to.
|
229
|
+
"""
|
230
|
+
self._save(self._resolve_path(path), True)
|
231
|
+
|
232
|
+
def _save(self, path: Path, force: bool = False) -> None:
|
233
|
+
path.parent.mkdir(parents=True, exist_ok=True)
|
234
|
+
|
235
|
+
if self._embeddings_only or self.cache and not force:
|
236
|
+
embeddings = self._embeddings
|
237
|
+
cached_idx = self._cached_idx
|
238
|
+
else:
|
239
|
+
embeddings = self.to_tensor()
|
240
|
+
cached_idx = list(range(len(self)))
|
241
|
+
try:
|
242
|
+
cache_data = {
|
243
|
+
"embeddings": embeddings,
|
244
|
+
"cached_indices": cached_idx,
|
245
|
+
"device": self.device,
|
246
|
+
}
|
247
|
+
torch.save(cache_data, path)
|
248
|
+
_logger.log(logging.DEBUG, f"Saved embeddings cache from {path}")
|
249
|
+
except Exception as e:
|
250
|
+
_logger.log(logging.ERROR, f"Failed to save embeddings cache: {e}")
|
251
|
+
|
252
|
+
@classmethod
|
253
|
+
def load(cls, path: Path | str) -> Embeddings:
|
254
|
+
"""
|
255
|
+
Loads the embeddings from disk.
|
256
|
+
|
257
|
+
Parameters
|
258
|
+
----------
|
259
|
+
path : Path or str
|
260
|
+
The file path to load the embeddings from.
|
261
|
+
"""
|
262
|
+
emb = Embeddings([], 0)
|
263
|
+
path = Path(os.path.abspath(path)) if isinstance(path, str) else path
|
264
|
+
if path.exists() and path.is_file():
|
265
|
+
try:
|
266
|
+
cache_data = torch.load(path, weights_only=False)
|
267
|
+
emb._embeddings_only = True
|
268
|
+
emb._embeddings = cache_data["embeddings"]
|
269
|
+
emb._cached_idx = cache_data["cached_indices"]
|
270
|
+
emb.device = cache_data["device"]
|
271
|
+
_logger.log(logging.DEBUG, f"Loaded embeddings cache from {path}")
|
272
|
+
except Exception as e:
|
273
|
+
_logger.log(logging.ERROR, f"Failed to load embeddings cache: {e}")
|
274
|
+
raise e
|
275
|
+
else:
|
276
|
+
raise FileNotFoundError(f"Specified cache file {path} was not found.")
|
277
|
+
|
278
|
+
return emb
|
279
|
+
|
166
280
|
def _encode(self, images: list[torch.Tensor]) -> torch.Tensor:
|
167
281
|
if self._transforms:
|
168
282
|
images = [transform(image) for transform in self._transforms for image in images]
|
@@ -195,31 +309,37 @@ class Embeddings:
|
|
195
309
|
embeddings = self._encode(images)
|
196
310
|
|
197
311
|
if not self._embeddings.shape:
|
198
|
-
full_shape = (len(self
|
312
|
+
full_shape = (len(self), *embeddings.shape[1:])
|
199
313
|
self._embeddings = torch.empty(full_shape, dtype=embeddings.dtype, device=self.device)
|
200
314
|
|
201
315
|
self._embeddings[uncached] = embeddings
|
202
316
|
self._cached_idx.update(uncached)
|
203
317
|
|
318
|
+
if isinstance(self.cache, Path):
|
319
|
+
self._save(self.cache)
|
320
|
+
|
204
321
|
yield self._embeddings[batch]
|
205
322
|
|
206
323
|
def __getitem__(self, key: int | slice, /) -> torch.Tensor:
|
207
324
|
if not isinstance(key, slice) and not hasattr(key, "__int__"):
|
208
325
|
raise TypeError("Invalid argument type.")
|
209
326
|
|
210
|
-
if
|
327
|
+
indices = list(range(len(self))[key]) if isinstance(key, slice) else [int(key)]
|
328
|
+
|
329
|
+
if self._embeddings_only:
|
211
330
|
if not self._embeddings.shape:
|
212
331
|
raise ValueError("Embeddings not initialized.")
|
332
|
+
if not set(indices).issubset(self._cached_idx):
|
333
|
+
raise ValueError("Unable to generate new embeddings from a shallow instance.")
|
213
334
|
return self._embeddings[key]
|
214
335
|
|
215
|
-
indices = list(range(len(self._dataset))[key]) if isinstance(key, slice) else [int(key)]
|
216
336
|
result = torch.vstack(list(self._batch(indices))).to(self.device)
|
217
337
|
return result.squeeze(0) if len(indices) == 1 else result
|
218
338
|
|
219
339
|
def __iter__(self) -> Iterator[torch.Tensor]:
|
220
340
|
# process in batches while yielding individual embeddings
|
221
|
-
for batch in self._batch(range(self
|
341
|
+
for batch in self._batch(range(len(self))):
|
222
342
|
yield from batch
|
223
343
|
|
224
344
|
def __len__(self) -> int:
|
225
|
-
return self.
|
345
|
+
return len(self._embeddings) if self._embeddings_only else len(self._dataset)
|
@@ -16,12 +16,12 @@ from dataeval.typing import (
|
|
16
16
|
)
|
17
17
|
from dataeval.utils._array import as_numpy, to_numpy
|
18
18
|
from dataeval.utils._bin import bin_data, digitize_data, is_continuous
|
19
|
-
from dataeval.utils.metadata import merge
|
19
|
+
from dataeval.utils.data.metadata import merge
|
20
20
|
|
21
21
|
if TYPE_CHECKING:
|
22
|
-
from dataeval.
|
22
|
+
from dataeval.data import Targets
|
23
23
|
else:
|
24
|
-
from dataeval.
|
24
|
+
from dataeval.data._targets import Targets
|
25
25
|
|
26
26
|
|
27
27
|
class Metadata:
|
@@ -25,6 +25,10 @@ class Selection(Generic[_TDatum]):
|
|
25
25
|
return f"{self.__class__.__name__}({', '.join([f'{k}={v}' for k, v in self.__dict__.items()])})"
|
26
26
|
|
27
27
|
|
28
|
+
class Subselection(Generic[_TDatum]):
|
29
|
+
def __call__(self, original: _TDatum) -> _TDatum: ...
|
30
|
+
|
31
|
+
|
28
32
|
class Select(AnnotatedDataset[_TDatum]):
|
29
33
|
"""
|
30
34
|
Wraps a dataset and applies selection criteria to it.
|
@@ -38,7 +42,7 @@ class Select(AnnotatedDataset[_TDatum]):
|
|
38
42
|
|
39
43
|
Examples
|
40
44
|
--------
|
41
|
-
>>> from dataeval.
|
45
|
+
>>> from dataeval.data.selections import ClassFilter, Limit
|
42
46
|
|
43
47
|
>>> # Construct a sample dataset with size of 100 and class count of 10
|
44
48
|
>>> # Elements at index `idx` are returned as tuples:
|
@@ -63,6 +67,7 @@ class Select(AnnotatedDataset[_TDatum]):
|
|
63
67
|
_selection: list[int]
|
64
68
|
_selections: Sequence[Selection[_TDatum]]
|
65
69
|
_size_limit: int
|
70
|
+
_subselections: list[tuple[Subselection[_TDatum], set[int]]]
|
66
71
|
|
67
72
|
def __init__(
|
68
73
|
self,
|
@@ -73,7 +78,8 @@ class Select(AnnotatedDataset[_TDatum]):
|
|
73
78
|
self._dataset = dataset
|
74
79
|
self._size_limit = len(dataset)
|
75
80
|
self._selection = list(range(self._size_limit))
|
76
|
-
self._selections = self.
|
81
|
+
self._selections = self._sort_selections(selections)
|
82
|
+
self._subselections = []
|
77
83
|
|
78
84
|
# Ensure metadata is populated correctly as DatasetMetadata TypedDict
|
79
85
|
_metadata = getattr(dataset, "metadata", {})
|
@@ -81,7 +87,7 @@ class Select(AnnotatedDataset[_TDatum]):
|
|
81
87
|
_metadata["id"] = dataset.__class__.__name__
|
82
88
|
self._metadata = DatasetMetadata(**_metadata)
|
83
89
|
|
84
|
-
self.
|
90
|
+
self._apply_selections()
|
85
91
|
|
86
92
|
@property
|
87
93
|
def metadata(self) -> DatasetMetadata:
|
@@ -94,24 +100,31 @@ class Select(AnnotatedDataset[_TDatum]):
|
|
94
100
|
selections = f"Selections: [{', '.join([str(s) for s in self._selections])}]"
|
95
101
|
return f"{title}\n{sep}{nt}{selections}{nt}Selected Size: {len(self)}\n\n{self._dataset}"
|
96
102
|
|
97
|
-
def
|
103
|
+
def _sort_selections(
|
104
|
+
self, selections: Selection[_TDatum] | Sequence[Selection[_TDatum]] | None
|
105
|
+
) -> list[Selection[_TDatum]]:
|
98
106
|
if not selections:
|
99
107
|
return []
|
100
108
|
|
101
|
-
|
102
|
-
grouped: dict[int, list[Selection]] = {}
|
103
|
-
for selection in
|
109
|
+
selections_list = [selections] if isinstance(selections, Selection) else list(selections)
|
110
|
+
grouped: dict[int, list[Selection[_TDatum]]] = {}
|
111
|
+
for selection in selections_list:
|
104
112
|
grouped.setdefault(selection.stage, []).append(selection)
|
105
113
|
selection_list = [selection for category in sorted(grouped) for selection in grouped[category]]
|
106
114
|
return selection_list
|
107
115
|
|
108
|
-
def
|
116
|
+
def _apply_selections(self) -> None:
|
109
117
|
for selection in self._selections:
|
110
118
|
selection(self)
|
111
119
|
self._selection = self._selection[: self._size_limit]
|
112
120
|
|
121
|
+
def _apply_subselection(self, datum: _TDatum, index: int) -> _TDatum:
|
122
|
+
for subselection, indices in self._subselections:
|
123
|
+
datum = subselection(datum) if index in indices else datum
|
124
|
+
return datum
|
125
|
+
|
113
126
|
def __getitem__(self, index: int) -> _TDatum:
|
114
|
-
return self._dataset[self._selection[index]]
|
127
|
+
return self._apply_subselection(self._dataset[self._selection[index]], index)
|
115
128
|
|
116
129
|
def __iter__(self) -> Iterator[_TDatum]:
|
117
130
|
for i in range(len(self)):
|
@@ -12,10 +12,10 @@ from sklearn.model_selection import GroupKFold, KFold, StratifiedGroupKFold, Str
|
|
12
12
|
from sklearn.utils.multiclass import type_of_target
|
13
13
|
|
14
14
|
from dataeval.config import EPSILON
|
15
|
+
from dataeval.data._metadata import Metadata
|
15
16
|
from dataeval.outputs._base import set_metadata
|
16
17
|
from dataeval.outputs._utils import SplitDatasetOutput, TrainValSplit
|
17
18
|
from dataeval.typing import AnnotatedDataset
|
18
|
-
from dataeval.utils.data._metadata import Metadata
|
19
19
|
|
20
20
|
_logger = logging.getLogger(__name__)
|
21
21
|
|
@@ -0,0 +1,19 @@
|
|
1
|
+
"""Provides selection classes for selecting subsets of Computer Vision datasets."""
|
2
|
+
|
3
|
+
__all__ = [
|
4
|
+
"ClassBalance",
|
5
|
+
"ClassFilter",
|
6
|
+
"Indices",
|
7
|
+
"Limit",
|
8
|
+
"Prioritize",
|
9
|
+
"Reverse",
|
10
|
+
"Shuffle",
|
11
|
+
]
|
12
|
+
|
13
|
+
from dataeval.data.selections._classbalance import ClassBalance
|
14
|
+
from dataeval.data.selections._classfilter import ClassFilter
|
15
|
+
from dataeval.data.selections._indices import Indices
|
16
|
+
from dataeval.data.selections._limit import Limit
|
17
|
+
from dataeval.data.selections._prioritize import Prioritize
|
18
|
+
from dataeval.data.selections._reverse import Reverse
|
19
|
+
from dataeval.data.selections._shuffle import Shuffle
|
@@ -2,12 +2,11 @@ from __future__ import annotations
|
|
2
2
|
|
3
3
|
__all__ = []
|
4
4
|
|
5
|
-
|
6
5
|
import numpy as np
|
7
6
|
|
7
|
+
from dataeval.data._selection import Select, Selection, SelectionStage
|
8
8
|
from dataeval.typing import Array, ImageClassificationDatum
|
9
9
|
from dataeval.utils._array import as_numpy
|
10
|
-
from dataeval.utils.data._selection import Select, Selection, SelectionStage
|
11
10
|
|
12
11
|
|
13
12
|
class ClassBalance(Selection[ImageClassificationDatum]):
|
@@ -0,0 +1,109 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
__all__ = []
|
4
|
+
|
5
|
+
from typing import Any, Generic, Iterable, Sequence, Sized, TypeVar, cast
|
6
|
+
|
7
|
+
import numpy as np
|
8
|
+
from numpy.typing import NDArray
|
9
|
+
|
10
|
+
from dataeval.data._selection import Select, Selection, SelectionStage, Subselection
|
11
|
+
from dataeval.typing import Array, ObjectDetectionDatum, ObjectDetectionTarget, SegmentationDatum, SegmentationTarget
|
12
|
+
from dataeval.utils._array import as_numpy
|
13
|
+
from dataeval.utils.data.metadata import flatten
|
14
|
+
|
15
|
+
|
16
|
+
class ClassFilter(Selection[Any]):
|
17
|
+
"""
|
18
|
+
Filter the dataset by class.
|
19
|
+
|
20
|
+
Parameters
|
21
|
+
----------
|
22
|
+
classes : Sequence[int]
|
23
|
+
The classes to filter by.
|
24
|
+
filter_detections : bool, default True
|
25
|
+
Whether to filter detections from targets for object detection and segmentation datasets.
|
26
|
+
"""
|
27
|
+
|
28
|
+
stage = SelectionStage.FILTER
|
29
|
+
|
30
|
+
def __init__(self, classes: Sequence[int], filter_detections: bool = True) -> None:
|
31
|
+
self.classes = classes
|
32
|
+
self.filter_detections = filter_detections
|
33
|
+
|
34
|
+
def __call__(self, dataset: Select[Any]) -> None:
|
35
|
+
if not self.classes:
|
36
|
+
return
|
37
|
+
|
38
|
+
selection = []
|
39
|
+
subselection = set()
|
40
|
+
for idx in dataset._selection:
|
41
|
+
target = dataset._dataset[idx][1]
|
42
|
+
if isinstance(target, Array):
|
43
|
+
# Get the label for the image
|
44
|
+
label = int(np.argmax(as_numpy(target)))
|
45
|
+
# Check to see if the label is in the classes to filter for
|
46
|
+
if label in self.classes:
|
47
|
+
# Include the image
|
48
|
+
selection.append(idx)
|
49
|
+
elif isinstance(target, (ObjectDetectionTarget, SegmentationTarget)):
|
50
|
+
# Get the set of labels from the target
|
51
|
+
labels = set(target.labels if isinstance(target.labels, Iterable) else [target.labels])
|
52
|
+
# Check to see if any labels are in the classes to filter for
|
53
|
+
if labels.intersection(self.classes):
|
54
|
+
# Include the image
|
55
|
+
selection.append(idx)
|
56
|
+
# If we are filtering out other labels and there are other labels, add a subselection filter
|
57
|
+
if self.filter_detections and labels.difference(self.classes):
|
58
|
+
subselection.add(idx)
|
59
|
+
else:
|
60
|
+
raise TypeError(f"ClassFilter does not support targets of type {type(target)}.")
|
61
|
+
|
62
|
+
dataset._selection = selection
|
63
|
+
dataset._subselections.append((ClassFilterSubSelection(self.classes), subselection))
|
64
|
+
|
65
|
+
|
66
|
+
_T = TypeVar("_T")
|
67
|
+
_TDatum = TypeVar("_TDatum", ObjectDetectionDatum, SegmentationDatum)
|
68
|
+
_TTarget = TypeVar("_TTarget", ObjectDetectionTarget, SegmentationTarget)
|
69
|
+
|
70
|
+
|
71
|
+
def _try_mask_object(obj: _T, mask: NDArray[np.bool_]) -> _T:
|
72
|
+
if isinstance(obj, Sized) and not isinstance(obj, (str, bytes, bytearray)) and len(obj) == len(mask):
|
73
|
+
if isinstance(obj, Array):
|
74
|
+
return obj[mask]
|
75
|
+
elif isinstance(obj, Sequence):
|
76
|
+
return cast(_T, [item for i, item in enumerate(obj) if mask[i]])
|
77
|
+
return obj
|
78
|
+
|
79
|
+
|
80
|
+
class ClassFilterTarget(Generic[_TTarget]):
|
81
|
+
def __init__(self, target: _TTarget, mask: NDArray[np.bool_]) -> None:
|
82
|
+
self.__dict__.update(target.__dict__)
|
83
|
+
self._length = len(target.labels) if isinstance(target.labels, Sized) else int(bool(target.labels))
|
84
|
+
self._mask = mask
|
85
|
+
self._target = target
|
86
|
+
|
87
|
+
def __getattribute__(self, name: str) -> Any:
|
88
|
+
if name in ("_length", "_mask", "_target") or name.startswith("__") and name.endswith("__"):
|
89
|
+
return super().__getattribute__(name)
|
90
|
+
|
91
|
+
attr = getattr(self._target, name)
|
92
|
+
return _try_mask_object(attr, self._mask)
|
93
|
+
|
94
|
+
|
95
|
+
class ClassFilterSubSelection(Subselection[Any]):
|
96
|
+
def __init__(self, classes: Sequence[int]) -> None:
|
97
|
+
self.classes = classes
|
98
|
+
|
99
|
+
def __call__(self, datum: _TDatum) -> _TDatum:
|
100
|
+
# build a mask for any arrays
|
101
|
+
image, target, metadata = datum
|
102
|
+
|
103
|
+
mask = np.isin(as_numpy(target.labels), self.classes)
|
104
|
+
flattened_metadata = flatten(metadata)[0]
|
105
|
+
filtered_metadata = {k: _try_mask_object(v, mask) for k, v in flattened_metadata.items()}
|
106
|
+
|
107
|
+
# return a masked datum
|
108
|
+
filtered_datum = image, ClassFilterTarget(target, mask), filtered_metadata
|
109
|
+
return cast(_TDatum, filtered_datum)
|
@@ -14,8 +14,8 @@ from sklearn.cluster import KMeans
|
|
14
14
|
from sklearn.metrics import pairwise_distances
|
15
15
|
|
16
16
|
from dataeval.config import EPSILON, DeviceLike, get_seed
|
17
|
-
from dataeval.
|
18
|
-
from dataeval.
|
17
|
+
from dataeval.data import Embeddings, Select
|
18
|
+
from dataeval.data._selection import Selection, SelectionStage
|
19
19
|
|
20
20
|
_logger = logging.getLogger(__name__)
|
21
21
|
|
@@ -8,9 +8,9 @@ import numpy as np
|
|
8
8
|
from numpy.random import BitGenerator, Generator, SeedSequence
|
9
9
|
from numpy.typing import NDArray
|
10
10
|
|
11
|
+
from dataeval.data._selection import Select, Selection, SelectionStage
|
11
12
|
from dataeval.typing import Array
|
12
13
|
from dataeval.utils._array import as_numpy
|
13
|
-
from dataeval.utils.data._selection import Select, Selection, SelectionStage
|
14
14
|
|
15
15
|
|
16
16
|
class Shuffle(Selection[Any]):
|
@@ -18,11 +18,11 @@ from typing import Callable, Literal, Protocol, TypeVar, runtime_checkable
|
|
18
18
|
import numpy as np
|
19
19
|
from numpy.typing import NDArray
|
20
20
|
|
21
|
+
from dataeval.data import Embeddings
|
21
22
|
from dataeval.outputs import DriftOutput
|
22
23
|
from dataeval.outputs._base import set_metadata
|
23
24
|
from dataeval.typing import Array
|
24
25
|
from dataeval.utils._array import as_numpy, flatten
|
25
|
-
from dataeval.utils.data import Embeddings
|
26
26
|
|
27
27
|
R = TypeVar("R")
|
28
28
|
|
dataeval/detectors/drift/_cvm.py
CHANGED
@@ -16,9 +16,9 @@ import numpy as np
|
|
16
16
|
from numpy.typing import NDArray
|
17
17
|
from scipy.stats import cramervonmises_2samp
|
18
18
|
|
19
|
+
from dataeval.data._embeddings import Embeddings
|
19
20
|
from dataeval.detectors.drift._base import BaseDriftUnivariate, UpdateStrategy
|
20
21
|
from dataeval.typing import Array
|
21
|
-
from dataeval.utils.data._embeddings import Embeddings
|
22
22
|
|
23
23
|
|
24
24
|
class DriftCVM(BaseDriftUnivariate):
|
@@ -52,7 +52,7 @@ class DriftCVM(BaseDriftUnivariate):
|
|
52
52
|
|
53
53
|
Example
|
54
54
|
-------
|
55
|
-
>>> from dataeval.
|
55
|
+
>>> from dataeval.data import Embeddings
|
56
56
|
|
57
57
|
Use Embeddings to encode images before testing for drift
|
58
58
|
|
dataeval/detectors/drift/_ks.py
CHANGED
@@ -16,9 +16,9 @@ import numpy as np
|
|
16
16
|
from numpy.typing import NDArray
|
17
17
|
from scipy.stats import ks_2samp
|
18
18
|
|
19
|
+
from dataeval.data._embeddings import Embeddings
|
19
20
|
from dataeval.detectors.drift._base import BaseDriftUnivariate, UpdateStrategy
|
20
21
|
from dataeval.typing import Array
|
21
|
-
from dataeval.utils.data._embeddings import Embeddings
|
22
22
|
|
23
23
|
|
24
24
|
class DriftKS(BaseDriftUnivariate):
|
@@ -54,7 +54,7 @@ class DriftKS(BaseDriftUnivariate):
|
|
54
54
|
|
55
55
|
Example
|
56
56
|
-------
|
57
|
-
>>> from dataeval.
|
57
|
+
>>> from dataeval.data import Embeddings
|
58
58
|
|
59
59
|
Use Embeddings to encode images before testing for drift
|
60
60
|
|
dataeval/detectors/drift/_mmd.py
CHANGED
@@ -15,11 +15,11 @@ from typing import Any, Callable
|
|
15
15
|
import torch
|
16
16
|
|
17
17
|
from dataeval.config import DeviceLike, get_device
|
18
|
+
from dataeval.data._embeddings import Embeddings
|
18
19
|
from dataeval.detectors.drift._base import BaseDrift, UpdateStrategy, update_strategy
|
19
20
|
from dataeval.outputs import DriftMMDOutput
|
20
21
|
from dataeval.outputs._base import set_metadata
|
21
22
|
from dataeval.typing import Array
|
22
|
-
from dataeval.utils.data._embeddings import Embeddings
|
23
23
|
|
24
24
|
|
25
25
|
class DriftMMD(BaseDrift):
|
@@ -51,7 +51,7 @@ class DriftMMD(BaseDrift):
|
|
51
51
|
|
52
52
|
Example
|
53
53
|
-------
|
54
|
-
>>> from dataeval.
|
54
|
+
>>> from dataeval.data import Embeddings
|
55
55
|
|
56
56
|
Use Embeddings to encode images before testing for drift
|
57
57
|
|