dataeval 0.83.0__py3-none-any.whl → 0.84.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (50) hide show
  1. dataeval/__init__.py +1 -1
  2. dataeval/config.py +3 -3
  3. dataeval/detectors/drift/__init__.py +2 -2
  4. dataeval/detectors/drift/_base.py +55 -203
  5. dataeval/detectors/drift/_cvm.py +19 -30
  6. dataeval/detectors/drift/_ks.py +18 -30
  7. dataeval/detectors/drift/_mmd.py +189 -53
  8. dataeval/detectors/drift/_uncertainty.py +52 -56
  9. dataeval/detectors/drift/updates.py +13 -12
  10. dataeval/detectors/linters/duplicates.py +5 -3
  11. dataeval/detectors/linters/outliers.py +2 -2
  12. dataeval/detectors/ood/ae.py +1 -1
  13. dataeval/metrics/bias/__init__.py +11 -1
  14. dataeval/metrics/bias/_completeness.py +130 -0
  15. dataeval/metrics/stats/_base.py +28 -32
  16. dataeval/metrics/stats/_dimensionstats.py +2 -2
  17. dataeval/metrics/stats/_hashstats.py +2 -2
  18. dataeval/metrics/stats/_imagestats.py +4 -4
  19. dataeval/metrics/stats/_labelstats.py +4 -45
  20. dataeval/metrics/stats/_pixelstats.py +2 -2
  21. dataeval/metrics/stats/_visualstats.py +2 -2
  22. dataeval/outputs/__init__.py +2 -1
  23. dataeval/outputs/_bias.py +31 -22
  24. dataeval/outputs/_stats.py +2 -3
  25. dataeval/typing.py +25 -22
  26. dataeval/utils/_array.py +43 -7
  27. dataeval/utils/data/_dataset.py +8 -4
  28. dataeval/utils/data/_embeddings.py +141 -24
  29. dataeval/utils/data/_images.py +38 -15
  30. dataeval/utils/data/_metadata.py +5 -4
  31. dataeval/utils/data/_selection.py +3 -15
  32. dataeval/utils/data/_split.py +76 -129
  33. dataeval/utils/data/datasets/_base.py +7 -4
  34. dataeval/utils/data/datasets/_cifar10.py +9 -9
  35. dataeval/utils/data/datasets/_milco.py +42 -14
  36. dataeval/utils/data/datasets/_mnist.py +9 -5
  37. dataeval/utils/data/datasets/_ships.py +8 -4
  38. dataeval/utils/data/datasets/_voc.py +40 -19
  39. dataeval/utils/data/selections/__init__.py +2 -0
  40. dataeval/utils/data/selections/_classbalance.py +38 -0
  41. dataeval/utils/data/selections/_classfilter.py +14 -29
  42. dataeval/utils/data/selections/_prioritize.py +1 -1
  43. dataeval/utils/data/selections/_shuffle.py +2 -2
  44. dataeval/utils/metadata.py +1 -1
  45. dataeval/utils/torch/_internal.py +12 -35
  46. {dataeval-0.83.0.dist-info → dataeval-0.84.1.dist-info}/METADATA +2 -3
  47. {dataeval-0.83.0.dist-info → dataeval-0.84.1.dist-info}/RECORD +49 -48
  48. dataeval/detectors/drift/_torch.py +0 -222
  49. {dataeval-0.83.0.dist-info → dataeval-0.84.1.dist-info}/LICENSE.txt +0 -0
  50. {dataeval-0.83.0.dist-info → dataeval-0.84.1.dist-info}/WHEEL +0 -0
dataeval/typing.py CHANGED
@@ -21,9 +21,10 @@ __all__ = [
21
21
 
22
22
 
23
23
  import sys
24
- from typing import Any, Generic, Iterator, Protocol, Sequence, TypedDict, TypeVar, Union, runtime_checkable
24
+ from typing import Any, Generic, Iterator, Protocol, TypedDict, TypeVar, runtime_checkable
25
25
 
26
- from typing_extensions import NotRequired, Required
26
+ import numpy.typing
27
+ from typing_extensions import NotRequired, ReadOnly, Required
27
28
 
28
29
  if sys.version_info >= (3, 10):
29
30
  from typing import TypeAlias
@@ -31,6 +32,16 @@ else:
31
32
  from typing_extensions import TypeAlias
32
33
 
33
34
 
35
+ ArrayLike: TypeAlias = numpy.typing.ArrayLike
36
+ """
37
+ Type alias for a `Union` representing objects that can be coerced into an array.
38
+
39
+ See Also
40
+ --------
41
+ `NumPy ArrayLike <https://numpy.org/doc/stable/reference/typing.html#numpy.typing.ArrayLike>`_
42
+ """
43
+
44
+
34
45
  @runtime_checkable
35
46
  class Array(Protocol):
36
47
  """
@@ -67,16 +78,8 @@ class Array(Protocol):
67
78
  def __len__(self) -> int: ...
68
79
 
69
80
 
70
- T = TypeVar("T")
81
+ _T = TypeVar("_T")
71
82
  _T_co = TypeVar("_T_co", covariant=True)
72
- _ScalarType = Union[int, float, bool, str]
73
- ArrayLike: TypeAlias = Union[Sequence[_ScalarType], Sequence[Sequence[_ScalarType]], Sequence[Array], Array]
74
- """
75
- Type alias for array-like objects used for interoperability with DataEval.
76
-
77
- This includes native Python sequences, as well as objects that conform to
78
- the :class:`Array` protocol.
79
- """
80
83
 
81
84
 
82
85
  class DatasetMetadata(TypedDict, total=False):
@@ -91,8 +94,8 @@ class DatasetMetadata(TypedDict, total=False):
91
94
  A lookup table converting label value to class name
92
95
  """
93
96
 
94
- id: Required[str]
95
- index2label: NotRequired[dict[int, str]]
97
+ id: Required[ReadOnly[str]]
98
+ index2label: NotRequired[ReadOnly[dict[int, str]]]
96
99
 
97
100
 
98
101
  @runtime_checkable
@@ -140,12 +143,12 @@ class AnnotatedDataset(Dataset[_T_co], Generic[_T_co], Protocol):
140
143
  # ========== IMAGE CLASSIFICATION DATASETS ==========
141
144
 
142
145
 
143
- ImageClassificationDatum: TypeAlias = tuple[Array, Array, dict[str, Any]]
146
+ ImageClassificationDatum: TypeAlias = tuple[ArrayLike, ArrayLike, dict[str, Any]]
144
147
  """
145
148
  Type alias for an image classification datum tuple.
146
149
 
147
- - :class:`Array` of shape (C, H, W) - Image data in channel, height, width format.
148
- - :class:`Array` of shape (N,) - Class label as one-hot encoded ground-truth or prediction confidences.
150
+ - :class:`ArrayLike` of shape (C, H, W) - Image data in channel, height, width format.
151
+ - :class:`ArrayLike` of shape (N,) - Class label as one-hot encoded ground-truth or prediction confidences.
149
152
  - dict[str, Any] - Datum level metadata.
150
153
  """
151
154
 
@@ -180,11 +183,11 @@ class ObjectDetectionTarget(Protocol):
180
183
  def scores(self) -> ArrayLike: ...
181
184
 
182
185
 
183
- ObjectDetectionDatum: TypeAlias = tuple[Array, ObjectDetectionTarget, dict[str, Any]]
186
+ ObjectDetectionDatum: TypeAlias = tuple[ArrayLike, ObjectDetectionTarget, dict[str, Any]]
184
187
  """
185
188
  Type alias for an object detection datum tuple.
186
189
 
187
- - :class:`Array` of shape (C, H, W) - Image data in channel, height, width format.
190
+ - :class:`ArrayLike` of shape (C, H, W) - Image data in channel, height, width format.
188
191
  - :class:`ObjectDetectionTarget` - Object detection target information for the image.
189
192
  - dict[str, Any] - Datum level metadata.
190
193
  """
@@ -221,11 +224,11 @@ class SegmentationTarget(Protocol):
221
224
  def scores(self) -> ArrayLike: ...
222
225
 
223
226
 
224
- SegmentationDatum: TypeAlias = tuple[Array, SegmentationTarget, dict[str, Any]]
227
+ SegmentationDatum: TypeAlias = tuple[ArrayLike, SegmentationTarget, dict[str, Any]]
225
228
  """
226
229
  Type alias for an image classification datum tuple.
227
230
 
228
- - :class:`Array` of shape (C, H, W) - Image data in channel, height, width format.
231
+ - :class:`ArrayLike` of shape (C, H, W) - Image data in channel, height, width format.
229
232
  - :class:`SegmentationTarget` - Segmentation target information for the image.
230
233
  - dict[str, Any] - Datum level metadata.
231
234
  """
@@ -237,7 +240,7 @@ Type alias for an :class:`AnnotatedDataset` of :class:`SegmentationDatum` elemen
237
240
 
238
241
 
239
242
  @runtime_checkable
240
- class Transform(Generic[T], Protocol):
243
+ class Transform(Generic[_T], Protocol):
241
244
  """
242
245
  Protocol defining a transform function.
243
246
 
@@ -262,4 +265,4 @@ class Transform(Generic[T], Protocol):
262
265
  array([0.004, 0.008, 0.012])
263
266
  """
264
267
 
265
- def __call__(self, data: T, /) -> T: ...
268
+ def __call__(self, data: _T, /) -> _T: ...
dataeval/utils/_array.py CHANGED
@@ -13,7 +13,7 @@ import torch
13
13
  from numpy.typing import NDArray
14
14
 
15
15
  from dataeval._log import LogMessage
16
- from dataeval.typing import ArrayLike
16
+ from dataeval.typing import Array, ArrayLike
17
17
 
18
18
  _logger = logging.getLogger(__name__)
19
19
 
@@ -92,7 +92,7 @@ def ensure_embeddings(
92
92
  @overload
93
93
  def ensure_embeddings(
94
94
  embeddings: T,
95
- dtype: None,
95
+ dtype: None = None,
96
96
  unit_interval: Literal[True, False, "force"] = False,
97
97
  ) -> T: ...
98
98
 
@@ -152,18 +152,54 @@ def ensure_embeddings(
152
152
  return arr
153
153
 
154
154
 
155
- def flatten(array: ArrayLike) -> NDArray[Any]:
155
+ @overload
156
+ def flatten(array: torch.Tensor) -> torch.Tensor: ...
157
+ @overload
158
+ def flatten(array: ArrayLike) -> NDArray[Any]: ...
159
+
160
+
161
+ def flatten(array: ArrayLike) -> NDArray[Any] | torch.Tensor:
156
162
  """
157
163
  Flattens input array from (N, ... ) to (N, -1) where all samples N have all data in their last dimension
158
164
 
159
165
  Parameters
160
166
  ----------
161
- X : NDArray, shape - (N, ... )
167
+ array : ArrayLike
162
168
  Input array
163
169
 
164
170
  Returns
165
171
  -------
166
- NDArray, shape - (N, -1)
172
+ np.ndarray or torch.Tensor, shape: (N, -1)
173
+ """
174
+ if isinstance(array, np.ndarray):
175
+ nparr = as_numpy(array)
176
+ return nparr.reshape((nparr.shape[0], -1))
177
+ elif isinstance(array, torch.Tensor):
178
+ return torch.flatten(array, start_dim=1)
179
+ else:
180
+ raise TypeError(f"Unsupported array type {type(array)}.")
181
+
182
+
183
+ _TArray = TypeVar("_TArray", bound=Array)
184
+
185
+
186
+ def channels_first_to_last(array: _TArray) -> _TArray:
167
187
  """
168
- nparr = as_numpy(array)
169
- return nparr.reshape((nparr.shape[0], -1))
188
+ Converts array from channels first to channels last format
189
+
190
+ Parameters
191
+ ----------
192
+ array : ArrayLike
193
+ Input array
194
+
195
+ Returns
196
+ -------
197
+ ArrayLike
198
+ Converted array
199
+ """
200
+ if isinstance(array, np.ndarray):
201
+ return np.transpose(array, (1, 2, 0))
202
+ elif isinstance(array, torch.Tensor):
203
+ return torch.permute(array, (1, 2, 0))
204
+ else:
205
+ raise TypeError(f"Unsupported array type {type(array)}.")
@@ -47,13 +47,17 @@ def _validate_data(
47
47
  or not len(bboxes[0][0]) == 4
48
48
  ):
49
49
  raise TypeError("Boxes must be a sequence of sequences of (x0, y0, x1, y1) for object detection.")
50
+ else:
51
+ raise ValueError(f"Unknown datum type '{datum_type}'. Must be 'ic' or 'od'.")
50
52
 
51
53
 
52
54
  def _find_max(arr: ArrayLike) -> Any:
53
- if isinstance(arr[0], (Iterable, Sequence, Array)):
54
- return max([_find_max(x) for x in arr]) # type: ignore
55
- else:
56
- return max(arr)
55
+ if isinstance(arr, (Iterable, Sequence, Array)):
56
+ if isinstance(arr[0], (Iterable, Sequence, Array)):
57
+ return max([_find_max(x) for x in arr]) # type: ignore
58
+ else:
59
+ return max(arr)
60
+ return arr
57
61
 
58
62
 
59
63
  _TLabels = TypeVar("_TLabels", Sequence[int], Sequence[Sequence[int]])
@@ -3,14 +3,16 @@ from __future__ import annotations
3
3
  __all__ = []
4
4
 
5
5
  import math
6
- from typing import Any, Iterator, Sequence
6
+ from typing import Any, Iterator, Sequence, cast
7
7
 
8
8
  import torch
9
+ from numpy.typing import NDArray
9
10
  from torch.utils.data import DataLoader, Subset
10
11
  from tqdm import tqdm
11
12
 
12
13
  from dataeval.config import DeviceLike, get_device
13
- from dataeval.typing import Array, Dataset
14
+ from dataeval.typing import Array, ArrayLike, Dataset, Transform
15
+ from dataeval.utils._array import as_numpy
14
16
  from dataeval.utils.torch.models import SupportsEncode
15
17
 
16
18
 
@@ -26,11 +28,15 @@ class Embeddings:
26
28
  Dataset to access original images from.
27
29
  batch_size : int
28
30
  Batch size to use when encoding images.
31
+ transforms : Transform or Sequence[Transform] or None, default None
32
+ Transforms to apply to images before encoding.
29
33
  model : torch.nn.Module or None, default None
30
34
  Model to use for encoding images.
31
35
  device : DeviceLike or None, default None
32
36
  The hardware device to use if specified, otherwise uses the DataEval
33
37
  default or torch default.
38
+ cache : bool, default False
39
+ Whether to cache the embeddings in memory.
34
40
  verbose : bool, default False
35
41
  Whether to print progress bar when encoding images.
36
42
  """
@@ -41,21 +47,29 @@ class Embeddings:
41
47
 
42
48
  def __init__(
43
49
  self,
44
- dataset: Dataset[tuple[Array, Any, Any]],
50
+ dataset: Dataset[tuple[ArrayLike, Any, Any]] | Dataset[ArrayLike],
45
51
  batch_size: int,
52
+ transforms: Transform[torch.Tensor] | Sequence[Transform[torch.Tensor]] | None = None,
46
53
  model: torch.nn.Module | None = None,
47
54
  device: DeviceLike | None = None,
55
+ cache: bool = False,
48
56
  verbose: bool = False,
49
57
  ) -> None:
50
58
  self.device = get_device(device)
51
- self.batch_size = batch_size
59
+ self.cache = cache
60
+ self.batch_size = batch_size if batch_size > 0 else 1
52
61
  self.verbose = verbose
53
62
 
54
63
  self._dataset = dataset
64
+ self._length = len(dataset)
55
65
  model = torch.nn.Flatten() if model is None else model
56
- self._model = model.to(self.device).eval()
66
+ self._transforms = [transforms] if isinstance(transforms, Transform) else transforms
67
+ self._model = model.to(self.device).eval() if isinstance(model, torch.nn.Module) else model
57
68
  self._encoder = model.encode if isinstance(model, SupportsEncode) else model
58
- self._collate_fn = lambda datum: [torch.as_tensor(i) for i, _, _ in datum]
69
+ self._collate_fn = lambda datum: [torch.as_tensor(d[0] if isinstance(d, tuple) else d) for d in datum]
70
+ self._cached_idx = set()
71
+ self._embeddings: torch.Tensor = torch.empty(())
72
+ self._shallow: bool = False
59
73
 
60
74
  def to_tensor(self, indices: Sequence[int] | None = None) -> torch.Tensor:
61
75
  """
@@ -79,30 +93,133 @@ class Embeddings:
79
93
  else:
80
94
  return self[:]
81
95
 
82
- # Reduce overhead cost by not tracking tensor gradients
83
- @torch.no_grad
96
+ def to_numpy(self, indices: Sequence[int] | None = None) -> NDArray[Any]:
97
+ """
98
+ Converts dataset to embeddings as numpy array.
99
+
100
+ Parameters
101
+ ----------
102
+ indices : Sequence[int] or None, default None
103
+ The indices to convert to embeddings
104
+
105
+ Returns
106
+ -------
107
+ NDArray[Any]
108
+
109
+ Warning
110
+ -------
111
+ Processing large quantities of data can be resource intensive.
112
+ """
113
+ return self.to_tensor(indices).cpu().numpy()
114
+
115
+ def new(self, dataset: Dataset[tuple[ArrayLike, Any, Any]] | Dataset[ArrayLike]) -> Embeddings:
116
+ """
117
+ Creates a new Embeddings object with the same parameters but a different dataset.
118
+
119
+ Parameters
120
+ ----------
121
+ dataset : ImageClassificationDataset or ObjectDetectionDataset
122
+ Dataset to access original images from.
123
+
124
+ Returns
125
+ -------
126
+ Embeddings
127
+ """
128
+ return Embeddings(
129
+ dataset, self.batch_size, self._transforms, self._model, self.device, self.cache, self.verbose
130
+ )
131
+
132
+ @classmethod
133
+ def from_array(cls, array: ArrayLike, device: DeviceLike | None = None) -> Embeddings:
134
+ """
135
+ Instantiates a shallow Embeddings object using an array.
136
+
137
+ Parameters
138
+ ----------
139
+ array : ArrayLike
140
+ The array to convert to embeddings.
141
+ device : DeviceLike or None, default None
142
+ The hardware device to use if specified, otherwise uses the DataEval
143
+ default or torch default.
144
+
145
+ Returns
146
+ -------
147
+ Embeddings
148
+
149
+ Example
150
+ -------
151
+ >>> import numpy as np
152
+ >>> from dataeval.utils.data._embeddings import Embeddings
153
+ >>> array = np.random.randn(100, 3, 224, 224)
154
+ >>> embeddings = Embeddings.from_array(array)
155
+ >>> print(embeddings.to_tensor().shape)
156
+ torch.Size([100, 3, 224, 224])
157
+ """
158
+ embeddings = Embeddings([], 0, None, None, device, True, False)
159
+ array = array if isinstance(array, Array) else as_numpy(array)
160
+ embeddings._length = len(array)
161
+ embeddings._cached_idx = set(range(len(array)))
162
+ embeddings._embeddings = torch.as_tensor(array).to(get_device(device))
163
+ embeddings._shallow = True
164
+ return embeddings
165
+
166
+ def _encode(self, images: list[torch.Tensor]) -> torch.Tensor:
167
+ if self._transforms:
168
+ images = [transform(image) for transform in self._transforms for image in images]
169
+ return self._encoder(torch.stack(images).to(self.device))
170
+
171
+ @torch.no_grad() # Reduce overhead cost by not tracking tensor gradients
84
172
  def _batch(self, indices: Sequence[int]) -> Iterator[torch.Tensor]:
85
- # manual batching
86
- dataloader = DataLoader(Subset(self._dataset, indices), batch_size=self.batch_size, collate_fn=self._collate_fn) # type: ignore
87
- for i, images in (
88
- tqdm(enumerate(dataloader), total=math.ceil(len(indices) / self.batch_size), desc="Batch processing")
89
- if self.verbose
90
- else enumerate(dataloader)
91
- ):
92
- embeddings = self._encoder(torch.stack(images).to(self.device))
93
- yield embeddings
173
+ dataset = cast(torch.utils.data.Dataset, self._dataset)
174
+ total_batches = math.ceil(len(indices) / self.batch_size)
175
+
176
+ # If not caching, process all indices normally
177
+ if not self.cache:
178
+ for images in tqdm(
179
+ DataLoader(Subset(dataset, indices), self.batch_size, collate_fn=self._collate_fn),
180
+ total=total_batches,
181
+ desc="Batch embedding",
182
+ disable=not self.verbose,
183
+ ):
184
+ yield self._encode(images)
185
+ return
186
+
187
+ # If caching, process each batch of indices at a time, preserving original order
188
+ for i in tqdm(range(0, len(indices), self.batch_size), desc="Batch embedding", disable=not self.verbose):
189
+ batch = indices[i : i + self.batch_size]
190
+ uncached = [idx for idx in batch if idx not in self._cached_idx]
191
+
192
+ if uncached:
193
+ # Process uncached indices as as single batch
194
+ for images in DataLoader(Subset(dataset, uncached), len(uncached), collate_fn=self._collate_fn):
195
+ embeddings = self._encode(images)
196
+
197
+ if not self._embeddings.shape:
198
+ full_shape = (len(self._dataset), *embeddings.shape[1:])
199
+ self._embeddings = torch.empty(full_shape, dtype=embeddings.dtype, device=self.device)
200
+
201
+ self._embeddings[uncached] = embeddings
202
+ self._cached_idx.update(uncached)
203
+
204
+ yield self._embeddings[batch]
94
205
 
95
206
  def __getitem__(self, key: int | slice, /) -> torch.Tensor:
96
- if isinstance(key, slice):
97
- return torch.vstack(list(self._batch(range(len(self._dataset))[key]))).to(self.device)
98
- elif isinstance(key, int):
99
- return self._encoder(torch.as_tensor(self._dataset[key][0]).to(self.device))
100
- raise TypeError("Invalid argument type.")
207
+ if not isinstance(key, slice) and not hasattr(key, "__int__"):
208
+ raise TypeError("Invalid argument type.")
209
+
210
+ if self._shallow:
211
+ if not self._embeddings.shape:
212
+ raise ValueError("Embeddings not initialized.")
213
+ return self._embeddings[key]
214
+
215
+ indices = list(range(len(self._dataset))[key]) if isinstance(key, slice) else [int(key)]
216
+ result = torch.vstack(list(self._batch(indices))).to(self.device)
217
+ return result.squeeze(0) if len(indices) == 1 else result
101
218
 
102
219
  def __iter__(self) -> Iterator[torch.Tensor]:
103
220
  # process in batches while yielding individual embeddings
104
- for batch in self._batch(range(len(self._dataset))):
221
+ for batch in self._batch(range(self._length)):
105
222
  yield from batch
106
223
 
107
224
  def __len__(self) -> int:
108
- return len(self._dataset)
225
+ return self._length
@@ -2,11 +2,15 @@ from __future__ import annotations
2
2
 
3
3
  __all__ = []
4
4
 
5
- from typing import Any, Generic, Iterator, Sequence, TypeVar, cast, overload
5
+ from typing import TYPE_CHECKING, Any, Generic, Iterator, Sequence, TypeVar, cast, overload
6
6
 
7
- from dataeval.typing import Dataset
7
+ from dataeval.typing import Array, ArrayLike, Dataset
8
+ from dataeval.utils._array import as_numpy, channels_first_to_last
8
9
 
9
- T = TypeVar("T")
10
+ if TYPE_CHECKING:
11
+ from matplotlib.figure import Figure
12
+
13
+ T = TypeVar("T", Array, ArrayLike)
10
14
 
11
15
 
12
16
  class Images(Generic[T]):
@@ -21,7 +25,10 @@ class Images(Generic[T]):
21
25
  Dataset to access images from.
22
26
  """
23
27
 
24
- def __init__(self, dataset: Dataset[tuple[T, Any, Any] | T]) -> None:
28
+ def __init__(
29
+ self,
30
+ dataset: Dataset[tuple[T, Any, Any] | T],
31
+ ) -> None:
25
32
  self._is_tuple_datum = isinstance(dataset[0], tuple)
26
33
  self._dataset = dataset
27
34
 
@@ -40,25 +47,41 @@ class Images(Generic[T]):
40
47
  """
41
48
  return self[:]
42
49
 
50
+ def plot(
51
+ self,
52
+ indices: Sequence[int],
53
+ images_per_row: int = 3,
54
+ figsize: tuple[int, int] = (10, 10),
55
+ ) -> Figure:
56
+ import matplotlib.pyplot as plt
57
+
58
+ num_images = len(indices)
59
+ num_rows = (num_images + images_per_row - 1) // images_per_row
60
+ fig, axes = plt.subplots(num_rows, images_per_row, figsize=figsize)
61
+ for i, ax in enumerate(axes.flatten()):
62
+ image = channels_first_to_last(as_numpy(self[i]))
63
+ ax.imshow(image)
64
+ ax.axis("off")
65
+ plt.tight_layout()
66
+ return fig
67
+
43
68
  @overload
44
69
  def __getitem__(self, key: int, /) -> T: ...
45
70
  @overload
46
71
  def __getitem__(self, key: slice, /) -> Sequence[T]: ...
47
72
 
48
73
  def __getitem__(self, key: int | slice, /) -> Sequence[T] | T:
74
+ if isinstance(key, slice):
75
+ return [self._get_image(k) for k in range(len(self._dataset))[key]]
76
+ elif hasattr(key, "__int__"):
77
+ return self._get_image(int(key))
78
+ raise TypeError(f"Key must be integers or slices, not {type(key)}")
79
+
80
+ def _get_image(self, index: int) -> T:
49
81
  if self._is_tuple_datum:
50
- dataset = cast(Dataset[tuple[T, Any, Any]], self._dataset)
51
- if isinstance(key, slice):
52
- return [dataset[k][0] for k in range(len(self._dataset))[key]]
53
- elif isinstance(key, int):
54
- return dataset[key][0]
82
+ return cast(Dataset[tuple[T, Any, Any]], self._dataset)[index][0]
55
83
  else:
56
- dataset = cast(Dataset[T], self._dataset)
57
- if isinstance(key, slice):
58
- return [dataset[k] for k in range(len(self._dataset))[key]]
59
- elif isinstance(key, int):
60
- return dataset[key]
61
- raise TypeError(f"Key must be integers or slices, not {type(key)}")
84
+ return cast(Dataset[T], self._dataset)[index]
62
85
 
63
86
  def __iter__(self) -> Iterator[T]:
64
87
  for i in range(len(self._dataset)):
@@ -3,7 +3,7 @@ from __future__ import annotations
3
3
  __all__ = []
4
4
 
5
5
  import warnings
6
- from typing import TYPE_CHECKING, Any, Literal, Mapping, Sequence, cast
6
+ from typing import TYPE_CHECKING, Any, Literal, Mapping, Sequence, Sized, cast
7
7
 
8
8
  import numpy as np
9
9
  from numpy.typing import NDArray
@@ -208,8 +208,9 @@ class Metadata:
208
208
  raw.append(metadata)
209
209
 
210
210
  if is_od_target := isinstance(target, ObjectDetectionTarget):
211
- target_len = len(target.labels)
212
- labels.extend(as_numpy(target.labels).tolist())
211
+ target_labels = as_numpy(target.labels)
212
+ target_len = len(target_labels)
213
+ labels.extend(target_labels.tolist())
213
214
  bboxes.extend(as_numpy(target.boxes).tolist())
214
215
  scores.extend(as_numpy(target.scores).tolist())
215
216
  srcidx.extend([i] * target_len)
@@ -360,7 +361,7 @@ class Metadata:
360
361
  self._merge()
361
362
  self._processed = False
362
363
  target_len = len(self.targets.source) if self.targets.source is not None else len(self.targets)
363
- if any(len(v) != target_len for v in factors.values()):
364
+ if any(len(v if isinstance(v, Sized) else as_numpy(v)) != target_len for v in factors.values()):
364
365
  raise ValueError(
365
366
  "The lists/arrays in the provided factors have a different length than the current metadata factors."
366
367
  )
@@ -5,7 +5,7 @@ __all__ = []
5
5
  from enum import IntEnum
6
6
  from typing import Generic, Iterator, Sequence, TypeVar
7
7
 
8
- from dataeval.typing import AnnotatedDataset, DatasetMetadata, Transform
8
+ from dataeval.typing import AnnotatedDataset, DatasetMetadata
9
9
 
10
10
  _TDatum = TypeVar("_TDatum")
11
11
 
@@ -35,8 +35,6 @@ class Select(AnnotatedDataset[_TDatum]):
35
35
  The dataset to wrap.
36
36
  selections : Selection or list[Selection], optional
37
37
  The selection criteria to apply to the dataset.
38
- transforms : Transform or list[Transform], optional
39
- The transforms to apply to the dataset.
40
38
 
41
39
  Examples
42
40
  --------
@@ -70,16 +68,12 @@ class Select(AnnotatedDataset[_TDatum]):
70
68
  self,
71
69
  dataset: AnnotatedDataset[_TDatum],
72
70
  selections: Selection[_TDatum] | Sequence[Selection[_TDatum]] | None = None,
73
- transforms: Transform[_TDatum] | Sequence[Transform[_TDatum]] | None = None,
74
71
  ) -> None:
75
72
  self.__dict__.update(dataset.__dict__)
76
73
  self._dataset = dataset
77
74
  self._size_limit = len(dataset)
78
75
  self._selection = list(range(self._size_limit))
79
76
  self._selections = self._sort(selections)
80
- self._transforms = (
81
- [] if transforms is None else [transforms] if isinstance(transforms, Transform) else transforms
82
- )
83
77
 
84
78
  # Ensure metadata is populated correctly as DatasetMetadata TypedDict
85
79
  _metadata = getattr(dataset, "metadata", {})
@@ -98,8 +92,7 @@ class Select(AnnotatedDataset[_TDatum]):
98
92
  title = f"{self.__class__.__name__} Dataset"
99
93
  sep = "-" * len(title)
100
94
  selections = f"Selections: [{', '.join([str(s) for s in self._selections])}]"
101
- transforms = f"Transforms: [{', '.join([str(t) for t in self._transforms])}]"
102
- return f"{title}\n{sep}{nt}{selections}{nt}{transforms}{nt}Selected Size: {len(self)}\n\n{self._dataset}"
95
+ return f"{title}\n{sep}{nt}{selections}{nt}Selected Size: {len(self)}\n\n{self._dataset}"
103
96
 
104
97
  def _sort(self, selections: Selection[_TDatum] | Sequence[Selection[_TDatum]] | None) -> list[Selection]:
105
98
  if not selections:
@@ -117,13 +110,8 @@ class Select(AnnotatedDataset[_TDatum]):
117
110
  selection(self)
118
111
  self._selection = self._selection[: self._size_limit]
119
112
 
120
- def _transform(self, datum: _TDatum) -> _TDatum:
121
- for t in self._transforms:
122
- datum = t(datum)
123
- return datum
124
-
125
113
  def __getitem__(self, index: int) -> _TDatum:
126
- return self._transform(self._dataset[self._selection[index]])
114
+ return self._dataset[self._selection[index]]
127
115
 
128
116
  def __iter__(self) -> Iterator[_TDatum]:
129
117
  for i in range(len(self)):