dataeval 0.84.0__py3-none-any.whl → 0.84.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dataeval/__init__.py +1 -1
- dataeval/detectors/drift/__init__.py +2 -2
- dataeval/detectors/drift/_base.py +55 -203
- dataeval/detectors/drift/_cvm.py +19 -30
- dataeval/detectors/drift/_ks.py +18 -30
- dataeval/detectors/drift/_mmd.py +189 -53
- dataeval/detectors/drift/_uncertainty.py +52 -56
- dataeval/detectors/drift/updates.py +13 -12
- dataeval/detectors/linters/duplicates.py +5 -3
- dataeval/detectors/linters/outliers.py +2 -2
- dataeval/detectors/ood/ae.py +1 -1
- dataeval/metrics/stats/_base.py +7 -7
- dataeval/metrics/stats/_dimensionstats.py +2 -2
- dataeval/metrics/stats/_hashstats.py +2 -2
- dataeval/metrics/stats/_imagestats.py +4 -4
- dataeval/metrics/stats/_pixelstats.py +2 -2
- dataeval/metrics/stats/_visualstats.py +2 -2
- dataeval/typing.py +22 -19
- dataeval/utils/_array.py +18 -7
- dataeval/utils/data/_dataset.py +6 -4
- dataeval/utils/data/_embeddings.py +46 -7
- dataeval/utils/data/_images.py +2 -2
- dataeval/utils/data/_metadata.py +5 -4
- dataeval/utils/data/datasets/_base.py +7 -4
- dataeval/utils/data/datasets/_cifar10.py +9 -9
- dataeval/utils/data/datasets/_milco.py +42 -14
- dataeval/utils/data/datasets/_mnist.py +9 -5
- dataeval/utils/data/datasets/_ships.py +8 -4
- dataeval/utils/data/datasets/_voc.py +40 -19
- dataeval/utils/data/selections/__init__.py +2 -0
- dataeval/utils/data/selections/_classbalance.py +38 -0
- dataeval/utils/data/selections/_classfilter.py +14 -29
- dataeval/utils/data/selections/_prioritize.py +1 -1
- dataeval/utils/data/selections/_shuffle.py +2 -2
- dataeval/utils/torch/_internal.py +12 -35
- {dataeval-0.84.0.dist-info → dataeval-0.84.1.dist-info}/METADATA +2 -3
- {dataeval-0.84.0.dist-info → dataeval-0.84.1.dist-info}/RECORD +39 -39
- dataeval/detectors/drift/_torch.py +0 -222
- {dataeval-0.84.0.dist-info → dataeval-0.84.1.dist-info}/LICENSE.txt +0 -0
- {dataeval-0.84.0.dist-info → dataeval-0.84.1.dist-info}/WHEEL +0 -0
dataeval/metrics/stats/_base.py
CHANGED
@@ -10,7 +10,7 @@ from copy import deepcopy
|
|
10
10
|
from dataclasses import dataclass
|
11
11
|
from functools import partial
|
12
12
|
from multiprocessing import Pool
|
13
|
-
from typing import Any, Callable, Generic, Iterable, Sequence, TypeVar
|
13
|
+
from typing import Any, Callable, Generic, Iterable, Sequence, TypeVar
|
14
14
|
|
15
15
|
import numpy as np
|
16
16
|
import tqdm
|
@@ -19,7 +19,7 @@ from numpy.typing import NDArray
|
|
19
19
|
from dataeval.config import get_max_processes
|
20
20
|
from dataeval.outputs._stats import BaseStatsOutput, SourceIndex
|
21
21
|
from dataeval.typing import Array, ArrayLike, Dataset, ObjectDetectionTarget
|
22
|
-
from dataeval.utils._array import to_numpy
|
22
|
+
from dataeval.utils._array import as_numpy, to_numpy
|
23
23
|
from dataeval.utils._image import normalize_image_shape, rescale
|
24
24
|
|
25
25
|
DTYPE_REGEX = re.compile(r"NDArray\[np\.(.*?)\]")
|
@@ -138,19 +138,19 @@ def process_stats(
|
|
138
138
|
|
139
139
|
|
140
140
|
def process_stats_unpack(
|
141
|
-
args: tuple[int,
|
141
|
+
args: tuple[int, ArrayLike, list[BoundingBox] | None],
|
142
142
|
per_channel: bool,
|
143
143
|
stats_processor_cls: Iterable[type[StatsProcessor[TStatsOutput]]],
|
144
144
|
) -> StatsProcessorOutput:
|
145
145
|
return process_stats(*args, per_channel=per_channel, stats_processor_cls=stats_processor_cls)
|
146
146
|
|
147
147
|
|
148
|
-
def _enumerate(dataset: Dataset[
|
148
|
+
def _enumerate(dataset: Dataset[ArrayLike] | Dataset[tuple[ArrayLike, Any, Any]], per_box: bool):
|
149
149
|
for i in range(len(dataset)):
|
150
150
|
d = dataset[i]
|
151
151
|
image = d[0] if isinstance(d, tuple) else d
|
152
152
|
if per_box and isinstance(d, tuple) and isinstance(d[1], ObjectDetectionTarget):
|
153
|
-
boxes =
|
153
|
+
boxes = d[1].boxes if isinstance(d[1].boxes, Array) else as_numpy(d[1].boxes)
|
154
154
|
target = [BoundingBox(float(box[i]) for i in range(4)) for box in boxes]
|
155
155
|
else:
|
156
156
|
target = None
|
@@ -159,7 +159,7 @@ def _enumerate(dataset: Dataset[Array] | Dataset[tuple[Array, Any, Any]], per_bo
|
|
159
159
|
|
160
160
|
|
161
161
|
def run_stats(
|
162
|
-
dataset: Dataset[
|
162
|
+
dataset: Dataset[ArrayLike] | Dataset[tuple[ArrayLike, Any, Any]],
|
163
163
|
per_box: bool,
|
164
164
|
per_channel: bool,
|
165
165
|
stats_processor_cls: Iterable[type[StatsProcessor[TStatsOutput]]],
|
@@ -173,7 +173,7 @@ def run_stats(
|
|
173
173
|
|
174
174
|
Parameters
|
175
175
|
----------
|
176
|
-
data : Dataset[
|
176
|
+
data : Dataset[ArrayLike] | Dataset[tuple[ArrayLike, Any, Any]]
|
177
177
|
A dataset of images and targets to compute statistics on.
|
178
178
|
per_box : bool
|
179
179
|
A flag which determines if the statistics should be evaluated on a per-box basis or not.
|
@@ -9,7 +9,7 @@ import numpy as np
|
|
9
9
|
from dataeval.metrics.stats._base import StatsProcessor, run_stats
|
10
10
|
from dataeval.outputs import DimensionStatsOutput
|
11
11
|
from dataeval.outputs._base import set_metadata
|
12
|
-
from dataeval.typing import
|
12
|
+
from dataeval.typing import ArrayLike, Dataset
|
13
13
|
from dataeval.utils._image import get_bitdepth
|
14
14
|
|
15
15
|
|
@@ -34,7 +34,7 @@ class DimensionStatsProcessor(StatsProcessor[DimensionStatsOutput]):
|
|
34
34
|
|
35
35
|
@set_metadata
|
36
36
|
def dimensionstats(
|
37
|
-
dataset: Dataset[
|
37
|
+
dataset: Dataset[ArrayLike] | Dataset[tuple[ArrayLike, Any, Any]],
|
38
38
|
*,
|
39
39
|
per_box: bool = False,
|
40
40
|
) -> DimensionStatsOutput:
|
@@ -14,7 +14,7 @@ from scipy.fftpack import dct
|
|
14
14
|
from dataeval.metrics.stats._base import StatsProcessor, run_stats
|
15
15
|
from dataeval.outputs import HashStatsOutput
|
16
16
|
from dataeval.outputs._base import set_metadata
|
17
|
-
from dataeval.typing import
|
17
|
+
from dataeval.typing import ArrayLike, Dataset
|
18
18
|
from dataeval.utils._array import as_numpy
|
19
19
|
from dataeval.utils._image import normalize_image_shape, rescale
|
20
20
|
|
@@ -105,7 +105,7 @@ class HashStatsProcessor(StatsProcessor[HashStatsOutput]):
|
|
105
105
|
|
106
106
|
@set_metadata
|
107
107
|
def hashstats(
|
108
|
-
dataset: Dataset[
|
108
|
+
dataset: Dataset[ArrayLike] | Dataset[tuple[ArrayLike, Any, Any]],
|
109
109
|
*,
|
110
110
|
per_box: bool = False,
|
111
111
|
) -> HashStatsOutput:
|
@@ -10,12 +10,12 @@ from dataeval.metrics.stats._pixelstats import PixelStatsProcessor
|
|
10
10
|
from dataeval.metrics.stats._visualstats import VisualStatsProcessor
|
11
11
|
from dataeval.outputs import ChannelStatsOutput, ImageStatsOutput
|
12
12
|
from dataeval.outputs._base import set_metadata
|
13
|
-
from dataeval.typing import
|
13
|
+
from dataeval.typing import ArrayLike, Dataset
|
14
14
|
|
15
15
|
|
16
16
|
@overload
|
17
17
|
def imagestats(
|
18
|
-
dataset: Dataset[
|
18
|
+
dataset: Dataset[ArrayLike] | Dataset[tuple[ArrayLike, Any, Any]],
|
19
19
|
*,
|
20
20
|
per_box: bool = False,
|
21
21
|
per_channel: Literal[True],
|
@@ -24,7 +24,7 @@ def imagestats(
|
|
24
24
|
|
25
25
|
@overload
|
26
26
|
def imagestats(
|
27
|
-
dataset: Dataset[
|
27
|
+
dataset: Dataset[ArrayLike] | Dataset[tuple[ArrayLike, Any, Any]],
|
28
28
|
*,
|
29
29
|
per_box: bool = False,
|
30
30
|
per_channel: Literal[False] = False,
|
@@ -33,7 +33,7 @@ def imagestats(
|
|
33
33
|
|
34
34
|
@set_metadata
|
35
35
|
def imagestats(
|
36
|
-
dataset: Dataset[
|
36
|
+
dataset: Dataset[ArrayLike] | Dataset[tuple[ArrayLike, Any, Any]],
|
37
37
|
*,
|
38
38
|
per_box: bool = False,
|
39
39
|
per_channel: bool = False,
|
@@ -10,7 +10,7 @@ from scipy.stats import entropy, kurtosis, skew
|
|
10
10
|
from dataeval.metrics.stats._base import StatsProcessor, run_stats
|
11
11
|
from dataeval.outputs import PixelStatsOutput
|
12
12
|
from dataeval.outputs._base import set_metadata
|
13
|
-
from dataeval.typing import
|
13
|
+
from dataeval.typing import ArrayLike, Dataset
|
14
14
|
|
15
15
|
|
16
16
|
class PixelStatsProcessor(StatsProcessor[PixelStatsOutput]):
|
@@ -37,7 +37,7 @@ class PixelStatsProcessor(StatsProcessor[PixelStatsOutput]):
|
|
37
37
|
|
38
38
|
@set_metadata
|
39
39
|
def pixelstats(
|
40
|
-
dataset: Dataset[
|
40
|
+
dataset: Dataset[ArrayLike] | Dataset[tuple[ArrayLike, Any, Any]],
|
41
41
|
*,
|
42
42
|
per_box: bool = False,
|
43
43
|
per_channel: bool = False,
|
@@ -9,7 +9,7 @@ import numpy as np
|
|
9
9
|
from dataeval.metrics.stats._base import StatsProcessor, run_stats
|
10
10
|
from dataeval.outputs import VisualStatsOutput
|
11
11
|
from dataeval.outputs._base import set_metadata
|
12
|
-
from dataeval.typing import
|
12
|
+
from dataeval.typing import ArrayLike, Dataset
|
13
13
|
from dataeval.utils._image import edge_filter
|
14
14
|
|
15
15
|
QUARTILES = (0, 25, 50, 75, 100)
|
@@ -44,7 +44,7 @@ class VisualStatsProcessor(StatsProcessor[VisualStatsOutput]):
|
|
44
44
|
|
45
45
|
@set_metadata
|
46
46
|
def visualstats(
|
47
|
-
dataset: Dataset[
|
47
|
+
dataset: Dataset[ArrayLike] | Dataset[tuple[ArrayLike, Any, Any]],
|
48
48
|
*,
|
49
49
|
per_box: bool = False,
|
50
50
|
per_channel: bool = False,
|
dataeval/typing.py
CHANGED
@@ -21,8 +21,9 @@ __all__ = [
|
|
21
21
|
|
22
22
|
|
23
23
|
import sys
|
24
|
-
from typing import Any, Generic, Iterator, Protocol,
|
24
|
+
from typing import Any, Generic, Iterator, Protocol, TypedDict, TypeVar, runtime_checkable
|
25
25
|
|
26
|
+
import numpy.typing
|
26
27
|
from typing_extensions import NotRequired, ReadOnly, Required
|
27
28
|
|
28
29
|
if sys.version_info >= (3, 10):
|
@@ -31,6 +32,16 @@ else:
|
|
31
32
|
from typing_extensions import TypeAlias
|
32
33
|
|
33
34
|
|
35
|
+
ArrayLike: TypeAlias = numpy.typing.ArrayLike
|
36
|
+
"""
|
37
|
+
Type alias for a `Union` representing objects that can be coerced into an array.
|
38
|
+
|
39
|
+
See Also
|
40
|
+
--------
|
41
|
+
`NumPy ArrayLike <https://numpy.org/doc/stable/reference/typing.html#numpy.typing.ArrayLike>`_
|
42
|
+
"""
|
43
|
+
|
44
|
+
|
34
45
|
@runtime_checkable
|
35
46
|
class Array(Protocol):
|
36
47
|
"""
|
@@ -67,16 +78,8 @@ class Array(Protocol):
|
|
67
78
|
def __len__(self) -> int: ...
|
68
79
|
|
69
80
|
|
70
|
-
|
81
|
+
_T = TypeVar("_T")
|
71
82
|
_T_co = TypeVar("_T_co", covariant=True)
|
72
|
-
_ScalarType = Union[int, float, bool, str]
|
73
|
-
ArrayLike: TypeAlias = Union[Sequence[_ScalarType], Sequence[Sequence[_ScalarType]], Sequence[Array], Array]
|
74
|
-
"""
|
75
|
-
Type alias for array-like objects used for interoperability with DataEval.
|
76
|
-
|
77
|
-
This includes native Python sequences, as well as objects that conform to
|
78
|
-
the :class:`Array` protocol.
|
79
|
-
"""
|
80
83
|
|
81
84
|
|
82
85
|
class DatasetMetadata(TypedDict, total=False):
|
@@ -140,12 +143,12 @@ class AnnotatedDataset(Dataset[_T_co], Generic[_T_co], Protocol):
|
|
140
143
|
# ========== IMAGE CLASSIFICATION DATASETS ==========
|
141
144
|
|
142
145
|
|
143
|
-
ImageClassificationDatum: TypeAlias = tuple[
|
146
|
+
ImageClassificationDatum: TypeAlias = tuple[ArrayLike, ArrayLike, dict[str, Any]]
|
144
147
|
"""
|
145
148
|
Type alias for an image classification datum tuple.
|
146
149
|
|
147
|
-
- :class:`
|
148
|
-
- :class:`
|
150
|
+
- :class:`ArrayLike` of shape (C, H, W) - Image data in channel, height, width format.
|
151
|
+
- :class:`ArrayLike` of shape (N,) - Class label as one-hot encoded ground-truth or prediction confidences.
|
149
152
|
- dict[str, Any] - Datum level metadata.
|
150
153
|
"""
|
151
154
|
|
@@ -180,11 +183,11 @@ class ObjectDetectionTarget(Protocol):
|
|
180
183
|
def scores(self) -> ArrayLike: ...
|
181
184
|
|
182
185
|
|
183
|
-
ObjectDetectionDatum: TypeAlias = tuple[
|
186
|
+
ObjectDetectionDatum: TypeAlias = tuple[ArrayLike, ObjectDetectionTarget, dict[str, Any]]
|
184
187
|
"""
|
185
188
|
Type alias for an object detection datum tuple.
|
186
189
|
|
187
|
-
- :class:`
|
190
|
+
- :class:`ArrayLike` of shape (C, H, W) - Image data in channel, height, width format.
|
188
191
|
- :class:`ObjectDetectionTarget` - Object detection target information for the image.
|
189
192
|
- dict[str, Any] - Datum level metadata.
|
190
193
|
"""
|
@@ -221,11 +224,11 @@ class SegmentationTarget(Protocol):
|
|
221
224
|
def scores(self) -> ArrayLike: ...
|
222
225
|
|
223
226
|
|
224
|
-
SegmentationDatum: TypeAlias = tuple[
|
227
|
+
SegmentationDatum: TypeAlias = tuple[ArrayLike, SegmentationTarget, dict[str, Any]]
|
225
228
|
"""
|
226
229
|
Type alias for an image classification datum tuple.
|
227
230
|
|
228
|
-
- :class:`
|
231
|
+
- :class:`ArrayLike` of shape (C, H, W) - Image data in channel, height, width format.
|
229
232
|
- :class:`SegmentationTarget` - Segmentation target information for the image.
|
230
233
|
- dict[str, Any] - Datum level metadata.
|
231
234
|
"""
|
@@ -237,7 +240,7 @@ Type alias for an :class:`AnnotatedDataset` of :class:`SegmentationDatum` elemen
|
|
237
240
|
|
238
241
|
|
239
242
|
@runtime_checkable
|
240
|
-
class Transform(Generic[
|
243
|
+
class Transform(Generic[_T], Protocol):
|
241
244
|
"""
|
242
245
|
Protocol defining a transform function.
|
243
246
|
|
@@ -262,4 +265,4 @@ class Transform(Generic[T], Protocol):
|
|
262
265
|
array([0.004, 0.008, 0.012])
|
263
266
|
"""
|
264
267
|
|
265
|
-
def __call__(self, data:
|
268
|
+
def __call__(self, data: _T, /) -> _T: ...
|
dataeval/utils/_array.py
CHANGED
@@ -92,7 +92,7 @@ def ensure_embeddings(
|
|
92
92
|
@overload
|
93
93
|
def ensure_embeddings(
|
94
94
|
embeddings: T,
|
95
|
-
dtype: None,
|
95
|
+
dtype: None = None,
|
96
96
|
unit_interval: Literal[True, False, "force"] = False,
|
97
97
|
) -> T: ...
|
98
98
|
|
@@ -152,21 +152,32 @@ def ensure_embeddings(
|
|
152
152
|
return arr
|
153
153
|
|
154
154
|
|
155
|
-
|
155
|
+
@overload
|
156
|
+
def flatten(array: torch.Tensor) -> torch.Tensor: ...
|
157
|
+
@overload
|
158
|
+
def flatten(array: ArrayLike) -> NDArray[Any]: ...
|
159
|
+
|
160
|
+
|
161
|
+
def flatten(array: ArrayLike) -> NDArray[Any] | torch.Tensor:
|
156
162
|
"""
|
157
163
|
Flattens input array from (N, ... ) to (N, -1) where all samples N have all data in their last dimension
|
158
164
|
|
159
165
|
Parameters
|
160
166
|
----------
|
161
|
-
|
167
|
+
array : ArrayLike
|
162
168
|
Input array
|
163
169
|
|
164
170
|
Returns
|
165
171
|
-------
|
166
|
-
|
172
|
+
np.ndarray or torch.Tensor, shape: (N, -1)
|
167
173
|
"""
|
168
|
-
|
169
|
-
|
174
|
+
if isinstance(array, np.ndarray):
|
175
|
+
nparr = as_numpy(array)
|
176
|
+
return nparr.reshape((nparr.shape[0], -1))
|
177
|
+
elif isinstance(array, torch.Tensor):
|
178
|
+
return torch.flatten(array, start_dim=1)
|
179
|
+
else:
|
180
|
+
raise TypeError(f"Unsupported array type {type(array)}.")
|
170
181
|
|
171
182
|
|
172
183
|
_TArray = TypeVar("_TArray", bound=Array)
|
@@ -191,4 +202,4 @@ def channels_first_to_last(array: _TArray) -> _TArray:
|
|
191
202
|
elif isinstance(array, torch.Tensor):
|
192
203
|
return torch.permute(array, (1, 2, 0))
|
193
204
|
else:
|
194
|
-
raise TypeError(f"Unsupported array type {type(array)}
|
205
|
+
raise TypeError(f"Unsupported array type {type(array)}.")
|
dataeval/utils/data/_dataset.py
CHANGED
@@ -52,10 +52,12 @@ def _validate_data(
|
|
52
52
|
|
53
53
|
|
54
54
|
def _find_max(arr: ArrayLike) -> Any:
|
55
|
-
if isinstance(arr
|
56
|
-
|
57
|
-
|
58
|
-
|
55
|
+
if isinstance(arr, (Iterable, Sequence, Array)):
|
56
|
+
if isinstance(arr[0], (Iterable, Sequence, Array)):
|
57
|
+
return max([_find_max(x) for x in arr]) # type: ignore
|
58
|
+
else:
|
59
|
+
return max(arr)
|
60
|
+
return arr
|
59
61
|
|
60
62
|
|
61
63
|
_TLabels = TypeVar("_TLabels", Sequence[int], Sequence[Sequence[int]])
|
@@ -6,11 +6,13 @@ import math
|
|
6
6
|
from typing import Any, Iterator, Sequence, cast
|
7
7
|
|
8
8
|
import torch
|
9
|
+
from numpy.typing import NDArray
|
9
10
|
from torch.utils.data import DataLoader, Subset
|
10
11
|
from tqdm import tqdm
|
11
12
|
|
12
13
|
from dataeval.config import DeviceLike, get_device
|
13
|
-
from dataeval.typing import Array, Dataset, Transform
|
14
|
+
from dataeval.typing import Array, ArrayLike, Dataset, Transform
|
15
|
+
from dataeval.utils._array import as_numpy
|
14
16
|
from dataeval.utils.torch.models import SupportsEncode
|
15
17
|
|
16
18
|
|
@@ -45,7 +47,7 @@ class Embeddings:
|
|
45
47
|
|
46
48
|
def __init__(
|
47
49
|
self,
|
48
|
-
dataset: Dataset[tuple[
|
50
|
+
dataset: Dataset[tuple[ArrayLike, Any, Any]] | Dataset[ArrayLike],
|
49
51
|
batch_size: int,
|
50
52
|
transforms: Transform[torch.Tensor] | Sequence[Transform[torch.Tensor]] | None = None,
|
51
53
|
model: torch.nn.Module | None = None,
|
@@ -62,9 +64,9 @@ class Embeddings:
|
|
62
64
|
self._length = len(dataset)
|
63
65
|
model = torch.nn.Flatten() if model is None else model
|
64
66
|
self._transforms = [transforms] if isinstance(transforms, Transform) else transforms
|
65
|
-
self._model = model.to(self.device).eval()
|
67
|
+
self._model = model.to(self.device).eval() if isinstance(model, torch.nn.Module) else model
|
66
68
|
self._encoder = model.encode if isinstance(model, SupportsEncode) else model
|
67
|
-
self._collate_fn = lambda datum: [torch.as_tensor(
|
69
|
+
self._collate_fn = lambda datum: [torch.as_tensor(d[0] if isinstance(d, tuple) else d) for d in datum]
|
68
70
|
self._cached_idx = set()
|
69
71
|
self._embeddings: torch.Tensor = torch.empty(())
|
70
72
|
self._shallow: bool = False
|
@@ -91,14 +93,50 @@ class Embeddings:
|
|
91
93
|
else:
|
92
94
|
return self[:]
|
93
95
|
|
96
|
+
def to_numpy(self, indices: Sequence[int] | None = None) -> NDArray[Any]:
|
97
|
+
"""
|
98
|
+
Converts dataset to embeddings as numpy array.
|
99
|
+
|
100
|
+
Parameters
|
101
|
+
----------
|
102
|
+
indices : Sequence[int] or None, default None
|
103
|
+
The indices to convert to embeddings
|
104
|
+
|
105
|
+
Returns
|
106
|
+
-------
|
107
|
+
NDArray[Any]
|
108
|
+
|
109
|
+
Warning
|
110
|
+
-------
|
111
|
+
Processing large quantities of data can be resource intensive.
|
112
|
+
"""
|
113
|
+
return self.to_tensor(indices).cpu().numpy()
|
114
|
+
|
115
|
+
def new(self, dataset: Dataset[tuple[ArrayLike, Any, Any]] | Dataset[ArrayLike]) -> Embeddings:
|
116
|
+
"""
|
117
|
+
Creates a new Embeddings object with the same parameters but a different dataset.
|
118
|
+
|
119
|
+
Parameters
|
120
|
+
----------
|
121
|
+
dataset : ImageClassificationDataset or ObjectDetectionDataset
|
122
|
+
Dataset to access original images from.
|
123
|
+
|
124
|
+
Returns
|
125
|
+
-------
|
126
|
+
Embeddings
|
127
|
+
"""
|
128
|
+
return Embeddings(
|
129
|
+
dataset, self.batch_size, self._transforms, self._model, self.device, self.cache, self.verbose
|
130
|
+
)
|
131
|
+
|
94
132
|
@classmethod
|
95
|
-
def from_array(cls, array:
|
133
|
+
def from_array(cls, array: ArrayLike, device: DeviceLike | None = None) -> Embeddings:
|
96
134
|
"""
|
97
135
|
Instantiates a shallow Embeddings object using an array.
|
98
136
|
|
99
137
|
Parameters
|
100
138
|
----------
|
101
|
-
array :
|
139
|
+
array : ArrayLike
|
102
140
|
The array to convert to embeddings.
|
103
141
|
device : DeviceLike or None, default None
|
104
142
|
The hardware device to use if specified, otherwise uses the DataEval
|
@@ -118,6 +156,7 @@ class Embeddings:
|
|
118
156
|
torch.Size([100, 3, 224, 224])
|
119
157
|
"""
|
120
158
|
embeddings = Embeddings([], 0, None, None, device, True, False)
|
159
|
+
array = array if isinstance(array, Array) else as_numpy(array)
|
121
160
|
embeddings._length = len(array)
|
122
161
|
embeddings._cached_idx = set(range(len(array)))
|
123
162
|
embeddings._embeddings = torch.as_tensor(array).to(get_device(device))
|
@@ -131,7 +170,7 @@ class Embeddings:
|
|
131
170
|
|
132
171
|
@torch.no_grad() # Reduce overhead cost by not tracking tensor gradients
|
133
172
|
def _batch(self, indices: Sequence[int]) -> Iterator[torch.Tensor]:
|
134
|
-
dataset = cast(torch.utils.data.Dataset
|
173
|
+
dataset = cast(torch.utils.data.Dataset, self._dataset)
|
135
174
|
total_batches = math.ceil(len(indices) / self.batch_size)
|
136
175
|
|
137
176
|
# If not caching, process all indices normally
|
dataeval/utils/data/_images.py
CHANGED
@@ -4,13 +4,13 @@ __all__ = []
|
|
4
4
|
|
5
5
|
from typing import TYPE_CHECKING, Any, Generic, Iterator, Sequence, TypeVar, cast, overload
|
6
6
|
|
7
|
-
from dataeval.typing import Array, Dataset
|
7
|
+
from dataeval.typing import Array, ArrayLike, Dataset
|
8
8
|
from dataeval.utils._array import as_numpy, channels_first_to_last
|
9
9
|
|
10
10
|
if TYPE_CHECKING:
|
11
11
|
from matplotlib.figure import Figure
|
12
12
|
|
13
|
-
T = TypeVar("T",
|
13
|
+
T = TypeVar("T", Array, ArrayLike)
|
14
14
|
|
15
15
|
|
16
16
|
class Images(Generic[T]):
|
dataeval/utils/data/_metadata.py
CHANGED
@@ -3,7 +3,7 @@ from __future__ import annotations
|
|
3
3
|
__all__ = []
|
4
4
|
|
5
5
|
import warnings
|
6
|
-
from typing import TYPE_CHECKING, Any, Literal, Mapping, Sequence, cast
|
6
|
+
from typing import TYPE_CHECKING, Any, Literal, Mapping, Sequence, Sized, cast
|
7
7
|
|
8
8
|
import numpy as np
|
9
9
|
from numpy.typing import NDArray
|
@@ -208,8 +208,9 @@ class Metadata:
|
|
208
208
|
raw.append(metadata)
|
209
209
|
|
210
210
|
if is_od_target := isinstance(target, ObjectDetectionTarget):
|
211
|
-
|
212
|
-
|
211
|
+
target_labels = as_numpy(target.labels)
|
212
|
+
target_len = len(target_labels)
|
213
|
+
labels.extend(target_labels.tolist())
|
213
214
|
bboxes.extend(as_numpy(target.boxes).tolist())
|
214
215
|
scores.extend(as_numpy(target.scores).tolist())
|
215
216
|
srcidx.extend([i] * target_len)
|
@@ -360,7 +361,7 @@ class Metadata:
|
|
360
361
|
self._merge()
|
361
362
|
self._processed = False
|
362
363
|
target_len = len(self.targets.source) if self.targets.source is not None else len(self.targets)
|
363
|
-
if any(len(v) != target_len for v in factors.values()):
|
364
|
+
if any(len(v if isinstance(v, Sized) else as_numpy(v)) != target_len for v in factors.values()):
|
364
365
|
raise ValueError(
|
365
366
|
"The lists/arrays in the provided factors have a different length than the current metadata factors."
|
366
367
|
)
|
@@ -19,9 +19,12 @@ from dataeval.utils.data.datasets._types import (
|
|
19
19
|
)
|
20
20
|
|
21
21
|
if TYPE_CHECKING:
|
22
|
-
from dataeval.typing import Transform
|
22
|
+
from dataeval.typing import Array, Transform
|
23
|
+
|
24
|
+
_TArray = TypeVar("_TArray", bound=Array)
|
25
|
+
else:
|
26
|
+
_TArray = TypeVar("_TArray")
|
23
27
|
|
24
|
-
_TArray = TypeVar("_TArray")
|
25
28
|
_TTarget = TypeVar("_TTarget")
|
26
29
|
_TRawTarget = TypeVar("_TRawTarget", list[int], list[str])
|
27
30
|
|
@@ -51,9 +54,9 @@ class BaseDataset(AnnotatedDataset[tuple[_TArray, _TTarget, dict[str, Any]]], Ge
|
|
51
54
|
def __init__(
|
52
55
|
self,
|
53
56
|
root: str | Path,
|
54
|
-
|
55
|
-
image_set: Literal["train", "val", "test", "base"] = "train",
|
57
|
+
image_set: Literal["train", "val", "test", "operational", "base"] = "train",
|
56
58
|
transforms: Transform[_TArray] | Sequence[Transform[_TArray]] | None = None,
|
59
|
+
download: bool = False,
|
57
60
|
verbose: bool = False,
|
58
61
|
) -> None:
|
59
62
|
self._root: Path = root.absolute() if isinstance(root, Path) else Path(root).absolute()
|
@@ -27,13 +27,13 @@ class CIFAR10(BaseICDataset[NDArray[Any]], BaseDatasetNumpyMixin):
|
|
27
27
|
----------
|
28
28
|
root : str or pathlib.Path
|
29
29
|
Root directory of dataset where the ``mnist`` folder exists.
|
30
|
-
download : bool, default False
|
31
|
-
If True, downloads the dataset from the internet and puts it in root directory.
|
32
|
-
Class checks to see if data is already downloaded to ensure it does not create a duplicate download.
|
33
30
|
image_set : "train", "test" or "base", default "train"
|
34
31
|
If "base", returns all of the data to allow the user to create their own splits.
|
35
32
|
transforms : Transform, Sequence[Transform] or None, default None
|
36
33
|
Transform(s) to apply to the data.
|
34
|
+
download : bool, default False
|
35
|
+
If True, downloads the dataset from the internet and puts it in root directory.
|
36
|
+
Class checks to see if data is already downloaded to ensure it does not create a duplicate download.
|
37
37
|
verbose : bool, default False
|
38
38
|
If True, outputs print statements.
|
39
39
|
|
@@ -43,16 +43,16 @@ class CIFAR10(BaseICDataset[NDArray[Any]], BaseDatasetNumpyMixin):
|
|
43
43
|
Location of the folder containing the data.
|
44
44
|
image_set : "train", "test" or "base"
|
45
45
|
The selected image set from the dataset.
|
46
|
+
transforms : Sequence[Transform]
|
47
|
+
The transforms to be applied to the data.
|
48
|
+
size : int
|
49
|
+
The size of the dataset.
|
46
50
|
index2label : dict[int, str]
|
47
51
|
Dictionary which translates from class integers to the associated class strings.
|
48
52
|
label2index : dict[str, int]
|
49
53
|
Dictionary which translates from class strings to the associated class integers.
|
50
54
|
metadata : DatasetMetadata
|
51
55
|
Typed dictionary containing dataset metadata, such as `id` which returns the dataset class name.
|
52
|
-
transforms : Sequence[Transform]
|
53
|
-
The transforms to be applied to the data.
|
54
|
-
size : int
|
55
|
-
The size of the dataset.
|
56
56
|
"""
|
57
57
|
|
58
58
|
_resources = [
|
@@ -80,16 +80,16 @@ class CIFAR10(BaseICDataset[NDArray[Any]], BaseDatasetNumpyMixin):
|
|
80
80
|
def __init__(
|
81
81
|
self,
|
82
82
|
root: str | Path,
|
83
|
-
download: bool = False,
|
84
83
|
image_set: Literal["train", "test", "base"] = "train",
|
85
84
|
transforms: Transform[NDArray[Any]] | Sequence[Transform[NDArray[Any]]] | None = None,
|
85
|
+
download: bool = False,
|
86
86
|
verbose: bool = False,
|
87
87
|
) -> None:
|
88
88
|
super().__init__(
|
89
89
|
root,
|
90
|
-
download,
|
91
90
|
image_set,
|
92
91
|
transforms,
|
92
|
+
download,
|
93
93
|
verbose,
|
94
94
|
)
|
95
95
|
|