dataeval 0.83.0__py3-none-any.whl → 0.84.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dataeval/__init__.py +1 -1
- dataeval/config.py +3 -3
- dataeval/detectors/drift/__init__.py +2 -2
- dataeval/detectors/drift/_base.py +55 -203
- dataeval/detectors/drift/_cvm.py +19 -30
- dataeval/detectors/drift/_ks.py +18 -30
- dataeval/detectors/drift/_mmd.py +189 -53
- dataeval/detectors/drift/_uncertainty.py +52 -56
- dataeval/detectors/drift/updates.py +13 -12
- dataeval/detectors/linters/duplicates.py +5 -3
- dataeval/detectors/linters/outliers.py +2 -2
- dataeval/detectors/ood/ae.py +1 -1
- dataeval/metrics/bias/__init__.py +11 -1
- dataeval/metrics/bias/_completeness.py +130 -0
- dataeval/metrics/stats/_base.py +28 -32
- dataeval/metrics/stats/_dimensionstats.py +2 -2
- dataeval/metrics/stats/_hashstats.py +2 -2
- dataeval/metrics/stats/_imagestats.py +4 -4
- dataeval/metrics/stats/_labelstats.py +4 -45
- dataeval/metrics/stats/_pixelstats.py +2 -2
- dataeval/metrics/stats/_visualstats.py +2 -2
- dataeval/outputs/__init__.py +2 -1
- dataeval/outputs/_bias.py +31 -22
- dataeval/outputs/_stats.py +2 -3
- dataeval/typing.py +25 -22
- dataeval/utils/_array.py +43 -7
- dataeval/utils/data/_dataset.py +8 -4
- dataeval/utils/data/_embeddings.py +141 -24
- dataeval/utils/data/_images.py +38 -15
- dataeval/utils/data/_metadata.py +5 -4
- dataeval/utils/data/_selection.py +3 -15
- dataeval/utils/data/_split.py +76 -129
- dataeval/utils/data/datasets/_base.py +7 -4
- dataeval/utils/data/datasets/_cifar10.py +9 -9
- dataeval/utils/data/datasets/_milco.py +42 -14
- dataeval/utils/data/datasets/_mnist.py +9 -5
- dataeval/utils/data/datasets/_ships.py +8 -4
- dataeval/utils/data/datasets/_voc.py +40 -19
- dataeval/utils/data/selections/__init__.py +2 -0
- dataeval/utils/data/selections/_classbalance.py +38 -0
- dataeval/utils/data/selections/_classfilter.py +14 -29
- dataeval/utils/data/selections/_prioritize.py +1 -1
- dataeval/utils/data/selections/_shuffle.py +2 -2
- dataeval/utils/metadata.py +1 -1
- dataeval/utils/torch/_internal.py +12 -35
- {dataeval-0.83.0.dist-info → dataeval-0.84.1.dist-info}/METADATA +2 -3
- {dataeval-0.83.0.dist-info → dataeval-0.84.1.dist-info}/RECORD +49 -48
- dataeval/detectors/drift/_torch.py +0 -222
- {dataeval-0.83.0.dist-info → dataeval-0.84.1.dist-info}/LICENSE.txt +0 -0
- {dataeval-0.83.0.dist-info → dataeval-0.84.1.dist-info}/WHEEL +0 -0
dataeval/typing.py
CHANGED
@@ -21,9 +21,10 @@ __all__ = [
|
|
21
21
|
|
22
22
|
|
23
23
|
import sys
|
24
|
-
from typing import Any, Generic, Iterator, Protocol,
|
24
|
+
from typing import Any, Generic, Iterator, Protocol, TypedDict, TypeVar, runtime_checkable
|
25
25
|
|
26
|
-
|
26
|
+
import numpy.typing
|
27
|
+
from typing_extensions import NotRequired, ReadOnly, Required
|
27
28
|
|
28
29
|
if sys.version_info >= (3, 10):
|
29
30
|
from typing import TypeAlias
|
@@ -31,6 +32,16 @@ else:
|
|
31
32
|
from typing_extensions import TypeAlias
|
32
33
|
|
33
34
|
|
35
|
+
ArrayLike: TypeAlias = numpy.typing.ArrayLike
|
36
|
+
"""
|
37
|
+
Type alias for a `Union` representing objects that can be coerced into an array.
|
38
|
+
|
39
|
+
See Also
|
40
|
+
--------
|
41
|
+
`NumPy ArrayLike <https://numpy.org/doc/stable/reference/typing.html#numpy.typing.ArrayLike>`_
|
42
|
+
"""
|
43
|
+
|
44
|
+
|
34
45
|
@runtime_checkable
|
35
46
|
class Array(Protocol):
|
36
47
|
"""
|
@@ -67,16 +78,8 @@ class Array(Protocol):
|
|
67
78
|
def __len__(self) -> int: ...
|
68
79
|
|
69
80
|
|
70
|
-
|
81
|
+
_T = TypeVar("_T")
|
71
82
|
_T_co = TypeVar("_T_co", covariant=True)
|
72
|
-
_ScalarType = Union[int, float, bool, str]
|
73
|
-
ArrayLike: TypeAlias = Union[Sequence[_ScalarType], Sequence[Sequence[_ScalarType]], Sequence[Array], Array]
|
74
|
-
"""
|
75
|
-
Type alias for array-like objects used for interoperability with DataEval.
|
76
|
-
|
77
|
-
This includes native Python sequences, as well as objects that conform to
|
78
|
-
the :class:`Array` protocol.
|
79
|
-
"""
|
80
83
|
|
81
84
|
|
82
85
|
class DatasetMetadata(TypedDict, total=False):
|
@@ -91,8 +94,8 @@ class DatasetMetadata(TypedDict, total=False):
|
|
91
94
|
A lookup table converting label value to class name
|
92
95
|
"""
|
93
96
|
|
94
|
-
id: Required[str]
|
95
|
-
index2label: NotRequired[dict[int, str]]
|
97
|
+
id: Required[ReadOnly[str]]
|
98
|
+
index2label: NotRequired[ReadOnly[dict[int, str]]]
|
96
99
|
|
97
100
|
|
98
101
|
@runtime_checkable
|
@@ -140,12 +143,12 @@ class AnnotatedDataset(Dataset[_T_co], Generic[_T_co], Protocol):
|
|
140
143
|
# ========== IMAGE CLASSIFICATION DATASETS ==========
|
141
144
|
|
142
145
|
|
143
|
-
ImageClassificationDatum: TypeAlias = tuple[
|
146
|
+
ImageClassificationDatum: TypeAlias = tuple[ArrayLike, ArrayLike, dict[str, Any]]
|
144
147
|
"""
|
145
148
|
Type alias for an image classification datum tuple.
|
146
149
|
|
147
|
-
- :class:`
|
148
|
-
- :class:`
|
150
|
+
- :class:`ArrayLike` of shape (C, H, W) - Image data in channel, height, width format.
|
151
|
+
- :class:`ArrayLike` of shape (N,) - Class label as one-hot encoded ground-truth or prediction confidences.
|
149
152
|
- dict[str, Any] - Datum level metadata.
|
150
153
|
"""
|
151
154
|
|
@@ -180,11 +183,11 @@ class ObjectDetectionTarget(Protocol):
|
|
180
183
|
def scores(self) -> ArrayLike: ...
|
181
184
|
|
182
185
|
|
183
|
-
ObjectDetectionDatum: TypeAlias = tuple[
|
186
|
+
ObjectDetectionDatum: TypeAlias = tuple[ArrayLike, ObjectDetectionTarget, dict[str, Any]]
|
184
187
|
"""
|
185
188
|
Type alias for an object detection datum tuple.
|
186
189
|
|
187
|
-
- :class:`
|
190
|
+
- :class:`ArrayLike` of shape (C, H, W) - Image data in channel, height, width format.
|
188
191
|
- :class:`ObjectDetectionTarget` - Object detection target information for the image.
|
189
192
|
- dict[str, Any] - Datum level metadata.
|
190
193
|
"""
|
@@ -221,11 +224,11 @@ class SegmentationTarget(Protocol):
|
|
221
224
|
def scores(self) -> ArrayLike: ...
|
222
225
|
|
223
226
|
|
224
|
-
SegmentationDatum: TypeAlias = tuple[
|
227
|
+
SegmentationDatum: TypeAlias = tuple[ArrayLike, SegmentationTarget, dict[str, Any]]
|
225
228
|
"""
|
226
229
|
Type alias for an image classification datum tuple.
|
227
230
|
|
228
|
-
- :class:`
|
231
|
+
- :class:`ArrayLike` of shape (C, H, W) - Image data in channel, height, width format.
|
229
232
|
- :class:`SegmentationTarget` - Segmentation target information for the image.
|
230
233
|
- dict[str, Any] - Datum level metadata.
|
231
234
|
"""
|
@@ -237,7 +240,7 @@ Type alias for an :class:`AnnotatedDataset` of :class:`SegmentationDatum` elemen
|
|
237
240
|
|
238
241
|
|
239
242
|
@runtime_checkable
|
240
|
-
class Transform(Generic[
|
243
|
+
class Transform(Generic[_T], Protocol):
|
241
244
|
"""
|
242
245
|
Protocol defining a transform function.
|
243
246
|
|
@@ -262,4 +265,4 @@ class Transform(Generic[T], Protocol):
|
|
262
265
|
array([0.004, 0.008, 0.012])
|
263
266
|
"""
|
264
267
|
|
265
|
-
def __call__(self, data:
|
268
|
+
def __call__(self, data: _T, /) -> _T: ...
|
dataeval/utils/_array.py
CHANGED
@@ -13,7 +13,7 @@ import torch
|
|
13
13
|
from numpy.typing import NDArray
|
14
14
|
|
15
15
|
from dataeval._log import LogMessage
|
16
|
-
from dataeval.typing import ArrayLike
|
16
|
+
from dataeval.typing import Array, ArrayLike
|
17
17
|
|
18
18
|
_logger = logging.getLogger(__name__)
|
19
19
|
|
@@ -92,7 +92,7 @@ def ensure_embeddings(
|
|
92
92
|
@overload
|
93
93
|
def ensure_embeddings(
|
94
94
|
embeddings: T,
|
95
|
-
dtype: None,
|
95
|
+
dtype: None = None,
|
96
96
|
unit_interval: Literal[True, False, "force"] = False,
|
97
97
|
) -> T: ...
|
98
98
|
|
@@ -152,18 +152,54 @@ def ensure_embeddings(
|
|
152
152
|
return arr
|
153
153
|
|
154
154
|
|
155
|
-
|
155
|
+
@overload
|
156
|
+
def flatten(array: torch.Tensor) -> torch.Tensor: ...
|
157
|
+
@overload
|
158
|
+
def flatten(array: ArrayLike) -> NDArray[Any]: ...
|
159
|
+
|
160
|
+
|
161
|
+
def flatten(array: ArrayLike) -> NDArray[Any] | torch.Tensor:
|
156
162
|
"""
|
157
163
|
Flattens input array from (N, ... ) to (N, -1) where all samples N have all data in their last dimension
|
158
164
|
|
159
165
|
Parameters
|
160
166
|
----------
|
161
|
-
|
167
|
+
array : ArrayLike
|
162
168
|
Input array
|
163
169
|
|
164
170
|
Returns
|
165
171
|
-------
|
166
|
-
|
172
|
+
np.ndarray or torch.Tensor, shape: (N, -1)
|
173
|
+
"""
|
174
|
+
if isinstance(array, np.ndarray):
|
175
|
+
nparr = as_numpy(array)
|
176
|
+
return nparr.reshape((nparr.shape[0], -1))
|
177
|
+
elif isinstance(array, torch.Tensor):
|
178
|
+
return torch.flatten(array, start_dim=1)
|
179
|
+
else:
|
180
|
+
raise TypeError(f"Unsupported array type {type(array)}.")
|
181
|
+
|
182
|
+
|
183
|
+
_TArray = TypeVar("_TArray", bound=Array)
|
184
|
+
|
185
|
+
|
186
|
+
def channels_first_to_last(array: _TArray) -> _TArray:
|
167
187
|
"""
|
168
|
-
|
169
|
-
|
188
|
+
Converts array from channels first to channels last format
|
189
|
+
|
190
|
+
Parameters
|
191
|
+
----------
|
192
|
+
array : ArrayLike
|
193
|
+
Input array
|
194
|
+
|
195
|
+
Returns
|
196
|
+
-------
|
197
|
+
ArrayLike
|
198
|
+
Converted array
|
199
|
+
"""
|
200
|
+
if isinstance(array, np.ndarray):
|
201
|
+
return np.transpose(array, (1, 2, 0))
|
202
|
+
elif isinstance(array, torch.Tensor):
|
203
|
+
return torch.permute(array, (1, 2, 0))
|
204
|
+
else:
|
205
|
+
raise TypeError(f"Unsupported array type {type(array)}.")
|
dataeval/utils/data/_dataset.py
CHANGED
@@ -47,13 +47,17 @@ def _validate_data(
|
|
47
47
|
or not len(bboxes[0][0]) == 4
|
48
48
|
):
|
49
49
|
raise TypeError("Boxes must be a sequence of sequences of (x0, y0, x1, y1) for object detection.")
|
50
|
+
else:
|
51
|
+
raise ValueError(f"Unknown datum type '{datum_type}'. Must be 'ic' or 'od'.")
|
50
52
|
|
51
53
|
|
52
54
|
def _find_max(arr: ArrayLike) -> Any:
|
53
|
-
if isinstance(arr
|
54
|
-
|
55
|
-
|
56
|
-
|
55
|
+
if isinstance(arr, (Iterable, Sequence, Array)):
|
56
|
+
if isinstance(arr[0], (Iterable, Sequence, Array)):
|
57
|
+
return max([_find_max(x) for x in arr]) # type: ignore
|
58
|
+
else:
|
59
|
+
return max(arr)
|
60
|
+
return arr
|
57
61
|
|
58
62
|
|
59
63
|
_TLabels = TypeVar("_TLabels", Sequence[int], Sequence[Sequence[int]])
|
@@ -3,14 +3,16 @@ from __future__ import annotations
|
|
3
3
|
__all__ = []
|
4
4
|
|
5
5
|
import math
|
6
|
-
from typing import Any, Iterator, Sequence
|
6
|
+
from typing import Any, Iterator, Sequence, cast
|
7
7
|
|
8
8
|
import torch
|
9
|
+
from numpy.typing import NDArray
|
9
10
|
from torch.utils.data import DataLoader, Subset
|
10
11
|
from tqdm import tqdm
|
11
12
|
|
12
13
|
from dataeval.config import DeviceLike, get_device
|
13
|
-
from dataeval.typing import Array, Dataset
|
14
|
+
from dataeval.typing import Array, ArrayLike, Dataset, Transform
|
15
|
+
from dataeval.utils._array import as_numpy
|
14
16
|
from dataeval.utils.torch.models import SupportsEncode
|
15
17
|
|
16
18
|
|
@@ -26,11 +28,15 @@ class Embeddings:
|
|
26
28
|
Dataset to access original images from.
|
27
29
|
batch_size : int
|
28
30
|
Batch size to use when encoding images.
|
31
|
+
transforms : Transform or Sequence[Transform] or None, default None
|
32
|
+
Transforms to apply to images before encoding.
|
29
33
|
model : torch.nn.Module or None, default None
|
30
34
|
Model to use for encoding images.
|
31
35
|
device : DeviceLike or None, default None
|
32
36
|
The hardware device to use if specified, otherwise uses the DataEval
|
33
37
|
default or torch default.
|
38
|
+
cache : bool, default False
|
39
|
+
Whether to cache the embeddings in memory.
|
34
40
|
verbose : bool, default False
|
35
41
|
Whether to print progress bar when encoding images.
|
36
42
|
"""
|
@@ -41,21 +47,29 @@ class Embeddings:
|
|
41
47
|
|
42
48
|
def __init__(
|
43
49
|
self,
|
44
|
-
dataset: Dataset[tuple[
|
50
|
+
dataset: Dataset[tuple[ArrayLike, Any, Any]] | Dataset[ArrayLike],
|
45
51
|
batch_size: int,
|
52
|
+
transforms: Transform[torch.Tensor] | Sequence[Transform[torch.Tensor]] | None = None,
|
46
53
|
model: torch.nn.Module | None = None,
|
47
54
|
device: DeviceLike | None = None,
|
55
|
+
cache: bool = False,
|
48
56
|
verbose: bool = False,
|
49
57
|
) -> None:
|
50
58
|
self.device = get_device(device)
|
51
|
-
self.
|
59
|
+
self.cache = cache
|
60
|
+
self.batch_size = batch_size if batch_size > 0 else 1
|
52
61
|
self.verbose = verbose
|
53
62
|
|
54
63
|
self._dataset = dataset
|
64
|
+
self._length = len(dataset)
|
55
65
|
model = torch.nn.Flatten() if model is None else model
|
56
|
-
self.
|
66
|
+
self._transforms = [transforms] if isinstance(transforms, Transform) else transforms
|
67
|
+
self._model = model.to(self.device).eval() if isinstance(model, torch.nn.Module) else model
|
57
68
|
self._encoder = model.encode if isinstance(model, SupportsEncode) else model
|
58
|
-
self._collate_fn = lambda datum: [torch.as_tensor(
|
69
|
+
self._collate_fn = lambda datum: [torch.as_tensor(d[0] if isinstance(d, tuple) else d) for d in datum]
|
70
|
+
self._cached_idx = set()
|
71
|
+
self._embeddings: torch.Tensor = torch.empty(())
|
72
|
+
self._shallow: bool = False
|
59
73
|
|
60
74
|
def to_tensor(self, indices: Sequence[int] | None = None) -> torch.Tensor:
|
61
75
|
"""
|
@@ -79,30 +93,133 @@ class Embeddings:
|
|
79
93
|
else:
|
80
94
|
return self[:]
|
81
95
|
|
82
|
-
|
83
|
-
|
96
|
+
def to_numpy(self, indices: Sequence[int] | None = None) -> NDArray[Any]:
|
97
|
+
"""
|
98
|
+
Converts dataset to embeddings as numpy array.
|
99
|
+
|
100
|
+
Parameters
|
101
|
+
----------
|
102
|
+
indices : Sequence[int] or None, default None
|
103
|
+
The indices to convert to embeddings
|
104
|
+
|
105
|
+
Returns
|
106
|
+
-------
|
107
|
+
NDArray[Any]
|
108
|
+
|
109
|
+
Warning
|
110
|
+
-------
|
111
|
+
Processing large quantities of data can be resource intensive.
|
112
|
+
"""
|
113
|
+
return self.to_tensor(indices).cpu().numpy()
|
114
|
+
|
115
|
+
def new(self, dataset: Dataset[tuple[ArrayLike, Any, Any]] | Dataset[ArrayLike]) -> Embeddings:
|
116
|
+
"""
|
117
|
+
Creates a new Embeddings object with the same parameters but a different dataset.
|
118
|
+
|
119
|
+
Parameters
|
120
|
+
----------
|
121
|
+
dataset : ImageClassificationDataset or ObjectDetectionDataset
|
122
|
+
Dataset to access original images from.
|
123
|
+
|
124
|
+
Returns
|
125
|
+
-------
|
126
|
+
Embeddings
|
127
|
+
"""
|
128
|
+
return Embeddings(
|
129
|
+
dataset, self.batch_size, self._transforms, self._model, self.device, self.cache, self.verbose
|
130
|
+
)
|
131
|
+
|
132
|
+
@classmethod
|
133
|
+
def from_array(cls, array: ArrayLike, device: DeviceLike | None = None) -> Embeddings:
|
134
|
+
"""
|
135
|
+
Instantiates a shallow Embeddings object using an array.
|
136
|
+
|
137
|
+
Parameters
|
138
|
+
----------
|
139
|
+
array : ArrayLike
|
140
|
+
The array to convert to embeddings.
|
141
|
+
device : DeviceLike or None, default None
|
142
|
+
The hardware device to use if specified, otherwise uses the DataEval
|
143
|
+
default or torch default.
|
144
|
+
|
145
|
+
Returns
|
146
|
+
-------
|
147
|
+
Embeddings
|
148
|
+
|
149
|
+
Example
|
150
|
+
-------
|
151
|
+
>>> import numpy as np
|
152
|
+
>>> from dataeval.utils.data._embeddings import Embeddings
|
153
|
+
>>> array = np.random.randn(100, 3, 224, 224)
|
154
|
+
>>> embeddings = Embeddings.from_array(array)
|
155
|
+
>>> print(embeddings.to_tensor().shape)
|
156
|
+
torch.Size([100, 3, 224, 224])
|
157
|
+
"""
|
158
|
+
embeddings = Embeddings([], 0, None, None, device, True, False)
|
159
|
+
array = array if isinstance(array, Array) else as_numpy(array)
|
160
|
+
embeddings._length = len(array)
|
161
|
+
embeddings._cached_idx = set(range(len(array)))
|
162
|
+
embeddings._embeddings = torch.as_tensor(array).to(get_device(device))
|
163
|
+
embeddings._shallow = True
|
164
|
+
return embeddings
|
165
|
+
|
166
|
+
def _encode(self, images: list[torch.Tensor]) -> torch.Tensor:
|
167
|
+
if self._transforms:
|
168
|
+
images = [transform(image) for transform in self._transforms for image in images]
|
169
|
+
return self._encoder(torch.stack(images).to(self.device))
|
170
|
+
|
171
|
+
@torch.no_grad() # Reduce overhead cost by not tracking tensor gradients
|
84
172
|
def _batch(self, indices: Sequence[int]) -> Iterator[torch.Tensor]:
|
85
|
-
|
86
|
-
|
87
|
-
|
88
|
-
|
89
|
-
|
90
|
-
|
91
|
-
|
92
|
-
|
93
|
-
|
173
|
+
dataset = cast(torch.utils.data.Dataset, self._dataset)
|
174
|
+
total_batches = math.ceil(len(indices) / self.batch_size)
|
175
|
+
|
176
|
+
# If not caching, process all indices normally
|
177
|
+
if not self.cache:
|
178
|
+
for images in tqdm(
|
179
|
+
DataLoader(Subset(dataset, indices), self.batch_size, collate_fn=self._collate_fn),
|
180
|
+
total=total_batches,
|
181
|
+
desc="Batch embedding",
|
182
|
+
disable=not self.verbose,
|
183
|
+
):
|
184
|
+
yield self._encode(images)
|
185
|
+
return
|
186
|
+
|
187
|
+
# If caching, process each batch of indices at a time, preserving original order
|
188
|
+
for i in tqdm(range(0, len(indices), self.batch_size), desc="Batch embedding", disable=not self.verbose):
|
189
|
+
batch = indices[i : i + self.batch_size]
|
190
|
+
uncached = [idx for idx in batch if idx not in self._cached_idx]
|
191
|
+
|
192
|
+
if uncached:
|
193
|
+
# Process uncached indices as as single batch
|
194
|
+
for images in DataLoader(Subset(dataset, uncached), len(uncached), collate_fn=self._collate_fn):
|
195
|
+
embeddings = self._encode(images)
|
196
|
+
|
197
|
+
if not self._embeddings.shape:
|
198
|
+
full_shape = (len(self._dataset), *embeddings.shape[1:])
|
199
|
+
self._embeddings = torch.empty(full_shape, dtype=embeddings.dtype, device=self.device)
|
200
|
+
|
201
|
+
self._embeddings[uncached] = embeddings
|
202
|
+
self._cached_idx.update(uncached)
|
203
|
+
|
204
|
+
yield self._embeddings[batch]
|
94
205
|
|
95
206
|
def __getitem__(self, key: int | slice, /) -> torch.Tensor:
|
96
|
-
if isinstance(key, slice):
|
97
|
-
|
98
|
-
|
99
|
-
|
100
|
-
|
207
|
+
if not isinstance(key, slice) and not hasattr(key, "__int__"):
|
208
|
+
raise TypeError("Invalid argument type.")
|
209
|
+
|
210
|
+
if self._shallow:
|
211
|
+
if not self._embeddings.shape:
|
212
|
+
raise ValueError("Embeddings not initialized.")
|
213
|
+
return self._embeddings[key]
|
214
|
+
|
215
|
+
indices = list(range(len(self._dataset))[key]) if isinstance(key, slice) else [int(key)]
|
216
|
+
result = torch.vstack(list(self._batch(indices))).to(self.device)
|
217
|
+
return result.squeeze(0) if len(indices) == 1 else result
|
101
218
|
|
102
219
|
def __iter__(self) -> Iterator[torch.Tensor]:
|
103
220
|
# process in batches while yielding individual embeddings
|
104
|
-
for batch in self._batch(range(
|
221
|
+
for batch in self._batch(range(self._length)):
|
105
222
|
yield from batch
|
106
223
|
|
107
224
|
def __len__(self) -> int:
|
108
|
-
return
|
225
|
+
return self._length
|
dataeval/utils/data/_images.py
CHANGED
@@ -2,11 +2,15 @@ from __future__ import annotations
|
|
2
2
|
|
3
3
|
__all__ = []
|
4
4
|
|
5
|
-
from typing import Any, Generic, Iterator, Sequence, TypeVar, cast, overload
|
5
|
+
from typing import TYPE_CHECKING, Any, Generic, Iterator, Sequence, TypeVar, cast, overload
|
6
6
|
|
7
|
-
from dataeval.typing import Dataset
|
7
|
+
from dataeval.typing import Array, ArrayLike, Dataset
|
8
|
+
from dataeval.utils._array import as_numpy, channels_first_to_last
|
8
9
|
|
9
|
-
|
10
|
+
if TYPE_CHECKING:
|
11
|
+
from matplotlib.figure import Figure
|
12
|
+
|
13
|
+
T = TypeVar("T", Array, ArrayLike)
|
10
14
|
|
11
15
|
|
12
16
|
class Images(Generic[T]):
|
@@ -21,7 +25,10 @@ class Images(Generic[T]):
|
|
21
25
|
Dataset to access images from.
|
22
26
|
"""
|
23
27
|
|
24
|
-
def __init__(
|
28
|
+
def __init__(
|
29
|
+
self,
|
30
|
+
dataset: Dataset[tuple[T, Any, Any] | T],
|
31
|
+
) -> None:
|
25
32
|
self._is_tuple_datum = isinstance(dataset[0], tuple)
|
26
33
|
self._dataset = dataset
|
27
34
|
|
@@ -40,25 +47,41 @@ class Images(Generic[T]):
|
|
40
47
|
"""
|
41
48
|
return self[:]
|
42
49
|
|
50
|
+
def plot(
|
51
|
+
self,
|
52
|
+
indices: Sequence[int],
|
53
|
+
images_per_row: int = 3,
|
54
|
+
figsize: tuple[int, int] = (10, 10),
|
55
|
+
) -> Figure:
|
56
|
+
import matplotlib.pyplot as plt
|
57
|
+
|
58
|
+
num_images = len(indices)
|
59
|
+
num_rows = (num_images + images_per_row - 1) // images_per_row
|
60
|
+
fig, axes = plt.subplots(num_rows, images_per_row, figsize=figsize)
|
61
|
+
for i, ax in enumerate(axes.flatten()):
|
62
|
+
image = channels_first_to_last(as_numpy(self[i]))
|
63
|
+
ax.imshow(image)
|
64
|
+
ax.axis("off")
|
65
|
+
plt.tight_layout()
|
66
|
+
return fig
|
67
|
+
|
43
68
|
@overload
|
44
69
|
def __getitem__(self, key: int, /) -> T: ...
|
45
70
|
@overload
|
46
71
|
def __getitem__(self, key: slice, /) -> Sequence[T]: ...
|
47
72
|
|
48
73
|
def __getitem__(self, key: int | slice, /) -> Sequence[T] | T:
|
74
|
+
if isinstance(key, slice):
|
75
|
+
return [self._get_image(k) for k in range(len(self._dataset))[key]]
|
76
|
+
elif hasattr(key, "__int__"):
|
77
|
+
return self._get_image(int(key))
|
78
|
+
raise TypeError(f"Key must be integers or slices, not {type(key)}")
|
79
|
+
|
80
|
+
def _get_image(self, index: int) -> T:
|
49
81
|
if self._is_tuple_datum:
|
50
|
-
|
51
|
-
if isinstance(key, slice):
|
52
|
-
return [dataset[k][0] for k in range(len(self._dataset))[key]]
|
53
|
-
elif isinstance(key, int):
|
54
|
-
return dataset[key][0]
|
82
|
+
return cast(Dataset[tuple[T, Any, Any]], self._dataset)[index][0]
|
55
83
|
else:
|
56
|
-
|
57
|
-
if isinstance(key, slice):
|
58
|
-
return [dataset[k] for k in range(len(self._dataset))[key]]
|
59
|
-
elif isinstance(key, int):
|
60
|
-
return dataset[key]
|
61
|
-
raise TypeError(f"Key must be integers or slices, not {type(key)}")
|
84
|
+
return cast(Dataset[T], self._dataset)[index]
|
62
85
|
|
63
86
|
def __iter__(self) -> Iterator[T]:
|
64
87
|
for i in range(len(self._dataset)):
|
dataeval/utils/data/_metadata.py
CHANGED
@@ -3,7 +3,7 @@ from __future__ import annotations
|
|
3
3
|
__all__ = []
|
4
4
|
|
5
5
|
import warnings
|
6
|
-
from typing import TYPE_CHECKING, Any, Literal, Mapping, Sequence, cast
|
6
|
+
from typing import TYPE_CHECKING, Any, Literal, Mapping, Sequence, Sized, cast
|
7
7
|
|
8
8
|
import numpy as np
|
9
9
|
from numpy.typing import NDArray
|
@@ -208,8 +208,9 @@ class Metadata:
|
|
208
208
|
raw.append(metadata)
|
209
209
|
|
210
210
|
if is_od_target := isinstance(target, ObjectDetectionTarget):
|
211
|
-
|
212
|
-
|
211
|
+
target_labels = as_numpy(target.labels)
|
212
|
+
target_len = len(target_labels)
|
213
|
+
labels.extend(target_labels.tolist())
|
213
214
|
bboxes.extend(as_numpy(target.boxes).tolist())
|
214
215
|
scores.extend(as_numpy(target.scores).tolist())
|
215
216
|
srcidx.extend([i] * target_len)
|
@@ -360,7 +361,7 @@ class Metadata:
|
|
360
361
|
self._merge()
|
361
362
|
self._processed = False
|
362
363
|
target_len = len(self.targets.source) if self.targets.source is not None else len(self.targets)
|
363
|
-
if any(len(v) != target_len for v in factors.values()):
|
364
|
+
if any(len(v if isinstance(v, Sized) else as_numpy(v)) != target_len for v in factors.values()):
|
364
365
|
raise ValueError(
|
365
366
|
"The lists/arrays in the provided factors have a different length than the current metadata factors."
|
366
367
|
)
|
@@ -5,7 +5,7 @@ __all__ = []
|
|
5
5
|
from enum import IntEnum
|
6
6
|
from typing import Generic, Iterator, Sequence, TypeVar
|
7
7
|
|
8
|
-
from dataeval.typing import AnnotatedDataset, DatasetMetadata
|
8
|
+
from dataeval.typing import AnnotatedDataset, DatasetMetadata
|
9
9
|
|
10
10
|
_TDatum = TypeVar("_TDatum")
|
11
11
|
|
@@ -35,8 +35,6 @@ class Select(AnnotatedDataset[_TDatum]):
|
|
35
35
|
The dataset to wrap.
|
36
36
|
selections : Selection or list[Selection], optional
|
37
37
|
The selection criteria to apply to the dataset.
|
38
|
-
transforms : Transform or list[Transform], optional
|
39
|
-
The transforms to apply to the dataset.
|
40
38
|
|
41
39
|
Examples
|
42
40
|
--------
|
@@ -70,16 +68,12 @@ class Select(AnnotatedDataset[_TDatum]):
|
|
70
68
|
self,
|
71
69
|
dataset: AnnotatedDataset[_TDatum],
|
72
70
|
selections: Selection[_TDatum] | Sequence[Selection[_TDatum]] | None = None,
|
73
|
-
transforms: Transform[_TDatum] | Sequence[Transform[_TDatum]] | None = None,
|
74
71
|
) -> None:
|
75
72
|
self.__dict__.update(dataset.__dict__)
|
76
73
|
self._dataset = dataset
|
77
74
|
self._size_limit = len(dataset)
|
78
75
|
self._selection = list(range(self._size_limit))
|
79
76
|
self._selections = self._sort(selections)
|
80
|
-
self._transforms = (
|
81
|
-
[] if transforms is None else [transforms] if isinstance(transforms, Transform) else transforms
|
82
|
-
)
|
83
77
|
|
84
78
|
# Ensure metadata is populated correctly as DatasetMetadata TypedDict
|
85
79
|
_metadata = getattr(dataset, "metadata", {})
|
@@ -98,8 +92,7 @@ class Select(AnnotatedDataset[_TDatum]):
|
|
98
92
|
title = f"{self.__class__.__name__} Dataset"
|
99
93
|
sep = "-" * len(title)
|
100
94
|
selections = f"Selections: [{', '.join([str(s) for s in self._selections])}]"
|
101
|
-
|
102
|
-
return f"{title}\n{sep}{nt}{selections}{nt}{transforms}{nt}Selected Size: {len(self)}\n\n{self._dataset}"
|
95
|
+
return f"{title}\n{sep}{nt}{selections}{nt}Selected Size: {len(self)}\n\n{self._dataset}"
|
103
96
|
|
104
97
|
def _sort(self, selections: Selection[_TDatum] | Sequence[Selection[_TDatum]] | None) -> list[Selection]:
|
105
98
|
if not selections:
|
@@ -117,13 +110,8 @@ class Select(AnnotatedDataset[_TDatum]):
|
|
117
110
|
selection(self)
|
118
111
|
self._selection = self._selection[: self._size_limit]
|
119
112
|
|
120
|
-
def _transform(self, datum: _TDatum) -> _TDatum:
|
121
|
-
for t in self._transforms:
|
122
|
-
datum = t(datum)
|
123
|
-
return datum
|
124
|
-
|
125
113
|
def __getitem__(self, index: int) -> _TDatum:
|
126
|
-
return self.
|
114
|
+
return self._dataset[self._selection[index]]
|
127
115
|
|
128
116
|
def __iter__(self) -> Iterator[_TDatum]:
|
129
117
|
for i in range(len(self)):
|