dataeval 0.82.1__py3-none-any.whl → 0.84.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dataeval/__init__.py +7 -2
- dataeval/config.py +13 -3
- dataeval/metadata/__init__.py +2 -2
- dataeval/metadata/_ood.py +144 -27
- dataeval/metrics/bias/__init__.py +11 -1
- dataeval/metrics/bias/_balance.py +3 -3
- dataeval/metrics/bias/_completeness.py +130 -0
- dataeval/metrics/estimators/_ber.py +2 -1
- dataeval/metrics/stats/_base.py +31 -36
- dataeval/metrics/stats/_dimensionstats.py +2 -2
- dataeval/metrics/stats/_hashstats.py +2 -2
- dataeval/metrics/stats/_imagestats.py +4 -4
- dataeval/metrics/stats/_labelstats.py +4 -45
- dataeval/metrics/stats/_pixelstats.py +2 -2
- dataeval/metrics/stats/_visualstats.py +2 -2
- dataeval/outputs/__init__.py +4 -2
- dataeval/outputs/_bias.py +31 -22
- dataeval/outputs/_metadata.py +7 -0
- dataeval/outputs/_stats.py +2 -3
- dataeval/typing.py +43 -12
- dataeval/utils/_array.py +26 -1
- dataeval/utils/_mst.py +1 -2
- dataeval/utils/data/_dataset.py +2 -0
- dataeval/utils/data/_embeddings.py +115 -32
- dataeval/utils/data/_images.py +38 -15
- dataeval/utils/data/_selection.py +7 -8
- dataeval/utils/data/_split.py +76 -129
- dataeval/utils/data/datasets/_base.py +4 -2
- dataeval/utils/data/datasets/_cifar10.py +17 -9
- dataeval/utils/data/datasets/_milco.py +18 -12
- dataeval/utils/data/datasets/_mnist.py +24 -8
- dataeval/utils/data/datasets/_ships.py +18 -8
- dataeval/utils/data/datasets/_types.py +1 -5
- dataeval/utils/data/datasets/_voc.py +47 -24
- dataeval/utils/data/selections/__init__.py +2 -0
- dataeval/utils/data/selections/_classfilter.py +1 -1
- dataeval/utils/data/selections/_prioritize.py +296 -0
- dataeval/utils/data/selections/_shuffle.py +13 -4
- dataeval/utils/metadata.py +1 -1
- dataeval/utils/torch/_gmm.py +3 -2
- {dataeval-0.82.1.dist-info → dataeval-0.84.0.dist-info}/METADATA +4 -4
- {dataeval-0.82.1.dist-info → dataeval-0.84.0.dist-info}/RECORD +44 -43
- dataeval/detectors/ood/metadata_ood_mi.py +0 -91
- {dataeval-0.82.1.dist-info → dataeval-0.84.0.dist-info}/LICENSE.txt +0 -0
- {dataeval-0.82.1.dist-info → dataeval-0.84.0.dist-info}/WHEEL +0 -0
dataeval/metrics/stats/_base.py
CHANGED
@@ -2,6 +2,7 @@ from __future__ import annotations
|
|
2
2
|
|
3
3
|
__all__ = []
|
4
4
|
|
5
|
+
import math
|
5
6
|
import re
|
6
7
|
import warnings
|
7
8
|
from collections import ChainMap
|
@@ -17,26 +18,13 @@ from numpy.typing import NDArray
|
|
17
18
|
|
18
19
|
from dataeval.config import get_max_processes
|
19
20
|
from dataeval.outputs._stats import BaseStatsOutput, SourceIndex
|
20
|
-
from dataeval.typing import ArrayLike, Dataset, ObjectDetectionTarget
|
21
|
+
from dataeval.typing import Array, ArrayLike, Dataset, ObjectDetectionTarget
|
21
22
|
from dataeval.utils._array import to_numpy
|
22
23
|
from dataeval.utils._image import normalize_image_shape, rescale
|
23
24
|
|
24
25
|
DTYPE_REGEX = re.compile(r"NDArray\[np\.(.*?)\]")
|
25
26
|
|
26
|
-
|
27
|
-
def normalize_box_shape(bounding_box: NDArray[Any]) -> NDArray[Any]:
|
28
|
-
"""
|
29
|
-
Normalizes the bounding box shape into (N,4).
|
30
|
-
"""
|
31
|
-
ndim = bounding_box.ndim
|
32
|
-
if ndim == 1:
|
33
|
-
return np.expand_dims(bounding_box, axis=0)
|
34
|
-
elif ndim > 2:
|
35
|
-
raise ValueError("Bounding boxes must have 2 dimensions: (# of boxes in an image, [X,Y,W,H]) -> (N,4)")
|
36
|
-
else:
|
37
|
-
return bounding_box
|
38
|
-
|
39
|
-
|
27
|
+
BoundingBox = tuple[float, float, float, float]
|
40
28
|
TStatsOutput = TypeVar("TStatsOutput", bound=BaseStatsOutput, covariant=True)
|
41
29
|
|
42
30
|
|
@@ -46,11 +34,15 @@ class StatsProcessor(Generic[TStatsOutput]):
|
|
46
34
|
image_function_map: dict[str, Callable[[StatsProcessor[TStatsOutput]], Any]] = {}
|
47
35
|
channel_function_map: dict[str, Callable[[StatsProcessor[TStatsOutput]], Any]] = {}
|
48
36
|
|
49
|
-
def __init__(self, image: NDArray[Any], box:
|
37
|
+
def __init__(self, image: NDArray[Any], box: BoundingBox | None, per_channel: bool) -> None:
|
50
38
|
self.raw = image
|
51
39
|
self.width: int = image.shape[-1]
|
52
40
|
self.height: int = image.shape[-2]
|
53
|
-
|
41
|
+
box = BoundingBox((0, 0, self.width, self.height)) if box is None else box
|
42
|
+
# Clip the bounding box to image
|
43
|
+
x0, y0 = (min(j, max(0, math.floor(box[i]))) for i, j in zip((0, 1), (self.width - 1, self.height - 1)))
|
44
|
+
x1, y1 = (min(j, max(1, math.ceil(box[i]))) for i, j in zip((2, 3), (self.width, self.height)))
|
45
|
+
self.box: NDArray[np.int64] = np.array([x0, y0, x1, y1], dtype=np.int64)
|
54
46
|
self._per_channel = per_channel
|
55
47
|
self._image = None
|
56
48
|
self._shape = None
|
@@ -122,22 +114,17 @@ class StatsProcessorOutput:
|
|
122
114
|
|
123
115
|
def process_stats(
|
124
116
|
i: int,
|
125
|
-
|
126
|
-
|
117
|
+
image: ArrayLike,
|
118
|
+
boxes: list[BoundingBox] | None,
|
127
119
|
per_channel: bool,
|
128
120
|
stats_processor_cls: Iterable[type[StatsProcessor[TStatsOutput]]],
|
129
121
|
) -> StatsProcessorOutput:
|
130
|
-
|
131
|
-
image, target = (to_numpy(cast(ArrayLike, data[0])), data[1]) if isinstance(data, tuple) else (to_numpy(data), None)
|
132
|
-
target = None if not isinstance(target, ObjectDetectionTarget) else target
|
133
|
-
boxes = to_numpy(target.boxes) if target is not None else None
|
122
|
+
image = to_numpy(image)
|
134
123
|
results_list: list[dict[str, Any]] = []
|
135
124
|
source_indices: list[SourceIndex] = []
|
136
125
|
box_counts: list[int] = []
|
137
126
|
warnings_list: list[str] = []
|
138
|
-
|
139
|
-
for i_b, box in enumerate(nboxes):
|
140
|
-
i_b = None if box is None else i_b
|
127
|
+
for i_b, box in [(None, None)] if boxes is None else enumerate(boxes):
|
141
128
|
processor_list = [p(image, box, per_channel) for p in stats_processor_cls]
|
142
129
|
if any(not p._is_valid_slice for p in processor_list) and i_b is not None and box is not None:
|
143
130
|
warnings_list.append(f"Bounding box [{i}][{i_b}]: {box} is out of bounds of {image.shape}.")
|
@@ -151,17 +138,28 @@ def process_stats(
|
|
151
138
|
|
152
139
|
|
153
140
|
def process_stats_unpack(
|
154
|
-
|
155
|
-
dataset: Dataset[ArrayLike] | Dataset[tuple[ArrayLike, Any, Any]],
|
156
|
-
per_box: bool,
|
141
|
+
args: tuple[int, Array, list[BoundingBox] | None],
|
157
142
|
per_channel: bool,
|
158
143
|
stats_processor_cls: Iterable[type[StatsProcessor[TStatsOutput]]],
|
159
144
|
) -> StatsProcessorOutput:
|
160
|
-
return process_stats(
|
145
|
+
return process_stats(*args, per_channel=per_channel, stats_processor_cls=stats_processor_cls)
|
146
|
+
|
147
|
+
|
148
|
+
def _enumerate(dataset: Dataset[Array] | Dataset[tuple[Array, Any, Any]], per_box: bool):
|
149
|
+
for i in range(len(dataset)):
|
150
|
+
d = dataset[i]
|
151
|
+
image = d[0] if isinstance(d, tuple) else d
|
152
|
+
if per_box and isinstance(d, tuple) and isinstance(d[1], ObjectDetectionTarget):
|
153
|
+
boxes = cast(Array, d[1].boxes)
|
154
|
+
target = [BoundingBox(float(box[i]) for i in range(4)) for box in boxes]
|
155
|
+
else:
|
156
|
+
target = None
|
157
|
+
|
158
|
+
yield i, image, target
|
161
159
|
|
162
160
|
|
163
161
|
def run_stats(
|
164
|
-
dataset: Dataset[
|
162
|
+
dataset: Dataset[Array] | Dataset[tuple[Array, Any, Any]],
|
165
163
|
per_box: bool,
|
166
164
|
per_channel: bool,
|
167
165
|
stats_processor_cls: Iterable[type[StatsProcessor[TStatsOutput]]],
|
@@ -175,7 +173,7 @@ def run_stats(
|
|
175
173
|
|
176
174
|
Parameters
|
177
175
|
----------
|
178
|
-
data : Dataset[
|
176
|
+
data : Dataset[Array] | Dataset[tuple[Array, Any, Any]]
|
179
177
|
A dataset of images and targets to compute statistics on.
|
180
178
|
per_box : bool
|
181
179
|
A flag which determines if the statistics should be evaluated on a per-box basis or not.
|
@@ -206,18 +204,15 @@ def run_stats(
|
|
206
204
|
warning_list = []
|
207
205
|
stats_processor_cls = stats_processor_cls if isinstance(stats_processor_cls, Iterable) else [stats_processor_cls]
|
208
206
|
|
209
|
-
# TODO: Introduce global controls for CPU job parallelism and GPU configurations
|
210
207
|
with Pool(processes=get_max_processes()) as p:
|
211
208
|
for r in tqdm.tqdm(
|
212
209
|
p.imap(
|
213
210
|
partial(
|
214
211
|
process_stats_unpack,
|
215
|
-
dataset=dataset,
|
216
|
-
per_box=per_box,
|
217
212
|
per_channel=per_channel,
|
218
213
|
stats_processor_cls=stats_processor_cls,
|
219
214
|
),
|
220
|
-
|
215
|
+
_enumerate(dataset, per_box),
|
221
216
|
),
|
222
217
|
total=len(dataset),
|
223
218
|
):
|
@@ -9,7 +9,7 @@ import numpy as np
|
|
9
9
|
from dataeval.metrics.stats._base import StatsProcessor, run_stats
|
10
10
|
from dataeval.outputs import DimensionStatsOutput
|
11
11
|
from dataeval.outputs._base import set_metadata
|
12
|
-
from dataeval.typing import
|
12
|
+
from dataeval.typing import Array, Dataset
|
13
13
|
from dataeval.utils._image import get_bitdepth
|
14
14
|
|
15
15
|
|
@@ -34,7 +34,7 @@ class DimensionStatsProcessor(StatsProcessor[DimensionStatsOutput]):
|
|
34
34
|
|
35
35
|
@set_metadata
|
36
36
|
def dimensionstats(
|
37
|
-
dataset: Dataset[
|
37
|
+
dataset: Dataset[Array] | Dataset[tuple[Array, Any, Any]],
|
38
38
|
*,
|
39
39
|
per_box: bool = False,
|
40
40
|
) -> DimensionStatsOutput:
|
@@ -14,7 +14,7 @@ from scipy.fftpack import dct
|
|
14
14
|
from dataeval.metrics.stats._base import StatsProcessor, run_stats
|
15
15
|
from dataeval.outputs import HashStatsOutput
|
16
16
|
from dataeval.outputs._base import set_metadata
|
17
|
-
from dataeval.typing import ArrayLike, Dataset
|
17
|
+
from dataeval.typing import Array, ArrayLike, Dataset
|
18
18
|
from dataeval.utils._array import as_numpy
|
19
19
|
from dataeval.utils._image import normalize_image_shape, rescale
|
20
20
|
|
@@ -105,7 +105,7 @@ class HashStatsProcessor(StatsProcessor[HashStatsOutput]):
|
|
105
105
|
|
106
106
|
@set_metadata
|
107
107
|
def hashstats(
|
108
|
-
dataset: Dataset[
|
108
|
+
dataset: Dataset[Array] | Dataset[tuple[Array, Any, Any]],
|
109
109
|
*,
|
110
110
|
per_box: bool = False,
|
111
111
|
) -> HashStatsOutput:
|
@@ -10,12 +10,12 @@ from dataeval.metrics.stats._pixelstats import PixelStatsProcessor
|
|
10
10
|
from dataeval.metrics.stats._visualstats import VisualStatsProcessor
|
11
11
|
from dataeval.outputs import ChannelStatsOutput, ImageStatsOutput
|
12
12
|
from dataeval.outputs._base import set_metadata
|
13
|
-
from dataeval.typing import
|
13
|
+
from dataeval.typing import Array, Dataset
|
14
14
|
|
15
15
|
|
16
16
|
@overload
|
17
17
|
def imagestats(
|
18
|
-
dataset: Dataset[
|
18
|
+
dataset: Dataset[Array] | Dataset[tuple[Array, Any, Any]],
|
19
19
|
*,
|
20
20
|
per_box: bool = False,
|
21
21
|
per_channel: Literal[True],
|
@@ -24,7 +24,7 @@ def imagestats(
|
|
24
24
|
|
25
25
|
@overload
|
26
26
|
def imagestats(
|
27
|
-
dataset: Dataset[
|
27
|
+
dataset: Dataset[Array] | Dataset[tuple[Array, Any, Any]],
|
28
28
|
*,
|
29
29
|
per_box: bool = False,
|
30
30
|
per_channel: Literal[False] = False,
|
@@ -33,7 +33,7 @@ def imagestats(
|
|
33
33
|
|
34
34
|
@set_metadata
|
35
35
|
def imagestats(
|
36
|
-
dataset: Dataset[
|
36
|
+
dataset: Dataset[Array] | Dataset[tuple[Array, Any, Any]],
|
37
37
|
*,
|
38
38
|
per_box: bool = False,
|
39
39
|
per_channel: bool = False,
|
@@ -5,54 +5,16 @@ __all__ = []
|
|
5
5
|
from collections import Counter, defaultdict
|
6
6
|
from typing import Any, Mapping, TypeVar
|
7
7
|
|
8
|
-
import numpy as np
|
9
|
-
|
10
8
|
from dataeval.outputs import LabelStatsOutput
|
11
9
|
from dataeval.outputs._base import set_metadata
|
12
|
-
from dataeval.typing import AnnotatedDataset
|
13
|
-
from dataeval.utils._array import as_numpy
|
10
|
+
from dataeval.typing import AnnotatedDataset
|
14
11
|
from dataeval.utils.data._metadata import Metadata
|
15
12
|
|
16
13
|
TValue = TypeVar("TValue")
|
17
14
|
|
18
15
|
|
19
|
-
def _ensure_2d(labels: ArrayLike) -> ArrayLike:
|
20
|
-
if isinstance(labels, np.ndarray):
|
21
|
-
return labels[:, None]
|
22
|
-
else:
|
23
|
-
return [[lbl] for lbl in labels] # type: ignore
|
24
|
-
|
25
|
-
|
26
|
-
def _get_list_depth(lst):
|
27
|
-
if isinstance(lst, list) and lst:
|
28
|
-
return 1 + max(_get_list_depth(item) for item in lst)
|
29
|
-
return 0
|
30
|
-
|
31
|
-
|
32
|
-
def _check_labels_dimension(labels: ArrayLike) -> ArrayLike:
|
33
|
-
# Check for nested lists beyond 2 levels
|
34
|
-
|
35
|
-
if isinstance(labels, np.ndarray):
|
36
|
-
if labels.ndim == 1:
|
37
|
-
return _ensure_2d(labels)
|
38
|
-
elif labels.ndim == 2:
|
39
|
-
return labels
|
40
|
-
else:
|
41
|
-
raise ValueError("The label array must not have more than 2 dimensions.")
|
42
|
-
elif isinstance(labels, list):
|
43
|
-
depth = _get_list_depth(labels)
|
44
|
-
if depth == 1:
|
45
|
-
return _ensure_2d(labels)
|
46
|
-
elif depth == 2:
|
47
|
-
return labels
|
48
|
-
else:
|
49
|
-
raise ValueError("The label list must not be empty or have more than 2 levels of nesting.")
|
50
|
-
else:
|
51
|
-
raise TypeError("Labels must be either a NumPy array or a list.")
|
52
|
-
|
53
|
-
|
54
16
|
def _sort_to_list(d: Mapping[int, TValue]) -> list[TValue]:
|
55
|
-
return [
|
17
|
+
return [t[1] for t in sorted(d.items())]
|
56
18
|
|
57
19
|
|
58
20
|
@set_metadata
|
@@ -98,12 +60,9 @@ def labelstats(dataset: Metadata | AnnotatedDataset[Any]) -> LabelStatsOutput:
|
|
98
60
|
label_per_image: list[int] = []
|
99
61
|
|
100
62
|
index2label = dict(enumerate(dataset.class_names))
|
101
|
-
labels = [target.labels.tolist() for target in dataset.targets]
|
102
|
-
|
103
|
-
labels_2d = _check_labels_dimension(labels)
|
104
63
|
|
105
|
-
for i,
|
106
|
-
group =
|
64
|
+
for i, target in enumerate(dataset.targets):
|
65
|
+
group = target.labels.tolist()
|
107
66
|
|
108
67
|
# Count occurrences of each label in all sublists
|
109
68
|
label_counts.update(group)
|
@@ -10,7 +10,7 @@ from scipy.stats import entropy, kurtosis, skew
|
|
10
10
|
from dataeval.metrics.stats._base import StatsProcessor, run_stats
|
11
11
|
from dataeval.outputs import PixelStatsOutput
|
12
12
|
from dataeval.outputs._base import set_metadata
|
13
|
-
from dataeval.typing import
|
13
|
+
from dataeval.typing import Array, Dataset
|
14
14
|
|
15
15
|
|
16
16
|
class PixelStatsProcessor(StatsProcessor[PixelStatsOutput]):
|
@@ -37,7 +37,7 @@ class PixelStatsProcessor(StatsProcessor[PixelStatsOutput]):
|
|
37
37
|
|
38
38
|
@set_metadata
|
39
39
|
def pixelstats(
|
40
|
-
dataset: Dataset[
|
40
|
+
dataset: Dataset[Array] | Dataset[tuple[Array, Any, Any]],
|
41
41
|
*,
|
42
42
|
per_box: bool = False,
|
43
43
|
per_channel: bool = False,
|
@@ -9,7 +9,7 @@ import numpy as np
|
|
9
9
|
from dataeval.metrics.stats._base import StatsProcessor, run_stats
|
10
10
|
from dataeval.outputs import VisualStatsOutput
|
11
11
|
from dataeval.outputs._base import set_metadata
|
12
|
-
from dataeval.typing import
|
12
|
+
from dataeval.typing import Array, Dataset
|
13
13
|
from dataeval.utils._image import edge_filter
|
14
14
|
|
15
15
|
QUARTILES = (0, 25, 50, 75, 100)
|
@@ -44,7 +44,7 @@ class VisualStatsProcessor(StatsProcessor[VisualStatsOutput]):
|
|
44
44
|
|
45
45
|
@set_metadata
|
46
46
|
def visualstats(
|
47
|
-
dataset: Dataset[
|
47
|
+
dataset: Dataset[Array] | Dataset[tuple[Array, Any, Any]],
|
48
48
|
*,
|
49
49
|
per_box: bool = False,
|
50
50
|
per_channel: bool = False,
|
dataeval/outputs/__init__.py
CHANGED
@@ -4,11 +4,11 @@ as well as runtime metadata for reproducibility and logging.
|
|
4
4
|
"""
|
5
5
|
|
6
6
|
from ._base import ExecutionMetadata
|
7
|
-
from ._bias import BalanceOutput, CoverageOutput, DiversityOutput, LabelParityOutput, ParityOutput
|
7
|
+
from ._bias import BalanceOutput, CompletenessOutput, CoverageOutput, DiversityOutput, LabelParityOutput, ParityOutput
|
8
8
|
from ._drift import DriftMMDOutput, DriftOutput
|
9
9
|
from ._estimators import BEROutput, ClustererOutput, DivergenceOutput, UAPOutput
|
10
10
|
from ._linters import DuplicatesOutput, OutliersOutput
|
11
|
-
from ._metadata import MetadataDistanceOutput, MetadataDistanceValues, MostDeviatedFactorsOutput
|
11
|
+
from ._metadata import MetadataDistanceOutput, MetadataDistanceValues, MostDeviatedFactorsOutput, OODPredictorOutput
|
12
12
|
from ._ood import OODOutput, OODScoreOutput
|
13
13
|
from ._stats import (
|
14
14
|
ChannelStatsOutput,
|
@@ -29,6 +29,7 @@ __all__ = [
|
|
29
29
|
"ChannelStatsOutput",
|
30
30
|
"ClustererOutput",
|
31
31
|
"CoverageOutput",
|
32
|
+
"CompletenessOutput",
|
32
33
|
"DimensionStatsOutput",
|
33
34
|
"DivergenceOutput",
|
34
35
|
"DiversityOutput",
|
@@ -44,6 +45,7 @@ __all__ = [
|
|
44
45
|
"MetadataDistanceValues",
|
45
46
|
"MostDeviatedFactorsOutput",
|
46
47
|
"OODOutput",
|
48
|
+
"OODPredictorOutput",
|
47
49
|
"OODScoreOutput",
|
48
50
|
"OutliersOutput",
|
49
51
|
"ParityOutput",
|
dataeval/outputs/_bias.py
CHANGED
@@ -14,9 +14,10 @@ with contextlib.suppress(ImportError):
|
|
14
14
|
from matplotlib.figure import Figure
|
15
15
|
|
16
16
|
from dataeval.outputs._base import Output
|
17
|
-
from dataeval.typing import ArrayLike
|
18
|
-
from dataeval.utils._array import
|
17
|
+
from dataeval.typing import ArrayLike, Dataset
|
18
|
+
from dataeval.utils._array import as_numpy, channels_first_to_last
|
19
19
|
from dataeval.utils._plot import heatmap
|
20
|
+
from dataeval.utils.data._images import Images
|
20
21
|
|
21
22
|
TData = TypeVar("TData", np.float64, NDArray[np.float64])
|
22
23
|
|
@@ -107,13 +108,13 @@ class CoverageOutput(Output):
|
|
107
108
|
critical_value_radii: NDArray[np.float64]
|
108
109
|
coverage_radius: float
|
109
110
|
|
110
|
-
def plot(self, images:
|
111
|
+
def plot(self, images: Images[Any] | Dataset[Any], top_k: int = 6) -> Figure:
|
111
112
|
"""
|
112
113
|
Plot the top k images together for visualization.
|
113
114
|
|
114
115
|
Parameters
|
115
116
|
----------
|
116
|
-
images :
|
117
|
+
images : Images or Dataset
|
117
118
|
Original images (not embeddings) in (N, C, H, W) or (N, H, W) format
|
118
119
|
top_k : int, default 6
|
119
120
|
Number of images to plot (plotting assumes groups of 3)
|
@@ -130,46 +131,54 @@ class CoverageOutput(Output):
|
|
130
131
|
import matplotlib.pyplot as plt
|
131
132
|
|
132
133
|
# Determine which images to plot
|
133
|
-
|
134
|
+
selected_indices = self.uncovered_indices[:top_k]
|
134
135
|
|
135
|
-
|
136
|
-
selected_images = to_numpy(images)[highest_uncovered_indices]
|
136
|
+
images = Images(images) if isinstance(images, Dataset) else images
|
137
137
|
|
138
138
|
# Plot the images
|
139
|
-
num_images = min(top_k, len(
|
140
|
-
|
141
|
-
ndim = selected_images.ndim
|
142
|
-
if ndim == 4:
|
143
|
-
selected_images = np.moveaxis(selected_images, 1, -1)
|
144
|
-
elif ndim == 3:
|
145
|
-
selected_images = np.repeat(selected_images[:, :, :, np.newaxis], 3, axis=-1)
|
146
|
-
else:
|
147
|
-
raise ValueError(
|
148
|
-
f"Expected a (N,C,H,W) or a (N, H, W) set of images, but got a {ndim}-dimensional set of images."
|
149
|
-
)
|
139
|
+
num_images = min(top_k, len(selected_indices))
|
150
140
|
|
151
141
|
rows = int(np.ceil(num_images / 3))
|
152
142
|
fig, axs = plt.subplots(rows, 3, figsize=(9, 3 * rows))
|
153
143
|
|
154
144
|
if rows == 1:
|
155
145
|
for j in range(3):
|
156
|
-
if j >= len(
|
146
|
+
if j >= len(selected_indices):
|
157
147
|
continue
|
158
|
-
|
148
|
+
image = channels_first_to_last(as_numpy(images[selected_indices[j]]))
|
149
|
+
axs[j].imshow(image)
|
159
150
|
axs[j].axis("off")
|
160
151
|
else:
|
161
152
|
for i in range(rows):
|
162
153
|
for j in range(3):
|
163
154
|
i_j = i * 3 + j
|
164
|
-
if i_j >= len(
|
155
|
+
if i_j >= len(selected_indices):
|
165
156
|
continue
|
166
|
-
|
157
|
+
image = channels_first_to_last(as_numpy(images[selected_indices[i_j]]))
|
158
|
+
axs[i, j].imshow(image)
|
167
159
|
axs[i, j].axis("off")
|
168
160
|
|
169
161
|
fig.tight_layout()
|
170
162
|
return fig
|
171
163
|
|
172
164
|
|
165
|
+
@dataclass(frozen=True)
|
166
|
+
class CompletenessOutput(Output):
|
167
|
+
"""
|
168
|
+
Output from the completeness function.
|
169
|
+
|
170
|
+
Attributes
|
171
|
+
----------
|
172
|
+
fraction_filled : float
|
173
|
+
Fraction of boxes that contain at least one data point
|
174
|
+
empty_box_centers : List[np.ndarray]
|
175
|
+
List of coordinates for centers of empty boxes
|
176
|
+
"""
|
177
|
+
|
178
|
+
fraction_filled: float
|
179
|
+
empty_box_centers: NDArray[np.float64]
|
180
|
+
|
181
|
+
|
173
182
|
@dataclass(frozen=True)
|
174
183
|
class BalanceOutput(Output):
|
175
184
|
"""
|
dataeval/outputs/_metadata.py
CHANGED
@@ -52,3 +52,10 @@ class MetadataDistanceOutput(MappingOutput[str, MetadataDistanceValues]):
|
|
52
52
|
value : :class:`.MetadataDistanceValues`
|
53
53
|
Output per feature name containing the statistic, statistic location, distance, and pvalue.
|
54
54
|
"""
|
55
|
+
|
56
|
+
|
57
|
+
class OODPredictorOutput(MappingOutput[str, float]):
|
58
|
+
"""
|
59
|
+
Output class for results of :func:`find_ood_predictors` for the
|
60
|
+
mutual information between factors and being out of distribution
|
61
|
+
"""
|
dataeval/outputs/_stats.py
CHANGED
@@ -4,7 +4,7 @@ __all__ = []
|
|
4
4
|
|
5
5
|
import contextlib
|
6
6
|
from dataclasses import dataclass
|
7
|
-
from typing import Any, Iterable, Optional, Union
|
7
|
+
from typing import Any, Iterable, NamedTuple, Optional, Union
|
8
8
|
|
9
9
|
import numpy as np
|
10
10
|
from numpy.typing import NDArray
|
@@ -22,8 +22,7 @@ SOURCE_INDEX = "source_index"
|
|
22
22
|
BOX_COUNT = "box_count"
|
23
23
|
|
24
24
|
|
25
|
-
|
26
|
-
class SourceIndex:
|
25
|
+
class SourceIndex(NamedTuple):
|
27
26
|
"""
|
28
27
|
The indices of the source image, box and channel.
|
29
28
|
|
dataeval/typing.py
CHANGED
@@ -1,5 +1,5 @@
|
|
1
1
|
"""
|
2
|
-
Common type
|
2
|
+
Common type protocols used for interoperability with DataEval.
|
3
3
|
"""
|
4
4
|
|
5
5
|
__all__ = [
|
@@ -16,13 +16,14 @@ __all__ = [
|
|
16
16
|
"SegmentationTarget",
|
17
17
|
"SegmentationDatum",
|
18
18
|
"SegmentationDataset",
|
19
|
+
"Transform",
|
19
20
|
]
|
20
21
|
|
21
22
|
|
22
23
|
import sys
|
23
24
|
from typing import Any, Generic, Iterator, Protocol, Sequence, TypedDict, TypeVar, Union, runtime_checkable
|
24
25
|
|
25
|
-
from typing_extensions import NotRequired, Required
|
26
|
+
from typing_extensions import NotRequired, ReadOnly, Required
|
26
27
|
|
27
28
|
if sys.version_info >= (3, 10):
|
28
29
|
from typing import TypeAlias
|
@@ -66,6 +67,7 @@ class Array(Protocol):
|
|
66
67
|
def __len__(self) -> int: ...
|
67
68
|
|
68
69
|
|
70
|
+
T = TypeVar("T")
|
69
71
|
_T_co = TypeVar("_T_co", covariant=True)
|
70
72
|
_ScalarType = Union[int, float, bool, str]
|
71
73
|
ArrayLike: TypeAlias = Union[Sequence[_ScalarType], Sequence[Sequence[_ScalarType]], Sequence[Array], Array]
|
@@ -89,8 +91,8 @@ class DatasetMetadata(TypedDict, total=False):
|
|
89
91
|
A lookup table converting label value to class name
|
90
92
|
"""
|
91
93
|
|
92
|
-
id: Required[str]
|
93
|
-
index2label: NotRequired[dict[int, str]]
|
94
|
+
id: Required[ReadOnly[str]]
|
95
|
+
index2label: NotRequired[ReadOnly[dict[int, str]]]
|
94
96
|
|
95
97
|
|
96
98
|
@runtime_checkable
|
@@ -140,7 +142,7 @@ class AnnotatedDataset(Dataset[_T_co], Generic[_T_co], Protocol):
|
|
140
142
|
|
141
143
|
ImageClassificationDatum: TypeAlias = tuple[Array, Array, dict[str, Any]]
|
142
144
|
"""
|
143
|
-
|
145
|
+
Type alias for an image classification datum tuple.
|
144
146
|
|
145
147
|
- :class:`Array` of shape (C, H, W) - Image data in channel, height, width format.
|
146
148
|
- :class:`Array` of shape (N,) - Class label as one-hot encoded ground-truth or prediction confidences.
|
@@ -150,7 +152,7 @@ A type definition for an image classification datum tuple.
|
|
150
152
|
|
151
153
|
ImageClassificationDataset: TypeAlias = AnnotatedDataset[ImageClassificationDatum]
|
152
154
|
"""
|
153
|
-
|
155
|
+
Type alias for an :class:`AnnotatedDataset` of :class:`ImageClassificationDatum` elements.
|
154
156
|
"""
|
155
157
|
|
156
158
|
# ========== OBJECT DETECTION DATASETS ==========
|
@@ -159,7 +161,7 @@ A type definition for an :class:`AnnotatedDataset` of :class:`ImageClassificatio
|
|
159
161
|
@runtime_checkable
|
160
162
|
class ObjectDetectionTarget(Protocol):
|
161
163
|
"""
|
162
|
-
|
164
|
+
Protocol for targets in an Object Detection dataset.
|
163
165
|
|
164
166
|
Attributes
|
165
167
|
----------
|
@@ -180,7 +182,7 @@ class ObjectDetectionTarget(Protocol):
|
|
180
182
|
|
181
183
|
ObjectDetectionDatum: TypeAlias = tuple[Array, ObjectDetectionTarget, dict[str, Any]]
|
182
184
|
"""
|
183
|
-
|
185
|
+
Type alias for an object detection datum tuple.
|
184
186
|
|
185
187
|
- :class:`Array` of shape (C, H, W) - Image data in channel, height, width format.
|
186
188
|
- :class:`ObjectDetectionTarget` - Object detection target information for the image.
|
@@ -190,7 +192,7 @@ A type definition for an object detection datum tuple.
|
|
190
192
|
|
191
193
|
ObjectDetectionDataset: TypeAlias = AnnotatedDataset[ObjectDetectionDatum]
|
192
194
|
"""
|
193
|
-
|
195
|
+
Type alias for an :class:`AnnotatedDataset` of :class:`ObjectDetectionDatum` elements.
|
194
196
|
"""
|
195
197
|
|
196
198
|
|
@@ -200,7 +202,7 @@ A type definition for an :class:`AnnotatedDataset` of :class:`ObjectDetectionDat
|
|
200
202
|
@runtime_checkable
|
201
203
|
class SegmentationTarget(Protocol):
|
202
204
|
"""
|
203
|
-
|
205
|
+
Protocol for targets in a Segmentation dataset.
|
204
206
|
|
205
207
|
Attributes
|
206
208
|
----------
|
@@ -221,7 +223,7 @@ class SegmentationTarget(Protocol):
|
|
221
223
|
|
222
224
|
SegmentationDatum: TypeAlias = tuple[Array, SegmentationTarget, dict[str, Any]]
|
223
225
|
"""
|
224
|
-
|
226
|
+
Type alias for an image classification datum tuple.
|
225
227
|
|
226
228
|
- :class:`Array` of shape (C, H, W) - Image data in channel, height, width format.
|
227
229
|
- :class:`SegmentationTarget` - Segmentation target information for the image.
|
@@ -230,5 +232,34 @@ A type definition for an image classification datum tuple.
|
|
230
232
|
|
231
233
|
SegmentationDataset: TypeAlias = AnnotatedDataset[SegmentationDatum]
|
232
234
|
"""
|
233
|
-
|
235
|
+
Type alias for an :class:`AnnotatedDataset` of :class:`SegmentationDatum` elements.
|
234
236
|
"""
|
237
|
+
|
238
|
+
|
239
|
+
@runtime_checkable
|
240
|
+
class Transform(Generic[T], Protocol):
|
241
|
+
"""
|
242
|
+
Protocol defining a transform function.
|
243
|
+
|
244
|
+
Requires a `__call__` method that returns transformed data.
|
245
|
+
|
246
|
+
Example
|
247
|
+
-------
|
248
|
+
>>> from typing import Any
|
249
|
+
>>> from numpy.typing import NDArray
|
250
|
+
|
251
|
+
>>> class MyTransform:
|
252
|
+
... def __init__(self, divisor: float) -> None:
|
253
|
+
... self.divisor = divisor
|
254
|
+
...
|
255
|
+
... def __call__(self, data: NDArray[Any], /) -> NDArray[Any]:
|
256
|
+
... return data / self.divisor
|
257
|
+
|
258
|
+
>>> my_transform = MyTransform(divisor=255.0)
|
259
|
+
>>> isinstance(my_transform, Transform)
|
260
|
+
True
|
261
|
+
>>> my_transform(np.array([1, 2, 3]))
|
262
|
+
array([0.004, 0.008, 0.012])
|
263
|
+
"""
|
264
|
+
|
265
|
+
def __call__(self, data: T, /) -> T: ...
|
dataeval/utils/_array.py
CHANGED
@@ -13,7 +13,7 @@ import torch
|
|
13
13
|
from numpy.typing import NDArray
|
14
14
|
|
15
15
|
from dataeval._log import LogMessage
|
16
|
-
from dataeval.typing import ArrayLike
|
16
|
+
from dataeval.typing import Array, ArrayLike
|
17
17
|
|
18
18
|
_logger = logging.getLogger(__name__)
|
19
19
|
|
@@ -167,3 +167,28 @@ def flatten(array: ArrayLike) -> NDArray[Any]:
|
|
167
167
|
"""
|
168
168
|
nparr = as_numpy(array)
|
169
169
|
return nparr.reshape((nparr.shape[0], -1))
|
170
|
+
|
171
|
+
|
172
|
+
_TArray = TypeVar("_TArray", bound=Array)
|
173
|
+
|
174
|
+
|
175
|
+
def channels_first_to_last(array: _TArray) -> _TArray:
|
176
|
+
"""
|
177
|
+
Converts array from channels first to channels last format
|
178
|
+
|
179
|
+
Parameters
|
180
|
+
----------
|
181
|
+
array : ArrayLike
|
182
|
+
Input array
|
183
|
+
|
184
|
+
Returns
|
185
|
+
-------
|
186
|
+
ArrayLike
|
187
|
+
Converted array
|
188
|
+
"""
|
189
|
+
if isinstance(array, np.ndarray):
|
190
|
+
return np.transpose(array, (1, 2, 0))
|
191
|
+
elif isinstance(array, torch.Tensor):
|
192
|
+
return torch.permute(array, (1, 2, 0))
|
193
|
+
else:
|
194
|
+
raise TypeError(f"Unsupported array type {type(array)} for conversion.")
|
dataeval/utils/_mst.py
CHANGED
@@ -10,10 +10,9 @@ from scipy.sparse.csgraph import minimum_spanning_tree as mst
|
|
10
10
|
from scipy.spatial.distance import pdist, squareform
|
11
11
|
from sklearn.neighbors import NearestNeighbors
|
12
12
|
|
13
|
+
from dataeval.config import EPSILON
|
13
14
|
from dataeval.utils._array import flatten
|
14
15
|
|
15
|
-
EPSILON = 1e-5
|
16
|
-
|
17
16
|
|
18
17
|
def minimum_spanning_tree(X: NDArray[Any]) -> Any:
|
19
18
|
"""
|