dataeval 0.84.1__py3-none-any.whl → 0.86.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (64) hide show
  1. dataeval/__init__.py +1 -1
  2. dataeval/data/__init__.py +19 -0
  3. dataeval/{utils/data → data}/_embeddings.py +137 -17
  4. dataeval/{utils/data → data}/_metadata.py +20 -8
  5. dataeval/{utils/data → data}/_selection.py +22 -9
  6. dataeval/{utils/data → data}/_split.py +1 -1
  7. dataeval/data/selections/__init__.py +19 -0
  8. dataeval/{utils/data → data}/selections/_classbalance.py +1 -2
  9. dataeval/data/selections/_classfilter.py +110 -0
  10. dataeval/{utils/data → data}/selections/_indices.py +1 -1
  11. dataeval/{utils/data → data}/selections/_limit.py +1 -1
  12. dataeval/{utils/data → data}/selections/_prioritize.py +2 -2
  13. dataeval/{utils/data → data}/selections/_reverse.py +1 -1
  14. dataeval/{utils/data → data}/selections/_shuffle.py +1 -1
  15. dataeval/detectors/drift/__init__.py +4 -1
  16. dataeval/detectors/drift/_base.py +1 -1
  17. dataeval/detectors/drift/_cvm.py +2 -2
  18. dataeval/detectors/drift/_ks.py +2 -2
  19. dataeval/detectors/drift/_mmd.py +2 -2
  20. dataeval/detectors/drift/_mvdc.py +92 -0
  21. dataeval/detectors/drift/_nml/__init__.py +6 -0
  22. dataeval/detectors/drift/_nml/_base.py +68 -0
  23. dataeval/detectors/drift/_nml/_chunk.py +404 -0
  24. dataeval/detectors/drift/_nml/_domainclassifier.py +192 -0
  25. dataeval/detectors/drift/_nml/_result.py +98 -0
  26. dataeval/detectors/drift/_nml/_thresholds.py +280 -0
  27. dataeval/detectors/linters/duplicates.py +1 -1
  28. dataeval/detectors/linters/outliers.py +1 -1
  29. dataeval/metadata/_distance.py +1 -1
  30. dataeval/metadata/_ood.py +4 -4
  31. dataeval/metrics/bias/_balance.py +1 -1
  32. dataeval/metrics/bias/_diversity.py +1 -1
  33. dataeval/metrics/bias/_parity.py +1 -1
  34. dataeval/metrics/stats/_labelstats.py +2 -2
  35. dataeval/outputs/__init__.py +2 -1
  36. dataeval/outputs/_bias.py +2 -4
  37. dataeval/outputs/_drift.py +68 -0
  38. dataeval/outputs/_linters.py +1 -6
  39. dataeval/outputs/_stats.py +1 -6
  40. dataeval/typing.py +31 -0
  41. dataeval/utils/__init__.py +2 -2
  42. dataeval/utils/data/__init__.py +5 -20
  43. dataeval/utils/data/collate.py +2 -0
  44. dataeval/utils/datasets/__init__.py +17 -0
  45. dataeval/utils/{data/datasets → datasets}/_base.py +3 -3
  46. dataeval/utils/{data/datasets → datasets}/_cifar10.py +2 -2
  47. dataeval/utils/{data/datasets → datasets}/_milco.py +2 -2
  48. dataeval/utils/{data/datasets → datasets}/_mnist.py +2 -2
  49. dataeval/utils/{data/datasets → datasets}/_ships.py +2 -2
  50. dataeval/utils/{data/datasets → datasets}/_voc.py +3 -3
  51. {dataeval-0.84.1.dist-info → dataeval-0.86.0.dist-info}/METADATA +3 -2
  52. dataeval-0.86.0.dist-info/RECORD +114 -0
  53. dataeval/utils/data/datasets/__init__.py +0 -17
  54. dataeval/utils/data/selections/__init__.py +0 -19
  55. dataeval/utils/data/selections/_classfilter.py +0 -44
  56. dataeval-0.84.1.dist-info/RECORD +0 -106
  57. /dataeval/{utils/data → data}/_images.py +0 -0
  58. /dataeval/{utils/data → data}/_targets.py +0 -0
  59. /dataeval/utils/{metadata.py → data/metadata.py} +0 -0
  60. /dataeval/utils/{data/datasets → datasets}/_fileio.py +0 -0
  61. /dataeval/utils/{data/datasets → datasets}/_mixin.py +0 -0
  62. /dataeval/utils/{data/datasets → datasets}/_types.py +0 -0
  63. {dataeval-0.84.1.dist-info → dataeval-0.86.0.dist-info}/LICENSE.txt +0 -0
  64. {dataeval-0.84.1.dist-info → dataeval-0.86.0.dist-info}/WHEEL +0 -0
dataeval/__init__.py CHANGED
@@ -8,7 +8,7 @@ shifts that impact performance of deployed models.
8
8
  from __future__ import annotations
9
9
 
10
10
  __all__ = ["config", "detectors", "log", "metrics", "typing", "utils", "workflows"]
11
- __version__ = "0.84.1"
11
+ __version__ = "0.86.0"
12
12
 
13
13
  import logging
14
14
 
@@ -0,0 +1,19 @@
1
+ """Provides utility functions for interacting with Computer Vision datasets."""
2
+
3
+ __all__ = [
4
+ "Embeddings",
5
+ "Images",
6
+ "Metadata",
7
+ "Select",
8
+ "SplitDatasetOutput",
9
+ "Targets",
10
+ "split_dataset",
11
+ ]
12
+
13
+ from dataeval.data._embeddings import Embeddings
14
+ from dataeval.data._images import Images
15
+ from dataeval.data._metadata import Metadata
16
+ from dataeval.data._selection import Select
17
+ from dataeval.data._split import split_dataset
18
+ from dataeval.data._targets import Targets
19
+ from dataeval.outputs._utils import SplitDatasetOutput
@@ -2,19 +2,25 @@ from __future__ import annotations
2
2
 
3
3
  __all__ = []
4
4
 
5
+ import logging
5
6
  import math
7
+ import os
8
+ from pathlib import Path
6
9
  from typing import Any, Iterator, Sequence, cast
7
10
 
8
11
  import torch
12
+ import xxhash as xxh
9
13
  from numpy.typing import NDArray
10
14
  from torch.utils.data import DataLoader, Subset
11
15
  from tqdm import tqdm
12
16
 
13
17
  from dataeval.config import DeviceLike, get_device
14
- from dataeval.typing import Array, ArrayLike, Dataset, Transform
18
+ from dataeval.typing import AnnotatedDataset, AnnotatedModel, Array, ArrayLike, Dataset, Transform
15
19
  from dataeval.utils._array import as_numpy
16
20
  from dataeval.utils.torch.models import SupportsEncode
17
21
 
22
+ _logger = logging.getLogger(__name__)
23
+
18
24
 
19
25
  class Embeddings:
20
26
  """
@@ -35,10 +41,23 @@ class Embeddings:
35
41
  device : DeviceLike or None, default None
36
42
  The hardware device to use if specified, otherwise uses the DataEval
37
43
  default or torch default.
38
- cache : bool, default False
39
- Whether to cache the embeddings in memory.
44
+ cache : Path, str, or bool, default False
45
+ Whether to cache the embeddings to a file or in memory.
46
+ When a Path or string is provided, embeddings will be cached to disk.
40
47
  verbose : bool, default False
41
48
  Whether to print progress bar when encoding images.
49
+
50
+ Attributes
51
+ ----------
52
+ batch_size : int
53
+ Batch size to use when encoding images.
54
+ cache : Path or bool
55
+ The path to cache embeddings to file, or True if caching to memory.
56
+ device : torch.device
57
+ The hardware device to use if specified, otherwise uses the DataEval
58
+ default or torch default.
59
+ verbose : bool
60
+ Whether to print progress bar when encoding images.
42
61
  """
43
62
 
44
63
  device: torch.device
@@ -52,24 +71,59 @@ class Embeddings:
52
71
  transforms: Transform[torch.Tensor] | Sequence[Transform[torch.Tensor]] | None = None,
53
72
  model: torch.nn.Module | None = None,
54
73
  device: DeviceLike | None = None,
55
- cache: bool = False,
74
+ cache: Path | str | bool = False,
56
75
  verbose: bool = False,
57
76
  ) -> None:
58
77
  self.device = get_device(device)
59
- self.cache = cache
60
78
  self.batch_size = batch_size if batch_size > 0 else 1
61
79
  self.verbose = verbose
62
80
 
81
+ self._embeddings_only: bool = False
63
82
  self._dataset = dataset
64
- self._length = len(dataset)
65
83
  model = torch.nn.Flatten() if model is None else model
66
84
  self._transforms = [transforms] if isinstance(transforms, Transform) else transforms
67
85
  self._model = model.to(self.device).eval() if isinstance(model, torch.nn.Module) else model
68
86
  self._encoder = model.encode if isinstance(model, SupportsEncode) else model
69
87
  self._collate_fn = lambda datum: [torch.as_tensor(d[0] if isinstance(d, tuple) else d) for d in datum]
70
- self._cached_idx = set()
88
+ self._cached_idx: set[int] = set()
71
89
  self._embeddings: torch.Tensor = torch.empty(())
72
- self._shallow: bool = False
90
+
91
+ self._cache = cache if isinstance(cache, bool) else self._resolve_path(cache)
92
+
93
+ def __hash__(self) -> int:
94
+ if self._embeddings_only:
95
+ bid = as_numpy(self._embeddings).ravel().tobytes()
96
+ else:
97
+ did = self._dataset.metadata["id"] if isinstance(self._dataset, AnnotatedDataset) else str(self._dataset)
98
+ mid = self._model.metadata["id"] if isinstance(self._model, AnnotatedModel) else str(self._model)
99
+ tid = str.join("|", [str(t) for t in self._transforms or []])
100
+ bid = f"{did}{mid}{tid}".encode()
101
+
102
+ return int(xxh.xxh3_64_hexdigest(bid), 16)
103
+
104
+ @property
105
+ def cache(self) -> Path | bool:
106
+ return self._cache
107
+
108
+ @cache.setter
109
+ def cache(self, value: Path | str | bool) -> None:
110
+ if isinstance(value, bool) and not value:
111
+ self._cached_idx = set()
112
+ self._embeddings = torch.empty(())
113
+ elif isinstance(value, (Path, str)):
114
+ value = self._resolve_path(value)
115
+
116
+ if isinstance(value, Path) and value != getattr(self, "_cache", None):
117
+ self._save(value)
118
+
119
+ self._cache = value
120
+
121
+ def _resolve_path(self, path: Path | str) -> Path:
122
+ if isinstance(path, str):
123
+ path = Path(os.path.abspath(path))
124
+ if isinstance(path, Path) and (path.is_dir() or not path.suffix):
125
+ path = path / f"emb-{hash(self)}.pt"
126
+ return path
73
127
 
74
128
  def to_tensor(self, indices: Sequence[int] | None = None) -> torch.Tensor:
75
129
  """
@@ -125,8 +179,10 @@ class Embeddings:
125
179
  -------
126
180
  Embeddings
127
181
  """
182
+ if self._embeddings_only:
183
+ raise ValueError("Embeddings object does not have a model.")
128
184
  return Embeddings(
129
- dataset, self.batch_size, self._transforms, self._model, self.device, self.cache, self.verbose
185
+ dataset, self.batch_size, self._transforms, self._model, self.device, bool(self.cache), self.verbose
130
186
  )
131
187
 
132
188
  @classmethod
@@ -149,7 +205,7 @@ class Embeddings:
149
205
  Example
150
206
  -------
151
207
  >>> import numpy as np
152
- >>> from dataeval.utils.data._embeddings import Embeddings
208
+ >>> from dataeval.data import Embeddings
153
209
  >>> array = np.random.randn(100, 3, 224, 224)
154
210
  >>> embeddings = Embeddings.from_array(array)
155
211
  >>> print(embeddings.to_tensor().shape)
@@ -157,12 +213,70 @@ class Embeddings:
157
213
  """
158
214
  embeddings = Embeddings([], 0, None, None, device, True, False)
159
215
  array = array if isinstance(array, Array) else as_numpy(array)
160
- embeddings._length = len(array)
161
216
  embeddings._cached_idx = set(range(len(array)))
162
217
  embeddings._embeddings = torch.as_tensor(array).to(get_device(device))
163
- embeddings._shallow = True
218
+ embeddings._embeddings_only = True
164
219
  return embeddings
165
220
 
221
+ def save(self, path: Path | str) -> None:
222
+ """
223
+ Saves the embeddings to disk.
224
+
225
+ Parameters
226
+ ----------
227
+ path : Path or str
228
+ The file path to save the embeddings to.
229
+ """
230
+ self._save(self._resolve_path(path), True)
231
+
232
+ def _save(self, path: Path, force: bool = False) -> None:
233
+ path.parent.mkdir(parents=True, exist_ok=True)
234
+
235
+ if self._embeddings_only or self.cache and not force:
236
+ embeddings = self._embeddings
237
+ cached_idx = self._cached_idx
238
+ else:
239
+ embeddings = self.to_tensor()
240
+ cached_idx = list(range(len(self)))
241
+ try:
242
+ cache_data = {
243
+ "embeddings": embeddings,
244
+ "cached_indices": cached_idx,
245
+ "device": self.device,
246
+ }
247
+ torch.save(cache_data, path)
248
+ _logger.log(logging.DEBUG, f"Saved embeddings cache from {path}")
249
+ except Exception as e:
250
+ _logger.log(logging.ERROR, f"Failed to save embeddings cache: {e}")
251
+
252
+ @classmethod
253
+ def load(cls, path: Path | str) -> Embeddings:
254
+ """
255
+ Loads the embeddings from disk.
256
+
257
+ Parameters
258
+ ----------
259
+ path : Path or str
260
+ The file path to load the embeddings from.
261
+ """
262
+ emb = Embeddings([], 0)
263
+ path = Path(os.path.abspath(path)) if isinstance(path, str) else path
264
+ if path.exists() and path.is_file():
265
+ try:
266
+ cache_data = torch.load(path, weights_only=False)
267
+ emb._embeddings_only = True
268
+ emb._embeddings = cache_data["embeddings"]
269
+ emb._cached_idx = cache_data["cached_indices"]
270
+ emb.device = cache_data["device"]
271
+ _logger.log(logging.DEBUG, f"Loaded embeddings cache from {path}")
272
+ except Exception as e:
273
+ _logger.log(logging.ERROR, f"Failed to load embeddings cache: {e}")
274
+ raise e
275
+ else:
276
+ raise FileNotFoundError(f"Specified cache file {path} was not found.")
277
+
278
+ return emb
279
+
166
280
  def _encode(self, images: list[torch.Tensor]) -> torch.Tensor:
167
281
  if self._transforms:
168
282
  images = [transform(image) for transform in self._transforms for image in images]
@@ -195,31 +309,37 @@ class Embeddings:
195
309
  embeddings = self._encode(images)
196
310
 
197
311
  if not self._embeddings.shape:
198
- full_shape = (len(self._dataset), *embeddings.shape[1:])
312
+ full_shape = (len(self), *embeddings.shape[1:])
199
313
  self._embeddings = torch.empty(full_shape, dtype=embeddings.dtype, device=self.device)
200
314
 
201
315
  self._embeddings[uncached] = embeddings
202
316
  self._cached_idx.update(uncached)
203
317
 
318
+ if isinstance(self.cache, Path):
319
+ self._save(self.cache)
320
+
204
321
  yield self._embeddings[batch]
205
322
 
206
323
  def __getitem__(self, key: int | slice, /) -> torch.Tensor:
207
324
  if not isinstance(key, slice) and not hasattr(key, "__int__"):
208
325
  raise TypeError("Invalid argument type.")
209
326
 
210
- if self._shallow:
327
+ indices = list(range(len(self))[key]) if isinstance(key, slice) else [int(key)]
328
+
329
+ if self._embeddings_only:
211
330
  if not self._embeddings.shape:
212
331
  raise ValueError("Embeddings not initialized.")
332
+ if not set(indices).issubset(self._cached_idx):
333
+ raise ValueError("Unable to generate new embeddings from a shallow instance.")
213
334
  return self._embeddings[key]
214
335
 
215
- indices = list(range(len(self._dataset))[key]) if isinstance(key, slice) else [int(key)]
216
336
  result = torch.vstack(list(self._batch(indices))).to(self.device)
217
337
  return result.squeeze(0) if len(indices) == 1 else result
218
338
 
219
339
  def __iter__(self) -> Iterator[torch.Tensor]:
220
340
  # process in batches while yielding individual embeddings
221
- for batch in self._batch(range(self._length)):
341
+ for batch in self._batch(range(len(self))):
222
342
  yield from batch
223
343
 
224
344
  def __len__(self) -> int:
225
- return self._length
345
+ return len(self._embeddings) if self._embeddings_only else len(self._dataset)
@@ -16,12 +16,12 @@ from dataeval.typing import (
16
16
  )
17
17
  from dataeval.utils._array import as_numpy, to_numpy
18
18
  from dataeval.utils._bin import bin_data, digitize_data, is_continuous
19
- from dataeval.utils.metadata import merge
19
+ from dataeval.utils.data.metadata import merge
20
20
 
21
21
  if TYPE_CHECKING:
22
- from dataeval.utils.data import Targets
22
+ from dataeval.data import Targets
23
23
  else:
24
- from dataeval.utils.data._targets import Targets
24
+ from dataeval.data._targets import Targets
25
25
 
26
26
 
27
27
  class Metadata:
@@ -191,6 +191,11 @@ class Metadata:
191
191
  self._process()
192
192
  return self._image_indices
193
193
 
194
+ @property
195
+ def image_count(self) -> int:
196
+ self._process()
197
+ return int(self._image_indices.max() + 1)
198
+
194
199
  def _collate(self, force: bool = False):
195
200
  if self._collated and not force:
196
201
  return
@@ -359,12 +364,19 @@ class Metadata:
359
364
 
360
365
  def add_factors(self, factors: Mapping[str, ArrayLike]) -> None:
361
366
  self._merge()
362
- self._processed = False
363
- target_len = len(self.targets.source) if self.targets.source is not None else len(self.targets)
364
- if any(len(v if isinstance(v, Sized) else as_numpy(v)) != target_len for v in factors.values()):
367
+
368
+ targets = len(self.targets.source) if self.targets.source is not None else len(self.targets)
369
+ images = self.image_count
370
+ lengths = {k: len(v if isinstance(v, Sized) else np.atleast_1d(as_numpy(v))) for k, v in factors.items()}
371
+ targets_match = all(f == targets for f in lengths.values())
372
+ images_match = targets_match if images == targets else all(f == images for f in lengths.values())
373
+ if not targets_match and not images_match:
365
374
  raise ValueError(
366
375
  "The lists/arrays in the provided factors have a different length than the current metadata factors."
367
376
  )
368
- merged = cast(tuple[dict[str, ArrayLike], dict[str, list[str]]], self._merged)[0]
377
+ merged = cast(dict[str, ArrayLike], self._merged[0] if self._merged is not None else {})
369
378
  for k, v in factors.items():
370
- merged[k] = v
379
+ v = as_numpy(v)
380
+ merged[k] = v if (self.targets.source is None or lengths[k] == targets) else v[self.targets.source]
381
+
382
+ self._processed = False
@@ -25,6 +25,10 @@ class Selection(Generic[_TDatum]):
25
25
  return f"{self.__class__.__name__}({', '.join([f'{k}={v}' for k, v in self.__dict__.items()])})"
26
26
 
27
27
 
28
+ class Subselection(Generic[_TDatum]):
29
+ def __call__(self, original: _TDatum) -> _TDatum: ...
30
+
31
+
28
32
  class Select(AnnotatedDataset[_TDatum]):
29
33
  """
30
34
  Wraps a dataset and applies selection criteria to it.
@@ -38,7 +42,7 @@ class Select(AnnotatedDataset[_TDatum]):
38
42
 
39
43
  Examples
40
44
  --------
41
- >>> from dataeval.utils.data.selections import ClassFilter, Limit
45
+ >>> from dataeval.data.selections import ClassFilter, Limit
42
46
 
43
47
  >>> # Construct a sample dataset with size of 100 and class count of 10
44
48
  >>> # Elements at index `idx` are returned as tuples:
@@ -63,6 +67,7 @@ class Select(AnnotatedDataset[_TDatum]):
63
67
  _selection: list[int]
64
68
  _selections: Sequence[Selection[_TDatum]]
65
69
  _size_limit: int
70
+ _subselections: list[tuple[Subselection[_TDatum], set[int]]]
66
71
 
67
72
  def __init__(
68
73
  self,
@@ -73,7 +78,8 @@ class Select(AnnotatedDataset[_TDatum]):
73
78
  self._dataset = dataset
74
79
  self._size_limit = len(dataset)
75
80
  self._selection = list(range(self._size_limit))
76
- self._selections = self._sort(selections)
81
+ self._selections = self._sort_selections(selections)
82
+ self._subselections = []
77
83
 
78
84
  # Ensure metadata is populated correctly as DatasetMetadata TypedDict
79
85
  _metadata = getattr(dataset, "metadata", {})
@@ -81,7 +87,7 @@ class Select(AnnotatedDataset[_TDatum]):
81
87
  _metadata["id"] = dataset.__class__.__name__
82
88
  self._metadata = DatasetMetadata(**_metadata)
83
89
 
84
- self._select()
90
+ self._apply_selections()
85
91
 
86
92
  @property
87
93
  def metadata(self) -> DatasetMetadata:
@@ -94,24 +100,31 @@ class Select(AnnotatedDataset[_TDatum]):
94
100
  selections = f"Selections: [{', '.join([str(s) for s in self._selections])}]"
95
101
  return f"{title}\n{sep}{nt}{selections}{nt}Selected Size: {len(self)}\n\n{self._dataset}"
96
102
 
97
- def _sort(self, selections: Selection[_TDatum] | Sequence[Selection[_TDatum]] | None) -> list[Selection]:
103
+ def _sort_selections(
104
+ self, selections: Selection[_TDatum] | Sequence[Selection[_TDatum]] | None
105
+ ) -> list[Selection[_TDatum]]:
98
106
  if not selections:
99
107
  return []
100
108
 
101
- selections = [selections] if isinstance(selections, Selection) else selections
102
- grouped: dict[int, list[Selection]] = {}
103
- for selection in selections:
109
+ selections_list = [selections] if isinstance(selections, Selection) else list(selections)
110
+ grouped: dict[int, list[Selection[_TDatum]]] = {}
111
+ for selection in selections_list:
104
112
  grouped.setdefault(selection.stage, []).append(selection)
105
113
  selection_list = [selection for category in sorted(grouped) for selection in grouped[category]]
106
114
  return selection_list
107
115
 
108
- def _select(self) -> None:
116
+ def _apply_selections(self) -> None:
109
117
  for selection in self._selections:
110
118
  selection(self)
111
119
  self._selection = self._selection[: self._size_limit]
112
120
 
121
+ def _apply_subselection(self, datum: _TDatum, index: int) -> _TDatum:
122
+ for subselection, indices in self._subselections:
123
+ datum = subselection(datum) if self._selection[index] in indices else datum
124
+ return datum
125
+
113
126
  def __getitem__(self, index: int) -> _TDatum:
114
- return self._dataset[self._selection[index]]
127
+ return self._apply_subselection(self._dataset[self._selection[index]], index)
115
128
 
116
129
  def __iter__(self) -> Iterator[_TDatum]:
117
130
  for i in range(len(self)):
@@ -12,10 +12,10 @@ from sklearn.model_selection import GroupKFold, KFold, StratifiedGroupKFold, Str
12
12
  from sklearn.utils.multiclass import type_of_target
13
13
 
14
14
  from dataeval.config import EPSILON
15
+ from dataeval.data._metadata import Metadata
15
16
  from dataeval.outputs._base import set_metadata
16
17
  from dataeval.outputs._utils import SplitDatasetOutput, TrainValSplit
17
18
  from dataeval.typing import AnnotatedDataset
18
- from dataeval.utils.data._metadata import Metadata
19
19
 
20
20
  _logger = logging.getLogger(__name__)
21
21
 
@@ -0,0 +1,19 @@
1
+ """Provides selection classes for selecting subsets of Computer Vision datasets."""
2
+
3
+ __all__ = [
4
+ "ClassBalance",
5
+ "ClassFilter",
6
+ "Indices",
7
+ "Limit",
8
+ "Prioritize",
9
+ "Reverse",
10
+ "Shuffle",
11
+ ]
12
+
13
+ from dataeval.data.selections._classbalance import ClassBalance
14
+ from dataeval.data.selections._classfilter import ClassFilter
15
+ from dataeval.data.selections._indices import Indices
16
+ from dataeval.data.selections._limit import Limit
17
+ from dataeval.data.selections._prioritize import Prioritize
18
+ from dataeval.data.selections._reverse import Reverse
19
+ from dataeval.data.selections._shuffle import Shuffle
@@ -2,12 +2,11 @@ from __future__ import annotations
2
2
 
3
3
  __all__ = []
4
4
 
5
-
6
5
  import numpy as np
7
6
 
7
+ from dataeval.data._selection import Select, Selection, SelectionStage
8
8
  from dataeval.typing import Array, ImageClassificationDatum
9
9
  from dataeval.utils._array import as_numpy
10
- from dataeval.utils.data._selection import Select, Selection, SelectionStage
11
10
 
12
11
 
13
12
  class ClassBalance(Selection[ImageClassificationDatum]):
@@ -0,0 +1,110 @@
1
+ from __future__ import annotations
2
+
3
+ __all__ = []
4
+
5
+ from typing import Any, Generic, Iterable, Sequence, Sized, TypeVar, cast
6
+
7
+ import numpy as np
8
+ from numpy.typing import NDArray
9
+
10
+ from dataeval.data._selection import Select, Selection, SelectionStage, Subselection
11
+ from dataeval.typing import Array, ObjectDetectionDatum, ObjectDetectionTarget, SegmentationDatum, SegmentationTarget
12
+ from dataeval.utils._array import as_numpy
13
+
14
+
15
+ class ClassFilter(Selection[Any]):
16
+ """
17
+ Filter the dataset by class.
18
+
19
+ Parameters
20
+ ----------
21
+ classes : Sequence[int]
22
+ The classes to filter by.
23
+ filter_detections : bool, default True
24
+ Whether to filter detections from targets for object detection and segmentation datasets.
25
+ """
26
+
27
+ stage = SelectionStage.FILTER
28
+
29
+ def __init__(self, classes: Sequence[int], filter_detections: bool = True) -> None:
30
+ self.classes = classes
31
+ self.filter_detections = filter_detections
32
+
33
+ def __call__(self, dataset: Select[Any]) -> None:
34
+ if not self.classes:
35
+ return
36
+
37
+ selection = []
38
+ subselection = set()
39
+ for idx in dataset._selection:
40
+ target = dataset._dataset[idx][1]
41
+ if isinstance(target, Array):
42
+ # Get the label for the image
43
+ label = int(np.argmax(as_numpy(target)))
44
+ # Check to see if the label is in the classes to filter for
45
+ if label in self.classes:
46
+ # Include the image
47
+ selection.append(idx)
48
+ elif isinstance(target, (ObjectDetectionTarget, SegmentationTarget)):
49
+ # Get the set of labels from the target
50
+ labels = set(target.labels if isinstance(target.labels, Iterable) else [target.labels])
51
+ # Check to see if any labels are in the classes to filter for
52
+ if labels.intersection(self.classes):
53
+ # Include the image
54
+ selection.append(idx)
55
+ # If we are filtering out other labels and there are other labels, add a subselection filter
56
+ if self.filter_detections and labels.difference(self.classes):
57
+ subselection.add(idx)
58
+ else:
59
+ raise TypeError(f"ClassFilter does not support targets of type {type(target)}.")
60
+
61
+ dataset._selection = selection
62
+ dataset._subselections.append((ClassFilterSubSelection(self.classes), subselection))
63
+
64
+
65
+ _T = TypeVar("_T")
66
+ _TDatum = TypeVar("_TDatum", ObjectDetectionDatum, SegmentationDatum)
67
+ _TTarget = TypeVar("_TTarget", ObjectDetectionTarget, SegmentationTarget)
68
+
69
+
70
+ def _try_mask_object(obj: _T, mask: NDArray[np.bool_]) -> _T:
71
+ if isinstance(obj, Sized) and not isinstance(obj, (str, bytes, bytearray)) and len(obj) == len(mask):
72
+ if isinstance(obj, Array):
73
+ return obj[mask]
74
+ elif isinstance(obj, Sequence):
75
+ return cast(_T, [item for i, item in enumerate(obj) if mask[i]])
76
+ return obj
77
+
78
+
79
+ class ClassFilterTarget(Generic[_TTarget]):
80
+ def __init__(self, target: _TTarget, mask: NDArray[np.bool_]) -> None:
81
+ self.__dict__.update(target.__dict__)
82
+ self._length = len(target.labels) if isinstance(target.labels, Sized) else int(bool(target.labels))
83
+ self._mask = mask
84
+ self._target = target
85
+
86
+ def __getattribute__(self, name: str) -> Any:
87
+ if name in ("_length", "_mask", "_target") or name.startswith("__") and name.endswith("__"):
88
+ return super().__getattribute__(name)
89
+
90
+ attr = getattr(self._target, name)
91
+ return _try_mask_object(attr, self._mask)
92
+
93
+
94
+ class ClassFilterSubSelection(Subselection[Any]):
95
+ def __init__(self, classes: Sequence[int]) -> None:
96
+ self.classes = classes
97
+
98
+ def _filter(self, d: dict[str, Any], mask: NDArray[np.bool_]) -> dict[str, Any]:
99
+ return {k: self._filter(v, mask) if isinstance(v, dict) else _try_mask_object(v, mask) for k, v in d.items()}
100
+
101
+ def __call__(self, datum: _TDatum) -> _TDatum:
102
+ # build a mask for any arrays
103
+ image, target, metadata = datum
104
+
105
+ mask = np.isin(as_numpy(target.labels), self.classes)
106
+ filtered_metadata = self._filter(metadata, mask)
107
+
108
+ # return a masked datum
109
+ filtered_datum = image, ClassFilterTarget(target, mask), filtered_metadata
110
+ return cast(_TDatum, filtered_datum)
@@ -4,7 +4,7 @@ __all__ = []
4
4
 
5
5
  from typing import Any, Sequence
6
6
 
7
- from dataeval.utils.data._selection import Select, Selection, SelectionStage
7
+ from dataeval.data._selection import Select, Selection, SelectionStage
8
8
 
9
9
 
10
10
  class Indices(Selection[Any]):
@@ -4,7 +4,7 @@ __all__ = []
4
4
 
5
5
  from typing import Any
6
6
 
7
- from dataeval.utils.data._selection import Select, Selection, SelectionStage
7
+ from dataeval.data._selection import Select, Selection, SelectionStage
8
8
 
9
9
 
10
10
  class Limit(Selection[Any]):
@@ -14,8 +14,8 @@ from sklearn.cluster import KMeans
14
14
  from sklearn.metrics import pairwise_distances
15
15
 
16
16
  from dataeval.config import EPSILON, DeviceLike, get_seed
17
- from dataeval.utils.data import Embeddings, Select
18
- from dataeval.utils.data._selection import Selection, SelectionStage
17
+ from dataeval.data import Embeddings, Select
18
+ from dataeval.data._selection import Selection, SelectionStage
19
19
 
20
20
  _logger = logging.getLogger(__name__)
21
21
 
@@ -4,7 +4,7 @@ __all__ = []
4
4
 
5
5
  from typing import Any
6
6
 
7
- from dataeval.utils.data._selection import Select, Selection, SelectionStage
7
+ from dataeval.data._selection import Select, Selection, SelectionStage
8
8
 
9
9
 
10
10
  class Reverse(Selection[Any]):
@@ -8,9 +8,9 @@ import numpy as np
8
8
  from numpy.random import BitGenerator, Generator, SeedSequence
9
9
  from numpy.typing import NDArray
10
10
 
11
+ from dataeval.data._selection import Select, Selection, SelectionStage
11
12
  from dataeval.typing import Array
12
13
  from dataeval.utils._array import as_numpy
13
- from dataeval.utils.data._selection import Select, Selection, SelectionStage
14
14
 
15
15
 
16
16
  class Shuffle(Selection[Any]):
@@ -7,6 +7,8 @@ __all__ = [
7
7
  "DriftKS",
8
8
  "DriftMMD",
9
9
  "DriftMMDOutput",
10
+ "DriftMVDC",
11
+ "DriftMVDCOutput",
10
12
  "DriftOutput",
11
13
  "DriftUncertainty",
12
14
  "UpdateStrategy",
@@ -18,5 +20,6 @@ from dataeval.detectors.drift._base import UpdateStrategy
18
20
  from dataeval.detectors.drift._cvm import DriftCVM
19
21
  from dataeval.detectors.drift._ks import DriftKS
20
22
  from dataeval.detectors.drift._mmd import DriftMMD
23
+ from dataeval.detectors.drift._mvdc import DriftMVDC
21
24
  from dataeval.detectors.drift._uncertainty import DriftUncertainty
22
- from dataeval.outputs._drift import DriftMMDOutput, DriftOutput
25
+ from dataeval.outputs._drift import DriftMMDOutput, DriftMVDCOutput, DriftOutput
@@ -18,11 +18,11 @@ from typing import Callable, Literal, Protocol, TypeVar, runtime_checkable
18
18
  import numpy as np
19
19
  from numpy.typing import NDArray
20
20
 
21
+ from dataeval.data import Embeddings
21
22
  from dataeval.outputs import DriftOutput
22
23
  from dataeval.outputs._base import set_metadata
23
24
  from dataeval.typing import Array
24
25
  from dataeval.utils._array import as_numpy, flatten
25
- from dataeval.utils.data import Embeddings
26
26
 
27
27
  R = TypeVar("R")
28
28