dataeval 0.83.0__py3-none-any.whl → 0.84.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (50) hide show
  1. dataeval/__init__.py +1 -1
  2. dataeval/config.py +3 -3
  3. dataeval/detectors/drift/__init__.py +2 -2
  4. dataeval/detectors/drift/_base.py +55 -203
  5. dataeval/detectors/drift/_cvm.py +19 -30
  6. dataeval/detectors/drift/_ks.py +18 -30
  7. dataeval/detectors/drift/_mmd.py +189 -53
  8. dataeval/detectors/drift/_uncertainty.py +52 -56
  9. dataeval/detectors/drift/updates.py +13 -12
  10. dataeval/detectors/linters/duplicates.py +5 -3
  11. dataeval/detectors/linters/outliers.py +2 -2
  12. dataeval/detectors/ood/ae.py +1 -1
  13. dataeval/metrics/bias/__init__.py +11 -1
  14. dataeval/metrics/bias/_completeness.py +130 -0
  15. dataeval/metrics/stats/_base.py +28 -32
  16. dataeval/metrics/stats/_dimensionstats.py +2 -2
  17. dataeval/metrics/stats/_hashstats.py +2 -2
  18. dataeval/metrics/stats/_imagestats.py +4 -4
  19. dataeval/metrics/stats/_labelstats.py +4 -45
  20. dataeval/metrics/stats/_pixelstats.py +2 -2
  21. dataeval/metrics/stats/_visualstats.py +2 -2
  22. dataeval/outputs/__init__.py +2 -1
  23. dataeval/outputs/_bias.py +31 -22
  24. dataeval/outputs/_stats.py +2 -3
  25. dataeval/typing.py +25 -22
  26. dataeval/utils/_array.py +43 -7
  27. dataeval/utils/data/_dataset.py +8 -4
  28. dataeval/utils/data/_embeddings.py +141 -24
  29. dataeval/utils/data/_images.py +38 -15
  30. dataeval/utils/data/_metadata.py +5 -4
  31. dataeval/utils/data/_selection.py +3 -15
  32. dataeval/utils/data/_split.py +76 -129
  33. dataeval/utils/data/datasets/_base.py +7 -4
  34. dataeval/utils/data/datasets/_cifar10.py +9 -9
  35. dataeval/utils/data/datasets/_milco.py +42 -14
  36. dataeval/utils/data/datasets/_mnist.py +9 -5
  37. dataeval/utils/data/datasets/_ships.py +8 -4
  38. dataeval/utils/data/datasets/_voc.py +40 -19
  39. dataeval/utils/data/selections/__init__.py +2 -0
  40. dataeval/utils/data/selections/_classbalance.py +38 -0
  41. dataeval/utils/data/selections/_classfilter.py +14 -29
  42. dataeval/utils/data/selections/_prioritize.py +1 -1
  43. dataeval/utils/data/selections/_shuffle.py +2 -2
  44. dataeval/utils/metadata.py +1 -1
  45. dataeval/utils/torch/_internal.py +12 -35
  46. {dataeval-0.83.0.dist-info → dataeval-0.84.1.dist-info}/METADATA +2 -3
  47. {dataeval-0.83.0.dist-info → dataeval-0.84.1.dist-info}/RECORD +49 -48
  48. dataeval/detectors/drift/_torch.py +0 -222
  49. {dataeval-0.83.0.dist-info → dataeval-0.84.1.dist-info}/LICENSE.txt +0 -0
  50. {dataeval-0.83.0.dist-info → dataeval-0.84.1.dist-info}/WHEEL +0 -0
@@ -0,0 +1,130 @@
1
+ from __future__ import annotations
2
+
3
+ import itertools
4
+
5
+ __all__ = []
6
+
7
+
8
+ import numpy as np
9
+
10
+ from dataeval.config import EPSILON
11
+ from dataeval.outputs import CompletenessOutput
12
+ from dataeval.typing import ArrayLike
13
+ from dataeval.utils._array import ensure_embeddings
14
+
15
+
16
+ def completeness(embeddings: ArrayLike, quantiles: int) -> CompletenessOutput:
17
+ """
18
+ Calculate the fraction of boxes in a grid defined by quantiles that
19
+ contain at least one data point.
20
+ Also returns the center coordinates of each empty box.
21
+
22
+ Parameters
23
+ ----------
24
+ embeddings : ArrayLike
25
+ Embedded dataset (or other low-dimensional data) (nxp)
26
+ quantiles : int
27
+ number of quantile values to use for partitioning each dimension
28
+ e.g., 1 would create a grid of 2^p boxes, 2, 3^p etc..
29
+
30
+ Returns
31
+ -------
32
+ CompletenessOutput
33
+ - fraction_filled: float - Fraction of boxes that contain at least one
34
+ data point
35
+ - empty_box_centers: List[np.ndarray] - List of coordinates for centers of empty
36
+ boxes
37
+
38
+ Raises
39
+ ------
40
+ ValueError
41
+ If embeddings are too high-dimensional (>10)
42
+ ValueError
43
+ If there are too many quantiles (>2)
44
+ ValueError
45
+ If embedding is invalid shape
46
+
47
+ Example
48
+ -------
49
+ >>> embs = np.array([[1, 0], [0, 1], [1, 1]])
50
+ >>> quantiles = 1
51
+ >>> result = completeness(embs, quantiles)
52
+ >>> result.fraction_filled
53
+ 0.75
54
+
55
+ Reference
56
+ ---------
57
+ This implementation is based on https://arxiv.org/abs/2002.03147.
58
+
59
+ [1] Byun, Taejoon, and Sanjai Rayadurgam. “Manifold for Machine Learning Assurance.”
60
+ Proceedings of the ACM/IEEE 42nd International Conference on Software Engineering
61
+ """
62
+ # Ensure proper data format
63
+ embeddings = ensure_embeddings(embeddings, dtype=np.float64, unit_interval=False)
64
+
65
+ # Get data dimensions
66
+ n, p = embeddings.shape
67
+ if quantiles > 2 or quantiles <= 0:
68
+ raise ValueError(
69
+ f"Number of quantiles ({quantiles}) is greater than 2 or is nonpositive. \
70
+ The metric scales exponentially in this value. Please 1 or 2 quantiles."
71
+ )
72
+ if p > 10:
73
+ raise ValueError(
74
+ f"Dimension of embeddings ({p}) is greater than 10. \
75
+ The metric scales exponentially in this value. Please reduce the embedding dimension."
76
+ )
77
+ if n == 0 or p == 0:
78
+ raise ValueError("Your provided embeddings do not contain any data!")
79
+ # n+2 edges partition the embedding dimension (e.g. [0,0.5,1] for quantiles = 1)
80
+ quantile_vec = np.linspace(0, 1, quantiles + 2)
81
+
82
+ # Calculate the bin edges for each dimension based on quantiles
83
+ bin_edges = []
84
+ for dim in range(p):
85
+ # Calculate the quantile values for this feature
86
+ edges = np.array(np.quantile(embeddings[:, dim], quantile_vec))
87
+ # Make sure the last bin contains all the remaining points
88
+ edges[-1] += EPSILON
89
+ bin_edges.append(edges)
90
+ # Convert each data point into its corresponding grid cell indices
91
+ grid_indices = []
92
+ for dim in range(p):
93
+ # For each dimension, find which bin each data point belongs to
94
+ # Digitize is 1 indexed so we subtract 1
95
+ indices = np.digitize(embeddings[:, dim], bin_edges[dim]) - 1
96
+ grid_indices.append(indices)
97
+
98
+ # Make the rows the data point and the column the grid index
99
+ grid_coords = np.array(grid_indices).T
100
+
101
+ # Use set to find unique tuple of grid coordinates
102
+ occupied_cells = set(map(tuple, grid_coords))
103
+
104
+ # For the fraction
105
+ num_occupied_cells = len(occupied_cells)
106
+
107
+ # Calculate total possible cells in the grid
108
+ num_bins_per_dim = [len(edges) - 1 for edges in bin_edges]
109
+ total_possible_cells = np.prod(num_bins_per_dim)
110
+
111
+ # Generate all possible grid cells
112
+ all_cells = set(itertools.product(*[range(bins) for bins in num_bins_per_dim]))
113
+
114
+ # Find the empty cells (cells with no data points)
115
+ empty_cells = all_cells - occupied_cells
116
+
117
+ # Calculate center points of empty boxes
118
+ empty_box_centers = []
119
+ for cell in empty_cells:
120
+ center_coords = []
121
+ for dim, idx in enumerate(cell):
122
+ # Calculate center of the bin as midpoint between edges
123
+ center = (bin_edges[dim][idx] + bin_edges[dim][idx + 1]) / 2
124
+ center_coords.append(center)
125
+ empty_box_centers.append(np.array(center_coords))
126
+
127
+ # Calculate the fraction
128
+ fraction = float(num_occupied_cells / total_possible_cells)
129
+ empty_box_centers = np.array(empty_box_centers)
130
+ return CompletenessOutput(fraction, empty_box_centers)
@@ -2,6 +2,7 @@ from __future__ import annotations
2
2
 
3
3
  __all__ = []
4
4
 
5
+ import math
5
6
  import re
6
7
  import warnings
7
8
  from collections import ChainMap
@@ -18,25 +19,12 @@ from numpy.typing import NDArray
18
19
  from dataeval.config import get_max_processes
19
20
  from dataeval.outputs._stats import BaseStatsOutput, SourceIndex
20
21
  from dataeval.typing import Array, ArrayLike, Dataset, ObjectDetectionTarget
21
- from dataeval.utils._array import to_numpy
22
+ from dataeval.utils._array import as_numpy, to_numpy
22
23
  from dataeval.utils._image import normalize_image_shape, rescale
23
24
 
24
25
  DTYPE_REGEX = re.compile(r"NDArray\[np\.(.*?)\]")
25
26
 
26
-
27
- def normalize_box_shape(bounding_box: NDArray[Any]) -> NDArray[Any]:
28
- """
29
- Normalizes the bounding box shape into (N,4).
30
- """
31
- ndim = bounding_box.ndim
32
- if ndim == 1:
33
- return np.expand_dims(bounding_box, axis=0)
34
- elif ndim > 2:
35
- raise ValueError("Bounding boxes must have 2 dimensions: (# of boxes in an image, [X,Y,W,H]) -> (N,4)")
36
- else:
37
- return bounding_box
38
-
39
-
27
+ BoundingBox = tuple[float, float, float, float]
40
28
  TStatsOutput = TypeVar("TStatsOutput", bound=BaseStatsOutput, covariant=True)
41
29
 
42
30
 
@@ -46,11 +34,15 @@ class StatsProcessor(Generic[TStatsOutput]):
46
34
  image_function_map: dict[str, Callable[[StatsProcessor[TStatsOutput]], Any]] = {}
47
35
  channel_function_map: dict[str, Callable[[StatsProcessor[TStatsOutput]], Any]] = {}
48
36
 
49
- def __init__(self, image: NDArray[Any], box: NDArray[Any] | None, per_channel: bool) -> None:
37
+ def __init__(self, image: NDArray[Any], box: BoundingBox | None, per_channel: bool) -> None:
50
38
  self.raw = image
51
39
  self.width: int = image.shape[-1]
52
40
  self.height: int = image.shape[-2]
53
- self.box: NDArray[np.int64] = np.array([0, 0, self.width, self.height]) if box is None else box.astype(np.int64)
41
+ box = BoundingBox((0, 0, self.width, self.height)) if box is None else box
42
+ # Clip the bounding box to image
43
+ x0, y0 = (min(j, max(0, math.floor(box[i]))) for i, j in zip((0, 1), (self.width - 1, self.height - 1)))
44
+ x1, y1 = (min(j, max(1, math.ceil(box[i]))) for i, j in zip((2, 3), (self.width, self.height)))
45
+ self.box: NDArray[np.int64] = np.array([x0, y0, x1, y1], dtype=np.int64)
54
46
  self._per_channel = per_channel
55
47
  self._image = None
56
48
  self._shape = None
@@ -123,18 +115,16 @@ class StatsProcessorOutput:
123
115
  def process_stats(
124
116
  i: int,
125
117
  image: ArrayLike,
126
- target: Any,
127
- per_box: bool,
118
+ boxes: list[BoundingBox] | None,
128
119
  per_channel: bool,
129
120
  stats_processor_cls: Iterable[type[StatsProcessor[TStatsOutput]]],
130
121
  ) -> StatsProcessorOutput:
131
122
  image = to_numpy(image)
132
- boxes = to_numpy(target.boxes) if isinstance(target, ObjectDetectionTarget) else None
133
123
  results_list: list[dict[str, Any]] = []
134
124
  source_indices: list[SourceIndex] = []
135
125
  box_counts: list[int] = []
136
126
  warnings_list: list[str] = []
137
- for i_b, box in [(None, None)] if boxes is None else enumerate(normalize_box_shape(boxes)):
127
+ for i_b, box in [(None, None)] if boxes is None else enumerate(boxes):
138
128
  processor_list = [p(image, box, per_channel) for p in stats_processor_cls]
139
129
  if any(not p._is_valid_slice for p in processor_list) and i_b is not None and box is not None:
140
130
  warnings_list.append(f"Bounding box [{i}][{i_b}]: {box} is out of bounds of {image.shape}.")
@@ -148,16 +138,28 @@ def process_stats(
148
138
 
149
139
 
150
140
  def process_stats_unpack(
151
- args: tuple[int, ArrayLike, Any],
152
- per_box: bool,
141
+ args: tuple[int, ArrayLike, list[BoundingBox] | None],
153
142
  per_channel: bool,
154
143
  stats_processor_cls: Iterable[type[StatsProcessor[TStatsOutput]]],
155
144
  ) -> StatsProcessorOutput:
156
- return process_stats(*args, per_box=per_box, per_channel=per_channel, stats_processor_cls=stats_processor_cls)
145
+ return process_stats(*args, per_channel=per_channel, stats_processor_cls=stats_processor_cls)
146
+
147
+
148
+ def _enumerate(dataset: Dataset[ArrayLike] | Dataset[tuple[ArrayLike, Any, Any]], per_box: bool):
149
+ for i in range(len(dataset)):
150
+ d = dataset[i]
151
+ image = d[0] if isinstance(d, tuple) else d
152
+ if per_box and isinstance(d, tuple) and isinstance(d[1], ObjectDetectionTarget):
153
+ boxes = d[1].boxes if isinstance(d[1].boxes, Array) else as_numpy(d[1].boxes)
154
+ target = [BoundingBox(float(box[i]) for i in range(4)) for box in boxes]
155
+ else:
156
+ target = None
157
+
158
+ yield i, image, target
157
159
 
158
160
 
159
161
  def run_stats(
160
- dataset: Dataset[Array] | Dataset[tuple[Array, Any, Any]],
162
+ dataset: Dataset[ArrayLike] | Dataset[tuple[ArrayLike, Any, Any]],
161
163
  per_box: bool,
162
164
  per_channel: bool,
163
165
  stats_processor_cls: Iterable[type[StatsProcessor[TStatsOutput]]],
@@ -171,7 +173,7 @@ def run_stats(
171
173
 
172
174
  Parameters
173
175
  ----------
174
- data : Dataset[Array] | Dataset[tuple[Array, Any, Any]]
176
+ data : Dataset[ArrayLike] | Dataset[tuple[ArrayLike, Any, Any]]
175
177
  A dataset of images and targets to compute statistics on.
176
178
  per_box : bool
177
179
  A flag which determines if the statistics should be evaluated on a per-box basis or not.
@@ -202,17 +204,11 @@ def run_stats(
202
204
  warning_list = []
203
205
  stats_processor_cls = stats_processor_cls if isinstance(stats_processor_cls, Iterable) else [stats_processor_cls]
204
206
 
205
- def _enumerate(dataset: Dataset[Array] | Dataset[tuple[Array, Any, Any]], per_box: bool):
206
- for i in range(len(dataset)):
207
- d = dataset[i]
208
- yield i, d[0] if isinstance(d, tuple) else d, d[1] if isinstance(d, tuple) and per_box else None
209
-
210
207
  with Pool(processes=get_max_processes()) as p:
211
208
  for r in tqdm.tqdm(
212
209
  p.imap(
213
210
  partial(
214
211
  process_stats_unpack,
215
- per_box=per_box,
216
212
  per_channel=per_channel,
217
213
  stats_processor_cls=stats_processor_cls,
218
214
  ),
@@ -9,7 +9,7 @@ import numpy as np
9
9
  from dataeval.metrics.stats._base import StatsProcessor, run_stats
10
10
  from dataeval.outputs import DimensionStatsOutput
11
11
  from dataeval.outputs._base import set_metadata
12
- from dataeval.typing import Array, Dataset
12
+ from dataeval.typing import ArrayLike, Dataset
13
13
  from dataeval.utils._image import get_bitdepth
14
14
 
15
15
 
@@ -34,7 +34,7 @@ class DimensionStatsProcessor(StatsProcessor[DimensionStatsOutput]):
34
34
 
35
35
  @set_metadata
36
36
  def dimensionstats(
37
- dataset: Dataset[Array] | Dataset[tuple[Array, Any, Any]],
37
+ dataset: Dataset[ArrayLike] | Dataset[tuple[ArrayLike, Any, Any]],
38
38
  *,
39
39
  per_box: bool = False,
40
40
  ) -> DimensionStatsOutput:
@@ -14,7 +14,7 @@ from scipy.fftpack import dct
14
14
  from dataeval.metrics.stats._base import StatsProcessor, run_stats
15
15
  from dataeval.outputs import HashStatsOutput
16
16
  from dataeval.outputs._base import set_metadata
17
- from dataeval.typing import Array, ArrayLike, Dataset
17
+ from dataeval.typing import ArrayLike, Dataset
18
18
  from dataeval.utils._array import as_numpy
19
19
  from dataeval.utils._image import normalize_image_shape, rescale
20
20
 
@@ -105,7 +105,7 @@ class HashStatsProcessor(StatsProcessor[HashStatsOutput]):
105
105
 
106
106
  @set_metadata
107
107
  def hashstats(
108
- dataset: Dataset[Array] | Dataset[tuple[Array, Any, Any]],
108
+ dataset: Dataset[ArrayLike] | Dataset[tuple[ArrayLike, Any, Any]],
109
109
  *,
110
110
  per_box: bool = False,
111
111
  ) -> HashStatsOutput:
@@ -10,12 +10,12 @@ from dataeval.metrics.stats._pixelstats import PixelStatsProcessor
10
10
  from dataeval.metrics.stats._visualstats import VisualStatsProcessor
11
11
  from dataeval.outputs import ChannelStatsOutput, ImageStatsOutput
12
12
  from dataeval.outputs._base import set_metadata
13
- from dataeval.typing import Array, Dataset
13
+ from dataeval.typing import ArrayLike, Dataset
14
14
 
15
15
 
16
16
  @overload
17
17
  def imagestats(
18
- dataset: Dataset[Array] | Dataset[tuple[Array, Any, Any]],
18
+ dataset: Dataset[ArrayLike] | Dataset[tuple[ArrayLike, Any, Any]],
19
19
  *,
20
20
  per_box: bool = False,
21
21
  per_channel: Literal[True],
@@ -24,7 +24,7 @@ def imagestats(
24
24
 
25
25
  @overload
26
26
  def imagestats(
27
- dataset: Dataset[Array] | Dataset[tuple[Array, Any, Any]],
27
+ dataset: Dataset[ArrayLike] | Dataset[tuple[ArrayLike, Any, Any]],
28
28
  *,
29
29
  per_box: bool = False,
30
30
  per_channel: Literal[False] = False,
@@ -33,7 +33,7 @@ def imagestats(
33
33
 
34
34
  @set_metadata
35
35
  def imagestats(
36
- dataset: Dataset[Array] | Dataset[tuple[Array, Any, Any]],
36
+ dataset: Dataset[ArrayLike] | Dataset[tuple[ArrayLike, Any, Any]],
37
37
  *,
38
38
  per_box: bool = False,
39
39
  per_channel: bool = False,
@@ -5,54 +5,16 @@ __all__ = []
5
5
  from collections import Counter, defaultdict
6
6
  from typing import Any, Mapping, TypeVar
7
7
 
8
- import numpy as np
9
-
10
8
  from dataeval.outputs import LabelStatsOutput
11
9
  from dataeval.outputs._base import set_metadata
12
- from dataeval.typing import AnnotatedDataset, ArrayLike
13
- from dataeval.utils._array import as_numpy
10
+ from dataeval.typing import AnnotatedDataset
14
11
  from dataeval.utils.data._metadata import Metadata
15
12
 
16
13
  TValue = TypeVar("TValue")
17
14
 
18
15
 
19
- def _ensure_2d(labels: ArrayLike) -> ArrayLike:
20
- if isinstance(labels, np.ndarray):
21
- return labels[:, None]
22
- else:
23
- return [[lbl] for lbl in labels] # type: ignore
24
-
25
-
26
- def _get_list_depth(lst):
27
- if isinstance(lst, list) and lst:
28
- return 1 + max(_get_list_depth(item) for item in lst)
29
- return 0
30
-
31
-
32
- def _check_labels_dimension(labels: ArrayLike) -> ArrayLike:
33
- # Check for nested lists beyond 2 levels
34
-
35
- if isinstance(labels, np.ndarray):
36
- if labels.ndim == 1:
37
- return _ensure_2d(labels)
38
- elif labels.ndim == 2:
39
- return labels
40
- else:
41
- raise ValueError("The label array must not have more than 2 dimensions.")
42
- elif isinstance(labels, list):
43
- depth = _get_list_depth(labels)
44
- if depth == 1:
45
- return _ensure_2d(labels)
46
- elif depth == 2:
47
- return labels
48
- else:
49
- raise ValueError("The label list must not be empty or have more than 2 levels of nesting.")
50
- else:
51
- raise TypeError("Labels must be either a NumPy array or a list.")
52
-
53
-
54
16
  def _sort_to_list(d: Mapping[int, TValue]) -> list[TValue]:
55
- return [v for _, v in sorted(d.items())]
17
+ return [t[1] for t in sorted(d.items())]
56
18
 
57
19
 
58
20
  @set_metadata
@@ -98,12 +60,9 @@ def labelstats(dataset: Metadata | AnnotatedDataset[Any]) -> LabelStatsOutput:
98
60
  label_per_image: list[int] = []
99
61
 
100
62
  index2label = dict(enumerate(dataset.class_names))
101
- labels = [target.labels.tolist() for target in dataset.targets]
102
-
103
- labels_2d = _check_labels_dimension(labels)
104
63
 
105
- for i, group in enumerate(labels_2d):
106
- group = as_numpy(group).tolist()
64
+ for i, target in enumerate(dataset.targets):
65
+ group = target.labels.tolist()
107
66
 
108
67
  # Count occurrences of each label in all sublists
109
68
  label_counts.update(group)
@@ -10,7 +10,7 @@ from scipy.stats import entropy, kurtosis, skew
10
10
  from dataeval.metrics.stats._base import StatsProcessor, run_stats
11
11
  from dataeval.outputs import PixelStatsOutput
12
12
  from dataeval.outputs._base import set_metadata
13
- from dataeval.typing import Array, Dataset
13
+ from dataeval.typing import ArrayLike, Dataset
14
14
 
15
15
 
16
16
  class PixelStatsProcessor(StatsProcessor[PixelStatsOutput]):
@@ -37,7 +37,7 @@ class PixelStatsProcessor(StatsProcessor[PixelStatsOutput]):
37
37
 
38
38
  @set_metadata
39
39
  def pixelstats(
40
- dataset: Dataset[Array] | Dataset[tuple[Array, Any, Any]],
40
+ dataset: Dataset[ArrayLike] | Dataset[tuple[ArrayLike, Any, Any]],
41
41
  *,
42
42
  per_box: bool = False,
43
43
  per_channel: bool = False,
@@ -9,7 +9,7 @@ import numpy as np
9
9
  from dataeval.metrics.stats._base import StatsProcessor, run_stats
10
10
  from dataeval.outputs import VisualStatsOutput
11
11
  from dataeval.outputs._base import set_metadata
12
- from dataeval.typing import Array, Dataset
12
+ from dataeval.typing import ArrayLike, Dataset
13
13
  from dataeval.utils._image import edge_filter
14
14
 
15
15
  QUARTILES = (0, 25, 50, 75, 100)
@@ -44,7 +44,7 @@ class VisualStatsProcessor(StatsProcessor[VisualStatsOutput]):
44
44
 
45
45
  @set_metadata
46
46
  def visualstats(
47
- dataset: Dataset[Array] | Dataset[tuple[Array, Any, Any]],
47
+ dataset: Dataset[ArrayLike] | Dataset[tuple[ArrayLike, Any, Any]],
48
48
  *,
49
49
  per_box: bool = False,
50
50
  per_channel: bool = False,
@@ -4,7 +4,7 @@ as well as runtime metadata for reproducibility and logging.
4
4
  """
5
5
 
6
6
  from ._base import ExecutionMetadata
7
- from ._bias import BalanceOutput, CoverageOutput, DiversityOutput, LabelParityOutput, ParityOutput
7
+ from ._bias import BalanceOutput, CompletenessOutput, CoverageOutput, DiversityOutput, LabelParityOutput, ParityOutput
8
8
  from ._drift import DriftMMDOutput, DriftOutput
9
9
  from ._estimators import BEROutput, ClustererOutput, DivergenceOutput, UAPOutput
10
10
  from ._linters import DuplicatesOutput, OutliersOutput
@@ -29,6 +29,7 @@ __all__ = [
29
29
  "ChannelStatsOutput",
30
30
  "ClustererOutput",
31
31
  "CoverageOutput",
32
+ "CompletenessOutput",
32
33
  "DimensionStatsOutput",
33
34
  "DivergenceOutput",
34
35
  "DiversityOutput",
dataeval/outputs/_bias.py CHANGED
@@ -14,9 +14,10 @@ with contextlib.suppress(ImportError):
14
14
  from matplotlib.figure import Figure
15
15
 
16
16
  from dataeval.outputs._base import Output
17
- from dataeval.typing import ArrayLike
18
- from dataeval.utils._array import to_numpy
17
+ from dataeval.typing import ArrayLike, Dataset
18
+ from dataeval.utils._array import as_numpy, channels_first_to_last
19
19
  from dataeval.utils._plot import heatmap
20
+ from dataeval.utils.data._images import Images
20
21
 
21
22
  TData = TypeVar("TData", np.float64, NDArray[np.float64])
22
23
 
@@ -107,13 +108,13 @@ class CoverageOutput(Output):
107
108
  critical_value_radii: NDArray[np.float64]
108
109
  coverage_radius: float
109
110
 
110
- def plot(self, images: ArrayLike, top_k: int = 6) -> Figure:
111
+ def plot(self, images: Images[Any] | Dataset[Any], top_k: int = 6) -> Figure:
111
112
  """
112
113
  Plot the top k images together for visualization.
113
114
 
114
115
  Parameters
115
116
  ----------
116
- images : ArrayLike
117
+ images : Images or Dataset
117
118
  Original images (not embeddings) in (N, C, H, W) or (N, H, W) format
118
119
  top_k : int, default 6
119
120
  Number of images to plot (plotting assumes groups of 3)
@@ -130,46 +131,54 @@ class CoverageOutput(Output):
130
131
  import matplotlib.pyplot as plt
131
132
 
132
133
  # Determine which images to plot
133
- highest_uncovered_indices = self.uncovered_indices[:top_k]
134
+ selected_indices = self.uncovered_indices[:top_k]
134
135
 
135
- # Grab the images
136
- selected_images = to_numpy(images)[highest_uncovered_indices]
136
+ images = Images(images) if isinstance(images, Dataset) else images
137
137
 
138
138
  # Plot the images
139
- num_images = min(top_k, len(images))
140
-
141
- ndim = selected_images.ndim
142
- if ndim == 4:
143
- selected_images = np.moveaxis(selected_images, 1, -1)
144
- elif ndim == 3:
145
- selected_images = np.repeat(selected_images[:, :, :, np.newaxis], 3, axis=-1)
146
- else:
147
- raise ValueError(
148
- f"Expected a (N,C,H,W) or a (N, H, W) set of images, but got a {ndim}-dimensional set of images."
149
- )
139
+ num_images = min(top_k, len(selected_indices))
150
140
 
151
141
  rows = int(np.ceil(num_images / 3))
152
142
  fig, axs = plt.subplots(rows, 3, figsize=(9, 3 * rows))
153
143
 
154
144
  if rows == 1:
155
145
  for j in range(3):
156
- if j >= len(selected_images):
146
+ if j >= len(selected_indices):
157
147
  continue
158
- axs[j].imshow(selected_images[j])
148
+ image = channels_first_to_last(as_numpy(images[selected_indices[j]]))
149
+ axs[j].imshow(image)
159
150
  axs[j].axis("off")
160
151
  else:
161
152
  for i in range(rows):
162
153
  for j in range(3):
163
154
  i_j = i * 3 + j
164
- if i_j >= len(selected_images):
155
+ if i_j >= len(selected_indices):
165
156
  continue
166
- axs[i, j].imshow(selected_images[i_j])
157
+ image = channels_first_to_last(as_numpy(images[selected_indices[i_j]]))
158
+ axs[i, j].imshow(image)
167
159
  axs[i, j].axis("off")
168
160
 
169
161
  fig.tight_layout()
170
162
  return fig
171
163
 
172
164
 
165
+ @dataclass(frozen=True)
166
+ class CompletenessOutput(Output):
167
+ """
168
+ Output from the completeness function.
169
+
170
+ Attributes
171
+ ----------
172
+ fraction_filled : float
173
+ Fraction of boxes that contain at least one data point
174
+ empty_box_centers : List[np.ndarray]
175
+ List of coordinates for centers of empty boxes
176
+ """
177
+
178
+ fraction_filled: float
179
+ empty_box_centers: NDArray[np.float64]
180
+
181
+
173
182
  @dataclass(frozen=True)
174
183
  class BalanceOutput(Output):
175
184
  """
@@ -4,7 +4,7 @@ __all__ = []
4
4
 
5
5
  import contextlib
6
6
  from dataclasses import dataclass
7
- from typing import Any, Iterable, Optional, Union
7
+ from typing import Any, Iterable, NamedTuple, Optional, Union
8
8
 
9
9
  import numpy as np
10
10
  from numpy.typing import NDArray
@@ -22,8 +22,7 @@ SOURCE_INDEX = "source_index"
22
22
  BOX_COUNT = "box_count"
23
23
 
24
24
 
25
- @dataclass(frozen=True)
26
- class SourceIndex:
25
+ class SourceIndex(NamedTuple):
27
26
  """
28
27
  The indices of the source image, box and channel.
29
28