dataeval 0.82.0__py3-none-any.whl → 0.82.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dataeval/__init__.py +1 -1
- dataeval/config.py +68 -11
- dataeval/detectors/drift/_mmd.py +9 -9
- dataeval/detectors/drift/_torch.py +7 -7
- dataeval/detectors/drift/_uncertainty.py +4 -4
- dataeval/detectors/linters/duplicates.py +3 -3
- dataeval/detectors/linters/outliers.py +3 -3
- dataeval/detectors/ood/ae.py +5 -4
- dataeval/detectors/ood/base.py +2 -2
- dataeval/detectors/ood/metadata_ood_mi.py +4 -6
- dataeval/detectors/ood/mixin.py +1 -1
- dataeval/detectors/ood/vae.py +2 -1
- dataeval/metadata/_distance.py +11 -44
- dataeval/metadata/_ood.py +9 -7
- dataeval/metrics/bias/_balance.py +7 -3
- dataeval/metrics/bias/_diversity.py +3 -0
- dataeval/metrics/bias/_parity.py +2 -0
- dataeval/metrics/stats/_base.py +3 -3
- dataeval/metrics/stats/_boxratiostats.py +1 -1
- dataeval/metrics/stats/_imagestats.py +4 -4
- dataeval/outputs/__init__.py +4 -0
- dataeval/outputs/_base.py +50 -21
- dataeval/outputs/_bias.py +1 -1
- dataeval/outputs/_linters.py +4 -2
- dataeval/outputs/_metadata.py +54 -0
- dataeval/outputs/_stats.py +12 -6
- dataeval/utils/data/_embeddings.py +8 -9
- dataeval/utils/data/_metadata.py +16 -7
- dataeval/utils/data/_selection.py +4 -8
- dataeval/utils/data/_split.py +3 -2
- dataeval/utils/data/selections/_classfilter.py +5 -3
- dataeval/utils/torch/_internal.py +5 -5
- dataeval/utils/torch/trainer.py +8 -8
- {dataeval-0.82.0.dist-info → dataeval-0.82.1.dist-info}/METADATA +1 -1
- {dataeval-0.82.0.dist-info → dataeval-0.82.1.dist-info}/RECORD +37 -36
- {dataeval-0.82.0.dist-info → dataeval-0.82.1.dist-info}/LICENSE.txt +0 -0
- {dataeval-0.82.0.dist-info → dataeval-0.82.1.dist-info}/WHEEL +0 -0
dataeval/metrics/stats/_base.py
CHANGED
@@ -248,13 +248,13 @@ def add_stats(a: TStatsOutput, b: TStatsOutput) -> TStatsOutput:
|
|
248
248
|
if type(a) is not type(b):
|
249
249
|
raise TypeError(f"Types {type(a)} and {type(b)} cannot be added.")
|
250
250
|
|
251
|
-
sum_dict = deepcopy(a.
|
251
|
+
sum_dict = deepcopy(a.data())
|
252
252
|
|
253
253
|
for k in sum_dict:
|
254
254
|
if isinstance(sum_dict[k], list):
|
255
|
-
sum_dict[k].extend(b.
|
255
|
+
sum_dict[k].extend(b.data()[k])
|
256
256
|
else:
|
257
|
-
sum_dict[k] = np.concatenate((sum_dict[k], b.
|
257
|
+
sum_dict[k] = np.concatenate((sum_dict[k], b.data()[k]))
|
258
258
|
|
259
259
|
return type(a)(**sum_dict)
|
260
260
|
|
@@ -153,7 +153,7 @@ def boxratiostats(
|
|
153
153
|
raise ValueError("Input for boxstats and imgstats must have matching channel information.")
|
154
154
|
|
155
155
|
output_dict = {}
|
156
|
-
for key in boxstats.
|
156
|
+
for key in boxstats.data():
|
157
157
|
output_dict[key] = calculate_ratios(key, boxstats, imgstats)
|
158
158
|
|
159
159
|
return output_cls(**output_dict)
|
@@ -42,8 +42,8 @@ def imagestats(
|
|
42
42
|
Calculates various :term:`statistics<Statistics>` for each image.
|
43
43
|
|
44
44
|
This function computes dimension, pixel and visual metrics
|
45
|
-
on the images or individual bounding boxes for each image
|
46
|
-
|
45
|
+
on the images or individual bounding boxes for each image. If
|
46
|
+
performing calculations per channel dimension stats are excluded.
|
47
47
|
|
48
48
|
Parameters
|
49
49
|
----------
|
@@ -61,7 +61,7 @@ def imagestats(
|
|
61
61
|
|
62
62
|
See Also
|
63
63
|
--------
|
64
|
-
dimensionstats,
|
64
|
+
dimensionstats, pixelstats, visualstats
|
65
65
|
|
66
66
|
Examples
|
67
67
|
--------
|
@@ -91,4 +91,4 @@ def imagestats(
|
|
91
91
|
output_cls = ImageStatsOutput
|
92
92
|
|
93
93
|
outputs = run_stats(dataset, per_box, per_channel, processors)
|
94
|
-
return output_cls(**{k: v for d in outputs for k, v in d.
|
94
|
+
return output_cls(**{k: v for d in outputs for k, v in d.data().items()})
|
dataeval/outputs/__init__.py
CHANGED
@@ -8,6 +8,7 @@ from ._bias import BalanceOutput, CoverageOutput, DiversityOutput, LabelParityOu
|
|
8
8
|
from ._drift import DriftMMDOutput, DriftOutput
|
9
9
|
from ._estimators import BEROutput, ClustererOutput, DivergenceOutput, UAPOutput
|
10
10
|
from ._linters import DuplicatesOutput, OutliersOutput
|
11
|
+
from ._metadata import MetadataDistanceOutput, MetadataDistanceValues, MostDeviatedFactorsOutput
|
11
12
|
from ._ood import OODOutput, OODScoreOutput
|
12
13
|
from ._stats import (
|
13
14
|
ChannelStatsOutput,
|
@@ -39,6 +40,9 @@ __all__ = [
|
|
39
40
|
"ImageStatsOutput",
|
40
41
|
"LabelParityOutput",
|
41
42
|
"LabelStatsOutput",
|
43
|
+
"MetadataDistanceOutput",
|
44
|
+
"MetadataDistanceValues",
|
45
|
+
"MostDeviatedFactorsOutput",
|
42
46
|
"OODOutput",
|
43
47
|
"OODScoreOutput",
|
44
48
|
"OutliersOutput",
|
dataeval/outputs/_base.py
CHANGED
@@ -4,11 +4,11 @@ __all__ = []
|
|
4
4
|
|
5
5
|
import inspect
|
6
6
|
import logging
|
7
|
-
from collections.abc import Mapping
|
7
|
+
from collections.abc import Collection, Mapping, Sequence
|
8
8
|
from dataclasses import dataclass
|
9
9
|
from datetime import datetime, timezone
|
10
10
|
from functools import partial, wraps
|
11
|
-
from typing import Any, Callable, Iterator, TypeVar
|
11
|
+
from typing import Any, Callable, Generic, Iterator, TypeVar, overload
|
12
12
|
|
13
13
|
import numpy as np
|
14
14
|
from typing_extensions import ParamSpec
|
@@ -56,16 +56,13 @@ class ExecutionMetadata:
|
|
56
56
|
)
|
57
57
|
|
58
58
|
|
59
|
-
|
60
|
-
_meta: ExecutionMetadata | None = None
|
59
|
+
T = TypeVar("T", covariant=True)
|
61
60
|
|
62
|
-
def __str__(self) -> str:
|
63
|
-
return f"{self.__class__.__name__}: {str(self.dict())}"
|
64
61
|
|
65
|
-
|
66
|
-
|
62
|
+
class GenericOutput(Generic[T]):
|
63
|
+
_meta: ExecutionMetadata | None = None
|
67
64
|
|
68
|
-
|
65
|
+
def data(self) -> T: ...
|
69
66
|
def meta(self) -> ExecutionMetadata:
|
70
67
|
"""
|
71
68
|
Metadata about the execution of the function or method for the Output class.
|
@@ -73,34 +70,66 @@ class Output:
|
|
73
70
|
return self._meta or ExecutionMetadata.empty()
|
74
71
|
|
75
72
|
|
76
|
-
|
77
|
-
|
73
|
+
class Output(GenericOutput[dict[str, Any]]):
|
74
|
+
def data(self) -> dict[str, Any]:
|
75
|
+
return {k: v for k, v in self.__dict__.items() if k != "_meta"}
|
78
76
|
|
77
|
+
def __repr__(self) -> str:
|
78
|
+
return str(self)
|
79
79
|
|
80
|
-
|
80
|
+
def __str__(self) -> str:
|
81
|
+
return f"{self.__class__.__name__}({', '.join([f'{k}={v}' for k, v in self.data().items()])})"
|
82
|
+
|
83
|
+
|
84
|
+
class BaseCollectionMixin(Collection[Any]):
|
81
85
|
__slots__ = ["_data"]
|
82
86
|
|
87
|
+
def data(self) -> Any:
|
88
|
+
return self._data
|
89
|
+
|
90
|
+
def __len__(self) -> int:
|
91
|
+
return len(self._data)
|
92
|
+
|
93
|
+
def __repr__(self) -> str:
|
94
|
+
return f"{self.__class__.__name__}({repr(self._data)})"
|
95
|
+
|
96
|
+
def __str__(self) -> str:
|
97
|
+
return str(self._data)
|
98
|
+
|
99
|
+
|
100
|
+
TKey = TypeVar("TKey", str, int, float, set)
|
101
|
+
TValue = TypeVar("TValue")
|
102
|
+
|
103
|
+
|
104
|
+
class MappingOutput(Mapping[TKey, TValue], BaseCollectionMixin, GenericOutput[Mapping[TKey, TValue]]):
|
83
105
|
def __init__(self, data: Mapping[TKey, TValue]):
|
84
106
|
self._data = data
|
85
107
|
|
86
108
|
def __getitem__(self, key: TKey) -> TValue:
|
87
|
-
return self._data
|
109
|
+
return self._data[key]
|
88
110
|
|
89
111
|
def __iter__(self) -> Iterator[TKey]:
|
90
|
-
return self._data
|
112
|
+
return iter(self._data)
|
91
113
|
|
92
|
-
def __len__(self) -> int:
|
93
|
-
return self._data.__len__()
|
94
114
|
|
95
|
-
|
96
|
-
|
115
|
+
class SequenceOutput(Sequence[TValue], BaseCollectionMixin, GenericOutput[Sequence[TValue]]):
|
116
|
+
def __init__(self, data: Sequence[TValue]):
|
117
|
+
self._data = data
|
118
|
+
|
119
|
+
@overload
|
120
|
+
def __getitem__(self, index: int) -> TValue: ...
|
121
|
+
@overload
|
122
|
+
def __getitem__(self, index: slice) -> Sequence[TValue]: ...
|
97
123
|
|
98
|
-
def
|
99
|
-
return
|
124
|
+
def __getitem__(self, index: int | slice) -> TValue | Sequence[TValue]:
|
125
|
+
return self._data[index]
|
126
|
+
|
127
|
+
def __iter__(self) -> Iterator[TValue]:
|
128
|
+
return iter(self._data)
|
100
129
|
|
101
130
|
|
102
131
|
P = ParamSpec("P")
|
103
|
-
R = TypeVar("R", bound=
|
132
|
+
R = TypeVar("R", bound=GenericOutput)
|
104
133
|
|
105
134
|
|
106
135
|
def set_metadata(fn: Callable[P, R] | None = None, *, state: list[str] | None = None) -> Callable[P, R]:
|
dataeval/outputs/_bias.py
CHANGED
@@ -364,7 +364,7 @@ class DiversityOutput(Output):
|
|
364
364
|
col_labels,
|
365
365
|
xlabel="Factors",
|
366
366
|
ylabel="Class",
|
367
|
-
cbarlabel=f"Normalized {asdict(self.meta)['arguments']['method'].title()} Index",
|
367
|
+
cbarlabel=f"Normalized {asdict(self.meta())['arguments']['method'].title()} Index",
|
368
368
|
)
|
369
369
|
|
370
370
|
else:
|
dataeval/outputs/_linters.py
CHANGED
@@ -24,7 +24,7 @@ TIndexIssueMap = TypeVar("TIndexIssueMap", IndexIssueMap, list[IndexIssueMap])
|
|
24
24
|
|
25
25
|
|
26
26
|
@dataclass(frozen=True)
|
27
|
-
class DuplicatesOutput(Generic[TIndexCollection]
|
27
|
+
class DuplicatesOutput(Output, Generic[TIndexCollection]):
|
28
28
|
"""
|
29
29
|
Output class for :class:`.Duplicates` lint detector.
|
30
30
|
|
@@ -35,6 +35,8 @@ class DuplicatesOutput(Generic[TIndexCollection], Output):
|
|
35
35
|
near: list[list[int] | dict[int, list[int]]]
|
36
36
|
Indices of images that are near matches
|
37
37
|
|
38
|
+
Notes
|
39
|
+
-----
|
38
40
|
- For a single dataset, indices are returned as a list of index groups.
|
39
41
|
- For multiple datasets, indices are returned as dictionaries where the key is the
|
40
42
|
index of the dataset, and the value is the list index groups from that dataset.
|
@@ -99,7 +101,7 @@ def _create_pandas_dataframe(class_wise):
|
|
99
101
|
|
100
102
|
|
101
103
|
@dataclass(frozen=True)
|
102
|
-
class OutliersOutput(Generic[TIndexIssueMap]
|
104
|
+
class OutliersOutput(Output, Generic[TIndexIssueMap]):
|
103
105
|
"""
|
104
106
|
Output class for :class:`.Outliers` lint detector.
|
105
107
|
|
@@ -0,0 +1,54 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
__all__ = []
|
4
|
+
|
5
|
+
from typing import NamedTuple
|
6
|
+
|
7
|
+
from dataeval.outputs._base import MappingOutput, SequenceOutput
|
8
|
+
|
9
|
+
|
10
|
+
class MostDeviatedFactorsOutput(SequenceOutput[tuple[str, float]]):
|
11
|
+
"""
|
12
|
+
Output class for results of :func:`.most_deviated_factors` for OOD samples with metadata.
|
13
|
+
|
14
|
+
Attributes
|
15
|
+
----------
|
16
|
+
value : tuple[str, float]
|
17
|
+
A tuple of the factor name and deviation of the highest metadata deviation
|
18
|
+
"""
|
19
|
+
|
20
|
+
|
21
|
+
class MetadataDistanceValues(NamedTuple):
|
22
|
+
"""
|
23
|
+
Statistics comparing metadata distance.
|
24
|
+
|
25
|
+
Attributes
|
26
|
+
----------
|
27
|
+
statistic : float
|
28
|
+
the KS statistic
|
29
|
+
location : float
|
30
|
+
The value at which the KS statistic has its maximum, measured in IQR-normalized units relative
|
31
|
+
to the median of the reference distribution.
|
32
|
+
dist : float
|
33
|
+
The Earth Mover's Distance normalized by the interquartile range (IQR) of the reference
|
34
|
+
pvalue : float
|
35
|
+
The p-value from the KS two-sample test
|
36
|
+
"""
|
37
|
+
|
38
|
+
statistic: float
|
39
|
+
location: float
|
40
|
+
dist: float
|
41
|
+
pvalue: float
|
42
|
+
|
43
|
+
|
44
|
+
class MetadataDistanceOutput(MappingOutput[str, MetadataDistanceValues]):
|
45
|
+
"""
|
46
|
+
Output class for results of ks_2samp featurewise comparisons of new metadata to reference metadata.
|
47
|
+
|
48
|
+
Attributes
|
49
|
+
----------
|
50
|
+
key : str
|
51
|
+
Metadata feature names
|
52
|
+
value : :class:`.MetadataDistanceValues`
|
53
|
+
Output per feature name containing the statistic, statistic location, distance, and pvalue.
|
54
|
+
"""
|
dataeval/outputs/_stats.py
CHANGED
@@ -4,7 +4,7 @@ __all__ = []
|
|
4
4
|
|
5
5
|
import contextlib
|
6
6
|
from dataclasses import dataclass
|
7
|
-
from typing import Iterable, Optional, Union
|
7
|
+
from typing import Any, Iterable, Optional, Union
|
8
8
|
|
9
9
|
import numpy as np
|
10
10
|
from numpy.typing import NDArray
|
@@ -63,7 +63,7 @@ class BaseStatsOutput(Output):
|
|
63
63
|
|
64
64
|
def __post_init__(self) -> None:
|
65
65
|
length = len(self.source_index)
|
66
|
-
bad = {k: len(v) for k, v in self.
|
66
|
+
bad = {k: len(v) for k, v in self.data().items() if k not in [SOURCE_INDEX, BOX_COUNT] and len(v) != length}
|
67
67
|
if bad:
|
68
68
|
raise ValueError(f"All values must have the same length as source_index. Bad values: {str(bad)}.")
|
69
69
|
|
@@ -105,7 +105,7 @@ class BaseStatsOutput(Output):
|
|
105
105
|
def _get_channels(
|
106
106
|
self, channel_limit: int | None = None, channel_index: int | Iterable[int] | None = None
|
107
107
|
) -> tuple[int, list[bool] | None]:
|
108
|
-
source_index = self.
|
108
|
+
source_index = self.data()[SOURCE_INDEX]
|
109
109
|
raw_channels = int(max([si.channel or 0 for si in source_index])) + 1
|
110
110
|
if isinstance(channel_index, int):
|
111
111
|
max_channels = 1 if channel_index < raw_channels else raw_channels
|
@@ -127,15 +127,21 @@ class BaseStatsOutput(Output):
|
|
127
127
|
|
128
128
|
return max_channels, ch_mask
|
129
129
|
|
130
|
+
def factors(self) -> dict[str, NDArray[Any]]:
|
131
|
+
return {
|
132
|
+
k: v
|
133
|
+
for k, v in self.data().items()
|
134
|
+
if k not in (SOURCE_INDEX, BOX_COUNT) and isinstance(v, np.ndarray) and v[v != 0].size > 0 and v.ndim == 1
|
135
|
+
}
|
136
|
+
|
130
137
|
def plot(
|
131
138
|
self, log: bool, channel_limit: int | None = None, channel_index: int | Iterable[int] | None = None
|
132
139
|
) -> None:
|
133
140
|
max_channels, ch_mask = self._get_channels(channel_limit, channel_index)
|
134
|
-
d = {k: v for k, v in self.dict().items() if isinstance(v, np.ndarray) and v[v != 0].size > 0 and v.ndim == 1}
|
135
141
|
if max_channels == 1:
|
136
|
-
histogram_plot(
|
142
|
+
histogram_plot(self.factors(), log)
|
137
143
|
else:
|
138
|
-
channel_histogram_plot(
|
144
|
+
channel_histogram_plot(self.factors(), log, max_channels, ch_mask)
|
139
145
|
|
140
146
|
|
141
147
|
@dataclass(frozen=True)
|
@@ -9,7 +9,7 @@ import torch
|
|
9
9
|
from torch.utils.data import DataLoader, Subset
|
10
10
|
from tqdm import tqdm
|
11
11
|
|
12
|
-
from dataeval.config import get_device
|
12
|
+
from dataeval.config import DeviceLike, get_device
|
13
13
|
from dataeval.typing import Array, Dataset
|
14
14
|
from dataeval.utils.torch.models import SupportsEncode
|
15
15
|
|
@@ -24,13 +24,14 @@ class Embeddings:
|
|
24
24
|
----------
|
25
25
|
dataset : ImageClassificationDataset or ObjectDetectionDataset
|
26
26
|
Dataset to access original images from.
|
27
|
-
batch_size : int
|
27
|
+
batch_size : int
|
28
28
|
Batch size to use when encoding images.
|
29
|
-
model : torch.nn.Module,
|
29
|
+
model : torch.nn.Module or None, default None
|
30
30
|
Model to use for encoding images.
|
31
|
-
device :
|
32
|
-
|
33
|
-
|
31
|
+
device : DeviceLike or None, default None
|
32
|
+
The hardware device to use if specified, otherwise uses the DataEval
|
33
|
+
default or torch default.
|
34
|
+
verbose : bool, default False
|
34
35
|
Whether to print progress bar when encoding images.
|
35
36
|
"""
|
36
37
|
|
@@ -42,9 +43,8 @@ class Embeddings:
|
|
42
43
|
self,
|
43
44
|
dataset: Dataset[tuple[Array, Any, Any]],
|
44
45
|
batch_size: int,
|
45
|
-
indices: Sequence[int] | None = None,
|
46
46
|
model: torch.nn.Module | None = None,
|
47
|
-
device:
|
47
|
+
device: DeviceLike | None = None,
|
48
48
|
verbose: bool = False,
|
49
49
|
) -> None:
|
50
50
|
self.device = get_device(device)
|
@@ -52,7 +52,6 @@ class Embeddings:
|
|
52
52
|
self.verbose = verbose
|
53
53
|
|
54
54
|
self._dataset = dataset
|
55
|
-
self._indices = indices if indices is not None else range(len(dataset))
|
56
55
|
model = torch.nn.Flatten() if model is None else model
|
57
56
|
self._model = model.to(self.device).eval()
|
58
57
|
self._encoder = model.encode if isinstance(model, SupportsEncode) else model
|
dataeval/utils/data/_metadata.py
CHANGED
@@ -3,7 +3,7 @@ from __future__ import annotations
|
|
3
3
|
__all__ = []
|
4
4
|
|
5
5
|
import warnings
|
6
|
-
from typing import TYPE_CHECKING, Any, Literal, Mapping, Sequence
|
6
|
+
from typing import TYPE_CHECKING, Any, Literal, Mapping, Sequence, cast
|
7
7
|
|
8
8
|
import numpy as np
|
9
9
|
from numpy.typing import NDArray
|
@@ -11,6 +11,7 @@ from numpy.typing import NDArray
|
|
11
11
|
from dataeval.typing import (
|
12
12
|
AnnotatedDataset,
|
13
13
|
Array,
|
14
|
+
ArrayLike,
|
14
15
|
ObjectDetectionTarget,
|
15
16
|
)
|
16
17
|
from dataeval.utils._array import as_numpy, to_numpy
|
@@ -276,16 +277,12 @@ class Metadata:
|
|
276
277
|
if self._processed and not force:
|
277
278
|
return
|
278
279
|
|
279
|
-
#
|
280
|
-
self.
|
281
|
-
self._merge()
|
280
|
+
# Create image indices from targets
|
281
|
+
self._image_indices = np.arange(len(self.raw)) if self.targets.source is None else self.targets.source
|
282
282
|
|
283
283
|
# Validate the metadata dimensions
|
284
284
|
self._validate()
|
285
285
|
|
286
|
-
# Create image indices from targets
|
287
|
-
self._image_indices = np.arange(len(self.raw)) if self.targets.source is None else self.targets.source
|
288
|
-
|
289
286
|
# Include specified metadata keys
|
290
287
|
if self.include:
|
291
288
|
metadata = {i: self.merged[i] for i in self.include if i in self.merged}
|
@@ -358,3 +355,15 @@ class Metadata:
|
|
358
355
|
)
|
359
356
|
self._total_num_factors = len(self._discrete_factor_names + self._continuous_factor_names) + 1
|
360
357
|
self._processed = True
|
358
|
+
|
359
|
+
def add_factors(self, factors: Mapping[str, ArrayLike]) -> None:
|
360
|
+
self._merge()
|
361
|
+
self._processed = False
|
362
|
+
target_len = len(self.targets.source) if self.targets.source is not None else len(self.targets)
|
363
|
+
if any(len(v) != target_len for v in factors.values()):
|
364
|
+
raise ValueError(
|
365
|
+
"The lists/arrays in the provided factors have a different length than the current metadata factors."
|
366
|
+
)
|
367
|
+
merged = cast(tuple[dict[str, ArrayLike], dict[str, list[str]]], self._merged)[0]
|
368
|
+
for k, v in factors.items():
|
369
|
+
merged[k] = v
|
@@ -3,11 +3,11 @@ from __future__ import annotations
|
|
3
3
|
__all__ = []
|
4
4
|
|
5
5
|
from enum import IntEnum
|
6
|
-
from typing import
|
6
|
+
from typing import Generic, Iterator, Sequence, TypeVar
|
7
7
|
|
8
8
|
from dataeval.typing import AnnotatedDataset, DatasetMetadata
|
9
9
|
|
10
|
-
_TDatum = TypeVar("_TDatum")
|
10
|
+
_TDatum = TypeVar("_TDatum", covariant=True)
|
11
11
|
|
12
12
|
|
13
13
|
class SelectionStage(IntEnum):
|
@@ -69,11 +69,11 @@ class Select(AnnotatedDataset[_TDatum]):
|
|
69
69
|
dataset: AnnotatedDataset[_TDatum],
|
70
70
|
selections: Selection[_TDatum] | list[Selection[_TDatum]] | None = None,
|
71
71
|
) -> None:
|
72
|
+
self.__dict__.update(dataset.__dict__)
|
72
73
|
self._dataset = dataset
|
73
74
|
self._size_limit = len(dataset)
|
74
75
|
self._selection = list(range(self._size_limit))
|
75
76
|
self._selections = self._sort_selections(selections)
|
76
|
-
self.__dict__.update(dataset.__dict__)
|
77
77
|
|
78
78
|
# Ensure metadata is populated correctly as DatasetMetadata TypedDict
|
79
79
|
_metadata = getattr(dataset, "metadata", {})
|
@@ -93,7 +93,7 @@ class Select(AnnotatedDataset[_TDatum]):
|
|
93
93
|
title = f"{self.__class__.__name__} Dataset"
|
94
94
|
sep = "-" * len(title)
|
95
95
|
selections = f"Selections: [{', '.join([str(s) for s in self._sort_selections(self._selections)])}]"
|
96
|
-
return f"{title}\n{sep}{nt}{selections}\n\n{self._dataset}"
|
96
|
+
return f"{title}\n{sep}{nt}{selections}{nt}Selected Size: {len(self)}\n\n{self._dataset}"
|
97
97
|
|
98
98
|
def _sort_selections(self, selections: Selection[_TDatum] | Sequence[Selection[_TDatum]] | None) -> list[Selection]:
|
99
99
|
if not selections:
|
@@ -111,10 +111,6 @@ class Select(AnnotatedDataset[_TDatum]):
|
|
111
111
|
selection(self)
|
112
112
|
self._selection = self._selection[: self._size_limit]
|
113
113
|
|
114
|
-
def __getattr__(self, name: str, /) -> Any:
|
115
|
-
selfattr = getattr(self._dataset, name, None)
|
116
|
-
return selfattr if selfattr is not None else getattr(self._dataset, name)
|
117
|
-
|
118
114
|
def __getitem__(self, index: int) -> _TDatum:
|
119
115
|
return self._dataset[self._selection[index]]
|
120
116
|
|
dataeval/utils/data/_split.py
CHANGED
@@ -12,6 +12,7 @@ from sklearn.metrics import silhouette_score
|
|
12
12
|
from sklearn.model_selection import GroupKFold, KFold, StratifiedGroupKFold, StratifiedKFold
|
13
13
|
from sklearn.utils.multiclass import type_of_target
|
14
14
|
|
15
|
+
from dataeval.config import get_seed
|
15
16
|
from dataeval.outputs._base import set_metadata
|
16
17
|
from dataeval.outputs._utils import SplitDatasetOutput, TrainValSplit
|
17
18
|
|
@@ -212,9 +213,9 @@ def bin_kmeans(array: NDArray[Any]) -> NDArray[np.intp]:
|
|
212
213
|
best_score = 0.50
|
213
214
|
bin_index = np.zeros(len(array), dtype=np.intp)
|
214
215
|
for k in range(2, 20):
|
215
|
-
clusterer = KMeans(n_clusters=k)
|
216
|
+
clusterer = KMeans(n_clusters=k, random_state=get_seed())
|
216
217
|
cluster_labels = clusterer.fit_predict(array)
|
217
|
-
score = silhouette_score(array, cluster_labels, sample_size=25_000)
|
218
|
+
score = silhouette_score(array, cluster_labels, sample_size=25_000, random_state=get_seed())
|
218
219
|
if score > best_score:
|
219
220
|
best_score = score
|
220
221
|
bin_index = cluster_labels.astype(np.intp)
|
@@ -2,7 +2,7 @@ from __future__ import annotations
|
|
2
2
|
|
3
3
|
__all__ = []
|
4
4
|
|
5
|
-
from typing import Sequence
|
5
|
+
from typing import Sequence, TypeVar
|
6
6
|
|
7
7
|
import numpy as np
|
8
8
|
|
@@ -10,8 +10,10 @@ from dataeval.typing import Array, ImageClassificationDatum
|
|
10
10
|
from dataeval.utils._array import as_numpy
|
11
11
|
from dataeval.utils.data._selection import Select, Selection, SelectionStage
|
12
12
|
|
13
|
+
TImageClassificationDatum = TypeVar("TImageClassificationDatum", bound=ImageClassificationDatum, covariant=True)
|
13
14
|
|
14
|
-
|
15
|
+
|
16
|
+
class ClassFilter(Selection[TImageClassificationDatum]):
|
15
17
|
"""
|
16
18
|
Filter and balance the dataset by class.
|
17
19
|
|
@@ -34,7 +36,7 @@ class ClassFilter(Selection[ImageClassificationDatum]):
|
|
34
36
|
self.classes = classes
|
35
37
|
self.balance = balance
|
36
38
|
|
37
|
-
def __call__(self, dataset: Select[
|
39
|
+
def __call__(self, dataset: Select[TImageClassificationDatum]) -> None:
|
38
40
|
if self.classes is None and not self.balance:
|
39
41
|
return
|
40
42
|
|
@@ -11,13 +11,13 @@ from numpy.typing import NDArray
|
|
11
11
|
from torch.utils.data import DataLoader, TensorDataset
|
12
12
|
from tqdm import tqdm
|
13
13
|
|
14
|
-
from dataeval.config import get_device
|
14
|
+
from dataeval.config import DeviceLike, get_device
|
15
15
|
|
16
16
|
|
17
17
|
def predict_batch(
|
18
18
|
x: NDArray[Any] | torch.Tensor,
|
19
19
|
model: Callable | torch.nn.Module | torch.nn.Sequential,
|
20
|
-
device:
|
20
|
+
device: DeviceLike | None = None,
|
21
21
|
batch_size: int = int(1e10),
|
22
22
|
preprocess_fn: Callable[[torch.Tensor], torch.Tensor] | None = None,
|
23
23
|
dtype: type[np.generic] | torch.dtype = np.float32,
|
@@ -31,9 +31,9 @@ def predict_batch(
|
|
31
31
|
Batch of instances.
|
32
32
|
model : Callable | nn.Module | nn.Sequential
|
33
33
|
PyTorch model.
|
34
|
-
device :
|
35
|
-
|
36
|
-
|
34
|
+
device : DeviceLike or None, default None
|
35
|
+
The hardware device to use if specified, otherwise uses the DataEval
|
36
|
+
default or torch default.
|
37
37
|
batch_size : int, default 1e10
|
38
38
|
Batch size used during prediction.
|
39
39
|
preprocess_fn : Callable | None, default None
|
dataeval/utils/torch/trainer.py
CHANGED
@@ -2,6 +2,8 @@
|
|
2
2
|
|
3
3
|
from __future__ import annotations
|
4
4
|
|
5
|
+
from dataeval.config import DeviceLike, get_device
|
6
|
+
|
5
7
|
__all__ = ["AETrainer"]
|
6
8
|
|
7
9
|
from typing import Any
|
@@ -25,9 +27,9 @@ class AETrainer:
|
|
25
27
|
----------
|
26
28
|
model : nn.Module
|
27
29
|
The model to be trained.
|
28
|
-
device :
|
29
|
-
The hardware device to use
|
30
|
-
|
30
|
+
device : DeviceLike or None, default None
|
31
|
+
The hardware device to use if specified, otherwise uses the DataEval
|
32
|
+
default or torch default.
|
31
33
|
batch_size : int, default 8
|
32
34
|
The number of images to process in a batch.
|
33
35
|
"""
|
@@ -35,13 +37,11 @@ class AETrainer:
|
|
35
37
|
def __init__(
|
36
38
|
self,
|
37
39
|
model: nn.Module,
|
38
|
-
device:
|
40
|
+
device: DeviceLike | None = None,
|
39
41
|
batch_size: int = 8,
|
40
42
|
):
|
41
|
-
|
42
|
-
|
43
|
-
self.device: torch.device = torch.device(device)
|
44
|
-
self.model: nn.Module = model.to(device)
|
43
|
+
self.device: torch.device = get_device(device)
|
44
|
+
self.model: nn.Module = model.to(self.device)
|
45
45
|
self.batch_size = batch_size
|
46
46
|
|
47
47
|
def train(self, dataset: Dataset[Any], epochs: int = 25) -> list[float]:
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: dataeval
|
3
|
-
Version: 0.82.
|
3
|
+
Version: 0.82.1
|
4
4
|
Summary: DataEval provides a simple interface to characterize image data and its impact on model performance across classification and object-detection tasks
|
5
5
|
Home-page: https://dataeval.ai/
|
6
6
|
License: MIT
|