dataeval 0.84.0__py3-none-any.whl → 0.85.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (67) hide show
  1. dataeval/__init__.py +1 -1
  2. dataeval/data/__init__.py +19 -0
  3. dataeval/data/_embeddings.py +345 -0
  4. dataeval/{utils/data → data}/_images.py +2 -2
  5. dataeval/{utils/data → data}/_metadata.py +8 -7
  6. dataeval/{utils/data → data}/_selection.py +22 -9
  7. dataeval/{utils/data → data}/_split.py +1 -1
  8. dataeval/data/selections/__init__.py +19 -0
  9. dataeval/data/selections/_classbalance.py +37 -0
  10. dataeval/data/selections/_classfilter.py +109 -0
  11. dataeval/{utils/data → data}/selections/_indices.py +1 -1
  12. dataeval/{utils/data → data}/selections/_limit.py +1 -1
  13. dataeval/{utils/data → data}/selections/_prioritize.py +3 -3
  14. dataeval/{utils/data → data}/selections/_reverse.py +1 -1
  15. dataeval/{utils/data → data}/selections/_shuffle.py +3 -3
  16. dataeval/detectors/drift/__init__.py +2 -2
  17. dataeval/detectors/drift/_base.py +55 -203
  18. dataeval/detectors/drift/_cvm.py +19 -30
  19. dataeval/detectors/drift/_ks.py +18 -30
  20. dataeval/detectors/drift/_mmd.py +189 -53
  21. dataeval/detectors/drift/_uncertainty.py +52 -56
  22. dataeval/detectors/drift/updates.py +13 -12
  23. dataeval/detectors/linters/duplicates.py +6 -4
  24. dataeval/detectors/linters/outliers.py +3 -3
  25. dataeval/detectors/ood/ae.py +1 -1
  26. dataeval/metadata/_distance.py +1 -1
  27. dataeval/metadata/_ood.py +4 -4
  28. dataeval/metrics/bias/_balance.py +1 -1
  29. dataeval/metrics/bias/_diversity.py +1 -1
  30. dataeval/metrics/bias/_parity.py +1 -1
  31. dataeval/metrics/stats/_base.py +7 -7
  32. dataeval/metrics/stats/_dimensionstats.py +2 -2
  33. dataeval/metrics/stats/_hashstats.py +2 -2
  34. dataeval/metrics/stats/_imagestats.py +4 -4
  35. dataeval/metrics/stats/_labelstats.py +2 -2
  36. dataeval/metrics/stats/_pixelstats.py +2 -2
  37. dataeval/metrics/stats/_visualstats.py +2 -2
  38. dataeval/outputs/_bias.py +1 -1
  39. dataeval/typing.py +53 -19
  40. dataeval/utils/__init__.py +2 -2
  41. dataeval/utils/_array.py +18 -7
  42. dataeval/utils/data/__init__.py +5 -20
  43. dataeval/utils/data/_dataset.py +6 -4
  44. dataeval/utils/data/collate.py +2 -0
  45. dataeval/utils/datasets/__init__.py +17 -0
  46. dataeval/utils/{data/datasets → datasets}/_base.py +10 -7
  47. dataeval/utils/{data/datasets → datasets}/_cifar10.py +11 -11
  48. dataeval/utils/{data/datasets → datasets}/_milco.py +44 -16
  49. dataeval/utils/{data/datasets → datasets}/_mnist.py +11 -7
  50. dataeval/utils/{data/datasets → datasets}/_ships.py +10 -6
  51. dataeval/utils/{data/datasets → datasets}/_voc.py +43 -22
  52. dataeval/utils/torch/_internal.py +12 -35
  53. {dataeval-0.84.0.dist-info → dataeval-0.85.0.dist-info}/METADATA +2 -3
  54. dataeval-0.85.0.dist-info/RECORD +107 -0
  55. dataeval/detectors/drift/_torch.py +0 -222
  56. dataeval/utils/data/_embeddings.py +0 -186
  57. dataeval/utils/data/datasets/__init__.py +0 -17
  58. dataeval/utils/data/selections/__init__.py +0 -17
  59. dataeval/utils/data/selections/_classfilter.py +0 -59
  60. dataeval-0.84.0.dist-info/RECORD +0 -106
  61. /dataeval/{utils/data → data}/_targets.py +0 -0
  62. /dataeval/utils/{metadata.py → data/metadata.py} +0 -0
  63. /dataeval/utils/{data/datasets → datasets}/_fileio.py +0 -0
  64. /dataeval/utils/{data/datasets → datasets}/_mixin.py +0 -0
  65. /dataeval/utils/{data/datasets → datasets}/_types.py +0 -0
  66. {dataeval-0.84.0.dist-info → dataeval-0.85.0.dist-info}/LICENSE.txt +0 -0
  67. {dataeval-0.84.0.dist-info → dataeval-0.85.0.dist-info}/WHEEL +0 -0
dataeval/__init__.py CHANGED
@@ -8,7 +8,7 @@ shifts that impact performance of deployed models.
8
8
  from __future__ import annotations
9
9
 
10
10
  __all__ = ["config", "detectors", "log", "metrics", "typing", "utils", "workflows"]
11
- __version__ = "0.84.0"
11
+ __version__ = "0.85.0"
12
12
 
13
13
  import logging
14
14
 
@@ -0,0 +1,19 @@
1
+ """Provides utility functions for interacting with Computer Vision datasets."""
2
+
3
+ __all__ = [
4
+ "Embeddings",
5
+ "Images",
6
+ "Metadata",
7
+ "Select",
8
+ "SplitDatasetOutput",
9
+ "Targets",
10
+ "split_dataset",
11
+ ]
12
+
13
+ from dataeval.data._embeddings import Embeddings
14
+ from dataeval.data._images import Images
15
+ from dataeval.data._metadata import Metadata
16
+ from dataeval.data._selection import Select
17
+ from dataeval.data._split import split_dataset
18
+ from dataeval.data._targets import Targets
19
+ from dataeval.outputs._utils import SplitDatasetOutput
@@ -0,0 +1,345 @@
1
+ from __future__ import annotations
2
+
3
+ __all__ = []
4
+
5
+ import logging
6
+ import math
7
+ import os
8
+ from pathlib import Path
9
+ from typing import Any, Iterator, Sequence, cast
10
+
11
+ import torch
12
+ import xxhash as xxh
13
+ from numpy.typing import NDArray
14
+ from torch.utils.data import DataLoader, Subset
15
+ from tqdm import tqdm
16
+
17
+ from dataeval.config import DeviceLike, get_device
18
+ from dataeval.typing import AnnotatedDataset, AnnotatedModel, Array, ArrayLike, Dataset, Transform
19
+ from dataeval.utils._array import as_numpy
20
+ from dataeval.utils.torch.models import SupportsEncode
21
+
22
+ _logger = logging.getLogger(__name__)
23
+
24
+
25
+ class Embeddings:
26
+ """
27
+ Collection of image embeddings from a dataset.
28
+
29
+ Embeddings are accessed by index or slice and are only loaded on-demand.
30
+
31
+ Parameters
32
+ ----------
33
+ dataset : ImageClassificationDataset or ObjectDetectionDataset
34
+ Dataset to access original images from.
35
+ batch_size : int
36
+ Batch size to use when encoding images.
37
+ transforms : Transform or Sequence[Transform] or None, default None
38
+ Transforms to apply to images before encoding.
39
+ model : torch.nn.Module or None, default None
40
+ Model to use for encoding images.
41
+ device : DeviceLike or None, default None
42
+ The hardware device to use if specified, otherwise uses the DataEval
43
+ default or torch default.
44
+ cache : Path, str, or bool, default False
45
+ Whether to cache the embeddings to a file or in memory.
46
+ When a Path or string is provided, embeddings will be cached to disk.
47
+ verbose : bool, default False
48
+ Whether to print progress bar when encoding images.
49
+
50
+ Attributes
51
+ ----------
52
+ batch_size : int
53
+ Batch size to use when encoding images.
54
+ cache : Path or bool
55
+ The path to cache embeddings to file, or True if caching to memory.
56
+ device : torch.device
57
+ The hardware device to use if specified, otherwise uses the DataEval
58
+ default or torch default.
59
+ verbose : bool
60
+ Whether to print progress bar when encoding images.
61
+ """
62
+
63
+ device: torch.device
64
+ batch_size: int
65
+ verbose: bool
66
+
67
+ def __init__(
68
+ self,
69
+ dataset: Dataset[tuple[ArrayLike, Any, Any]] | Dataset[ArrayLike],
70
+ batch_size: int,
71
+ transforms: Transform[torch.Tensor] | Sequence[Transform[torch.Tensor]] | None = None,
72
+ model: torch.nn.Module | None = None,
73
+ device: DeviceLike | None = None,
74
+ cache: Path | str | bool = False,
75
+ verbose: bool = False,
76
+ ) -> None:
77
+ self.device = get_device(device)
78
+ self.batch_size = batch_size if batch_size > 0 else 1
79
+ self.verbose = verbose
80
+
81
+ self._embeddings_only: bool = False
82
+ self._dataset = dataset
83
+ model = torch.nn.Flatten() if model is None else model
84
+ self._transforms = [transforms] if isinstance(transforms, Transform) else transforms
85
+ self._model = model.to(self.device).eval() if isinstance(model, torch.nn.Module) else model
86
+ self._encoder = model.encode if isinstance(model, SupportsEncode) else model
87
+ self._collate_fn = lambda datum: [torch.as_tensor(d[0] if isinstance(d, tuple) else d) for d in datum]
88
+ self._cached_idx: set[int] = set()
89
+ self._embeddings: torch.Tensor = torch.empty(())
90
+
91
+ self._cache = cache if isinstance(cache, bool) else self._resolve_path(cache)
92
+
93
+ def __hash__(self) -> int:
94
+ if self._embeddings_only:
95
+ bid = as_numpy(self._embeddings).ravel().tobytes()
96
+ else:
97
+ did = self._dataset.metadata["id"] if isinstance(self._dataset, AnnotatedDataset) else str(self._dataset)
98
+ mid = self._model.metadata["id"] if isinstance(self._model, AnnotatedModel) else str(self._model)
99
+ tid = str.join("|", [str(t) for t in self._transforms or []])
100
+ bid = f"{did}{mid}{tid}".encode()
101
+
102
+ return int(xxh.xxh3_64_hexdigest(bid), 16)
103
+
104
+ @property
105
+ def cache(self) -> Path | bool:
106
+ return self._cache
107
+
108
+ @cache.setter
109
+ def cache(self, value: Path | str | bool) -> None:
110
+ if isinstance(value, bool) and not value:
111
+ self._cached_idx = set()
112
+ self._embeddings = torch.empty(())
113
+ elif isinstance(value, (Path, str)):
114
+ value = self._resolve_path(value)
115
+
116
+ if isinstance(value, Path) and value != getattr(self, "_cache", None):
117
+ self._save(value)
118
+
119
+ self._cache = value
120
+
121
+ def _resolve_path(self, path: Path | str) -> Path:
122
+ if isinstance(path, str):
123
+ path = Path(os.path.abspath(path))
124
+ if isinstance(path, Path) and (path.is_dir() or not path.suffix):
125
+ path = path / f"emb-{hash(self)}.pt"
126
+ return path
127
+
128
+ def to_tensor(self, indices: Sequence[int] | None = None) -> torch.Tensor:
129
+ """
130
+ Converts dataset to embeddings.
131
+
132
+ Parameters
133
+ ----------
134
+ indices : Sequence[int] or None, default None
135
+ The indices to convert to embeddings
136
+
137
+ Returns
138
+ -------
139
+ torch.Tensor
140
+
141
+ Warning
142
+ -------
143
+ Processing large quantities of data can be resource intensive.
144
+ """
145
+ if indices is not None:
146
+ return torch.vstack(list(self._batch(indices))).to(self.device)
147
+ else:
148
+ return self[:]
149
+
150
+ def to_numpy(self, indices: Sequence[int] | None = None) -> NDArray[Any]:
151
+ """
152
+ Converts dataset to embeddings as numpy array.
153
+
154
+ Parameters
155
+ ----------
156
+ indices : Sequence[int] or None, default None
157
+ The indices to convert to embeddings
158
+
159
+ Returns
160
+ -------
161
+ NDArray[Any]
162
+
163
+ Warning
164
+ -------
165
+ Processing large quantities of data can be resource intensive.
166
+ """
167
+ return self.to_tensor(indices).cpu().numpy()
168
+
169
+ def new(self, dataset: Dataset[tuple[ArrayLike, Any, Any]] | Dataset[ArrayLike]) -> Embeddings:
170
+ """
171
+ Creates a new Embeddings object with the same parameters but a different dataset.
172
+
173
+ Parameters
174
+ ----------
175
+ dataset : ImageClassificationDataset or ObjectDetectionDataset
176
+ Dataset to access original images from.
177
+
178
+ Returns
179
+ -------
180
+ Embeddings
181
+ """
182
+ if self._embeddings_only:
183
+ raise ValueError("Embeddings object does not have a model.")
184
+ return Embeddings(
185
+ dataset, self.batch_size, self._transforms, self._model, self.device, bool(self.cache), self.verbose
186
+ )
187
+
188
+ @classmethod
189
+ def from_array(cls, array: ArrayLike, device: DeviceLike | None = None) -> Embeddings:
190
+ """
191
+ Instantiates a shallow Embeddings object using an array.
192
+
193
+ Parameters
194
+ ----------
195
+ array : ArrayLike
196
+ The array to convert to embeddings.
197
+ device : DeviceLike or None, default None
198
+ The hardware device to use if specified, otherwise uses the DataEval
199
+ default or torch default.
200
+
201
+ Returns
202
+ -------
203
+ Embeddings
204
+
205
+ Example
206
+ -------
207
+ >>> import numpy as np
208
+ >>> from dataeval.data import Embeddings
209
+ >>> array = np.random.randn(100, 3, 224, 224)
210
+ >>> embeddings = Embeddings.from_array(array)
211
+ >>> print(embeddings.to_tensor().shape)
212
+ torch.Size([100, 3, 224, 224])
213
+ """
214
+ embeddings = Embeddings([], 0, None, None, device, True, False)
215
+ array = array if isinstance(array, Array) else as_numpy(array)
216
+ embeddings._cached_idx = set(range(len(array)))
217
+ embeddings._embeddings = torch.as_tensor(array).to(get_device(device))
218
+ embeddings._embeddings_only = True
219
+ return embeddings
220
+
221
+ def save(self, path: Path | str) -> None:
222
+ """
223
+ Saves the embeddings to disk.
224
+
225
+ Parameters
226
+ ----------
227
+ path : Path or str
228
+ The file path to save the embeddings to.
229
+ """
230
+ self._save(self._resolve_path(path), True)
231
+
232
+ def _save(self, path: Path, force: bool = False) -> None:
233
+ path.parent.mkdir(parents=True, exist_ok=True)
234
+
235
+ if self._embeddings_only or self.cache and not force:
236
+ embeddings = self._embeddings
237
+ cached_idx = self._cached_idx
238
+ else:
239
+ embeddings = self.to_tensor()
240
+ cached_idx = list(range(len(self)))
241
+ try:
242
+ cache_data = {
243
+ "embeddings": embeddings,
244
+ "cached_indices": cached_idx,
245
+ "device": self.device,
246
+ }
247
+ torch.save(cache_data, path)
248
+ _logger.log(logging.DEBUG, f"Saved embeddings cache from {path}")
249
+ except Exception as e:
250
+ _logger.log(logging.ERROR, f"Failed to save embeddings cache: {e}")
251
+
252
+ @classmethod
253
+ def load(cls, path: Path | str) -> Embeddings:
254
+ """
255
+ Loads the embeddings from disk.
256
+
257
+ Parameters
258
+ ----------
259
+ path : Path or str
260
+ The file path to load the embeddings from.
261
+ """
262
+ emb = Embeddings([], 0)
263
+ path = Path(os.path.abspath(path)) if isinstance(path, str) else path
264
+ if path.exists() and path.is_file():
265
+ try:
266
+ cache_data = torch.load(path, weights_only=False)
267
+ emb._embeddings_only = True
268
+ emb._embeddings = cache_data["embeddings"]
269
+ emb._cached_idx = cache_data["cached_indices"]
270
+ emb.device = cache_data["device"]
271
+ _logger.log(logging.DEBUG, f"Loaded embeddings cache from {path}")
272
+ except Exception as e:
273
+ _logger.log(logging.ERROR, f"Failed to load embeddings cache: {e}")
274
+ raise e
275
+ else:
276
+ raise FileNotFoundError(f"Specified cache file {path} was not found.")
277
+
278
+ return emb
279
+
280
+ def _encode(self, images: list[torch.Tensor]) -> torch.Tensor:
281
+ if self._transforms:
282
+ images = [transform(image) for transform in self._transforms for image in images]
283
+ return self._encoder(torch.stack(images).to(self.device))
284
+
285
+ @torch.no_grad() # Reduce overhead cost by not tracking tensor gradients
286
+ def _batch(self, indices: Sequence[int]) -> Iterator[torch.Tensor]:
287
+ dataset = cast(torch.utils.data.Dataset, self._dataset)
288
+ total_batches = math.ceil(len(indices) / self.batch_size)
289
+
290
+ # If not caching, process all indices normally
291
+ if not self.cache:
292
+ for images in tqdm(
293
+ DataLoader(Subset(dataset, indices), self.batch_size, collate_fn=self._collate_fn),
294
+ total=total_batches,
295
+ desc="Batch embedding",
296
+ disable=not self.verbose,
297
+ ):
298
+ yield self._encode(images)
299
+ return
300
+
301
+ # If caching, process each batch of indices at a time, preserving original order
302
+ for i in tqdm(range(0, len(indices), self.batch_size), desc="Batch embedding", disable=not self.verbose):
303
+ batch = indices[i : i + self.batch_size]
304
+ uncached = [idx for idx in batch if idx not in self._cached_idx]
305
+
306
+ if uncached:
307
+ # Process uncached indices as as single batch
308
+ for images in DataLoader(Subset(dataset, uncached), len(uncached), collate_fn=self._collate_fn):
309
+ embeddings = self._encode(images)
310
+
311
+ if not self._embeddings.shape:
312
+ full_shape = (len(self), *embeddings.shape[1:])
313
+ self._embeddings = torch.empty(full_shape, dtype=embeddings.dtype, device=self.device)
314
+
315
+ self._embeddings[uncached] = embeddings
316
+ self._cached_idx.update(uncached)
317
+
318
+ if isinstance(self.cache, Path):
319
+ self._save(self.cache)
320
+
321
+ yield self._embeddings[batch]
322
+
323
+ def __getitem__(self, key: int | slice, /) -> torch.Tensor:
324
+ if not isinstance(key, slice) and not hasattr(key, "__int__"):
325
+ raise TypeError("Invalid argument type.")
326
+
327
+ indices = list(range(len(self))[key]) if isinstance(key, slice) else [int(key)]
328
+
329
+ if self._embeddings_only:
330
+ if not self._embeddings.shape:
331
+ raise ValueError("Embeddings not initialized.")
332
+ if not set(indices).issubset(self._cached_idx):
333
+ raise ValueError("Unable to generate new embeddings from a shallow instance.")
334
+ return self._embeddings[key]
335
+
336
+ result = torch.vstack(list(self._batch(indices))).to(self.device)
337
+ return result.squeeze(0) if len(indices) == 1 else result
338
+
339
+ def __iter__(self) -> Iterator[torch.Tensor]:
340
+ # process in batches while yielding individual embeddings
341
+ for batch in self._batch(range(len(self))):
342
+ yield from batch
343
+
344
+ def __len__(self) -> int:
345
+ return len(self._embeddings) if self._embeddings_only else len(self._dataset)
@@ -4,13 +4,13 @@ __all__ = []
4
4
 
5
5
  from typing import TYPE_CHECKING, Any, Generic, Iterator, Sequence, TypeVar, cast, overload
6
6
 
7
- from dataeval.typing import Array, Dataset
7
+ from dataeval.typing import Array, ArrayLike, Dataset
8
8
  from dataeval.utils._array import as_numpy, channels_first_to_last
9
9
 
10
10
  if TYPE_CHECKING:
11
11
  from matplotlib.figure import Figure
12
12
 
13
- T = TypeVar("T", bound=Array)
13
+ T = TypeVar("T", Array, ArrayLike)
14
14
 
15
15
 
16
16
  class Images(Generic[T]):
@@ -3,7 +3,7 @@ from __future__ import annotations
3
3
  __all__ = []
4
4
 
5
5
  import warnings
6
- from typing import TYPE_CHECKING, Any, Literal, Mapping, Sequence, cast
6
+ from typing import TYPE_CHECKING, Any, Literal, Mapping, Sequence, Sized, cast
7
7
 
8
8
  import numpy as np
9
9
  from numpy.typing import NDArray
@@ -16,12 +16,12 @@ from dataeval.typing import (
16
16
  )
17
17
  from dataeval.utils._array import as_numpy, to_numpy
18
18
  from dataeval.utils._bin import bin_data, digitize_data, is_continuous
19
- from dataeval.utils.metadata import merge
19
+ from dataeval.utils.data.metadata import merge
20
20
 
21
21
  if TYPE_CHECKING:
22
- from dataeval.utils.data import Targets
22
+ from dataeval.data import Targets
23
23
  else:
24
- from dataeval.utils.data._targets import Targets
24
+ from dataeval.data._targets import Targets
25
25
 
26
26
 
27
27
  class Metadata:
@@ -208,8 +208,9 @@ class Metadata:
208
208
  raw.append(metadata)
209
209
 
210
210
  if is_od_target := isinstance(target, ObjectDetectionTarget):
211
- target_len = len(target.labels)
212
- labels.extend(as_numpy(target.labels).tolist())
211
+ target_labels = as_numpy(target.labels)
212
+ target_len = len(target_labels)
213
+ labels.extend(target_labels.tolist())
213
214
  bboxes.extend(as_numpy(target.boxes).tolist())
214
215
  scores.extend(as_numpy(target.scores).tolist())
215
216
  srcidx.extend([i] * target_len)
@@ -360,7 +361,7 @@ class Metadata:
360
361
  self._merge()
361
362
  self._processed = False
362
363
  target_len = len(self.targets.source) if self.targets.source is not None else len(self.targets)
363
- if any(len(v) != target_len for v in factors.values()):
364
+ if any(len(v if isinstance(v, Sized) else as_numpy(v)) != target_len for v in factors.values()):
364
365
  raise ValueError(
365
366
  "The lists/arrays in the provided factors have a different length than the current metadata factors."
366
367
  )
@@ -25,6 +25,10 @@ class Selection(Generic[_TDatum]):
25
25
  return f"{self.__class__.__name__}({', '.join([f'{k}={v}' for k, v in self.__dict__.items()])})"
26
26
 
27
27
 
28
+ class Subselection(Generic[_TDatum]):
29
+ def __call__(self, original: _TDatum) -> _TDatum: ...
30
+
31
+
28
32
  class Select(AnnotatedDataset[_TDatum]):
29
33
  """
30
34
  Wraps a dataset and applies selection criteria to it.
@@ -38,7 +42,7 @@ class Select(AnnotatedDataset[_TDatum]):
38
42
 
39
43
  Examples
40
44
  --------
41
- >>> from dataeval.utils.data.selections import ClassFilter, Limit
45
+ >>> from dataeval.data.selections import ClassFilter, Limit
42
46
 
43
47
  >>> # Construct a sample dataset with size of 100 and class count of 10
44
48
  >>> # Elements at index `idx` are returned as tuples:
@@ -63,6 +67,7 @@ class Select(AnnotatedDataset[_TDatum]):
63
67
  _selection: list[int]
64
68
  _selections: Sequence[Selection[_TDatum]]
65
69
  _size_limit: int
70
+ _subselections: list[tuple[Subselection[_TDatum], set[int]]]
66
71
 
67
72
  def __init__(
68
73
  self,
@@ -73,7 +78,8 @@ class Select(AnnotatedDataset[_TDatum]):
73
78
  self._dataset = dataset
74
79
  self._size_limit = len(dataset)
75
80
  self._selection = list(range(self._size_limit))
76
- self._selections = self._sort(selections)
81
+ self._selections = self._sort_selections(selections)
82
+ self._subselections = []
77
83
 
78
84
  # Ensure metadata is populated correctly as DatasetMetadata TypedDict
79
85
  _metadata = getattr(dataset, "metadata", {})
@@ -81,7 +87,7 @@ class Select(AnnotatedDataset[_TDatum]):
81
87
  _metadata["id"] = dataset.__class__.__name__
82
88
  self._metadata = DatasetMetadata(**_metadata)
83
89
 
84
- self._select()
90
+ self._apply_selections()
85
91
 
86
92
  @property
87
93
  def metadata(self) -> DatasetMetadata:
@@ -94,24 +100,31 @@ class Select(AnnotatedDataset[_TDatum]):
94
100
  selections = f"Selections: [{', '.join([str(s) for s in self._selections])}]"
95
101
  return f"{title}\n{sep}{nt}{selections}{nt}Selected Size: {len(self)}\n\n{self._dataset}"
96
102
 
97
- def _sort(self, selections: Selection[_TDatum] | Sequence[Selection[_TDatum]] | None) -> list[Selection]:
103
+ def _sort_selections(
104
+ self, selections: Selection[_TDatum] | Sequence[Selection[_TDatum]] | None
105
+ ) -> list[Selection[_TDatum]]:
98
106
  if not selections:
99
107
  return []
100
108
 
101
- selections = [selections] if isinstance(selections, Selection) else selections
102
- grouped: dict[int, list[Selection]] = {}
103
- for selection in selections:
109
+ selections_list = [selections] if isinstance(selections, Selection) else list(selections)
110
+ grouped: dict[int, list[Selection[_TDatum]]] = {}
111
+ for selection in selections_list:
104
112
  grouped.setdefault(selection.stage, []).append(selection)
105
113
  selection_list = [selection for category in sorted(grouped) for selection in grouped[category]]
106
114
  return selection_list
107
115
 
108
- def _select(self) -> None:
116
+ def _apply_selections(self) -> None:
109
117
  for selection in self._selections:
110
118
  selection(self)
111
119
  self._selection = self._selection[: self._size_limit]
112
120
 
121
+ def _apply_subselection(self, datum: _TDatum, index: int) -> _TDatum:
122
+ for subselection, indices in self._subselections:
123
+ datum = subselection(datum) if index in indices else datum
124
+ return datum
125
+
113
126
  def __getitem__(self, index: int) -> _TDatum:
114
- return self._dataset[self._selection[index]]
127
+ return self._apply_subselection(self._dataset[self._selection[index]], index)
115
128
 
116
129
  def __iter__(self) -> Iterator[_TDatum]:
117
130
  for i in range(len(self)):
@@ -12,10 +12,10 @@ from sklearn.model_selection import GroupKFold, KFold, StratifiedGroupKFold, Str
12
12
  from sklearn.utils.multiclass import type_of_target
13
13
 
14
14
  from dataeval.config import EPSILON
15
+ from dataeval.data._metadata import Metadata
15
16
  from dataeval.outputs._base import set_metadata
16
17
  from dataeval.outputs._utils import SplitDatasetOutput, TrainValSplit
17
18
  from dataeval.typing import AnnotatedDataset
18
- from dataeval.utils.data._metadata import Metadata
19
19
 
20
20
  _logger = logging.getLogger(__name__)
21
21
 
@@ -0,0 +1,19 @@
1
+ """Provides selection classes for selecting subsets of Computer Vision datasets."""
2
+
3
+ __all__ = [
4
+ "ClassBalance",
5
+ "ClassFilter",
6
+ "Indices",
7
+ "Limit",
8
+ "Prioritize",
9
+ "Reverse",
10
+ "Shuffle",
11
+ ]
12
+
13
+ from dataeval.data.selections._classbalance import ClassBalance
14
+ from dataeval.data.selections._classfilter import ClassFilter
15
+ from dataeval.data.selections._indices import Indices
16
+ from dataeval.data.selections._limit import Limit
17
+ from dataeval.data.selections._prioritize import Prioritize
18
+ from dataeval.data.selections._reverse import Reverse
19
+ from dataeval.data.selections._shuffle import Shuffle
@@ -0,0 +1,37 @@
1
+ from __future__ import annotations
2
+
3
+ __all__ = []
4
+
5
+ import numpy as np
6
+
7
+ from dataeval.data._selection import Select, Selection, SelectionStage
8
+ from dataeval.typing import Array, ImageClassificationDatum
9
+ from dataeval.utils._array import as_numpy
10
+
11
+
12
+ class ClassBalance(Selection[ImageClassificationDatum]):
13
+ """
14
+ Balance the dataset by class.
15
+
16
+ Note
17
+ ----
18
+ The total number of instances of each class will be equalized which may result
19
+ in a lower total number of instances than specified by the selection limit.
20
+ """
21
+
22
+ stage = SelectionStage.FILTER
23
+
24
+ def __call__(self, dataset: Select[ImageClassificationDatum]) -> None:
25
+ class_indices: dict[int, list[int]] = {}
26
+ for i, idx in enumerate(dataset._selection):
27
+ target = dataset._dataset[idx][1]
28
+ if isinstance(target, Array):
29
+ label = int(np.argmax(as_numpy(target)))
30
+ else:
31
+ # ObjectDetectionTarget and SegmentationTarget not supported yet
32
+ raise TypeError("ClassFilter only supports classification targets as an array of confidence scores.")
33
+ class_indices.setdefault(label, []).append(i)
34
+
35
+ per_class_limit = min(min(len(c) for c in class_indices.values()), dataset._size_limit // len(class_indices))
36
+ subselection = sorted([i for v in class_indices.values() for i in v[:per_class_limit]])
37
+ dataset._selection = [dataset._selection[i] for i in subselection]