dataeval 0.82.0__py3-none-any.whl → 0.83.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dataeval/__init__.py +7 -2
- dataeval/config.py +78 -11
- dataeval/detectors/drift/_mmd.py +9 -9
- dataeval/detectors/drift/_torch.py +7 -7
- dataeval/detectors/drift/_uncertainty.py +4 -4
- dataeval/detectors/linters/duplicates.py +3 -3
- dataeval/detectors/linters/outliers.py +3 -3
- dataeval/detectors/ood/ae.py +5 -4
- dataeval/detectors/ood/base.py +2 -2
- dataeval/detectors/ood/mixin.py +1 -1
- dataeval/detectors/ood/vae.py +2 -1
- dataeval/metadata/__init__.py +2 -2
- dataeval/metadata/_distance.py +11 -44
- dataeval/metadata/_ood.py +152 -33
- dataeval/metrics/bias/_balance.py +9 -5
- dataeval/metrics/bias/_diversity.py +3 -0
- dataeval/metrics/bias/_parity.py +2 -0
- dataeval/metrics/estimators/_ber.py +2 -1
- dataeval/metrics/stats/_base.py +20 -21
- dataeval/metrics/stats/_boxratiostats.py +1 -1
- dataeval/metrics/stats/_dimensionstats.py +2 -2
- dataeval/metrics/stats/_hashstats.py +2 -2
- dataeval/metrics/stats/_imagestats.py +8 -8
- dataeval/metrics/stats/_pixelstats.py +2 -2
- dataeval/metrics/stats/_visualstats.py +2 -2
- dataeval/outputs/__init__.py +5 -0
- dataeval/outputs/_base.py +50 -21
- dataeval/outputs/_bias.py +1 -1
- dataeval/outputs/_linters.py +4 -2
- dataeval/outputs/_metadata.py +61 -0
- dataeval/outputs/_stats.py +12 -6
- dataeval/typing.py +40 -9
- dataeval/utils/_mst.py +1 -2
- dataeval/utils/data/_embeddings.py +23 -19
- dataeval/utils/data/_metadata.py +16 -7
- dataeval/utils/data/_selection.py +22 -15
- dataeval/utils/data/_split.py +3 -2
- dataeval/utils/data/datasets/_base.py +4 -2
- dataeval/utils/data/datasets/_cifar10.py +17 -9
- dataeval/utils/data/datasets/_milco.py +18 -12
- dataeval/utils/data/datasets/_mnist.py +24 -8
- dataeval/utils/data/datasets/_ships.py +18 -8
- dataeval/utils/data/datasets/_types.py +1 -5
- dataeval/utils/data/datasets/_voc.py +47 -24
- dataeval/utils/data/selections/__init__.py +2 -0
- dataeval/utils/data/selections/_classfilter.py +5 -3
- dataeval/utils/data/selections/_prioritize.py +296 -0
- dataeval/utils/data/selections/_shuffle.py +13 -4
- dataeval/utils/torch/_gmm.py +3 -2
- dataeval/utils/torch/_internal.py +5 -5
- dataeval/utils/torch/trainer.py +8 -8
- {dataeval-0.82.0.dist-info → dataeval-0.83.0.dist-info}/METADATA +4 -4
- dataeval-0.83.0.dist-info/RECORD +105 -0
- dataeval/detectors/ood/metadata_ood_mi.py +0 -93
- dataeval-0.82.0.dist-info/RECORD +0 -104
- {dataeval-0.82.0.dist-info → dataeval-0.83.0.dist-info}/LICENSE.txt +0 -0
- {dataeval-0.82.0.dist-info → dataeval-0.83.0.dist-info}/WHEEL +0 -0
dataeval/outputs/_base.py
CHANGED
@@ -4,11 +4,11 @@ __all__ = []
|
|
4
4
|
|
5
5
|
import inspect
|
6
6
|
import logging
|
7
|
-
from collections.abc import Mapping
|
7
|
+
from collections.abc import Collection, Mapping, Sequence
|
8
8
|
from dataclasses import dataclass
|
9
9
|
from datetime import datetime, timezone
|
10
10
|
from functools import partial, wraps
|
11
|
-
from typing import Any, Callable, Iterator, TypeVar
|
11
|
+
from typing import Any, Callable, Generic, Iterator, TypeVar, overload
|
12
12
|
|
13
13
|
import numpy as np
|
14
14
|
from typing_extensions import ParamSpec
|
@@ -56,16 +56,13 @@ class ExecutionMetadata:
|
|
56
56
|
)
|
57
57
|
|
58
58
|
|
59
|
-
|
60
|
-
_meta: ExecutionMetadata | None = None
|
59
|
+
T = TypeVar("T", covariant=True)
|
61
60
|
|
62
|
-
def __str__(self) -> str:
|
63
|
-
return f"{self.__class__.__name__}: {str(self.dict())}"
|
64
61
|
|
65
|
-
|
66
|
-
|
62
|
+
class GenericOutput(Generic[T]):
|
63
|
+
_meta: ExecutionMetadata | None = None
|
67
64
|
|
68
|
-
|
65
|
+
def data(self) -> T: ...
|
69
66
|
def meta(self) -> ExecutionMetadata:
|
70
67
|
"""
|
71
68
|
Metadata about the execution of the function or method for the Output class.
|
@@ -73,34 +70,66 @@ class Output:
|
|
73
70
|
return self._meta or ExecutionMetadata.empty()
|
74
71
|
|
75
72
|
|
76
|
-
|
77
|
-
|
73
|
+
class Output(GenericOutput[dict[str, Any]]):
|
74
|
+
def data(self) -> dict[str, Any]:
|
75
|
+
return {k: v for k, v in self.__dict__.items() if k != "_meta"}
|
78
76
|
|
77
|
+
def __repr__(self) -> str:
|
78
|
+
return str(self)
|
79
79
|
|
80
|
-
|
80
|
+
def __str__(self) -> str:
|
81
|
+
return f"{self.__class__.__name__}({', '.join([f'{k}={v}' for k, v in self.data().items()])})"
|
82
|
+
|
83
|
+
|
84
|
+
class BaseCollectionMixin(Collection[Any]):
|
81
85
|
__slots__ = ["_data"]
|
82
86
|
|
87
|
+
def data(self) -> Any:
|
88
|
+
return self._data
|
89
|
+
|
90
|
+
def __len__(self) -> int:
|
91
|
+
return len(self._data)
|
92
|
+
|
93
|
+
def __repr__(self) -> str:
|
94
|
+
return f"{self.__class__.__name__}({repr(self._data)})"
|
95
|
+
|
96
|
+
def __str__(self) -> str:
|
97
|
+
return str(self._data)
|
98
|
+
|
99
|
+
|
100
|
+
TKey = TypeVar("TKey", str, int, float, set)
|
101
|
+
TValue = TypeVar("TValue")
|
102
|
+
|
103
|
+
|
104
|
+
class MappingOutput(Mapping[TKey, TValue], BaseCollectionMixin, GenericOutput[Mapping[TKey, TValue]]):
|
83
105
|
def __init__(self, data: Mapping[TKey, TValue]):
|
84
106
|
self._data = data
|
85
107
|
|
86
108
|
def __getitem__(self, key: TKey) -> TValue:
|
87
|
-
return self._data
|
109
|
+
return self._data[key]
|
88
110
|
|
89
111
|
def __iter__(self) -> Iterator[TKey]:
|
90
|
-
return self._data
|
112
|
+
return iter(self._data)
|
91
113
|
|
92
|
-
def __len__(self) -> int:
|
93
|
-
return self._data.__len__()
|
94
114
|
|
95
|
-
|
96
|
-
|
115
|
+
class SequenceOutput(Sequence[TValue], BaseCollectionMixin, GenericOutput[Sequence[TValue]]):
|
116
|
+
def __init__(self, data: Sequence[TValue]):
|
117
|
+
self._data = data
|
118
|
+
|
119
|
+
@overload
|
120
|
+
def __getitem__(self, index: int) -> TValue: ...
|
121
|
+
@overload
|
122
|
+
def __getitem__(self, index: slice) -> Sequence[TValue]: ...
|
97
123
|
|
98
|
-
def
|
99
|
-
return
|
124
|
+
def __getitem__(self, index: int | slice) -> TValue | Sequence[TValue]:
|
125
|
+
return self._data[index]
|
126
|
+
|
127
|
+
def __iter__(self) -> Iterator[TValue]:
|
128
|
+
return iter(self._data)
|
100
129
|
|
101
130
|
|
102
131
|
P = ParamSpec("P")
|
103
|
-
R = TypeVar("R", bound=
|
132
|
+
R = TypeVar("R", bound=GenericOutput)
|
104
133
|
|
105
134
|
|
106
135
|
def set_metadata(fn: Callable[P, R] | None = None, *, state: list[str] | None = None) -> Callable[P, R]:
|
dataeval/outputs/_bias.py
CHANGED
@@ -364,7 +364,7 @@ class DiversityOutput(Output):
|
|
364
364
|
col_labels,
|
365
365
|
xlabel="Factors",
|
366
366
|
ylabel="Class",
|
367
|
-
cbarlabel=f"Normalized {asdict(self.meta)['arguments']['method'].title()} Index",
|
367
|
+
cbarlabel=f"Normalized {asdict(self.meta())['arguments']['method'].title()} Index",
|
368
368
|
)
|
369
369
|
|
370
370
|
else:
|
dataeval/outputs/_linters.py
CHANGED
@@ -24,7 +24,7 @@ TIndexIssueMap = TypeVar("TIndexIssueMap", IndexIssueMap, list[IndexIssueMap])
|
|
24
24
|
|
25
25
|
|
26
26
|
@dataclass(frozen=True)
|
27
|
-
class DuplicatesOutput(Generic[TIndexCollection]
|
27
|
+
class DuplicatesOutput(Output, Generic[TIndexCollection]):
|
28
28
|
"""
|
29
29
|
Output class for :class:`.Duplicates` lint detector.
|
30
30
|
|
@@ -35,6 +35,8 @@ class DuplicatesOutput(Generic[TIndexCollection], Output):
|
|
35
35
|
near: list[list[int] | dict[int, list[int]]]
|
36
36
|
Indices of images that are near matches
|
37
37
|
|
38
|
+
Notes
|
39
|
+
-----
|
38
40
|
- For a single dataset, indices are returned as a list of index groups.
|
39
41
|
- For multiple datasets, indices are returned as dictionaries where the key is the
|
40
42
|
index of the dataset, and the value is the list index groups from that dataset.
|
@@ -99,7 +101,7 @@ def _create_pandas_dataframe(class_wise):
|
|
99
101
|
|
100
102
|
|
101
103
|
@dataclass(frozen=True)
|
102
|
-
class OutliersOutput(Generic[TIndexIssueMap]
|
104
|
+
class OutliersOutput(Output, Generic[TIndexIssueMap]):
|
103
105
|
"""
|
104
106
|
Output class for :class:`.Outliers` lint detector.
|
105
107
|
|
@@ -0,0 +1,61 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
__all__ = []
|
4
|
+
|
5
|
+
from typing import NamedTuple
|
6
|
+
|
7
|
+
from dataeval.outputs._base import MappingOutput, SequenceOutput
|
8
|
+
|
9
|
+
|
10
|
+
class MostDeviatedFactorsOutput(SequenceOutput[tuple[str, float]]):
|
11
|
+
"""
|
12
|
+
Output class for results of :func:`.most_deviated_factors` for OOD samples with metadata.
|
13
|
+
|
14
|
+
Attributes
|
15
|
+
----------
|
16
|
+
value : tuple[str, float]
|
17
|
+
A tuple of the factor name and deviation of the highest metadata deviation
|
18
|
+
"""
|
19
|
+
|
20
|
+
|
21
|
+
class MetadataDistanceValues(NamedTuple):
|
22
|
+
"""
|
23
|
+
Statistics comparing metadata distance.
|
24
|
+
|
25
|
+
Attributes
|
26
|
+
----------
|
27
|
+
statistic : float
|
28
|
+
the KS statistic
|
29
|
+
location : float
|
30
|
+
The value at which the KS statistic has its maximum, measured in IQR-normalized units relative
|
31
|
+
to the median of the reference distribution.
|
32
|
+
dist : float
|
33
|
+
The Earth Mover's Distance normalized by the interquartile range (IQR) of the reference
|
34
|
+
pvalue : float
|
35
|
+
The p-value from the KS two-sample test
|
36
|
+
"""
|
37
|
+
|
38
|
+
statistic: float
|
39
|
+
location: float
|
40
|
+
dist: float
|
41
|
+
pvalue: float
|
42
|
+
|
43
|
+
|
44
|
+
class MetadataDistanceOutput(MappingOutput[str, MetadataDistanceValues]):
|
45
|
+
"""
|
46
|
+
Output class for results of ks_2samp featurewise comparisons of new metadata to reference metadata.
|
47
|
+
|
48
|
+
Attributes
|
49
|
+
----------
|
50
|
+
key : str
|
51
|
+
Metadata feature names
|
52
|
+
value : :class:`.MetadataDistanceValues`
|
53
|
+
Output per feature name containing the statistic, statistic location, distance, and pvalue.
|
54
|
+
"""
|
55
|
+
|
56
|
+
|
57
|
+
class OODPredictorOutput(MappingOutput[str, float]):
|
58
|
+
"""
|
59
|
+
Output class for results of :func:`find_ood_predictors` for the
|
60
|
+
mutual information between factors and being out of distribution
|
61
|
+
"""
|
dataeval/outputs/_stats.py
CHANGED
@@ -4,7 +4,7 @@ __all__ = []
|
|
4
4
|
|
5
5
|
import contextlib
|
6
6
|
from dataclasses import dataclass
|
7
|
-
from typing import Iterable, Optional, Union
|
7
|
+
from typing import Any, Iterable, Optional, Union
|
8
8
|
|
9
9
|
import numpy as np
|
10
10
|
from numpy.typing import NDArray
|
@@ -63,7 +63,7 @@ class BaseStatsOutput(Output):
|
|
63
63
|
|
64
64
|
def __post_init__(self) -> None:
|
65
65
|
length = len(self.source_index)
|
66
|
-
bad = {k: len(v) for k, v in self.
|
66
|
+
bad = {k: len(v) for k, v in self.data().items() if k not in [SOURCE_INDEX, BOX_COUNT] and len(v) != length}
|
67
67
|
if bad:
|
68
68
|
raise ValueError(f"All values must have the same length as source_index. Bad values: {str(bad)}.")
|
69
69
|
|
@@ -105,7 +105,7 @@ class BaseStatsOutput(Output):
|
|
105
105
|
def _get_channels(
|
106
106
|
self, channel_limit: int | None = None, channel_index: int | Iterable[int] | None = None
|
107
107
|
) -> tuple[int, list[bool] | None]:
|
108
|
-
source_index = self.
|
108
|
+
source_index = self.data()[SOURCE_INDEX]
|
109
109
|
raw_channels = int(max([si.channel or 0 for si in source_index])) + 1
|
110
110
|
if isinstance(channel_index, int):
|
111
111
|
max_channels = 1 if channel_index < raw_channels else raw_channels
|
@@ -127,15 +127,21 @@ class BaseStatsOutput(Output):
|
|
127
127
|
|
128
128
|
return max_channels, ch_mask
|
129
129
|
|
130
|
+
def factors(self) -> dict[str, NDArray[Any]]:
|
131
|
+
return {
|
132
|
+
k: v
|
133
|
+
for k, v in self.data().items()
|
134
|
+
if k not in (SOURCE_INDEX, BOX_COUNT) and isinstance(v, np.ndarray) and v[v != 0].size > 0 and v.ndim == 1
|
135
|
+
}
|
136
|
+
|
130
137
|
def plot(
|
131
138
|
self, log: bool, channel_limit: int | None = None, channel_index: int | Iterable[int] | None = None
|
132
139
|
) -> None:
|
133
140
|
max_channels, ch_mask = self._get_channels(channel_limit, channel_index)
|
134
|
-
d = {k: v for k, v in self.dict().items() if isinstance(v, np.ndarray) and v[v != 0].size > 0 and v.ndim == 1}
|
135
141
|
if max_channels == 1:
|
136
|
-
histogram_plot(
|
142
|
+
histogram_plot(self.factors(), log)
|
137
143
|
else:
|
138
|
-
channel_histogram_plot(
|
144
|
+
channel_histogram_plot(self.factors(), log, max_channels, ch_mask)
|
139
145
|
|
140
146
|
|
141
147
|
@dataclass(frozen=True)
|
dataeval/typing.py
CHANGED
@@ -1,5 +1,5 @@
|
|
1
1
|
"""
|
2
|
-
Common type
|
2
|
+
Common type protocols used for interoperability with DataEval.
|
3
3
|
"""
|
4
4
|
|
5
5
|
__all__ = [
|
@@ -16,6 +16,7 @@ __all__ = [
|
|
16
16
|
"SegmentationTarget",
|
17
17
|
"SegmentationDatum",
|
18
18
|
"SegmentationDataset",
|
19
|
+
"Transform",
|
19
20
|
]
|
20
21
|
|
21
22
|
|
@@ -66,6 +67,7 @@ class Array(Protocol):
|
|
66
67
|
def __len__(self) -> int: ...
|
67
68
|
|
68
69
|
|
70
|
+
T = TypeVar("T")
|
69
71
|
_T_co = TypeVar("_T_co", covariant=True)
|
70
72
|
_ScalarType = Union[int, float, bool, str]
|
71
73
|
ArrayLike: TypeAlias = Union[Sequence[_ScalarType], Sequence[Sequence[_ScalarType]], Sequence[Array], Array]
|
@@ -140,7 +142,7 @@ class AnnotatedDataset(Dataset[_T_co], Generic[_T_co], Protocol):
|
|
140
142
|
|
141
143
|
ImageClassificationDatum: TypeAlias = tuple[Array, Array, dict[str, Any]]
|
142
144
|
"""
|
143
|
-
|
145
|
+
Type alias for an image classification datum tuple.
|
144
146
|
|
145
147
|
- :class:`Array` of shape (C, H, W) - Image data in channel, height, width format.
|
146
148
|
- :class:`Array` of shape (N,) - Class label as one-hot encoded ground-truth or prediction confidences.
|
@@ -150,7 +152,7 @@ A type definition for an image classification datum tuple.
|
|
150
152
|
|
151
153
|
ImageClassificationDataset: TypeAlias = AnnotatedDataset[ImageClassificationDatum]
|
152
154
|
"""
|
153
|
-
|
155
|
+
Type alias for an :class:`AnnotatedDataset` of :class:`ImageClassificationDatum` elements.
|
154
156
|
"""
|
155
157
|
|
156
158
|
# ========== OBJECT DETECTION DATASETS ==========
|
@@ -159,7 +161,7 @@ A type definition for an :class:`AnnotatedDataset` of :class:`ImageClassificatio
|
|
159
161
|
@runtime_checkable
|
160
162
|
class ObjectDetectionTarget(Protocol):
|
161
163
|
"""
|
162
|
-
|
164
|
+
Protocol for targets in an Object Detection dataset.
|
163
165
|
|
164
166
|
Attributes
|
165
167
|
----------
|
@@ -180,7 +182,7 @@ class ObjectDetectionTarget(Protocol):
|
|
180
182
|
|
181
183
|
ObjectDetectionDatum: TypeAlias = tuple[Array, ObjectDetectionTarget, dict[str, Any]]
|
182
184
|
"""
|
183
|
-
|
185
|
+
Type alias for an object detection datum tuple.
|
184
186
|
|
185
187
|
- :class:`Array` of shape (C, H, W) - Image data in channel, height, width format.
|
186
188
|
- :class:`ObjectDetectionTarget` - Object detection target information for the image.
|
@@ -190,7 +192,7 @@ A type definition for an object detection datum tuple.
|
|
190
192
|
|
191
193
|
ObjectDetectionDataset: TypeAlias = AnnotatedDataset[ObjectDetectionDatum]
|
192
194
|
"""
|
193
|
-
|
195
|
+
Type alias for an :class:`AnnotatedDataset` of :class:`ObjectDetectionDatum` elements.
|
194
196
|
"""
|
195
197
|
|
196
198
|
|
@@ -200,7 +202,7 @@ A type definition for an :class:`AnnotatedDataset` of :class:`ObjectDetectionDat
|
|
200
202
|
@runtime_checkable
|
201
203
|
class SegmentationTarget(Protocol):
|
202
204
|
"""
|
203
|
-
|
205
|
+
Protocol for targets in a Segmentation dataset.
|
204
206
|
|
205
207
|
Attributes
|
206
208
|
----------
|
@@ -221,7 +223,7 @@ class SegmentationTarget(Protocol):
|
|
221
223
|
|
222
224
|
SegmentationDatum: TypeAlias = tuple[Array, SegmentationTarget, dict[str, Any]]
|
223
225
|
"""
|
224
|
-
|
226
|
+
Type alias for an image classification datum tuple.
|
225
227
|
|
226
228
|
- :class:`Array` of shape (C, H, W) - Image data in channel, height, width format.
|
227
229
|
- :class:`SegmentationTarget` - Segmentation target information for the image.
|
@@ -230,5 +232,34 @@ A type definition for an image classification datum tuple.
|
|
230
232
|
|
231
233
|
SegmentationDataset: TypeAlias = AnnotatedDataset[SegmentationDatum]
|
232
234
|
"""
|
233
|
-
|
235
|
+
Type alias for an :class:`AnnotatedDataset` of :class:`SegmentationDatum` elements.
|
234
236
|
"""
|
237
|
+
|
238
|
+
|
239
|
+
@runtime_checkable
|
240
|
+
class Transform(Generic[T], Protocol):
|
241
|
+
"""
|
242
|
+
Protocol defining a transform function.
|
243
|
+
|
244
|
+
Requires a `__call__` method that returns transformed data.
|
245
|
+
|
246
|
+
Example
|
247
|
+
-------
|
248
|
+
>>> from typing import Any
|
249
|
+
>>> from numpy.typing import NDArray
|
250
|
+
|
251
|
+
>>> class MyTransform:
|
252
|
+
... def __init__(self, divisor: float) -> None:
|
253
|
+
... self.divisor = divisor
|
254
|
+
...
|
255
|
+
... def __call__(self, data: NDArray[Any], /) -> NDArray[Any]:
|
256
|
+
... return data / self.divisor
|
257
|
+
|
258
|
+
>>> my_transform = MyTransform(divisor=255.0)
|
259
|
+
>>> isinstance(my_transform, Transform)
|
260
|
+
True
|
261
|
+
>>> my_transform(np.array([1, 2, 3]))
|
262
|
+
array([0.004, 0.008, 0.012])
|
263
|
+
"""
|
264
|
+
|
265
|
+
def __call__(self, data: T, /) -> T: ...
|
dataeval/utils/_mst.py
CHANGED
@@ -10,10 +10,9 @@ from scipy.sparse.csgraph import minimum_spanning_tree as mst
|
|
10
10
|
from scipy.spatial.distance import pdist, squareform
|
11
11
|
from sklearn.neighbors import NearestNeighbors
|
12
12
|
|
13
|
+
from dataeval.config import EPSILON
|
13
14
|
from dataeval.utils._array import flatten
|
14
15
|
|
15
|
-
EPSILON = 1e-5
|
16
|
-
|
17
16
|
|
18
17
|
def minimum_spanning_tree(X: NDArray[Any]) -> Any:
|
19
18
|
"""
|
@@ -9,7 +9,7 @@ import torch
|
|
9
9
|
from torch.utils.data import DataLoader, Subset
|
10
10
|
from tqdm import tqdm
|
11
11
|
|
12
|
-
from dataeval.config import get_device
|
12
|
+
from dataeval.config import DeviceLike, get_device
|
13
13
|
from dataeval.typing import Array, Dataset
|
14
14
|
from dataeval.utils.torch.models import SupportsEncode
|
15
15
|
|
@@ -24,13 +24,14 @@ class Embeddings:
|
|
24
24
|
----------
|
25
25
|
dataset : ImageClassificationDataset or ObjectDetectionDataset
|
26
26
|
Dataset to access original images from.
|
27
|
-
batch_size : int
|
27
|
+
batch_size : int
|
28
28
|
Batch size to use when encoding images.
|
29
|
-
model : torch.nn.Module,
|
29
|
+
model : torch.nn.Module or None, default None
|
30
30
|
Model to use for encoding images.
|
31
|
-
device :
|
32
|
-
|
33
|
-
|
31
|
+
device : DeviceLike or None, default None
|
32
|
+
The hardware device to use if specified, otherwise uses the DataEval
|
33
|
+
default or torch default.
|
34
|
+
verbose : bool, default False
|
34
35
|
Whether to print progress bar when encoding images.
|
35
36
|
"""
|
36
37
|
|
@@ -42,9 +43,8 @@ class Embeddings:
|
|
42
43
|
self,
|
43
44
|
dataset: Dataset[tuple[Array, Any, Any]],
|
44
45
|
batch_size: int,
|
45
|
-
indices: Sequence[int] | None = None,
|
46
46
|
model: torch.nn.Module | None = None,
|
47
|
-
device:
|
47
|
+
device: DeviceLike | None = None,
|
48
48
|
verbose: bool = False,
|
49
49
|
) -> None:
|
50
50
|
self.device = get_device(device)
|
@@ -52,26 +52,32 @@ class Embeddings:
|
|
52
52
|
self.verbose = verbose
|
53
53
|
|
54
54
|
self._dataset = dataset
|
55
|
-
self._indices = indices if indices is not None else range(len(dataset))
|
56
55
|
model = torch.nn.Flatten() if model is None else model
|
57
56
|
self._model = model.to(self.device).eval()
|
58
57
|
self._encoder = model.encode if isinstance(model, SupportsEncode) else model
|
59
58
|
self._collate_fn = lambda datum: [torch.as_tensor(i) for i, _, _ in datum]
|
60
59
|
|
61
|
-
def to_tensor(self) -> torch.Tensor:
|
60
|
+
def to_tensor(self, indices: Sequence[int] | None = None) -> torch.Tensor:
|
62
61
|
"""
|
63
|
-
Converts
|
62
|
+
Converts dataset to embeddings.
|
64
63
|
|
65
|
-
|
66
|
-
|
67
|
-
|
68
|
-
|
64
|
+
Parameters
|
65
|
+
----------
|
66
|
+
indices : Sequence[int] or None, default None
|
67
|
+
The indices to convert to embeddings
|
69
68
|
|
70
69
|
Returns
|
71
70
|
-------
|
72
71
|
torch.Tensor
|
72
|
+
|
73
|
+
Warning
|
74
|
+
-------
|
75
|
+
Processing large quantities of data can be resource intensive.
|
73
76
|
"""
|
74
|
-
|
77
|
+
if indices is not None:
|
78
|
+
return torch.vstack(list(self._batch(indices))).to(self.device)
|
79
|
+
else:
|
80
|
+
return self[:]
|
75
81
|
|
76
82
|
# Reduce overhead cost by not tracking tensor gradients
|
77
83
|
@torch.no_grad
|
@@ -86,9 +92,7 @@ class Embeddings:
|
|
86
92
|
embeddings = self._encoder(torch.stack(images).to(self.device))
|
87
93
|
yield embeddings
|
88
94
|
|
89
|
-
def __getitem__(self, key: int | slice
|
90
|
-
if isinstance(key, list):
|
91
|
-
return torch.vstack(list(self._batch(key))).to(self.device)
|
95
|
+
def __getitem__(self, key: int | slice, /) -> torch.Tensor:
|
92
96
|
if isinstance(key, slice):
|
93
97
|
return torch.vstack(list(self._batch(range(len(self._dataset))[key]))).to(self.device)
|
94
98
|
elif isinstance(key, int):
|
dataeval/utils/data/_metadata.py
CHANGED
@@ -3,7 +3,7 @@ from __future__ import annotations
|
|
3
3
|
__all__ = []
|
4
4
|
|
5
5
|
import warnings
|
6
|
-
from typing import TYPE_CHECKING, Any, Literal, Mapping, Sequence
|
6
|
+
from typing import TYPE_CHECKING, Any, Literal, Mapping, Sequence, cast
|
7
7
|
|
8
8
|
import numpy as np
|
9
9
|
from numpy.typing import NDArray
|
@@ -11,6 +11,7 @@ from numpy.typing import NDArray
|
|
11
11
|
from dataeval.typing import (
|
12
12
|
AnnotatedDataset,
|
13
13
|
Array,
|
14
|
+
ArrayLike,
|
14
15
|
ObjectDetectionTarget,
|
15
16
|
)
|
16
17
|
from dataeval.utils._array import as_numpy, to_numpy
|
@@ -276,16 +277,12 @@ class Metadata:
|
|
276
277
|
if self._processed and not force:
|
277
278
|
return
|
278
279
|
|
279
|
-
#
|
280
|
-
self.
|
281
|
-
self._merge()
|
280
|
+
# Create image indices from targets
|
281
|
+
self._image_indices = np.arange(len(self.raw)) if self.targets.source is None else self.targets.source
|
282
282
|
|
283
283
|
# Validate the metadata dimensions
|
284
284
|
self._validate()
|
285
285
|
|
286
|
-
# Create image indices from targets
|
287
|
-
self._image_indices = np.arange(len(self.raw)) if self.targets.source is None else self.targets.source
|
288
|
-
|
289
286
|
# Include specified metadata keys
|
290
287
|
if self.include:
|
291
288
|
metadata = {i: self.merged[i] for i in self.include if i in self.merged}
|
@@ -358,3 +355,15 @@ class Metadata:
|
|
358
355
|
)
|
359
356
|
self._total_num_factors = len(self._discrete_factor_names + self._continuous_factor_names) + 1
|
360
357
|
self._processed = True
|
358
|
+
|
359
|
+
def add_factors(self, factors: Mapping[str, ArrayLike]) -> None:
|
360
|
+
self._merge()
|
361
|
+
self._processed = False
|
362
|
+
target_len = len(self.targets.source) if self.targets.source is not None else len(self.targets)
|
363
|
+
if any(len(v) != target_len for v in factors.values()):
|
364
|
+
raise ValueError(
|
365
|
+
"The lists/arrays in the provided factors have a different length than the current metadata factors."
|
366
|
+
)
|
367
|
+
merged = cast(tuple[dict[str, ArrayLike], dict[str, list[str]]], self._merged)[0]
|
368
|
+
for k, v in factors.items():
|
369
|
+
merged[k] = v
|
@@ -3,9 +3,9 @@ from __future__ import annotations
|
|
3
3
|
__all__ = []
|
4
4
|
|
5
5
|
from enum import IntEnum
|
6
|
-
from typing import
|
6
|
+
from typing import Generic, Iterator, Sequence, TypeVar
|
7
7
|
|
8
|
-
from dataeval.typing import AnnotatedDataset, DatasetMetadata
|
8
|
+
from dataeval.typing import AnnotatedDataset, DatasetMetadata, Transform
|
9
9
|
|
10
10
|
_TDatum = TypeVar("_TDatum")
|
11
11
|
|
@@ -35,6 +35,8 @@ class Select(AnnotatedDataset[_TDatum]):
|
|
35
35
|
The dataset to wrap.
|
36
36
|
selections : Selection or list[Selection], optional
|
37
37
|
The selection criteria to apply to the dataset.
|
38
|
+
transforms : Transform or list[Transform], optional
|
39
|
+
The transforms to apply to the dataset.
|
38
40
|
|
39
41
|
Examples
|
40
42
|
--------
|
@@ -67,13 +69,17 @@ class Select(AnnotatedDataset[_TDatum]):
|
|
67
69
|
def __init__(
|
68
70
|
self,
|
69
71
|
dataset: AnnotatedDataset[_TDatum],
|
70
|
-
selections: Selection[_TDatum] |
|
72
|
+
selections: Selection[_TDatum] | Sequence[Selection[_TDatum]] | None = None,
|
73
|
+
transforms: Transform[_TDatum] | Sequence[Transform[_TDatum]] | None = None,
|
71
74
|
) -> None:
|
75
|
+
self.__dict__.update(dataset.__dict__)
|
72
76
|
self._dataset = dataset
|
73
77
|
self._size_limit = len(dataset)
|
74
78
|
self._selection = list(range(self._size_limit))
|
75
|
-
self._selections = self.
|
76
|
-
self.
|
79
|
+
self._selections = self._sort(selections)
|
80
|
+
self._transforms = (
|
81
|
+
[] if transforms is None else [transforms] if isinstance(transforms, Transform) else transforms
|
82
|
+
)
|
77
83
|
|
78
84
|
# Ensure metadata is populated correctly as DatasetMetadata TypedDict
|
79
85
|
_metadata = getattr(dataset, "metadata", {})
|
@@ -81,8 +87,7 @@ class Select(AnnotatedDataset[_TDatum]):
|
|
81
87
|
_metadata["id"] = dataset.__class__.__name__
|
82
88
|
self._metadata = DatasetMetadata(**_metadata)
|
83
89
|
|
84
|
-
|
85
|
-
self._apply_selections()
|
90
|
+
self._select()
|
86
91
|
|
87
92
|
@property
|
88
93
|
def metadata(self) -> DatasetMetadata:
|
@@ -92,10 +97,11 @@ class Select(AnnotatedDataset[_TDatum]):
|
|
92
97
|
nt = "\n "
|
93
98
|
title = f"{self.__class__.__name__} Dataset"
|
94
99
|
sep = "-" * len(title)
|
95
|
-
selections = f"Selections: [{', '.join([str(s) for s in self.
|
96
|
-
|
100
|
+
selections = f"Selections: [{', '.join([str(s) for s in self._selections])}]"
|
101
|
+
transforms = f"Transforms: [{', '.join([str(t) for t in self._transforms])}]"
|
102
|
+
return f"{title}\n{sep}{nt}{selections}{nt}{transforms}{nt}Selected Size: {len(self)}\n\n{self._dataset}"
|
97
103
|
|
98
|
-
def
|
104
|
+
def _sort(self, selections: Selection[_TDatum] | Sequence[Selection[_TDatum]] | None) -> list[Selection]:
|
99
105
|
if not selections:
|
100
106
|
return []
|
101
107
|
|
@@ -106,17 +112,18 @@ class Select(AnnotatedDataset[_TDatum]):
|
|
106
112
|
selection_list = [selection for category in sorted(grouped) for selection in grouped[category]]
|
107
113
|
return selection_list
|
108
114
|
|
109
|
-
def
|
115
|
+
def _select(self) -> None:
|
110
116
|
for selection in self._selections:
|
111
117
|
selection(self)
|
112
118
|
self._selection = self._selection[: self._size_limit]
|
113
119
|
|
114
|
-
def
|
115
|
-
|
116
|
-
|
120
|
+
def _transform(self, datum: _TDatum) -> _TDatum:
|
121
|
+
for t in self._transforms:
|
122
|
+
datum = t(datum)
|
123
|
+
return datum
|
117
124
|
|
118
125
|
def __getitem__(self, index: int) -> _TDatum:
|
119
|
-
return self._dataset[self._selection[index]]
|
126
|
+
return self._transform(self._dataset[self._selection[index]])
|
120
127
|
|
121
128
|
def __iter__(self) -> Iterator[_TDatum]:
|
122
129
|
for i in range(len(self)):
|
dataeval/utils/data/_split.py
CHANGED
@@ -12,6 +12,7 @@ from sklearn.metrics import silhouette_score
|
|
12
12
|
from sklearn.model_selection import GroupKFold, KFold, StratifiedGroupKFold, StratifiedKFold
|
13
13
|
from sklearn.utils.multiclass import type_of_target
|
14
14
|
|
15
|
+
from dataeval.config import get_seed
|
15
16
|
from dataeval.outputs._base import set_metadata
|
16
17
|
from dataeval.outputs._utils import SplitDatasetOutput, TrainValSplit
|
17
18
|
|
@@ -212,9 +213,9 @@ def bin_kmeans(array: NDArray[Any]) -> NDArray[np.intp]:
|
|
212
213
|
best_score = 0.50
|
213
214
|
bin_index = np.zeros(len(array), dtype=np.intp)
|
214
215
|
for k in range(2, 20):
|
215
|
-
clusterer = KMeans(n_clusters=k)
|
216
|
+
clusterer = KMeans(n_clusters=k, random_state=get_seed())
|
216
217
|
cluster_labels = clusterer.fit_predict(array)
|
217
|
-
score = silhouette_score(array, cluster_labels, sample_size=25_000)
|
218
|
+
score = silhouette_score(array, cluster_labels, sample_size=25_000, random_state=get_seed())
|
218
219
|
if score > best_score:
|
219
220
|
best_score = score
|
220
221
|
bin_index = cluster_labels.astype(np.intp)
|
@@ -4,7 +4,7 @@ __all__ = []
|
|
4
4
|
|
5
5
|
from abc import abstractmethod
|
6
6
|
from pathlib import Path
|
7
|
-
from typing import Any, Generic, Iterator, Literal, NamedTuple, Sequence, TypeVar
|
7
|
+
from typing import TYPE_CHECKING, Any, Generic, Iterator, Literal, NamedTuple, Sequence, TypeVar
|
8
8
|
|
9
9
|
from dataeval.utils.data.datasets._fileio import _ensure_exists
|
10
10
|
from dataeval.utils.data.datasets._mixin import BaseDatasetMixin
|
@@ -16,9 +16,11 @@ from dataeval.utils.data.datasets._types import (
|
|
16
16
|
ObjectDetectionTarget,
|
17
17
|
SegmentationDataset,
|
18
18
|
SegmentationTarget,
|
19
|
-
Transform,
|
20
19
|
)
|
21
20
|
|
21
|
+
if TYPE_CHECKING:
|
22
|
+
from dataeval.typing import Transform
|
23
|
+
|
22
24
|
_TArray = TypeVar("_TArray")
|
23
25
|
_TTarget = TypeVar("_TTarget")
|
24
26
|
_TRawTarget = TypeVar("_TRawTarget", list[int], list[str])
|