dataeval 0.84.0__py3-none-any.whl → 0.84.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (40) hide show
  1. dataeval/__init__.py +1 -1
  2. dataeval/detectors/drift/__init__.py +2 -2
  3. dataeval/detectors/drift/_base.py +55 -203
  4. dataeval/detectors/drift/_cvm.py +19 -30
  5. dataeval/detectors/drift/_ks.py +18 -30
  6. dataeval/detectors/drift/_mmd.py +189 -53
  7. dataeval/detectors/drift/_uncertainty.py +52 -56
  8. dataeval/detectors/drift/updates.py +13 -12
  9. dataeval/detectors/linters/duplicates.py +5 -3
  10. dataeval/detectors/linters/outliers.py +2 -2
  11. dataeval/detectors/ood/ae.py +1 -1
  12. dataeval/metrics/stats/_base.py +7 -7
  13. dataeval/metrics/stats/_dimensionstats.py +2 -2
  14. dataeval/metrics/stats/_hashstats.py +2 -2
  15. dataeval/metrics/stats/_imagestats.py +4 -4
  16. dataeval/metrics/stats/_pixelstats.py +2 -2
  17. dataeval/metrics/stats/_visualstats.py +2 -2
  18. dataeval/typing.py +22 -19
  19. dataeval/utils/_array.py +18 -7
  20. dataeval/utils/data/_dataset.py +6 -4
  21. dataeval/utils/data/_embeddings.py +46 -7
  22. dataeval/utils/data/_images.py +2 -2
  23. dataeval/utils/data/_metadata.py +5 -4
  24. dataeval/utils/data/datasets/_base.py +7 -4
  25. dataeval/utils/data/datasets/_cifar10.py +9 -9
  26. dataeval/utils/data/datasets/_milco.py +42 -14
  27. dataeval/utils/data/datasets/_mnist.py +9 -5
  28. dataeval/utils/data/datasets/_ships.py +8 -4
  29. dataeval/utils/data/datasets/_voc.py +40 -19
  30. dataeval/utils/data/selections/__init__.py +2 -0
  31. dataeval/utils/data/selections/_classbalance.py +38 -0
  32. dataeval/utils/data/selections/_classfilter.py +14 -29
  33. dataeval/utils/data/selections/_prioritize.py +1 -1
  34. dataeval/utils/data/selections/_shuffle.py +2 -2
  35. dataeval/utils/torch/_internal.py +12 -35
  36. {dataeval-0.84.0.dist-info → dataeval-0.84.1.dist-info}/METADATA +2 -3
  37. {dataeval-0.84.0.dist-info → dataeval-0.84.1.dist-info}/RECORD +39 -39
  38. dataeval/detectors/drift/_torch.py +0 -222
  39. {dataeval-0.84.0.dist-info → dataeval-0.84.1.dist-info}/LICENSE.txt +0 -0
  40. {dataeval-0.84.0.dist-info → dataeval-0.84.1.dist-info}/WHEEL +0 -0
@@ -10,7 +10,7 @@ from copy import deepcopy
10
10
  from dataclasses import dataclass
11
11
  from functools import partial
12
12
  from multiprocessing import Pool
13
- from typing import Any, Callable, Generic, Iterable, Sequence, TypeVar, cast
13
+ from typing import Any, Callable, Generic, Iterable, Sequence, TypeVar
14
14
 
15
15
  import numpy as np
16
16
  import tqdm
@@ -19,7 +19,7 @@ from numpy.typing import NDArray
19
19
  from dataeval.config import get_max_processes
20
20
  from dataeval.outputs._stats import BaseStatsOutput, SourceIndex
21
21
  from dataeval.typing import Array, ArrayLike, Dataset, ObjectDetectionTarget
22
- from dataeval.utils._array import to_numpy
22
+ from dataeval.utils._array import as_numpy, to_numpy
23
23
  from dataeval.utils._image import normalize_image_shape, rescale
24
24
 
25
25
  DTYPE_REGEX = re.compile(r"NDArray\[np\.(.*?)\]")
@@ -138,19 +138,19 @@ def process_stats(
138
138
 
139
139
 
140
140
  def process_stats_unpack(
141
- args: tuple[int, Array, list[BoundingBox] | None],
141
+ args: tuple[int, ArrayLike, list[BoundingBox] | None],
142
142
  per_channel: bool,
143
143
  stats_processor_cls: Iterable[type[StatsProcessor[TStatsOutput]]],
144
144
  ) -> StatsProcessorOutput:
145
145
  return process_stats(*args, per_channel=per_channel, stats_processor_cls=stats_processor_cls)
146
146
 
147
147
 
148
- def _enumerate(dataset: Dataset[Array] | Dataset[tuple[Array, Any, Any]], per_box: bool):
148
+ def _enumerate(dataset: Dataset[ArrayLike] | Dataset[tuple[ArrayLike, Any, Any]], per_box: bool):
149
149
  for i in range(len(dataset)):
150
150
  d = dataset[i]
151
151
  image = d[0] if isinstance(d, tuple) else d
152
152
  if per_box and isinstance(d, tuple) and isinstance(d[1], ObjectDetectionTarget):
153
- boxes = cast(Array, d[1].boxes)
153
+ boxes = d[1].boxes if isinstance(d[1].boxes, Array) else as_numpy(d[1].boxes)
154
154
  target = [BoundingBox(float(box[i]) for i in range(4)) for box in boxes]
155
155
  else:
156
156
  target = None
@@ -159,7 +159,7 @@ def _enumerate(dataset: Dataset[Array] | Dataset[tuple[Array, Any, Any]], per_bo
159
159
 
160
160
 
161
161
  def run_stats(
162
- dataset: Dataset[Array] | Dataset[tuple[Array, Any, Any]],
162
+ dataset: Dataset[ArrayLike] | Dataset[tuple[ArrayLike, Any, Any]],
163
163
  per_box: bool,
164
164
  per_channel: bool,
165
165
  stats_processor_cls: Iterable[type[StatsProcessor[TStatsOutput]]],
@@ -173,7 +173,7 @@ def run_stats(
173
173
 
174
174
  Parameters
175
175
  ----------
176
- data : Dataset[Array] | Dataset[tuple[Array, Any, Any]]
176
+ data : Dataset[ArrayLike] | Dataset[tuple[ArrayLike, Any, Any]]
177
177
  A dataset of images and targets to compute statistics on.
178
178
  per_box : bool
179
179
  A flag which determines if the statistics should be evaluated on a per-box basis or not.
@@ -9,7 +9,7 @@ import numpy as np
9
9
  from dataeval.metrics.stats._base import StatsProcessor, run_stats
10
10
  from dataeval.outputs import DimensionStatsOutput
11
11
  from dataeval.outputs._base import set_metadata
12
- from dataeval.typing import Array, Dataset
12
+ from dataeval.typing import ArrayLike, Dataset
13
13
  from dataeval.utils._image import get_bitdepth
14
14
 
15
15
 
@@ -34,7 +34,7 @@ class DimensionStatsProcessor(StatsProcessor[DimensionStatsOutput]):
34
34
 
35
35
  @set_metadata
36
36
  def dimensionstats(
37
- dataset: Dataset[Array] | Dataset[tuple[Array, Any, Any]],
37
+ dataset: Dataset[ArrayLike] | Dataset[tuple[ArrayLike, Any, Any]],
38
38
  *,
39
39
  per_box: bool = False,
40
40
  ) -> DimensionStatsOutput:
@@ -14,7 +14,7 @@ from scipy.fftpack import dct
14
14
  from dataeval.metrics.stats._base import StatsProcessor, run_stats
15
15
  from dataeval.outputs import HashStatsOutput
16
16
  from dataeval.outputs._base import set_metadata
17
- from dataeval.typing import Array, ArrayLike, Dataset
17
+ from dataeval.typing import ArrayLike, Dataset
18
18
  from dataeval.utils._array import as_numpy
19
19
  from dataeval.utils._image import normalize_image_shape, rescale
20
20
 
@@ -105,7 +105,7 @@ class HashStatsProcessor(StatsProcessor[HashStatsOutput]):
105
105
 
106
106
  @set_metadata
107
107
  def hashstats(
108
- dataset: Dataset[Array] | Dataset[tuple[Array, Any, Any]],
108
+ dataset: Dataset[ArrayLike] | Dataset[tuple[ArrayLike, Any, Any]],
109
109
  *,
110
110
  per_box: bool = False,
111
111
  ) -> HashStatsOutput:
@@ -10,12 +10,12 @@ from dataeval.metrics.stats._pixelstats import PixelStatsProcessor
10
10
  from dataeval.metrics.stats._visualstats import VisualStatsProcessor
11
11
  from dataeval.outputs import ChannelStatsOutput, ImageStatsOutput
12
12
  from dataeval.outputs._base import set_metadata
13
- from dataeval.typing import Array, Dataset
13
+ from dataeval.typing import ArrayLike, Dataset
14
14
 
15
15
 
16
16
  @overload
17
17
  def imagestats(
18
- dataset: Dataset[Array] | Dataset[tuple[Array, Any, Any]],
18
+ dataset: Dataset[ArrayLike] | Dataset[tuple[ArrayLike, Any, Any]],
19
19
  *,
20
20
  per_box: bool = False,
21
21
  per_channel: Literal[True],
@@ -24,7 +24,7 @@ def imagestats(
24
24
 
25
25
  @overload
26
26
  def imagestats(
27
- dataset: Dataset[Array] | Dataset[tuple[Array, Any, Any]],
27
+ dataset: Dataset[ArrayLike] | Dataset[tuple[ArrayLike, Any, Any]],
28
28
  *,
29
29
  per_box: bool = False,
30
30
  per_channel: Literal[False] = False,
@@ -33,7 +33,7 @@ def imagestats(
33
33
 
34
34
  @set_metadata
35
35
  def imagestats(
36
- dataset: Dataset[Array] | Dataset[tuple[Array, Any, Any]],
36
+ dataset: Dataset[ArrayLike] | Dataset[tuple[ArrayLike, Any, Any]],
37
37
  *,
38
38
  per_box: bool = False,
39
39
  per_channel: bool = False,
@@ -10,7 +10,7 @@ from scipy.stats import entropy, kurtosis, skew
10
10
  from dataeval.metrics.stats._base import StatsProcessor, run_stats
11
11
  from dataeval.outputs import PixelStatsOutput
12
12
  from dataeval.outputs._base import set_metadata
13
- from dataeval.typing import Array, Dataset
13
+ from dataeval.typing import ArrayLike, Dataset
14
14
 
15
15
 
16
16
  class PixelStatsProcessor(StatsProcessor[PixelStatsOutput]):
@@ -37,7 +37,7 @@ class PixelStatsProcessor(StatsProcessor[PixelStatsOutput]):
37
37
 
38
38
  @set_metadata
39
39
  def pixelstats(
40
- dataset: Dataset[Array] | Dataset[tuple[Array, Any, Any]],
40
+ dataset: Dataset[ArrayLike] | Dataset[tuple[ArrayLike, Any, Any]],
41
41
  *,
42
42
  per_box: bool = False,
43
43
  per_channel: bool = False,
@@ -9,7 +9,7 @@ import numpy as np
9
9
  from dataeval.metrics.stats._base import StatsProcessor, run_stats
10
10
  from dataeval.outputs import VisualStatsOutput
11
11
  from dataeval.outputs._base import set_metadata
12
- from dataeval.typing import Array, Dataset
12
+ from dataeval.typing import ArrayLike, Dataset
13
13
  from dataeval.utils._image import edge_filter
14
14
 
15
15
  QUARTILES = (0, 25, 50, 75, 100)
@@ -44,7 +44,7 @@ class VisualStatsProcessor(StatsProcessor[VisualStatsOutput]):
44
44
 
45
45
  @set_metadata
46
46
  def visualstats(
47
- dataset: Dataset[Array] | Dataset[tuple[Array, Any, Any]],
47
+ dataset: Dataset[ArrayLike] | Dataset[tuple[ArrayLike, Any, Any]],
48
48
  *,
49
49
  per_box: bool = False,
50
50
  per_channel: bool = False,
dataeval/typing.py CHANGED
@@ -21,8 +21,9 @@ __all__ = [
21
21
 
22
22
 
23
23
  import sys
24
- from typing import Any, Generic, Iterator, Protocol, Sequence, TypedDict, TypeVar, Union, runtime_checkable
24
+ from typing import Any, Generic, Iterator, Protocol, TypedDict, TypeVar, runtime_checkable
25
25
 
26
+ import numpy.typing
26
27
  from typing_extensions import NotRequired, ReadOnly, Required
27
28
 
28
29
  if sys.version_info >= (3, 10):
@@ -31,6 +32,16 @@ else:
31
32
  from typing_extensions import TypeAlias
32
33
 
33
34
 
35
+ ArrayLike: TypeAlias = numpy.typing.ArrayLike
36
+ """
37
+ Type alias for a `Union` representing objects that can be coerced into an array.
38
+
39
+ See Also
40
+ --------
41
+ `NumPy ArrayLike <https://numpy.org/doc/stable/reference/typing.html#numpy.typing.ArrayLike>`_
42
+ """
43
+
44
+
34
45
  @runtime_checkable
35
46
  class Array(Protocol):
36
47
  """
@@ -67,16 +78,8 @@ class Array(Protocol):
67
78
  def __len__(self) -> int: ...
68
79
 
69
80
 
70
- T = TypeVar("T")
81
+ _T = TypeVar("_T")
71
82
  _T_co = TypeVar("_T_co", covariant=True)
72
- _ScalarType = Union[int, float, bool, str]
73
- ArrayLike: TypeAlias = Union[Sequence[_ScalarType], Sequence[Sequence[_ScalarType]], Sequence[Array], Array]
74
- """
75
- Type alias for array-like objects used for interoperability with DataEval.
76
-
77
- This includes native Python sequences, as well as objects that conform to
78
- the :class:`Array` protocol.
79
- """
80
83
 
81
84
 
82
85
  class DatasetMetadata(TypedDict, total=False):
@@ -140,12 +143,12 @@ class AnnotatedDataset(Dataset[_T_co], Generic[_T_co], Protocol):
140
143
  # ========== IMAGE CLASSIFICATION DATASETS ==========
141
144
 
142
145
 
143
- ImageClassificationDatum: TypeAlias = tuple[Array, Array, dict[str, Any]]
146
+ ImageClassificationDatum: TypeAlias = tuple[ArrayLike, ArrayLike, dict[str, Any]]
144
147
  """
145
148
  Type alias for an image classification datum tuple.
146
149
 
147
- - :class:`Array` of shape (C, H, W) - Image data in channel, height, width format.
148
- - :class:`Array` of shape (N,) - Class label as one-hot encoded ground-truth or prediction confidences.
150
+ - :class:`ArrayLike` of shape (C, H, W) - Image data in channel, height, width format.
151
+ - :class:`ArrayLike` of shape (N,) - Class label as one-hot encoded ground-truth or prediction confidences.
149
152
  - dict[str, Any] - Datum level metadata.
150
153
  """
151
154
 
@@ -180,11 +183,11 @@ class ObjectDetectionTarget(Protocol):
180
183
  def scores(self) -> ArrayLike: ...
181
184
 
182
185
 
183
- ObjectDetectionDatum: TypeAlias = tuple[Array, ObjectDetectionTarget, dict[str, Any]]
186
+ ObjectDetectionDatum: TypeAlias = tuple[ArrayLike, ObjectDetectionTarget, dict[str, Any]]
184
187
  """
185
188
  Type alias for an object detection datum tuple.
186
189
 
187
- - :class:`Array` of shape (C, H, W) - Image data in channel, height, width format.
190
+ - :class:`ArrayLike` of shape (C, H, W) - Image data in channel, height, width format.
188
191
  - :class:`ObjectDetectionTarget` - Object detection target information for the image.
189
192
  - dict[str, Any] - Datum level metadata.
190
193
  """
@@ -221,11 +224,11 @@ class SegmentationTarget(Protocol):
221
224
  def scores(self) -> ArrayLike: ...
222
225
 
223
226
 
224
- SegmentationDatum: TypeAlias = tuple[Array, SegmentationTarget, dict[str, Any]]
227
+ SegmentationDatum: TypeAlias = tuple[ArrayLike, SegmentationTarget, dict[str, Any]]
225
228
  """
226
229
  Type alias for an image classification datum tuple.
227
230
 
228
- - :class:`Array` of shape (C, H, W) - Image data in channel, height, width format.
231
+ - :class:`ArrayLike` of shape (C, H, W) - Image data in channel, height, width format.
229
232
  - :class:`SegmentationTarget` - Segmentation target information for the image.
230
233
  - dict[str, Any] - Datum level metadata.
231
234
  """
@@ -237,7 +240,7 @@ Type alias for an :class:`AnnotatedDataset` of :class:`SegmentationDatum` elemen
237
240
 
238
241
 
239
242
  @runtime_checkable
240
- class Transform(Generic[T], Protocol):
243
+ class Transform(Generic[_T], Protocol):
241
244
  """
242
245
  Protocol defining a transform function.
243
246
 
@@ -262,4 +265,4 @@ class Transform(Generic[T], Protocol):
262
265
  array([0.004, 0.008, 0.012])
263
266
  """
264
267
 
265
- def __call__(self, data: T, /) -> T: ...
268
+ def __call__(self, data: _T, /) -> _T: ...
dataeval/utils/_array.py CHANGED
@@ -92,7 +92,7 @@ def ensure_embeddings(
92
92
  @overload
93
93
  def ensure_embeddings(
94
94
  embeddings: T,
95
- dtype: None,
95
+ dtype: None = None,
96
96
  unit_interval: Literal[True, False, "force"] = False,
97
97
  ) -> T: ...
98
98
 
@@ -152,21 +152,32 @@ def ensure_embeddings(
152
152
  return arr
153
153
 
154
154
 
155
- def flatten(array: ArrayLike) -> NDArray[Any]:
155
+ @overload
156
+ def flatten(array: torch.Tensor) -> torch.Tensor: ...
157
+ @overload
158
+ def flatten(array: ArrayLike) -> NDArray[Any]: ...
159
+
160
+
161
+ def flatten(array: ArrayLike) -> NDArray[Any] | torch.Tensor:
156
162
  """
157
163
  Flattens input array from (N, ... ) to (N, -1) where all samples N have all data in their last dimension
158
164
 
159
165
  Parameters
160
166
  ----------
161
- X : NDArray, shape - (N, ... )
167
+ array : ArrayLike
162
168
  Input array
163
169
 
164
170
  Returns
165
171
  -------
166
- NDArray, shape - (N, -1)
172
+ np.ndarray or torch.Tensor, shape: (N, -1)
167
173
  """
168
- nparr = as_numpy(array)
169
- return nparr.reshape((nparr.shape[0], -1))
174
+ if isinstance(array, np.ndarray):
175
+ nparr = as_numpy(array)
176
+ return nparr.reshape((nparr.shape[0], -1))
177
+ elif isinstance(array, torch.Tensor):
178
+ return torch.flatten(array, start_dim=1)
179
+ else:
180
+ raise TypeError(f"Unsupported array type {type(array)}.")
170
181
 
171
182
 
172
183
  _TArray = TypeVar("_TArray", bound=Array)
@@ -191,4 +202,4 @@ def channels_first_to_last(array: _TArray) -> _TArray:
191
202
  elif isinstance(array, torch.Tensor):
192
203
  return torch.permute(array, (1, 2, 0))
193
204
  else:
194
- raise TypeError(f"Unsupported array type {type(array)} for conversion.")
205
+ raise TypeError(f"Unsupported array type {type(array)}.")
@@ -52,10 +52,12 @@ def _validate_data(
52
52
 
53
53
 
54
54
  def _find_max(arr: ArrayLike) -> Any:
55
- if isinstance(arr[0], (Iterable, Sequence, Array)):
56
- return max([_find_max(x) for x in arr]) # type: ignore
57
- else:
58
- return max(arr)
55
+ if isinstance(arr, (Iterable, Sequence, Array)):
56
+ if isinstance(arr[0], (Iterable, Sequence, Array)):
57
+ return max([_find_max(x) for x in arr]) # type: ignore
58
+ else:
59
+ return max(arr)
60
+ return arr
59
61
 
60
62
 
61
63
  _TLabels = TypeVar("_TLabels", Sequence[int], Sequence[Sequence[int]])
@@ -6,11 +6,13 @@ import math
6
6
  from typing import Any, Iterator, Sequence, cast
7
7
 
8
8
  import torch
9
+ from numpy.typing import NDArray
9
10
  from torch.utils.data import DataLoader, Subset
10
11
  from tqdm import tqdm
11
12
 
12
13
  from dataeval.config import DeviceLike, get_device
13
- from dataeval.typing import Array, Dataset, Transform
14
+ from dataeval.typing import Array, ArrayLike, Dataset, Transform
15
+ from dataeval.utils._array import as_numpy
14
16
  from dataeval.utils.torch.models import SupportsEncode
15
17
 
16
18
 
@@ -45,7 +47,7 @@ class Embeddings:
45
47
 
46
48
  def __init__(
47
49
  self,
48
- dataset: Dataset[tuple[Array, Any, Any]],
50
+ dataset: Dataset[tuple[ArrayLike, Any, Any]] | Dataset[ArrayLike],
49
51
  batch_size: int,
50
52
  transforms: Transform[torch.Tensor] | Sequence[Transform[torch.Tensor]] | None = None,
51
53
  model: torch.nn.Module | None = None,
@@ -62,9 +64,9 @@ class Embeddings:
62
64
  self._length = len(dataset)
63
65
  model = torch.nn.Flatten() if model is None else model
64
66
  self._transforms = [transforms] if isinstance(transforms, Transform) else transforms
65
- self._model = model.to(self.device).eval()
67
+ self._model = model.to(self.device).eval() if isinstance(model, torch.nn.Module) else model
66
68
  self._encoder = model.encode if isinstance(model, SupportsEncode) else model
67
- self._collate_fn = lambda datum: [torch.as_tensor(i) for i, _, _ in datum]
69
+ self._collate_fn = lambda datum: [torch.as_tensor(d[0] if isinstance(d, tuple) else d) for d in datum]
68
70
  self._cached_idx = set()
69
71
  self._embeddings: torch.Tensor = torch.empty(())
70
72
  self._shallow: bool = False
@@ -91,14 +93,50 @@ class Embeddings:
91
93
  else:
92
94
  return self[:]
93
95
 
96
+ def to_numpy(self, indices: Sequence[int] | None = None) -> NDArray[Any]:
97
+ """
98
+ Converts dataset to embeddings as numpy array.
99
+
100
+ Parameters
101
+ ----------
102
+ indices : Sequence[int] or None, default None
103
+ The indices to convert to embeddings
104
+
105
+ Returns
106
+ -------
107
+ NDArray[Any]
108
+
109
+ Warning
110
+ -------
111
+ Processing large quantities of data can be resource intensive.
112
+ """
113
+ return self.to_tensor(indices).cpu().numpy()
114
+
115
+ def new(self, dataset: Dataset[tuple[ArrayLike, Any, Any]] | Dataset[ArrayLike]) -> Embeddings:
116
+ """
117
+ Creates a new Embeddings object with the same parameters but a different dataset.
118
+
119
+ Parameters
120
+ ----------
121
+ dataset : ImageClassificationDataset or ObjectDetectionDataset
122
+ Dataset to access original images from.
123
+
124
+ Returns
125
+ -------
126
+ Embeddings
127
+ """
128
+ return Embeddings(
129
+ dataset, self.batch_size, self._transforms, self._model, self.device, self.cache, self.verbose
130
+ )
131
+
94
132
  @classmethod
95
- def from_array(cls, array: Array, device: DeviceLike | None = None) -> Embeddings:
133
+ def from_array(cls, array: ArrayLike, device: DeviceLike | None = None) -> Embeddings:
96
134
  """
97
135
  Instantiates a shallow Embeddings object using an array.
98
136
 
99
137
  Parameters
100
138
  ----------
101
- array : Array
139
+ array : ArrayLike
102
140
  The array to convert to embeddings.
103
141
  device : DeviceLike or None, default None
104
142
  The hardware device to use if specified, otherwise uses the DataEval
@@ -118,6 +156,7 @@ class Embeddings:
118
156
  torch.Size([100, 3, 224, 224])
119
157
  """
120
158
  embeddings = Embeddings([], 0, None, None, device, True, False)
159
+ array = array if isinstance(array, Array) else as_numpy(array)
121
160
  embeddings._length = len(array)
122
161
  embeddings._cached_idx = set(range(len(array)))
123
162
  embeddings._embeddings = torch.as_tensor(array).to(get_device(device))
@@ -131,7 +170,7 @@ class Embeddings:
131
170
 
132
171
  @torch.no_grad() # Reduce overhead cost by not tracking tensor gradients
133
172
  def _batch(self, indices: Sequence[int]) -> Iterator[torch.Tensor]:
134
- dataset = cast(torch.utils.data.Dataset[tuple[Array, Any, Any]], self._dataset)
173
+ dataset = cast(torch.utils.data.Dataset, self._dataset)
135
174
  total_batches = math.ceil(len(indices) / self.batch_size)
136
175
 
137
176
  # If not caching, process all indices normally
@@ -4,13 +4,13 @@ __all__ = []
4
4
 
5
5
  from typing import TYPE_CHECKING, Any, Generic, Iterator, Sequence, TypeVar, cast, overload
6
6
 
7
- from dataeval.typing import Array, Dataset
7
+ from dataeval.typing import Array, ArrayLike, Dataset
8
8
  from dataeval.utils._array import as_numpy, channels_first_to_last
9
9
 
10
10
  if TYPE_CHECKING:
11
11
  from matplotlib.figure import Figure
12
12
 
13
- T = TypeVar("T", bound=Array)
13
+ T = TypeVar("T", Array, ArrayLike)
14
14
 
15
15
 
16
16
  class Images(Generic[T]):
@@ -3,7 +3,7 @@ from __future__ import annotations
3
3
  __all__ = []
4
4
 
5
5
  import warnings
6
- from typing import TYPE_CHECKING, Any, Literal, Mapping, Sequence, cast
6
+ from typing import TYPE_CHECKING, Any, Literal, Mapping, Sequence, Sized, cast
7
7
 
8
8
  import numpy as np
9
9
  from numpy.typing import NDArray
@@ -208,8 +208,9 @@ class Metadata:
208
208
  raw.append(metadata)
209
209
 
210
210
  if is_od_target := isinstance(target, ObjectDetectionTarget):
211
- target_len = len(target.labels)
212
- labels.extend(as_numpy(target.labels).tolist())
211
+ target_labels = as_numpy(target.labels)
212
+ target_len = len(target_labels)
213
+ labels.extend(target_labels.tolist())
213
214
  bboxes.extend(as_numpy(target.boxes).tolist())
214
215
  scores.extend(as_numpy(target.scores).tolist())
215
216
  srcidx.extend([i] * target_len)
@@ -360,7 +361,7 @@ class Metadata:
360
361
  self._merge()
361
362
  self._processed = False
362
363
  target_len = len(self.targets.source) if self.targets.source is not None else len(self.targets)
363
- if any(len(v) != target_len for v in factors.values()):
364
+ if any(len(v if isinstance(v, Sized) else as_numpy(v)) != target_len for v in factors.values()):
364
365
  raise ValueError(
365
366
  "The lists/arrays in the provided factors have a different length than the current metadata factors."
366
367
  )
@@ -19,9 +19,12 @@ from dataeval.utils.data.datasets._types import (
19
19
  )
20
20
 
21
21
  if TYPE_CHECKING:
22
- from dataeval.typing import Transform
22
+ from dataeval.typing import Array, Transform
23
+
24
+ _TArray = TypeVar("_TArray", bound=Array)
25
+ else:
26
+ _TArray = TypeVar("_TArray")
23
27
 
24
- _TArray = TypeVar("_TArray")
25
28
  _TTarget = TypeVar("_TTarget")
26
29
  _TRawTarget = TypeVar("_TRawTarget", list[int], list[str])
27
30
 
@@ -51,9 +54,9 @@ class BaseDataset(AnnotatedDataset[tuple[_TArray, _TTarget, dict[str, Any]]], Ge
51
54
  def __init__(
52
55
  self,
53
56
  root: str | Path,
54
- download: bool = False,
55
- image_set: Literal["train", "val", "test", "base"] = "train",
57
+ image_set: Literal["train", "val", "test", "operational", "base"] = "train",
56
58
  transforms: Transform[_TArray] | Sequence[Transform[_TArray]] | None = None,
59
+ download: bool = False,
57
60
  verbose: bool = False,
58
61
  ) -> None:
59
62
  self._root: Path = root.absolute() if isinstance(root, Path) else Path(root).absolute()
@@ -27,13 +27,13 @@ class CIFAR10(BaseICDataset[NDArray[Any]], BaseDatasetNumpyMixin):
27
27
  ----------
28
28
  root : str or pathlib.Path
29
29
  Root directory of dataset where the ``mnist`` folder exists.
30
- download : bool, default False
31
- If True, downloads the dataset from the internet and puts it in root directory.
32
- Class checks to see if data is already downloaded to ensure it does not create a duplicate download.
33
30
  image_set : "train", "test" or "base", default "train"
34
31
  If "base", returns all of the data to allow the user to create their own splits.
35
32
  transforms : Transform, Sequence[Transform] or None, default None
36
33
  Transform(s) to apply to the data.
34
+ download : bool, default False
35
+ If True, downloads the dataset from the internet and puts it in root directory.
36
+ Class checks to see if data is already downloaded to ensure it does not create a duplicate download.
37
37
  verbose : bool, default False
38
38
  If True, outputs print statements.
39
39
 
@@ -43,16 +43,16 @@ class CIFAR10(BaseICDataset[NDArray[Any]], BaseDatasetNumpyMixin):
43
43
  Location of the folder containing the data.
44
44
  image_set : "train", "test" or "base"
45
45
  The selected image set from the dataset.
46
+ transforms : Sequence[Transform]
47
+ The transforms to be applied to the data.
48
+ size : int
49
+ The size of the dataset.
46
50
  index2label : dict[int, str]
47
51
  Dictionary which translates from class integers to the associated class strings.
48
52
  label2index : dict[str, int]
49
53
  Dictionary which translates from class strings to the associated class integers.
50
54
  metadata : DatasetMetadata
51
55
  Typed dictionary containing dataset metadata, such as `id` which returns the dataset class name.
52
- transforms : Sequence[Transform]
53
- The transforms to be applied to the data.
54
- size : int
55
- The size of the dataset.
56
56
  """
57
57
 
58
58
  _resources = [
@@ -80,16 +80,16 @@ class CIFAR10(BaseICDataset[NDArray[Any]], BaseDatasetNumpyMixin):
80
80
  def __init__(
81
81
  self,
82
82
  root: str | Path,
83
- download: bool = False,
84
83
  image_set: Literal["train", "test", "base"] = "train",
85
84
  transforms: Transform[NDArray[Any]] | Sequence[Transform[NDArray[Any]]] | None = None,
85
+ download: bool = False,
86
86
  verbose: bool = False,
87
87
  ) -> None:
88
88
  super().__init__(
89
89
  root,
90
- download,
91
90
  image_set,
92
91
  transforms,
92
+ download,
93
93
  verbose,
94
94
  )
95
95