dataeval 0.83.0__py3-none-any.whl → 0.84.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dataeval/__init__.py +1 -1
- dataeval/config.py +3 -3
- dataeval/metrics/bias/__init__.py +11 -1
- dataeval/metrics/bias/_completeness.py +130 -0
- dataeval/metrics/stats/_base.py +26 -30
- dataeval/metrics/stats/_labelstats.py +4 -45
- dataeval/outputs/__init__.py +2 -1
- dataeval/outputs/_bias.py +31 -22
- dataeval/outputs/_stats.py +2 -3
- dataeval/typing.py +3 -3
- dataeval/utils/_array.py +26 -1
- dataeval/utils/data/_dataset.py +2 -0
- dataeval/utils/data/_embeddings.py +99 -21
- dataeval/utils/data/_images.py +38 -15
- dataeval/utils/data/_selection.py +3 -15
- dataeval/utils/data/_split.py +76 -129
- dataeval/utils/metadata.py +1 -1
- {dataeval-0.83.0.dist-info → dataeval-0.84.0.dist-info}/METADATA +1 -1
- {dataeval-0.83.0.dist-info → dataeval-0.84.0.dist-info}/RECORD +21 -20
- {dataeval-0.83.0.dist-info → dataeval-0.84.0.dist-info}/LICENSE.txt +0 -0
- {dataeval-0.83.0.dist-info → dataeval-0.84.0.dist-info}/WHEEL +0 -0
dataeval/__init__.py
CHANGED
dataeval/config.py
CHANGED
@@ -45,13 +45,13 @@ def _todevice(device: DeviceLike) -> torch.device:
|
|
45
45
|
return torch.device(*device) if isinstance(device, tuple) else torch.device(device)
|
46
46
|
|
47
47
|
|
48
|
-
def set_device(device: DeviceLike) -> None:
|
48
|
+
def set_device(device: DeviceLike | None) -> None:
|
49
49
|
"""
|
50
50
|
Sets the default device to use when executing against a PyTorch backend.
|
51
51
|
|
52
52
|
Parameters
|
53
53
|
----------
|
54
|
-
device : DeviceLike
|
54
|
+
device : DeviceLike or None
|
55
55
|
The default device to use. See documentation for more information.
|
56
56
|
|
57
57
|
See Also
|
@@ -59,7 +59,7 @@ def set_device(device: DeviceLike) -> None:
|
|
59
59
|
`torch.device <https://pytorch.org/docs/stable/tensor_attributes.html#torch.device>`_
|
60
60
|
"""
|
61
61
|
global _device
|
62
|
-
_device = _todevice(device)
|
62
|
+
_device = None if device is None else _todevice(device)
|
63
63
|
|
64
64
|
|
65
65
|
def get_device(override: DeviceLike | None = None) -> torch.device:
|
@@ -6,10 +6,12 @@ representation which may impact model performance.
|
|
6
6
|
__all__ = [
|
7
7
|
"BalanceOutput",
|
8
8
|
"CoverageOutput",
|
9
|
+
"CompletenessOutput",
|
9
10
|
"DiversityOutput",
|
10
11
|
"LabelParityOutput",
|
11
12
|
"ParityOutput",
|
12
13
|
"balance",
|
14
|
+
"completeness",
|
13
15
|
"coverage",
|
14
16
|
"diversity",
|
15
17
|
"label_parity",
|
@@ -17,7 +19,15 @@ __all__ = [
|
|
17
19
|
]
|
18
20
|
|
19
21
|
from dataeval.metrics.bias._balance import balance
|
22
|
+
from dataeval.metrics.bias._completeness import completeness
|
20
23
|
from dataeval.metrics.bias._coverage import coverage
|
21
24
|
from dataeval.metrics.bias._diversity import diversity
|
22
25
|
from dataeval.metrics.bias._parity import label_parity, parity
|
23
|
-
from dataeval.outputs._bias import
|
26
|
+
from dataeval.outputs._bias import (
|
27
|
+
BalanceOutput,
|
28
|
+
CompletenessOutput,
|
29
|
+
CoverageOutput,
|
30
|
+
DiversityOutput,
|
31
|
+
LabelParityOutput,
|
32
|
+
ParityOutput,
|
33
|
+
)
|
@@ -0,0 +1,130 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
import itertools
|
4
|
+
|
5
|
+
__all__ = []
|
6
|
+
|
7
|
+
|
8
|
+
import numpy as np
|
9
|
+
|
10
|
+
from dataeval.config import EPSILON
|
11
|
+
from dataeval.outputs import CompletenessOutput
|
12
|
+
from dataeval.typing import ArrayLike
|
13
|
+
from dataeval.utils._array import ensure_embeddings
|
14
|
+
|
15
|
+
|
16
|
+
def completeness(embeddings: ArrayLike, quantiles: int) -> CompletenessOutput:
|
17
|
+
"""
|
18
|
+
Calculate the fraction of boxes in a grid defined by quantiles that
|
19
|
+
contain at least one data point.
|
20
|
+
Also returns the center coordinates of each empty box.
|
21
|
+
|
22
|
+
Parameters
|
23
|
+
----------
|
24
|
+
embeddings : ArrayLike
|
25
|
+
Embedded dataset (or other low-dimensional data) (nxp)
|
26
|
+
quantiles : int
|
27
|
+
number of quantile values to use for partitioning each dimension
|
28
|
+
e.g., 1 would create a grid of 2^p boxes, 2, 3^p etc..
|
29
|
+
|
30
|
+
Returns
|
31
|
+
-------
|
32
|
+
CompletenessOutput
|
33
|
+
- fraction_filled: float - Fraction of boxes that contain at least one
|
34
|
+
data point
|
35
|
+
- empty_box_centers: List[np.ndarray] - List of coordinates for centers of empty
|
36
|
+
boxes
|
37
|
+
|
38
|
+
Raises
|
39
|
+
------
|
40
|
+
ValueError
|
41
|
+
If embeddings are too high-dimensional (>10)
|
42
|
+
ValueError
|
43
|
+
If there are too many quantiles (>2)
|
44
|
+
ValueError
|
45
|
+
If embedding is invalid shape
|
46
|
+
|
47
|
+
Example
|
48
|
+
-------
|
49
|
+
>>> embs = np.array([[1, 0], [0, 1], [1, 1]])
|
50
|
+
>>> quantiles = 1
|
51
|
+
>>> result = completeness(embs, quantiles)
|
52
|
+
>>> result.fraction_filled
|
53
|
+
0.75
|
54
|
+
|
55
|
+
Reference
|
56
|
+
---------
|
57
|
+
This implementation is based on https://arxiv.org/abs/2002.03147.
|
58
|
+
|
59
|
+
[1] Byun, Taejoon, and Sanjai Rayadurgam. “Manifold for Machine Learning Assurance.”
|
60
|
+
Proceedings of the ACM/IEEE 42nd International Conference on Software Engineering
|
61
|
+
"""
|
62
|
+
# Ensure proper data format
|
63
|
+
embeddings = ensure_embeddings(embeddings, dtype=np.float64, unit_interval=False)
|
64
|
+
|
65
|
+
# Get data dimensions
|
66
|
+
n, p = embeddings.shape
|
67
|
+
if quantiles > 2 or quantiles <= 0:
|
68
|
+
raise ValueError(
|
69
|
+
f"Number of quantiles ({quantiles}) is greater than 2 or is nonpositive. \
|
70
|
+
The metric scales exponentially in this value. Please 1 or 2 quantiles."
|
71
|
+
)
|
72
|
+
if p > 10:
|
73
|
+
raise ValueError(
|
74
|
+
f"Dimension of embeddings ({p}) is greater than 10. \
|
75
|
+
The metric scales exponentially in this value. Please reduce the embedding dimension."
|
76
|
+
)
|
77
|
+
if n == 0 or p == 0:
|
78
|
+
raise ValueError("Your provided embeddings do not contain any data!")
|
79
|
+
# n+2 edges partition the embedding dimension (e.g. [0,0.5,1] for quantiles = 1)
|
80
|
+
quantile_vec = np.linspace(0, 1, quantiles + 2)
|
81
|
+
|
82
|
+
# Calculate the bin edges for each dimension based on quantiles
|
83
|
+
bin_edges = []
|
84
|
+
for dim in range(p):
|
85
|
+
# Calculate the quantile values for this feature
|
86
|
+
edges = np.array(np.quantile(embeddings[:, dim], quantile_vec))
|
87
|
+
# Make sure the last bin contains all the remaining points
|
88
|
+
edges[-1] += EPSILON
|
89
|
+
bin_edges.append(edges)
|
90
|
+
# Convert each data point into its corresponding grid cell indices
|
91
|
+
grid_indices = []
|
92
|
+
for dim in range(p):
|
93
|
+
# For each dimension, find which bin each data point belongs to
|
94
|
+
# Digitize is 1 indexed so we subtract 1
|
95
|
+
indices = np.digitize(embeddings[:, dim], bin_edges[dim]) - 1
|
96
|
+
grid_indices.append(indices)
|
97
|
+
|
98
|
+
# Make the rows the data point and the column the grid index
|
99
|
+
grid_coords = np.array(grid_indices).T
|
100
|
+
|
101
|
+
# Use set to find unique tuple of grid coordinates
|
102
|
+
occupied_cells = set(map(tuple, grid_coords))
|
103
|
+
|
104
|
+
# For the fraction
|
105
|
+
num_occupied_cells = len(occupied_cells)
|
106
|
+
|
107
|
+
# Calculate total possible cells in the grid
|
108
|
+
num_bins_per_dim = [len(edges) - 1 for edges in bin_edges]
|
109
|
+
total_possible_cells = np.prod(num_bins_per_dim)
|
110
|
+
|
111
|
+
# Generate all possible grid cells
|
112
|
+
all_cells = set(itertools.product(*[range(bins) for bins in num_bins_per_dim]))
|
113
|
+
|
114
|
+
# Find the empty cells (cells with no data points)
|
115
|
+
empty_cells = all_cells - occupied_cells
|
116
|
+
|
117
|
+
# Calculate center points of empty boxes
|
118
|
+
empty_box_centers = []
|
119
|
+
for cell in empty_cells:
|
120
|
+
center_coords = []
|
121
|
+
for dim, idx in enumerate(cell):
|
122
|
+
# Calculate center of the bin as midpoint between edges
|
123
|
+
center = (bin_edges[dim][idx] + bin_edges[dim][idx + 1]) / 2
|
124
|
+
center_coords.append(center)
|
125
|
+
empty_box_centers.append(np.array(center_coords))
|
126
|
+
|
127
|
+
# Calculate the fraction
|
128
|
+
fraction = float(num_occupied_cells / total_possible_cells)
|
129
|
+
empty_box_centers = np.array(empty_box_centers)
|
130
|
+
return CompletenessOutput(fraction, empty_box_centers)
|
dataeval/metrics/stats/_base.py
CHANGED
@@ -2,6 +2,7 @@ from __future__ import annotations
|
|
2
2
|
|
3
3
|
__all__ = []
|
4
4
|
|
5
|
+
import math
|
5
6
|
import re
|
6
7
|
import warnings
|
7
8
|
from collections import ChainMap
|
@@ -9,7 +10,7 @@ from copy import deepcopy
|
|
9
10
|
from dataclasses import dataclass
|
10
11
|
from functools import partial
|
11
12
|
from multiprocessing import Pool
|
12
|
-
from typing import Any, Callable, Generic, Iterable, Sequence, TypeVar
|
13
|
+
from typing import Any, Callable, Generic, Iterable, Sequence, TypeVar, cast
|
13
14
|
|
14
15
|
import numpy as np
|
15
16
|
import tqdm
|
@@ -23,20 +24,7 @@ from dataeval.utils._image import normalize_image_shape, rescale
|
|
23
24
|
|
24
25
|
DTYPE_REGEX = re.compile(r"NDArray\[np\.(.*?)\]")
|
25
26
|
|
26
|
-
|
27
|
-
def normalize_box_shape(bounding_box: NDArray[Any]) -> NDArray[Any]:
|
28
|
-
"""
|
29
|
-
Normalizes the bounding box shape into (N,4).
|
30
|
-
"""
|
31
|
-
ndim = bounding_box.ndim
|
32
|
-
if ndim == 1:
|
33
|
-
return np.expand_dims(bounding_box, axis=0)
|
34
|
-
elif ndim > 2:
|
35
|
-
raise ValueError("Bounding boxes must have 2 dimensions: (# of boxes in an image, [X,Y,W,H]) -> (N,4)")
|
36
|
-
else:
|
37
|
-
return bounding_box
|
38
|
-
|
39
|
-
|
27
|
+
BoundingBox = tuple[float, float, float, float]
|
40
28
|
TStatsOutput = TypeVar("TStatsOutput", bound=BaseStatsOutput, covariant=True)
|
41
29
|
|
42
30
|
|
@@ -46,11 +34,15 @@ class StatsProcessor(Generic[TStatsOutput]):
|
|
46
34
|
image_function_map: dict[str, Callable[[StatsProcessor[TStatsOutput]], Any]] = {}
|
47
35
|
channel_function_map: dict[str, Callable[[StatsProcessor[TStatsOutput]], Any]] = {}
|
48
36
|
|
49
|
-
def __init__(self, image: NDArray[Any], box:
|
37
|
+
def __init__(self, image: NDArray[Any], box: BoundingBox | None, per_channel: bool) -> None:
|
50
38
|
self.raw = image
|
51
39
|
self.width: int = image.shape[-1]
|
52
40
|
self.height: int = image.shape[-2]
|
53
|
-
|
41
|
+
box = BoundingBox((0, 0, self.width, self.height)) if box is None else box
|
42
|
+
# Clip the bounding box to image
|
43
|
+
x0, y0 = (min(j, max(0, math.floor(box[i]))) for i, j in zip((0, 1), (self.width - 1, self.height - 1)))
|
44
|
+
x1, y1 = (min(j, max(1, math.ceil(box[i]))) for i, j in zip((2, 3), (self.width, self.height)))
|
45
|
+
self.box: NDArray[np.int64] = np.array([x0, y0, x1, y1], dtype=np.int64)
|
54
46
|
self._per_channel = per_channel
|
55
47
|
self._image = None
|
56
48
|
self._shape = None
|
@@ -123,18 +115,16 @@ class StatsProcessorOutput:
|
|
123
115
|
def process_stats(
|
124
116
|
i: int,
|
125
117
|
image: ArrayLike,
|
126
|
-
|
127
|
-
per_box: bool,
|
118
|
+
boxes: list[BoundingBox] | None,
|
128
119
|
per_channel: bool,
|
129
120
|
stats_processor_cls: Iterable[type[StatsProcessor[TStatsOutput]]],
|
130
121
|
) -> StatsProcessorOutput:
|
131
122
|
image = to_numpy(image)
|
132
|
-
boxes = to_numpy(target.boxes) if isinstance(target, ObjectDetectionTarget) else None
|
133
123
|
results_list: list[dict[str, Any]] = []
|
134
124
|
source_indices: list[SourceIndex] = []
|
135
125
|
box_counts: list[int] = []
|
136
126
|
warnings_list: list[str] = []
|
137
|
-
for i_b, box in [(None, None)] if boxes is None else enumerate(
|
127
|
+
for i_b, box in [(None, None)] if boxes is None else enumerate(boxes):
|
138
128
|
processor_list = [p(image, box, per_channel) for p in stats_processor_cls]
|
139
129
|
if any(not p._is_valid_slice for p in processor_list) and i_b is not None and box is not None:
|
140
130
|
warnings_list.append(f"Bounding box [{i}][{i_b}]: {box} is out of bounds of {image.shape}.")
|
@@ -148,12 +138,24 @@ def process_stats(
|
|
148
138
|
|
149
139
|
|
150
140
|
def process_stats_unpack(
|
151
|
-
args: tuple[int,
|
152
|
-
per_box: bool,
|
141
|
+
args: tuple[int, Array, list[BoundingBox] | None],
|
153
142
|
per_channel: bool,
|
154
143
|
stats_processor_cls: Iterable[type[StatsProcessor[TStatsOutput]]],
|
155
144
|
) -> StatsProcessorOutput:
|
156
|
-
return process_stats(*args,
|
145
|
+
return process_stats(*args, per_channel=per_channel, stats_processor_cls=stats_processor_cls)
|
146
|
+
|
147
|
+
|
148
|
+
def _enumerate(dataset: Dataset[Array] | Dataset[tuple[Array, Any, Any]], per_box: bool):
|
149
|
+
for i in range(len(dataset)):
|
150
|
+
d = dataset[i]
|
151
|
+
image = d[0] if isinstance(d, tuple) else d
|
152
|
+
if per_box and isinstance(d, tuple) and isinstance(d[1], ObjectDetectionTarget):
|
153
|
+
boxes = cast(Array, d[1].boxes)
|
154
|
+
target = [BoundingBox(float(box[i]) for i in range(4)) for box in boxes]
|
155
|
+
else:
|
156
|
+
target = None
|
157
|
+
|
158
|
+
yield i, image, target
|
157
159
|
|
158
160
|
|
159
161
|
def run_stats(
|
@@ -202,17 +204,11 @@ def run_stats(
|
|
202
204
|
warning_list = []
|
203
205
|
stats_processor_cls = stats_processor_cls if isinstance(stats_processor_cls, Iterable) else [stats_processor_cls]
|
204
206
|
|
205
|
-
def _enumerate(dataset: Dataset[Array] | Dataset[tuple[Array, Any, Any]], per_box: bool):
|
206
|
-
for i in range(len(dataset)):
|
207
|
-
d = dataset[i]
|
208
|
-
yield i, d[0] if isinstance(d, tuple) else d, d[1] if isinstance(d, tuple) and per_box else None
|
209
|
-
|
210
207
|
with Pool(processes=get_max_processes()) as p:
|
211
208
|
for r in tqdm.tqdm(
|
212
209
|
p.imap(
|
213
210
|
partial(
|
214
211
|
process_stats_unpack,
|
215
|
-
per_box=per_box,
|
216
212
|
per_channel=per_channel,
|
217
213
|
stats_processor_cls=stats_processor_cls,
|
218
214
|
),
|
@@ -5,54 +5,16 @@ __all__ = []
|
|
5
5
|
from collections import Counter, defaultdict
|
6
6
|
from typing import Any, Mapping, TypeVar
|
7
7
|
|
8
|
-
import numpy as np
|
9
|
-
|
10
8
|
from dataeval.outputs import LabelStatsOutput
|
11
9
|
from dataeval.outputs._base import set_metadata
|
12
|
-
from dataeval.typing import AnnotatedDataset
|
13
|
-
from dataeval.utils._array import as_numpy
|
10
|
+
from dataeval.typing import AnnotatedDataset
|
14
11
|
from dataeval.utils.data._metadata import Metadata
|
15
12
|
|
16
13
|
TValue = TypeVar("TValue")
|
17
14
|
|
18
15
|
|
19
|
-
def _ensure_2d(labels: ArrayLike) -> ArrayLike:
|
20
|
-
if isinstance(labels, np.ndarray):
|
21
|
-
return labels[:, None]
|
22
|
-
else:
|
23
|
-
return [[lbl] for lbl in labels] # type: ignore
|
24
|
-
|
25
|
-
|
26
|
-
def _get_list_depth(lst):
|
27
|
-
if isinstance(lst, list) and lst:
|
28
|
-
return 1 + max(_get_list_depth(item) for item in lst)
|
29
|
-
return 0
|
30
|
-
|
31
|
-
|
32
|
-
def _check_labels_dimension(labels: ArrayLike) -> ArrayLike:
|
33
|
-
# Check for nested lists beyond 2 levels
|
34
|
-
|
35
|
-
if isinstance(labels, np.ndarray):
|
36
|
-
if labels.ndim == 1:
|
37
|
-
return _ensure_2d(labels)
|
38
|
-
elif labels.ndim == 2:
|
39
|
-
return labels
|
40
|
-
else:
|
41
|
-
raise ValueError("The label array must not have more than 2 dimensions.")
|
42
|
-
elif isinstance(labels, list):
|
43
|
-
depth = _get_list_depth(labels)
|
44
|
-
if depth == 1:
|
45
|
-
return _ensure_2d(labels)
|
46
|
-
elif depth == 2:
|
47
|
-
return labels
|
48
|
-
else:
|
49
|
-
raise ValueError("The label list must not be empty or have more than 2 levels of nesting.")
|
50
|
-
else:
|
51
|
-
raise TypeError("Labels must be either a NumPy array or a list.")
|
52
|
-
|
53
|
-
|
54
16
|
def _sort_to_list(d: Mapping[int, TValue]) -> list[TValue]:
|
55
|
-
return [
|
17
|
+
return [t[1] for t in sorted(d.items())]
|
56
18
|
|
57
19
|
|
58
20
|
@set_metadata
|
@@ -98,12 +60,9 @@ def labelstats(dataset: Metadata | AnnotatedDataset[Any]) -> LabelStatsOutput:
|
|
98
60
|
label_per_image: list[int] = []
|
99
61
|
|
100
62
|
index2label = dict(enumerate(dataset.class_names))
|
101
|
-
labels = [target.labels.tolist() for target in dataset.targets]
|
102
|
-
|
103
|
-
labels_2d = _check_labels_dimension(labels)
|
104
63
|
|
105
|
-
for i,
|
106
|
-
group =
|
64
|
+
for i, target in enumerate(dataset.targets):
|
65
|
+
group = target.labels.tolist()
|
107
66
|
|
108
67
|
# Count occurrences of each label in all sublists
|
109
68
|
label_counts.update(group)
|
dataeval/outputs/__init__.py
CHANGED
@@ -4,7 +4,7 @@ as well as runtime metadata for reproducibility and logging.
|
|
4
4
|
"""
|
5
5
|
|
6
6
|
from ._base import ExecutionMetadata
|
7
|
-
from ._bias import BalanceOutput, CoverageOutput, DiversityOutput, LabelParityOutput, ParityOutput
|
7
|
+
from ._bias import BalanceOutput, CompletenessOutput, CoverageOutput, DiversityOutput, LabelParityOutput, ParityOutput
|
8
8
|
from ._drift import DriftMMDOutput, DriftOutput
|
9
9
|
from ._estimators import BEROutput, ClustererOutput, DivergenceOutput, UAPOutput
|
10
10
|
from ._linters import DuplicatesOutput, OutliersOutput
|
@@ -29,6 +29,7 @@ __all__ = [
|
|
29
29
|
"ChannelStatsOutput",
|
30
30
|
"ClustererOutput",
|
31
31
|
"CoverageOutput",
|
32
|
+
"CompletenessOutput",
|
32
33
|
"DimensionStatsOutput",
|
33
34
|
"DivergenceOutput",
|
34
35
|
"DiversityOutput",
|
dataeval/outputs/_bias.py
CHANGED
@@ -14,9 +14,10 @@ with contextlib.suppress(ImportError):
|
|
14
14
|
from matplotlib.figure import Figure
|
15
15
|
|
16
16
|
from dataeval.outputs._base import Output
|
17
|
-
from dataeval.typing import ArrayLike
|
18
|
-
from dataeval.utils._array import
|
17
|
+
from dataeval.typing import ArrayLike, Dataset
|
18
|
+
from dataeval.utils._array import as_numpy, channels_first_to_last
|
19
19
|
from dataeval.utils._plot import heatmap
|
20
|
+
from dataeval.utils.data._images import Images
|
20
21
|
|
21
22
|
TData = TypeVar("TData", np.float64, NDArray[np.float64])
|
22
23
|
|
@@ -107,13 +108,13 @@ class CoverageOutput(Output):
|
|
107
108
|
critical_value_radii: NDArray[np.float64]
|
108
109
|
coverage_radius: float
|
109
110
|
|
110
|
-
def plot(self, images:
|
111
|
+
def plot(self, images: Images[Any] | Dataset[Any], top_k: int = 6) -> Figure:
|
111
112
|
"""
|
112
113
|
Plot the top k images together for visualization.
|
113
114
|
|
114
115
|
Parameters
|
115
116
|
----------
|
116
|
-
images :
|
117
|
+
images : Images or Dataset
|
117
118
|
Original images (not embeddings) in (N, C, H, W) or (N, H, W) format
|
118
119
|
top_k : int, default 6
|
119
120
|
Number of images to plot (plotting assumes groups of 3)
|
@@ -130,46 +131,54 @@ class CoverageOutput(Output):
|
|
130
131
|
import matplotlib.pyplot as plt
|
131
132
|
|
132
133
|
# Determine which images to plot
|
133
|
-
|
134
|
+
selected_indices = self.uncovered_indices[:top_k]
|
134
135
|
|
135
|
-
|
136
|
-
selected_images = to_numpy(images)[highest_uncovered_indices]
|
136
|
+
images = Images(images) if isinstance(images, Dataset) else images
|
137
137
|
|
138
138
|
# Plot the images
|
139
|
-
num_images = min(top_k, len(
|
140
|
-
|
141
|
-
ndim = selected_images.ndim
|
142
|
-
if ndim == 4:
|
143
|
-
selected_images = np.moveaxis(selected_images, 1, -1)
|
144
|
-
elif ndim == 3:
|
145
|
-
selected_images = np.repeat(selected_images[:, :, :, np.newaxis], 3, axis=-1)
|
146
|
-
else:
|
147
|
-
raise ValueError(
|
148
|
-
f"Expected a (N,C,H,W) or a (N, H, W) set of images, but got a {ndim}-dimensional set of images."
|
149
|
-
)
|
139
|
+
num_images = min(top_k, len(selected_indices))
|
150
140
|
|
151
141
|
rows = int(np.ceil(num_images / 3))
|
152
142
|
fig, axs = plt.subplots(rows, 3, figsize=(9, 3 * rows))
|
153
143
|
|
154
144
|
if rows == 1:
|
155
145
|
for j in range(3):
|
156
|
-
if j >= len(
|
146
|
+
if j >= len(selected_indices):
|
157
147
|
continue
|
158
|
-
|
148
|
+
image = channels_first_to_last(as_numpy(images[selected_indices[j]]))
|
149
|
+
axs[j].imshow(image)
|
159
150
|
axs[j].axis("off")
|
160
151
|
else:
|
161
152
|
for i in range(rows):
|
162
153
|
for j in range(3):
|
163
154
|
i_j = i * 3 + j
|
164
|
-
if i_j >= len(
|
155
|
+
if i_j >= len(selected_indices):
|
165
156
|
continue
|
166
|
-
|
157
|
+
image = channels_first_to_last(as_numpy(images[selected_indices[i_j]]))
|
158
|
+
axs[i, j].imshow(image)
|
167
159
|
axs[i, j].axis("off")
|
168
160
|
|
169
161
|
fig.tight_layout()
|
170
162
|
return fig
|
171
163
|
|
172
164
|
|
165
|
+
@dataclass(frozen=True)
|
166
|
+
class CompletenessOutput(Output):
|
167
|
+
"""
|
168
|
+
Output from the completeness function.
|
169
|
+
|
170
|
+
Attributes
|
171
|
+
----------
|
172
|
+
fraction_filled : float
|
173
|
+
Fraction of boxes that contain at least one data point
|
174
|
+
empty_box_centers : List[np.ndarray]
|
175
|
+
List of coordinates for centers of empty boxes
|
176
|
+
"""
|
177
|
+
|
178
|
+
fraction_filled: float
|
179
|
+
empty_box_centers: NDArray[np.float64]
|
180
|
+
|
181
|
+
|
173
182
|
@dataclass(frozen=True)
|
174
183
|
class BalanceOutput(Output):
|
175
184
|
"""
|
dataeval/outputs/_stats.py
CHANGED
@@ -4,7 +4,7 @@ __all__ = []
|
|
4
4
|
|
5
5
|
import contextlib
|
6
6
|
from dataclasses import dataclass
|
7
|
-
from typing import Any, Iterable, Optional, Union
|
7
|
+
from typing import Any, Iterable, NamedTuple, Optional, Union
|
8
8
|
|
9
9
|
import numpy as np
|
10
10
|
from numpy.typing import NDArray
|
@@ -22,8 +22,7 @@ SOURCE_INDEX = "source_index"
|
|
22
22
|
BOX_COUNT = "box_count"
|
23
23
|
|
24
24
|
|
25
|
-
|
26
|
-
class SourceIndex:
|
25
|
+
class SourceIndex(NamedTuple):
|
27
26
|
"""
|
28
27
|
The indices of the source image, box and channel.
|
29
28
|
|
dataeval/typing.py
CHANGED
@@ -23,7 +23,7 @@ __all__ = [
|
|
23
23
|
import sys
|
24
24
|
from typing import Any, Generic, Iterator, Protocol, Sequence, TypedDict, TypeVar, Union, runtime_checkable
|
25
25
|
|
26
|
-
from typing_extensions import NotRequired, Required
|
26
|
+
from typing_extensions import NotRequired, ReadOnly, Required
|
27
27
|
|
28
28
|
if sys.version_info >= (3, 10):
|
29
29
|
from typing import TypeAlias
|
@@ -91,8 +91,8 @@ class DatasetMetadata(TypedDict, total=False):
|
|
91
91
|
A lookup table converting label value to class name
|
92
92
|
"""
|
93
93
|
|
94
|
-
id: Required[str]
|
95
|
-
index2label: NotRequired[dict[int, str]]
|
94
|
+
id: Required[ReadOnly[str]]
|
95
|
+
index2label: NotRequired[ReadOnly[dict[int, str]]]
|
96
96
|
|
97
97
|
|
98
98
|
@runtime_checkable
|
dataeval/utils/_array.py
CHANGED
@@ -13,7 +13,7 @@ import torch
|
|
13
13
|
from numpy.typing import NDArray
|
14
14
|
|
15
15
|
from dataeval._log import LogMessage
|
16
|
-
from dataeval.typing import ArrayLike
|
16
|
+
from dataeval.typing import Array, ArrayLike
|
17
17
|
|
18
18
|
_logger = logging.getLogger(__name__)
|
19
19
|
|
@@ -167,3 +167,28 @@ def flatten(array: ArrayLike) -> NDArray[Any]:
|
|
167
167
|
"""
|
168
168
|
nparr = as_numpy(array)
|
169
169
|
return nparr.reshape((nparr.shape[0], -1))
|
170
|
+
|
171
|
+
|
172
|
+
_TArray = TypeVar("_TArray", bound=Array)
|
173
|
+
|
174
|
+
|
175
|
+
def channels_first_to_last(array: _TArray) -> _TArray:
|
176
|
+
"""
|
177
|
+
Converts array from channels first to channels last format
|
178
|
+
|
179
|
+
Parameters
|
180
|
+
----------
|
181
|
+
array : ArrayLike
|
182
|
+
Input array
|
183
|
+
|
184
|
+
Returns
|
185
|
+
-------
|
186
|
+
ArrayLike
|
187
|
+
Converted array
|
188
|
+
"""
|
189
|
+
if isinstance(array, np.ndarray):
|
190
|
+
return np.transpose(array, (1, 2, 0))
|
191
|
+
elif isinstance(array, torch.Tensor):
|
192
|
+
return torch.permute(array, (1, 2, 0))
|
193
|
+
else:
|
194
|
+
raise TypeError(f"Unsupported array type {type(array)} for conversion.")
|
dataeval/utils/data/_dataset.py
CHANGED
@@ -47,6 +47,8 @@ def _validate_data(
|
|
47
47
|
or not len(bboxes[0][0]) == 4
|
48
48
|
):
|
49
49
|
raise TypeError("Boxes must be a sequence of sequences of (x0, y0, x1, y1) for object detection.")
|
50
|
+
else:
|
51
|
+
raise ValueError(f"Unknown datum type '{datum_type}'. Must be 'ic' or 'od'.")
|
50
52
|
|
51
53
|
|
52
54
|
def _find_max(arr: ArrayLike) -> Any:
|