dataeval 0.82.0__py3-none-any.whl → 0.83.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (57) hide show
  1. dataeval/__init__.py +7 -2
  2. dataeval/config.py +78 -11
  3. dataeval/detectors/drift/_mmd.py +9 -9
  4. dataeval/detectors/drift/_torch.py +7 -7
  5. dataeval/detectors/drift/_uncertainty.py +4 -4
  6. dataeval/detectors/linters/duplicates.py +3 -3
  7. dataeval/detectors/linters/outliers.py +3 -3
  8. dataeval/detectors/ood/ae.py +5 -4
  9. dataeval/detectors/ood/base.py +2 -2
  10. dataeval/detectors/ood/mixin.py +1 -1
  11. dataeval/detectors/ood/vae.py +2 -1
  12. dataeval/metadata/__init__.py +2 -2
  13. dataeval/metadata/_distance.py +11 -44
  14. dataeval/metadata/_ood.py +152 -33
  15. dataeval/metrics/bias/_balance.py +9 -5
  16. dataeval/metrics/bias/_diversity.py +3 -0
  17. dataeval/metrics/bias/_parity.py +2 -0
  18. dataeval/metrics/estimators/_ber.py +2 -1
  19. dataeval/metrics/stats/_base.py +20 -21
  20. dataeval/metrics/stats/_boxratiostats.py +1 -1
  21. dataeval/metrics/stats/_dimensionstats.py +2 -2
  22. dataeval/metrics/stats/_hashstats.py +2 -2
  23. dataeval/metrics/stats/_imagestats.py +8 -8
  24. dataeval/metrics/stats/_pixelstats.py +2 -2
  25. dataeval/metrics/stats/_visualstats.py +2 -2
  26. dataeval/outputs/__init__.py +5 -0
  27. dataeval/outputs/_base.py +50 -21
  28. dataeval/outputs/_bias.py +1 -1
  29. dataeval/outputs/_linters.py +4 -2
  30. dataeval/outputs/_metadata.py +61 -0
  31. dataeval/outputs/_stats.py +12 -6
  32. dataeval/typing.py +40 -9
  33. dataeval/utils/_mst.py +1 -2
  34. dataeval/utils/data/_embeddings.py +23 -19
  35. dataeval/utils/data/_metadata.py +16 -7
  36. dataeval/utils/data/_selection.py +22 -15
  37. dataeval/utils/data/_split.py +3 -2
  38. dataeval/utils/data/datasets/_base.py +4 -2
  39. dataeval/utils/data/datasets/_cifar10.py +17 -9
  40. dataeval/utils/data/datasets/_milco.py +18 -12
  41. dataeval/utils/data/datasets/_mnist.py +24 -8
  42. dataeval/utils/data/datasets/_ships.py +18 -8
  43. dataeval/utils/data/datasets/_types.py +1 -5
  44. dataeval/utils/data/datasets/_voc.py +47 -24
  45. dataeval/utils/data/selections/__init__.py +2 -0
  46. dataeval/utils/data/selections/_classfilter.py +5 -3
  47. dataeval/utils/data/selections/_prioritize.py +296 -0
  48. dataeval/utils/data/selections/_shuffle.py +13 -4
  49. dataeval/utils/torch/_gmm.py +3 -2
  50. dataeval/utils/torch/_internal.py +5 -5
  51. dataeval/utils/torch/trainer.py +8 -8
  52. {dataeval-0.82.0.dist-info → dataeval-0.83.0.dist-info}/METADATA +4 -4
  53. dataeval-0.83.0.dist-info/RECORD +105 -0
  54. dataeval/detectors/ood/metadata_ood_mi.py +0 -93
  55. dataeval-0.82.0.dist-info/RECORD +0 -104
  56. {dataeval-0.82.0.dist-info → dataeval-0.83.0.dist-info}/LICENSE.txt +0 -0
  57. {dataeval-0.82.0.dist-info → dataeval-0.83.0.dist-info}/WHEEL +0 -0
dataeval/outputs/_base.py CHANGED
@@ -4,11 +4,11 @@ __all__ = []
4
4
 
5
5
  import inspect
6
6
  import logging
7
- from collections.abc import Mapping
7
+ from collections.abc import Collection, Mapping, Sequence
8
8
  from dataclasses import dataclass
9
9
  from datetime import datetime, timezone
10
10
  from functools import partial, wraps
11
- from typing import Any, Callable, Iterator, TypeVar
11
+ from typing import Any, Callable, Generic, Iterator, TypeVar, overload
12
12
 
13
13
  import numpy as np
14
14
  from typing_extensions import ParamSpec
@@ -56,16 +56,13 @@ class ExecutionMetadata:
56
56
  )
57
57
 
58
58
 
59
- class Output:
60
- _meta: ExecutionMetadata | None = None
59
+ T = TypeVar("T", covariant=True)
61
60
 
62
- def __str__(self) -> str:
63
- return f"{self.__class__.__name__}: {str(self.dict())}"
64
61
 
65
- def dict(self) -> dict[str, Any]:
66
- return {k: v for k, v in self.__dict__.items() if k != "_meta"}
62
+ class GenericOutput(Generic[T]):
63
+ _meta: ExecutionMetadata | None = None
67
64
 
68
- @property
65
+ def data(self) -> T: ...
69
66
  def meta(self) -> ExecutionMetadata:
70
67
  """
71
68
  Metadata about the execution of the function or method for the Output class.
@@ -73,34 +70,66 @@ class Output:
73
70
  return self._meta or ExecutionMetadata.empty()
74
71
 
75
72
 
76
- TKey = TypeVar("TKey", str, int, float, set)
77
- TValue = TypeVar("TValue")
73
+ class Output(GenericOutput[dict[str, Any]]):
74
+ def data(self) -> dict[str, Any]:
75
+ return {k: v for k, v in self.__dict__.items() if k != "_meta"}
78
76
 
77
+ def __repr__(self) -> str:
78
+ return str(self)
79
79
 
80
- class MappingOutput(Mapping[TKey, TValue], Output):
80
+ def __str__(self) -> str:
81
+ return f"{self.__class__.__name__}({', '.join([f'{k}={v}' for k, v in self.data().items()])})"
82
+
83
+
84
+ class BaseCollectionMixin(Collection[Any]):
81
85
  __slots__ = ["_data"]
82
86
 
87
+ def data(self) -> Any:
88
+ return self._data
89
+
90
+ def __len__(self) -> int:
91
+ return len(self._data)
92
+
93
+ def __repr__(self) -> str:
94
+ return f"{self.__class__.__name__}({repr(self._data)})"
95
+
96
+ def __str__(self) -> str:
97
+ return str(self._data)
98
+
99
+
100
+ TKey = TypeVar("TKey", str, int, float, set)
101
+ TValue = TypeVar("TValue")
102
+
103
+
104
+ class MappingOutput(Mapping[TKey, TValue], BaseCollectionMixin, GenericOutput[Mapping[TKey, TValue]]):
83
105
  def __init__(self, data: Mapping[TKey, TValue]):
84
106
  self._data = data
85
107
 
86
108
  def __getitem__(self, key: TKey) -> TValue:
87
- return self._data.__getitem__(key)
109
+ return self._data[key]
88
110
 
89
111
  def __iter__(self) -> Iterator[TKey]:
90
- return self._data.__iter__()
112
+ return iter(self._data)
91
113
 
92
- def __len__(self) -> int:
93
- return self._data.__len__()
94
114
 
95
- def dict(self) -> dict[str, TValue]:
96
- return {str(k): v for k, v in self._data.items()}
115
+ class SequenceOutput(Sequence[TValue], BaseCollectionMixin, GenericOutput[Sequence[TValue]]):
116
+ def __init__(self, data: Sequence[TValue]):
117
+ self._data = data
118
+
119
+ @overload
120
+ def __getitem__(self, index: int) -> TValue: ...
121
+ @overload
122
+ def __getitem__(self, index: slice) -> Sequence[TValue]: ...
97
123
 
98
- def __str__(self) -> str:
99
- return str(self.dict())
124
+ def __getitem__(self, index: int | slice) -> TValue | Sequence[TValue]:
125
+ return self._data[index]
126
+
127
+ def __iter__(self) -> Iterator[TValue]:
128
+ return iter(self._data)
100
129
 
101
130
 
102
131
  P = ParamSpec("P")
103
- R = TypeVar("R", bound=Output)
132
+ R = TypeVar("R", bound=GenericOutput)
104
133
 
105
134
 
106
135
  def set_metadata(fn: Callable[P, R] | None = None, *, state: list[str] | None = None) -> Callable[P, R]:
dataeval/outputs/_bias.py CHANGED
@@ -364,7 +364,7 @@ class DiversityOutput(Output):
364
364
  col_labels,
365
365
  xlabel="Factors",
366
366
  ylabel="Class",
367
- cbarlabel=f"Normalized {asdict(self.meta)['arguments']['method'].title()} Index",
367
+ cbarlabel=f"Normalized {asdict(self.meta())['arguments']['method'].title()} Index",
368
368
  )
369
369
 
370
370
  else:
@@ -24,7 +24,7 @@ TIndexIssueMap = TypeVar("TIndexIssueMap", IndexIssueMap, list[IndexIssueMap])
24
24
 
25
25
 
26
26
  @dataclass(frozen=True)
27
- class DuplicatesOutput(Generic[TIndexCollection], Output):
27
+ class DuplicatesOutput(Output, Generic[TIndexCollection]):
28
28
  """
29
29
  Output class for :class:`.Duplicates` lint detector.
30
30
 
@@ -35,6 +35,8 @@ class DuplicatesOutput(Generic[TIndexCollection], Output):
35
35
  near: list[list[int] | dict[int, list[int]]]
36
36
  Indices of images that are near matches
37
37
 
38
+ Notes
39
+ -----
38
40
  - For a single dataset, indices are returned as a list of index groups.
39
41
  - For multiple datasets, indices are returned as dictionaries where the key is the
40
42
  index of the dataset, and the value is the list index groups from that dataset.
@@ -99,7 +101,7 @@ def _create_pandas_dataframe(class_wise):
99
101
 
100
102
 
101
103
  @dataclass(frozen=True)
102
- class OutliersOutput(Generic[TIndexIssueMap], Output):
104
+ class OutliersOutput(Output, Generic[TIndexIssueMap]):
103
105
  """
104
106
  Output class for :class:`.Outliers` lint detector.
105
107
 
@@ -0,0 +1,61 @@
1
+ from __future__ import annotations
2
+
3
+ __all__ = []
4
+
5
+ from typing import NamedTuple
6
+
7
+ from dataeval.outputs._base import MappingOutput, SequenceOutput
8
+
9
+
10
+ class MostDeviatedFactorsOutput(SequenceOutput[tuple[str, float]]):
11
+ """
12
+ Output class for results of :func:`.most_deviated_factors` for OOD samples with metadata.
13
+
14
+ Attributes
15
+ ----------
16
+ value : tuple[str, float]
17
+ A tuple of the factor name and deviation of the highest metadata deviation
18
+ """
19
+
20
+
21
+ class MetadataDistanceValues(NamedTuple):
22
+ """
23
+ Statistics comparing metadata distance.
24
+
25
+ Attributes
26
+ ----------
27
+ statistic : float
28
+ the KS statistic
29
+ location : float
30
+ The value at which the KS statistic has its maximum, measured in IQR-normalized units relative
31
+ to the median of the reference distribution.
32
+ dist : float
33
+ The Earth Mover's Distance normalized by the interquartile range (IQR) of the reference
34
+ pvalue : float
35
+ The p-value from the KS two-sample test
36
+ """
37
+
38
+ statistic: float
39
+ location: float
40
+ dist: float
41
+ pvalue: float
42
+
43
+
44
+ class MetadataDistanceOutput(MappingOutput[str, MetadataDistanceValues]):
45
+ """
46
+ Output class for results of ks_2samp featurewise comparisons of new metadata to reference metadata.
47
+
48
+ Attributes
49
+ ----------
50
+ key : str
51
+ Metadata feature names
52
+ value : :class:`.MetadataDistanceValues`
53
+ Output per feature name containing the statistic, statistic location, distance, and pvalue.
54
+ """
55
+
56
+
57
+ class OODPredictorOutput(MappingOutput[str, float]):
58
+ """
59
+ Output class for results of :func:`find_ood_predictors` for the
60
+ mutual information between factors and being out of distribution
61
+ """
@@ -4,7 +4,7 @@ __all__ = []
4
4
 
5
5
  import contextlib
6
6
  from dataclasses import dataclass
7
- from typing import Iterable, Optional, Union
7
+ from typing import Any, Iterable, Optional, Union
8
8
 
9
9
  import numpy as np
10
10
  from numpy.typing import NDArray
@@ -63,7 +63,7 @@ class BaseStatsOutput(Output):
63
63
 
64
64
  def __post_init__(self) -> None:
65
65
  length = len(self.source_index)
66
- bad = {k: len(v) for k, v in self.dict().items() if k not in [SOURCE_INDEX, BOX_COUNT] and len(v) != length}
66
+ bad = {k: len(v) for k, v in self.data().items() if k not in [SOURCE_INDEX, BOX_COUNT] and len(v) != length}
67
67
  if bad:
68
68
  raise ValueError(f"All values must have the same length as source_index. Bad values: {str(bad)}.")
69
69
 
@@ -105,7 +105,7 @@ class BaseStatsOutput(Output):
105
105
  def _get_channels(
106
106
  self, channel_limit: int | None = None, channel_index: int | Iterable[int] | None = None
107
107
  ) -> tuple[int, list[bool] | None]:
108
- source_index = self.dict()[SOURCE_INDEX]
108
+ source_index = self.data()[SOURCE_INDEX]
109
109
  raw_channels = int(max([si.channel or 0 for si in source_index])) + 1
110
110
  if isinstance(channel_index, int):
111
111
  max_channels = 1 if channel_index < raw_channels else raw_channels
@@ -127,15 +127,21 @@ class BaseStatsOutput(Output):
127
127
 
128
128
  return max_channels, ch_mask
129
129
 
130
+ def factors(self) -> dict[str, NDArray[Any]]:
131
+ return {
132
+ k: v
133
+ for k, v in self.data().items()
134
+ if k not in (SOURCE_INDEX, BOX_COUNT) and isinstance(v, np.ndarray) and v[v != 0].size > 0 and v.ndim == 1
135
+ }
136
+
130
137
  def plot(
131
138
  self, log: bool, channel_limit: int | None = None, channel_index: int | Iterable[int] | None = None
132
139
  ) -> None:
133
140
  max_channels, ch_mask = self._get_channels(channel_limit, channel_index)
134
- d = {k: v for k, v in self.dict().items() if isinstance(v, np.ndarray) and v[v != 0].size > 0 and v.ndim == 1}
135
141
  if max_channels == 1:
136
- histogram_plot(d, log)
142
+ histogram_plot(self.factors(), log)
137
143
  else:
138
- channel_histogram_plot(d, log, max_channels, ch_mask)
144
+ channel_histogram_plot(self.factors(), log, max_channels, ch_mask)
139
145
 
140
146
 
141
147
  @dataclass(frozen=True)
dataeval/typing.py CHANGED
@@ -1,5 +1,5 @@
1
1
  """
2
- Common type hints used for interoperability with DataEval.
2
+ Common type protocols used for interoperability with DataEval.
3
3
  """
4
4
 
5
5
  __all__ = [
@@ -16,6 +16,7 @@ __all__ = [
16
16
  "SegmentationTarget",
17
17
  "SegmentationDatum",
18
18
  "SegmentationDataset",
19
+ "Transform",
19
20
  ]
20
21
 
21
22
 
@@ -66,6 +67,7 @@ class Array(Protocol):
66
67
  def __len__(self) -> int: ...
67
68
 
68
69
 
70
+ T = TypeVar("T")
69
71
  _T_co = TypeVar("_T_co", covariant=True)
70
72
  _ScalarType = Union[int, float, bool, str]
71
73
  ArrayLike: TypeAlias = Union[Sequence[_ScalarType], Sequence[Sequence[_ScalarType]], Sequence[Array], Array]
@@ -140,7 +142,7 @@ class AnnotatedDataset(Dataset[_T_co], Generic[_T_co], Protocol):
140
142
 
141
143
  ImageClassificationDatum: TypeAlias = tuple[Array, Array, dict[str, Any]]
142
144
  """
143
- A type definition for an image classification datum tuple.
145
+ Type alias for an image classification datum tuple.
144
146
 
145
147
  - :class:`Array` of shape (C, H, W) - Image data in channel, height, width format.
146
148
  - :class:`Array` of shape (N,) - Class label as one-hot encoded ground-truth or prediction confidences.
@@ -150,7 +152,7 @@ A type definition for an image classification datum tuple.
150
152
 
151
153
  ImageClassificationDataset: TypeAlias = AnnotatedDataset[ImageClassificationDatum]
152
154
  """
153
- A type definition for an :class:`AnnotatedDataset` of :class:`ImageClassificationDatum` elements.
155
+ Type alias for an :class:`AnnotatedDataset` of :class:`ImageClassificationDatum` elements.
154
156
  """
155
157
 
156
158
  # ========== OBJECT DETECTION DATASETS ==========
@@ -159,7 +161,7 @@ A type definition for an :class:`AnnotatedDataset` of :class:`ImageClassificatio
159
161
  @runtime_checkable
160
162
  class ObjectDetectionTarget(Protocol):
161
163
  """
162
- A protocol for targets in an Object Detection dataset.
164
+ Protocol for targets in an Object Detection dataset.
163
165
 
164
166
  Attributes
165
167
  ----------
@@ -180,7 +182,7 @@ class ObjectDetectionTarget(Protocol):
180
182
 
181
183
  ObjectDetectionDatum: TypeAlias = tuple[Array, ObjectDetectionTarget, dict[str, Any]]
182
184
  """
183
- A type definition for an object detection datum tuple.
185
+ Type alias for an object detection datum tuple.
184
186
 
185
187
  - :class:`Array` of shape (C, H, W) - Image data in channel, height, width format.
186
188
  - :class:`ObjectDetectionTarget` - Object detection target information for the image.
@@ -190,7 +192,7 @@ A type definition for an object detection datum tuple.
190
192
 
191
193
  ObjectDetectionDataset: TypeAlias = AnnotatedDataset[ObjectDetectionDatum]
192
194
  """
193
- A type definition for an :class:`AnnotatedDataset` of :class:`ObjectDetectionDatum` elements.
195
+ Type alias for an :class:`AnnotatedDataset` of :class:`ObjectDetectionDatum` elements.
194
196
  """
195
197
 
196
198
 
@@ -200,7 +202,7 @@ A type definition for an :class:`AnnotatedDataset` of :class:`ObjectDetectionDat
200
202
  @runtime_checkable
201
203
  class SegmentationTarget(Protocol):
202
204
  """
203
- A protocol for targets in a Segmentation dataset.
205
+ Protocol for targets in a Segmentation dataset.
204
206
 
205
207
  Attributes
206
208
  ----------
@@ -221,7 +223,7 @@ class SegmentationTarget(Protocol):
221
223
 
222
224
  SegmentationDatum: TypeAlias = tuple[Array, SegmentationTarget, dict[str, Any]]
223
225
  """
224
- A type definition for an image classification datum tuple.
226
+ Type alias for an image classification datum tuple.
225
227
 
226
228
  - :class:`Array` of shape (C, H, W) - Image data in channel, height, width format.
227
229
  - :class:`SegmentationTarget` - Segmentation target information for the image.
@@ -230,5 +232,34 @@ A type definition for an image classification datum tuple.
230
232
 
231
233
  SegmentationDataset: TypeAlias = AnnotatedDataset[SegmentationDatum]
232
234
  """
233
- A type definition for an :class:`AnnotatedDataset` of :class:`SegmentationDatum` elements.
235
+ Type alias for an :class:`AnnotatedDataset` of :class:`SegmentationDatum` elements.
234
236
  """
237
+
238
+
239
+ @runtime_checkable
240
+ class Transform(Generic[T], Protocol):
241
+ """
242
+ Protocol defining a transform function.
243
+
244
+ Requires a `__call__` method that returns transformed data.
245
+
246
+ Example
247
+ -------
248
+ >>> from typing import Any
249
+ >>> from numpy.typing import NDArray
250
+
251
+ >>> class MyTransform:
252
+ ... def __init__(self, divisor: float) -> None:
253
+ ... self.divisor = divisor
254
+ ...
255
+ ... def __call__(self, data: NDArray[Any], /) -> NDArray[Any]:
256
+ ... return data / self.divisor
257
+
258
+ >>> my_transform = MyTransform(divisor=255.0)
259
+ >>> isinstance(my_transform, Transform)
260
+ True
261
+ >>> my_transform(np.array([1, 2, 3]))
262
+ array([0.004, 0.008, 0.012])
263
+ """
264
+
265
+ def __call__(self, data: T, /) -> T: ...
dataeval/utils/_mst.py CHANGED
@@ -10,10 +10,9 @@ from scipy.sparse.csgraph import minimum_spanning_tree as mst
10
10
  from scipy.spatial.distance import pdist, squareform
11
11
  from sklearn.neighbors import NearestNeighbors
12
12
 
13
+ from dataeval.config import EPSILON
13
14
  from dataeval.utils._array import flatten
14
15
 
15
- EPSILON = 1e-5
16
-
17
16
 
18
17
  def minimum_spanning_tree(X: NDArray[Any]) -> Any:
19
18
  """
@@ -9,7 +9,7 @@ import torch
9
9
  from torch.utils.data import DataLoader, Subset
10
10
  from tqdm import tqdm
11
11
 
12
- from dataeval.config import get_device
12
+ from dataeval.config import DeviceLike, get_device
13
13
  from dataeval.typing import Array, Dataset
14
14
  from dataeval.utils.torch.models import SupportsEncode
15
15
 
@@ -24,13 +24,14 @@ class Embeddings:
24
24
  ----------
25
25
  dataset : ImageClassificationDataset or ObjectDetectionDataset
26
26
  Dataset to access original images from.
27
- batch_size : int, optional
27
+ batch_size : int
28
28
  Batch size to use when encoding images.
29
- model : torch.nn.Module, optional
29
+ model : torch.nn.Module or None, default None
30
30
  Model to use for encoding images.
31
- device : torch.device, optional
32
- Device to use for encoding images.
33
- verbose : bool, optional
31
+ device : DeviceLike or None, default None
32
+ The hardware device to use if specified, otherwise uses the DataEval
33
+ default or torch default.
34
+ verbose : bool, default False
34
35
  Whether to print progress bar when encoding images.
35
36
  """
36
37
 
@@ -42,9 +43,8 @@ class Embeddings:
42
43
  self,
43
44
  dataset: Dataset[tuple[Array, Any, Any]],
44
45
  batch_size: int,
45
- indices: Sequence[int] | None = None,
46
46
  model: torch.nn.Module | None = None,
47
- device: torch.device | str | None = None,
47
+ device: DeviceLike | None = None,
48
48
  verbose: bool = False,
49
49
  ) -> None:
50
50
  self.device = get_device(device)
@@ -52,26 +52,32 @@ class Embeddings:
52
52
  self.verbose = verbose
53
53
 
54
54
  self._dataset = dataset
55
- self._indices = indices if indices is not None else range(len(dataset))
56
55
  model = torch.nn.Flatten() if model is None else model
57
56
  self._model = model.to(self.device).eval()
58
57
  self._encoder = model.encode if isinstance(model, SupportsEncode) else model
59
58
  self._collate_fn = lambda datum: [torch.as_tensor(i) for i, _, _ in datum]
60
59
 
61
- def to_tensor(self) -> torch.Tensor:
60
+ def to_tensor(self, indices: Sequence[int] | None = None) -> torch.Tensor:
62
61
  """
63
- Converts entire dataset to embeddings.
62
+ Converts dataset to embeddings.
64
63
 
65
- Warning
66
- -------
67
- Will process the entire dataset in batches and return
68
- embeddings as a single Tensor in memory.
64
+ Parameters
65
+ ----------
66
+ indices : Sequence[int] or None, default None
67
+ The indices to convert to embeddings
69
68
 
70
69
  Returns
71
70
  -------
72
71
  torch.Tensor
72
+
73
+ Warning
74
+ -------
75
+ Processing large quantities of data can be resource intensive.
73
76
  """
74
- return self[:]
77
+ if indices is not None:
78
+ return torch.vstack(list(self._batch(indices))).to(self.device)
79
+ else:
80
+ return self[:]
75
81
 
76
82
  # Reduce overhead cost by not tracking tensor gradients
77
83
  @torch.no_grad
@@ -86,9 +92,7 @@ class Embeddings:
86
92
  embeddings = self._encoder(torch.stack(images).to(self.device))
87
93
  yield embeddings
88
94
 
89
- def __getitem__(self, key: int | slice | list[int], /) -> torch.Tensor:
90
- if isinstance(key, list):
91
- return torch.vstack(list(self._batch(key))).to(self.device)
95
+ def __getitem__(self, key: int | slice, /) -> torch.Tensor:
92
96
  if isinstance(key, slice):
93
97
  return torch.vstack(list(self._batch(range(len(self._dataset))[key]))).to(self.device)
94
98
  elif isinstance(key, int):
@@ -3,7 +3,7 @@ from __future__ import annotations
3
3
  __all__ = []
4
4
 
5
5
  import warnings
6
- from typing import TYPE_CHECKING, Any, Literal, Mapping, Sequence
6
+ from typing import TYPE_CHECKING, Any, Literal, Mapping, Sequence, cast
7
7
 
8
8
  import numpy as np
9
9
  from numpy.typing import NDArray
@@ -11,6 +11,7 @@ from numpy.typing import NDArray
11
11
  from dataeval.typing import (
12
12
  AnnotatedDataset,
13
13
  Array,
14
+ ArrayLike,
14
15
  ObjectDetectionTarget,
15
16
  )
16
17
  from dataeval.utils._array import as_numpy, to_numpy
@@ -276,16 +277,12 @@ class Metadata:
276
277
  if self._processed and not force:
277
278
  return
278
279
 
279
- # Trigger collate and merge if not yet done
280
- self._collate()
281
- self._merge()
280
+ # Create image indices from targets
281
+ self._image_indices = np.arange(len(self.raw)) if self.targets.source is None else self.targets.source
282
282
 
283
283
  # Validate the metadata dimensions
284
284
  self._validate()
285
285
 
286
- # Create image indices from targets
287
- self._image_indices = np.arange(len(self.raw)) if self.targets.source is None else self.targets.source
288
-
289
286
  # Include specified metadata keys
290
287
  if self.include:
291
288
  metadata = {i: self.merged[i] for i in self.include if i in self.merged}
@@ -358,3 +355,15 @@ class Metadata:
358
355
  )
359
356
  self._total_num_factors = len(self._discrete_factor_names + self._continuous_factor_names) + 1
360
357
  self._processed = True
358
+
359
+ def add_factors(self, factors: Mapping[str, ArrayLike]) -> None:
360
+ self._merge()
361
+ self._processed = False
362
+ target_len = len(self.targets.source) if self.targets.source is not None else len(self.targets)
363
+ if any(len(v) != target_len for v in factors.values()):
364
+ raise ValueError(
365
+ "The lists/arrays in the provided factors have a different length than the current metadata factors."
366
+ )
367
+ merged = cast(tuple[dict[str, ArrayLike], dict[str, list[str]]], self._merged)[0]
368
+ for k, v in factors.items():
369
+ merged[k] = v
@@ -3,9 +3,9 @@ from __future__ import annotations
3
3
  __all__ = []
4
4
 
5
5
  from enum import IntEnum
6
- from typing import Any, Generic, Iterator, Sequence, TypeVar
6
+ from typing import Generic, Iterator, Sequence, TypeVar
7
7
 
8
- from dataeval.typing import AnnotatedDataset, DatasetMetadata
8
+ from dataeval.typing import AnnotatedDataset, DatasetMetadata, Transform
9
9
 
10
10
  _TDatum = TypeVar("_TDatum")
11
11
 
@@ -35,6 +35,8 @@ class Select(AnnotatedDataset[_TDatum]):
35
35
  The dataset to wrap.
36
36
  selections : Selection or list[Selection], optional
37
37
  The selection criteria to apply to the dataset.
38
+ transforms : Transform or list[Transform], optional
39
+ The transforms to apply to the dataset.
38
40
 
39
41
  Examples
40
42
  --------
@@ -67,13 +69,17 @@ class Select(AnnotatedDataset[_TDatum]):
67
69
  def __init__(
68
70
  self,
69
71
  dataset: AnnotatedDataset[_TDatum],
70
- selections: Selection[_TDatum] | list[Selection[_TDatum]] | None = None,
72
+ selections: Selection[_TDatum] | Sequence[Selection[_TDatum]] | None = None,
73
+ transforms: Transform[_TDatum] | Sequence[Transform[_TDatum]] | None = None,
71
74
  ) -> None:
75
+ self.__dict__.update(dataset.__dict__)
72
76
  self._dataset = dataset
73
77
  self._size_limit = len(dataset)
74
78
  self._selection = list(range(self._size_limit))
75
- self._selections = self._sort_selections(selections)
76
- self.__dict__.update(dataset.__dict__)
79
+ self._selections = self._sort(selections)
80
+ self._transforms = (
81
+ [] if transforms is None else [transforms] if isinstance(transforms, Transform) else transforms
82
+ )
77
83
 
78
84
  # Ensure metadata is populated correctly as DatasetMetadata TypedDict
79
85
  _metadata = getattr(dataset, "metadata", {})
@@ -81,8 +87,7 @@ class Select(AnnotatedDataset[_TDatum]):
81
87
  _metadata["id"] = dataset.__class__.__name__
82
88
  self._metadata = DatasetMetadata(**_metadata)
83
89
 
84
- if self._selections:
85
- self._apply_selections()
90
+ self._select()
86
91
 
87
92
  @property
88
93
  def metadata(self) -> DatasetMetadata:
@@ -92,10 +97,11 @@ class Select(AnnotatedDataset[_TDatum]):
92
97
  nt = "\n "
93
98
  title = f"{self.__class__.__name__} Dataset"
94
99
  sep = "-" * len(title)
95
- selections = f"Selections: [{', '.join([str(s) for s in self._sort_selections(self._selections)])}]"
96
- return f"{title}\n{sep}{nt}{selections}\n\n{self._dataset}"
100
+ selections = f"Selections: [{', '.join([str(s) for s in self._selections])}]"
101
+ transforms = f"Transforms: [{', '.join([str(t) for t in self._transforms])}]"
102
+ return f"{title}\n{sep}{nt}{selections}{nt}{transforms}{nt}Selected Size: {len(self)}\n\n{self._dataset}"
97
103
 
98
- def _sort_selections(self, selections: Selection[_TDatum] | Sequence[Selection[_TDatum]] | None) -> list[Selection]:
104
+ def _sort(self, selections: Selection[_TDatum] | Sequence[Selection[_TDatum]] | None) -> list[Selection]:
99
105
  if not selections:
100
106
  return []
101
107
 
@@ -106,17 +112,18 @@ class Select(AnnotatedDataset[_TDatum]):
106
112
  selection_list = [selection for category in sorted(grouped) for selection in grouped[category]]
107
113
  return selection_list
108
114
 
109
- def _apply_selections(self) -> None:
115
+ def _select(self) -> None:
110
116
  for selection in self._selections:
111
117
  selection(self)
112
118
  self._selection = self._selection[: self._size_limit]
113
119
 
114
- def __getattr__(self, name: str, /) -> Any:
115
- selfattr = getattr(self._dataset, name, None)
116
- return selfattr if selfattr is not None else getattr(self._dataset, name)
120
+ def _transform(self, datum: _TDatum) -> _TDatum:
121
+ for t in self._transforms:
122
+ datum = t(datum)
123
+ return datum
117
124
 
118
125
  def __getitem__(self, index: int) -> _TDatum:
119
- return self._dataset[self._selection[index]]
126
+ return self._transform(self._dataset[self._selection[index]])
120
127
 
121
128
  def __iter__(self) -> Iterator[_TDatum]:
122
129
  for i in range(len(self)):
@@ -12,6 +12,7 @@ from sklearn.metrics import silhouette_score
12
12
  from sklearn.model_selection import GroupKFold, KFold, StratifiedGroupKFold, StratifiedKFold
13
13
  from sklearn.utils.multiclass import type_of_target
14
14
 
15
+ from dataeval.config import get_seed
15
16
  from dataeval.outputs._base import set_metadata
16
17
  from dataeval.outputs._utils import SplitDatasetOutput, TrainValSplit
17
18
 
@@ -212,9 +213,9 @@ def bin_kmeans(array: NDArray[Any]) -> NDArray[np.intp]:
212
213
  best_score = 0.50
213
214
  bin_index = np.zeros(len(array), dtype=np.intp)
214
215
  for k in range(2, 20):
215
- clusterer = KMeans(n_clusters=k)
216
+ clusterer = KMeans(n_clusters=k, random_state=get_seed())
216
217
  cluster_labels = clusterer.fit_predict(array)
217
- score = silhouette_score(array, cluster_labels, sample_size=25_000)
218
+ score = silhouette_score(array, cluster_labels, sample_size=25_000, random_state=get_seed())
218
219
  if score > best_score:
219
220
  best_score = score
220
221
  bin_index = cluster_labels.astype(np.intp)
@@ -4,7 +4,7 @@ __all__ = []
4
4
 
5
5
  from abc import abstractmethod
6
6
  from pathlib import Path
7
- from typing import Any, Generic, Iterator, Literal, NamedTuple, Sequence, TypeVar
7
+ from typing import TYPE_CHECKING, Any, Generic, Iterator, Literal, NamedTuple, Sequence, TypeVar
8
8
 
9
9
  from dataeval.utils.data.datasets._fileio import _ensure_exists
10
10
  from dataeval.utils.data.datasets._mixin import BaseDatasetMixin
@@ -16,9 +16,11 @@ from dataeval.utils.data.datasets._types import (
16
16
  ObjectDetectionTarget,
17
17
  SegmentationDataset,
18
18
  SegmentationTarget,
19
- Transform,
20
19
  )
21
20
 
21
+ if TYPE_CHECKING:
22
+ from dataeval.typing import Transform
23
+
22
24
  _TArray = TypeVar("_TArray")
23
25
  _TTarget = TypeVar("_TTarget")
24
26
  _TRawTarget = TypeVar("_TRawTarget", list[int], list[str])