dataeval 0.82.1__py3-none-any.whl → 0.84.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (45) hide show
  1. dataeval/__init__.py +7 -2
  2. dataeval/config.py +13 -3
  3. dataeval/metadata/__init__.py +2 -2
  4. dataeval/metadata/_ood.py +144 -27
  5. dataeval/metrics/bias/__init__.py +11 -1
  6. dataeval/metrics/bias/_balance.py +3 -3
  7. dataeval/metrics/bias/_completeness.py +130 -0
  8. dataeval/metrics/estimators/_ber.py +2 -1
  9. dataeval/metrics/stats/_base.py +31 -36
  10. dataeval/metrics/stats/_dimensionstats.py +2 -2
  11. dataeval/metrics/stats/_hashstats.py +2 -2
  12. dataeval/metrics/stats/_imagestats.py +4 -4
  13. dataeval/metrics/stats/_labelstats.py +4 -45
  14. dataeval/metrics/stats/_pixelstats.py +2 -2
  15. dataeval/metrics/stats/_visualstats.py +2 -2
  16. dataeval/outputs/__init__.py +4 -2
  17. dataeval/outputs/_bias.py +31 -22
  18. dataeval/outputs/_metadata.py +7 -0
  19. dataeval/outputs/_stats.py +2 -3
  20. dataeval/typing.py +43 -12
  21. dataeval/utils/_array.py +26 -1
  22. dataeval/utils/_mst.py +1 -2
  23. dataeval/utils/data/_dataset.py +2 -0
  24. dataeval/utils/data/_embeddings.py +115 -32
  25. dataeval/utils/data/_images.py +38 -15
  26. dataeval/utils/data/_selection.py +7 -8
  27. dataeval/utils/data/_split.py +76 -129
  28. dataeval/utils/data/datasets/_base.py +4 -2
  29. dataeval/utils/data/datasets/_cifar10.py +17 -9
  30. dataeval/utils/data/datasets/_milco.py +18 -12
  31. dataeval/utils/data/datasets/_mnist.py +24 -8
  32. dataeval/utils/data/datasets/_ships.py +18 -8
  33. dataeval/utils/data/datasets/_types.py +1 -5
  34. dataeval/utils/data/datasets/_voc.py +47 -24
  35. dataeval/utils/data/selections/__init__.py +2 -0
  36. dataeval/utils/data/selections/_classfilter.py +1 -1
  37. dataeval/utils/data/selections/_prioritize.py +296 -0
  38. dataeval/utils/data/selections/_shuffle.py +13 -4
  39. dataeval/utils/metadata.py +1 -1
  40. dataeval/utils/torch/_gmm.py +3 -2
  41. {dataeval-0.82.1.dist-info → dataeval-0.84.0.dist-info}/METADATA +4 -4
  42. {dataeval-0.82.1.dist-info → dataeval-0.84.0.dist-info}/RECORD +44 -43
  43. dataeval/detectors/ood/metadata_ood_mi.py +0 -91
  44. {dataeval-0.82.1.dist-info → dataeval-0.84.0.dist-info}/LICENSE.txt +0 -0
  45. {dataeval-0.82.1.dist-info → dataeval-0.84.0.dist-info}/WHEEL +0 -0
@@ -2,6 +2,7 @@ from __future__ import annotations
2
2
 
3
3
  __all__ = []
4
4
 
5
+ import math
5
6
  import re
6
7
  import warnings
7
8
  from collections import ChainMap
@@ -17,26 +18,13 @@ from numpy.typing import NDArray
17
18
 
18
19
  from dataeval.config import get_max_processes
19
20
  from dataeval.outputs._stats import BaseStatsOutput, SourceIndex
20
- from dataeval.typing import ArrayLike, Dataset, ObjectDetectionTarget
21
+ from dataeval.typing import Array, ArrayLike, Dataset, ObjectDetectionTarget
21
22
  from dataeval.utils._array import to_numpy
22
23
  from dataeval.utils._image import normalize_image_shape, rescale
23
24
 
24
25
  DTYPE_REGEX = re.compile(r"NDArray\[np\.(.*?)\]")
25
26
 
26
-
27
- def normalize_box_shape(bounding_box: NDArray[Any]) -> NDArray[Any]:
28
- """
29
- Normalizes the bounding box shape into (N,4).
30
- """
31
- ndim = bounding_box.ndim
32
- if ndim == 1:
33
- return np.expand_dims(bounding_box, axis=0)
34
- elif ndim > 2:
35
- raise ValueError("Bounding boxes must have 2 dimensions: (# of boxes in an image, [X,Y,W,H]) -> (N,4)")
36
- else:
37
- return bounding_box
38
-
39
-
27
+ BoundingBox = tuple[float, float, float, float]
40
28
  TStatsOutput = TypeVar("TStatsOutput", bound=BaseStatsOutput, covariant=True)
41
29
 
42
30
 
@@ -46,11 +34,15 @@ class StatsProcessor(Generic[TStatsOutput]):
46
34
  image_function_map: dict[str, Callable[[StatsProcessor[TStatsOutput]], Any]] = {}
47
35
  channel_function_map: dict[str, Callable[[StatsProcessor[TStatsOutput]], Any]] = {}
48
36
 
49
- def __init__(self, image: NDArray[Any], box: NDArray[Any] | None, per_channel: bool) -> None:
37
+ def __init__(self, image: NDArray[Any], box: BoundingBox | None, per_channel: bool) -> None:
50
38
  self.raw = image
51
39
  self.width: int = image.shape[-1]
52
40
  self.height: int = image.shape[-2]
53
- self.box: NDArray[np.int64] = np.array([0, 0, self.width, self.height]) if box is None else box.astype(np.int64)
41
+ box = BoundingBox((0, 0, self.width, self.height)) if box is None else box
42
+ # Clip the bounding box to image
43
+ x0, y0 = (min(j, max(0, math.floor(box[i]))) for i, j in zip((0, 1), (self.width - 1, self.height - 1)))
44
+ x1, y1 = (min(j, max(1, math.ceil(box[i]))) for i, j in zip((2, 3), (self.width, self.height)))
45
+ self.box: NDArray[np.int64] = np.array([x0, y0, x1, y1], dtype=np.int64)
54
46
  self._per_channel = per_channel
55
47
  self._image = None
56
48
  self._shape = None
@@ -122,22 +114,17 @@ class StatsProcessorOutput:
122
114
 
123
115
  def process_stats(
124
116
  i: int,
125
- dataset: Dataset[ArrayLike] | Dataset[tuple[ArrayLike, Any, Any]],
126
- per_box: bool,
117
+ image: ArrayLike,
118
+ boxes: list[BoundingBox] | None,
127
119
  per_channel: bool,
128
120
  stats_processor_cls: Iterable[type[StatsProcessor[TStatsOutput]]],
129
121
  ) -> StatsProcessorOutput:
130
- data = dataset[i]
131
- image, target = (to_numpy(cast(ArrayLike, data[0])), data[1]) if isinstance(data, tuple) else (to_numpy(data), None)
132
- target = None if not isinstance(target, ObjectDetectionTarget) else target
133
- boxes = to_numpy(target.boxes) if target is not None else None
122
+ image = to_numpy(image)
134
123
  results_list: list[dict[str, Any]] = []
135
124
  source_indices: list[SourceIndex] = []
136
125
  box_counts: list[int] = []
137
126
  warnings_list: list[str] = []
138
- nboxes = [None] if boxes is None or not per_box else normalize_box_shape(boxes)
139
- for i_b, box in enumerate(nboxes):
140
- i_b = None if box is None else i_b
127
+ for i_b, box in [(None, None)] if boxes is None else enumerate(boxes):
141
128
  processor_list = [p(image, box, per_channel) for p in stats_processor_cls]
142
129
  if any(not p._is_valid_slice for p in processor_list) and i_b is not None and box is not None:
143
130
  warnings_list.append(f"Bounding box [{i}][{i_b}]: {box} is out of bounds of {image.shape}.")
@@ -151,17 +138,28 @@ def process_stats(
151
138
 
152
139
 
153
140
  def process_stats_unpack(
154
- i: int,
155
- dataset: Dataset[ArrayLike] | Dataset[tuple[ArrayLike, Any, Any]],
156
- per_box: bool,
141
+ args: tuple[int, Array, list[BoundingBox] | None],
157
142
  per_channel: bool,
158
143
  stats_processor_cls: Iterable[type[StatsProcessor[TStatsOutput]]],
159
144
  ) -> StatsProcessorOutput:
160
- return process_stats(i, dataset, per_box=per_box, per_channel=per_channel, stats_processor_cls=stats_processor_cls)
145
+ return process_stats(*args, per_channel=per_channel, stats_processor_cls=stats_processor_cls)
146
+
147
+
148
+ def _enumerate(dataset: Dataset[Array] | Dataset[tuple[Array, Any, Any]], per_box: bool):
149
+ for i in range(len(dataset)):
150
+ d = dataset[i]
151
+ image = d[0] if isinstance(d, tuple) else d
152
+ if per_box and isinstance(d, tuple) and isinstance(d[1], ObjectDetectionTarget):
153
+ boxes = cast(Array, d[1].boxes)
154
+ target = [BoundingBox(float(box[i]) for i in range(4)) for box in boxes]
155
+ else:
156
+ target = None
157
+
158
+ yield i, image, target
161
159
 
162
160
 
163
161
  def run_stats(
164
- dataset: Dataset[ArrayLike] | Dataset[tuple[ArrayLike, Any, Any]],
162
+ dataset: Dataset[Array] | Dataset[tuple[Array, Any, Any]],
165
163
  per_box: bool,
166
164
  per_channel: bool,
167
165
  stats_processor_cls: Iterable[type[StatsProcessor[TStatsOutput]]],
@@ -175,7 +173,7 @@ def run_stats(
175
173
 
176
174
  Parameters
177
175
  ----------
178
- data : Dataset[ArrayLike] | Dataset[tuple[ArrayLike, Any, Any]]
176
+ data : Dataset[Array] | Dataset[tuple[Array, Any, Any]]
179
177
  A dataset of images and targets to compute statistics on.
180
178
  per_box : bool
181
179
  A flag which determines if the statistics should be evaluated on a per-box basis or not.
@@ -206,18 +204,15 @@ def run_stats(
206
204
  warning_list = []
207
205
  stats_processor_cls = stats_processor_cls if isinstance(stats_processor_cls, Iterable) else [stats_processor_cls]
208
206
 
209
- # TODO: Introduce global controls for CPU job parallelism and GPU configurations
210
207
  with Pool(processes=get_max_processes()) as p:
211
208
  for r in tqdm.tqdm(
212
209
  p.imap(
213
210
  partial(
214
211
  process_stats_unpack,
215
- dataset=dataset,
216
- per_box=per_box,
217
212
  per_channel=per_channel,
218
213
  stats_processor_cls=stats_processor_cls,
219
214
  ),
220
- range(len(dataset)),
215
+ _enumerate(dataset, per_box),
221
216
  ),
222
217
  total=len(dataset),
223
218
  ):
@@ -9,7 +9,7 @@ import numpy as np
9
9
  from dataeval.metrics.stats._base import StatsProcessor, run_stats
10
10
  from dataeval.outputs import DimensionStatsOutput
11
11
  from dataeval.outputs._base import set_metadata
12
- from dataeval.typing import ArrayLike, Dataset
12
+ from dataeval.typing import Array, Dataset
13
13
  from dataeval.utils._image import get_bitdepth
14
14
 
15
15
 
@@ -34,7 +34,7 @@ class DimensionStatsProcessor(StatsProcessor[DimensionStatsOutput]):
34
34
 
35
35
  @set_metadata
36
36
  def dimensionstats(
37
- dataset: Dataset[ArrayLike] | Dataset[tuple[ArrayLike, Any, Any]],
37
+ dataset: Dataset[Array] | Dataset[tuple[Array, Any, Any]],
38
38
  *,
39
39
  per_box: bool = False,
40
40
  ) -> DimensionStatsOutput:
@@ -14,7 +14,7 @@ from scipy.fftpack import dct
14
14
  from dataeval.metrics.stats._base import StatsProcessor, run_stats
15
15
  from dataeval.outputs import HashStatsOutput
16
16
  from dataeval.outputs._base import set_metadata
17
- from dataeval.typing import ArrayLike, Dataset
17
+ from dataeval.typing import Array, ArrayLike, Dataset
18
18
  from dataeval.utils._array import as_numpy
19
19
  from dataeval.utils._image import normalize_image_shape, rescale
20
20
 
@@ -105,7 +105,7 @@ class HashStatsProcessor(StatsProcessor[HashStatsOutput]):
105
105
 
106
106
  @set_metadata
107
107
  def hashstats(
108
- dataset: Dataset[ArrayLike] | Dataset[tuple[ArrayLike, Any, Any]],
108
+ dataset: Dataset[Array] | Dataset[tuple[Array, Any, Any]],
109
109
  *,
110
110
  per_box: bool = False,
111
111
  ) -> HashStatsOutput:
@@ -10,12 +10,12 @@ from dataeval.metrics.stats._pixelstats import PixelStatsProcessor
10
10
  from dataeval.metrics.stats._visualstats import VisualStatsProcessor
11
11
  from dataeval.outputs import ChannelStatsOutput, ImageStatsOutput
12
12
  from dataeval.outputs._base import set_metadata
13
- from dataeval.typing import ArrayLike, Dataset
13
+ from dataeval.typing import Array, Dataset
14
14
 
15
15
 
16
16
  @overload
17
17
  def imagestats(
18
- dataset: Dataset[ArrayLike] | Dataset[tuple[ArrayLike, Any, Any]],
18
+ dataset: Dataset[Array] | Dataset[tuple[Array, Any, Any]],
19
19
  *,
20
20
  per_box: bool = False,
21
21
  per_channel: Literal[True],
@@ -24,7 +24,7 @@ def imagestats(
24
24
 
25
25
  @overload
26
26
  def imagestats(
27
- dataset: Dataset[ArrayLike] | Dataset[tuple[ArrayLike, Any, Any]],
27
+ dataset: Dataset[Array] | Dataset[tuple[Array, Any, Any]],
28
28
  *,
29
29
  per_box: bool = False,
30
30
  per_channel: Literal[False] = False,
@@ -33,7 +33,7 @@ def imagestats(
33
33
 
34
34
  @set_metadata
35
35
  def imagestats(
36
- dataset: Dataset[ArrayLike] | Dataset[tuple[ArrayLike, Any, Any]],
36
+ dataset: Dataset[Array] | Dataset[tuple[Array, Any, Any]],
37
37
  *,
38
38
  per_box: bool = False,
39
39
  per_channel: bool = False,
@@ -5,54 +5,16 @@ __all__ = []
5
5
  from collections import Counter, defaultdict
6
6
  from typing import Any, Mapping, TypeVar
7
7
 
8
- import numpy as np
9
-
10
8
  from dataeval.outputs import LabelStatsOutput
11
9
  from dataeval.outputs._base import set_metadata
12
- from dataeval.typing import AnnotatedDataset, ArrayLike
13
- from dataeval.utils._array import as_numpy
10
+ from dataeval.typing import AnnotatedDataset
14
11
  from dataeval.utils.data._metadata import Metadata
15
12
 
16
13
  TValue = TypeVar("TValue")
17
14
 
18
15
 
19
- def _ensure_2d(labels: ArrayLike) -> ArrayLike:
20
- if isinstance(labels, np.ndarray):
21
- return labels[:, None]
22
- else:
23
- return [[lbl] for lbl in labels] # type: ignore
24
-
25
-
26
- def _get_list_depth(lst):
27
- if isinstance(lst, list) and lst:
28
- return 1 + max(_get_list_depth(item) for item in lst)
29
- return 0
30
-
31
-
32
- def _check_labels_dimension(labels: ArrayLike) -> ArrayLike:
33
- # Check for nested lists beyond 2 levels
34
-
35
- if isinstance(labels, np.ndarray):
36
- if labels.ndim == 1:
37
- return _ensure_2d(labels)
38
- elif labels.ndim == 2:
39
- return labels
40
- else:
41
- raise ValueError("The label array must not have more than 2 dimensions.")
42
- elif isinstance(labels, list):
43
- depth = _get_list_depth(labels)
44
- if depth == 1:
45
- return _ensure_2d(labels)
46
- elif depth == 2:
47
- return labels
48
- else:
49
- raise ValueError("The label list must not be empty or have more than 2 levels of nesting.")
50
- else:
51
- raise TypeError("Labels must be either a NumPy array or a list.")
52
-
53
-
54
16
  def _sort_to_list(d: Mapping[int, TValue]) -> list[TValue]:
55
- return [v for _, v in sorted(d.items())]
17
+ return [t[1] for t in sorted(d.items())]
56
18
 
57
19
 
58
20
  @set_metadata
@@ -98,12 +60,9 @@ def labelstats(dataset: Metadata | AnnotatedDataset[Any]) -> LabelStatsOutput:
98
60
  label_per_image: list[int] = []
99
61
 
100
62
  index2label = dict(enumerate(dataset.class_names))
101
- labels = [target.labels.tolist() for target in dataset.targets]
102
-
103
- labels_2d = _check_labels_dimension(labels)
104
63
 
105
- for i, group in enumerate(labels_2d):
106
- group = as_numpy(group).tolist()
64
+ for i, target in enumerate(dataset.targets):
65
+ group = target.labels.tolist()
107
66
 
108
67
  # Count occurrences of each label in all sublists
109
68
  label_counts.update(group)
@@ -10,7 +10,7 @@ from scipy.stats import entropy, kurtosis, skew
10
10
  from dataeval.metrics.stats._base import StatsProcessor, run_stats
11
11
  from dataeval.outputs import PixelStatsOutput
12
12
  from dataeval.outputs._base import set_metadata
13
- from dataeval.typing import ArrayLike, Dataset
13
+ from dataeval.typing import Array, Dataset
14
14
 
15
15
 
16
16
  class PixelStatsProcessor(StatsProcessor[PixelStatsOutput]):
@@ -37,7 +37,7 @@ class PixelStatsProcessor(StatsProcessor[PixelStatsOutput]):
37
37
 
38
38
  @set_metadata
39
39
  def pixelstats(
40
- dataset: Dataset[ArrayLike] | Dataset[tuple[ArrayLike, Any, Any]],
40
+ dataset: Dataset[Array] | Dataset[tuple[Array, Any, Any]],
41
41
  *,
42
42
  per_box: bool = False,
43
43
  per_channel: bool = False,
@@ -9,7 +9,7 @@ import numpy as np
9
9
  from dataeval.metrics.stats._base import StatsProcessor, run_stats
10
10
  from dataeval.outputs import VisualStatsOutput
11
11
  from dataeval.outputs._base import set_metadata
12
- from dataeval.typing import ArrayLike, Dataset
12
+ from dataeval.typing import Array, Dataset
13
13
  from dataeval.utils._image import edge_filter
14
14
 
15
15
  QUARTILES = (0, 25, 50, 75, 100)
@@ -44,7 +44,7 @@ class VisualStatsProcessor(StatsProcessor[VisualStatsOutput]):
44
44
 
45
45
  @set_metadata
46
46
  def visualstats(
47
- dataset: Dataset[ArrayLike] | Dataset[tuple[ArrayLike, Any, Any]],
47
+ dataset: Dataset[Array] | Dataset[tuple[Array, Any, Any]],
48
48
  *,
49
49
  per_box: bool = False,
50
50
  per_channel: bool = False,
@@ -4,11 +4,11 @@ as well as runtime metadata for reproducibility and logging.
4
4
  """
5
5
 
6
6
  from ._base import ExecutionMetadata
7
- from ._bias import BalanceOutput, CoverageOutput, DiversityOutput, LabelParityOutput, ParityOutput
7
+ from ._bias import BalanceOutput, CompletenessOutput, CoverageOutput, DiversityOutput, LabelParityOutput, ParityOutput
8
8
  from ._drift import DriftMMDOutput, DriftOutput
9
9
  from ._estimators import BEROutput, ClustererOutput, DivergenceOutput, UAPOutput
10
10
  from ._linters import DuplicatesOutput, OutliersOutput
11
- from ._metadata import MetadataDistanceOutput, MetadataDistanceValues, MostDeviatedFactorsOutput
11
+ from ._metadata import MetadataDistanceOutput, MetadataDistanceValues, MostDeviatedFactorsOutput, OODPredictorOutput
12
12
  from ._ood import OODOutput, OODScoreOutput
13
13
  from ._stats import (
14
14
  ChannelStatsOutput,
@@ -29,6 +29,7 @@ __all__ = [
29
29
  "ChannelStatsOutput",
30
30
  "ClustererOutput",
31
31
  "CoverageOutput",
32
+ "CompletenessOutput",
32
33
  "DimensionStatsOutput",
33
34
  "DivergenceOutput",
34
35
  "DiversityOutput",
@@ -44,6 +45,7 @@ __all__ = [
44
45
  "MetadataDistanceValues",
45
46
  "MostDeviatedFactorsOutput",
46
47
  "OODOutput",
48
+ "OODPredictorOutput",
47
49
  "OODScoreOutput",
48
50
  "OutliersOutput",
49
51
  "ParityOutput",
dataeval/outputs/_bias.py CHANGED
@@ -14,9 +14,10 @@ with contextlib.suppress(ImportError):
14
14
  from matplotlib.figure import Figure
15
15
 
16
16
  from dataeval.outputs._base import Output
17
- from dataeval.typing import ArrayLike
18
- from dataeval.utils._array import to_numpy
17
+ from dataeval.typing import ArrayLike, Dataset
18
+ from dataeval.utils._array import as_numpy, channels_first_to_last
19
19
  from dataeval.utils._plot import heatmap
20
+ from dataeval.utils.data._images import Images
20
21
 
21
22
  TData = TypeVar("TData", np.float64, NDArray[np.float64])
22
23
 
@@ -107,13 +108,13 @@ class CoverageOutput(Output):
107
108
  critical_value_radii: NDArray[np.float64]
108
109
  coverage_radius: float
109
110
 
110
- def plot(self, images: ArrayLike, top_k: int = 6) -> Figure:
111
+ def plot(self, images: Images[Any] | Dataset[Any], top_k: int = 6) -> Figure:
111
112
  """
112
113
  Plot the top k images together for visualization.
113
114
 
114
115
  Parameters
115
116
  ----------
116
- images : ArrayLike
117
+ images : Images or Dataset
117
118
  Original images (not embeddings) in (N, C, H, W) or (N, H, W) format
118
119
  top_k : int, default 6
119
120
  Number of images to plot (plotting assumes groups of 3)
@@ -130,46 +131,54 @@ class CoverageOutput(Output):
130
131
  import matplotlib.pyplot as plt
131
132
 
132
133
  # Determine which images to plot
133
- highest_uncovered_indices = self.uncovered_indices[:top_k]
134
+ selected_indices = self.uncovered_indices[:top_k]
134
135
 
135
- # Grab the images
136
- selected_images = to_numpy(images)[highest_uncovered_indices]
136
+ images = Images(images) if isinstance(images, Dataset) else images
137
137
 
138
138
  # Plot the images
139
- num_images = min(top_k, len(images))
140
-
141
- ndim = selected_images.ndim
142
- if ndim == 4:
143
- selected_images = np.moveaxis(selected_images, 1, -1)
144
- elif ndim == 3:
145
- selected_images = np.repeat(selected_images[:, :, :, np.newaxis], 3, axis=-1)
146
- else:
147
- raise ValueError(
148
- f"Expected a (N,C,H,W) or a (N, H, W) set of images, but got a {ndim}-dimensional set of images."
149
- )
139
+ num_images = min(top_k, len(selected_indices))
150
140
 
151
141
  rows = int(np.ceil(num_images / 3))
152
142
  fig, axs = plt.subplots(rows, 3, figsize=(9, 3 * rows))
153
143
 
154
144
  if rows == 1:
155
145
  for j in range(3):
156
- if j >= len(selected_images):
146
+ if j >= len(selected_indices):
157
147
  continue
158
- axs[j].imshow(selected_images[j])
148
+ image = channels_first_to_last(as_numpy(images[selected_indices[j]]))
149
+ axs[j].imshow(image)
159
150
  axs[j].axis("off")
160
151
  else:
161
152
  for i in range(rows):
162
153
  for j in range(3):
163
154
  i_j = i * 3 + j
164
- if i_j >= len(selected_images):
155
+ if i_j >= len(selected_indices):
165
156
  continue
166
- axs[i, j].imshow(selected_images[i_j])
157
+ image = channels_first_to_last(as_numpy(images[selected_indices[i_j]]))
158
+ axs[i, j].imshow(image)
167
159
  axs[i, j].axis("off")
168
160
 
169
161
  fig.tight_layout()
170
162
  return fig
171
163
 
172
164
 
165
+ @dataclass(frozen=True)
166
+ class CompletenessOutput(Output):
167
+ """
168
+ Output from the completeness function.
169
+
170
+ Attributes
171
+ ----------
172
+ fraction_filled : float
173
+ Fraction of boxes that contain at least one data point
174
+ empty_box_centers : List[np.ndarray]
175
+ List of coordinates for centers of empty boxes
176
+ """
177
+
178
+ fraction_filled: float
179
+ empty_box_centers: NDArray[np.float64]
180
+
181
+
173
182
  @dataclass(frozen=True)
174
183
  class BalanceOutput(Output):
175
184
  """
@@ -52,3 +52,10 @@ class MetadataDistanceOutput(MappingOutput[str, MetadataDistanceValues]):
52
52
  value : :class:`.MetadataDistanceValues`
53
53
  Output per feature name containing the statistic, statistic location, distance, and pvalue.
54
54
  """
55
+
56
+
57
+ class OODPredictorOutput(MappingOutput[str, float]):
58
+ """
59
+ Output class for results of :func:`find_ood_predictors` for the
60
+ mutual information between factors and being out of distribution
61
+ """
@@ -4,7 +4,7 @@ __all__ = []
4
4
 
5
5
  import contextlib
6
6
  from dataclasses import dataclass
7
- from typing import Any, Iterable, Optional, Union
7
+ from typing import Any, Iterable, NamedTuple, Optional, Union
8
8
 
9
9
  import numpy as np
10
10
  from numpy.typing import NDArray
@@ -22,8 +22,7 @@ SOURCE_INDEX = "source_index"
22
22
  BOX_COUNT = "box_count"
23
23
 
24
24
 
25
- @dataclass(frozen=True)
26
- class SourceIndex:
25
+ class SourceIndex(NamedTuple):
27
26
  """
28
27
  The indices of the source image, box and channel.
29
28
 
dataeval/typing.py CHANGED
@@ -1,5 +1,5 @@
1
1
  """
2
- Common type hints used for interoperability with DataEval.
2
+ Common type protocols used for interoperability with DataEval.
3
3
  """
4
4
 
5
5
  __all__ = [
@@ -16,13 +16,14 @@ __all__ = [
16
16
  "SegmentationTarget",
17
17
  "SegmentationDatum",
18
18
  "SegmentationDataset",
19
+ "Transform",
19
20
  ]
20
21
 
21
22
 
22
23
  import sys
23
24
  from typing import Any, Generic, Iterator, Protocol, Sequence, TypedDict, TypeVar, Union, runtime_checkable
24
25
 
25
- from typing_extensions import NotRequired, Required
26
+ from typing_extensions import NotRequired, ReadOnly, Required
26
27
 
27
28
  if sys.version_info >= (3, 10):
28
29
  from typing import TypeAlias
@@ -66,6 +67,7 @@ class Array(Protocol):
66
67
  def __len__(self) -> int: ...
67
68
 
68
69
 
70
+ T = TypeVar("T")
69
71
  _T_co = TypeVar("_T_co", covariant=True)
70
72
  _ScalarType = Union[int, float, bool, str]
71
73
  ArrayLike: TypeAlias = Union[Sequence[_ScalarType], Sequence[Sequence[_ScalarType]], Sequence[Array], Array]
@@ -89,8 +91,8 @@ class DatasetMetadata(TypedDict, total=False):
89
91
  A lookup table converting label value to class name
90
92
  """
91
93
 
92
- id: Required[str]
93
- index2label: NotRequired[dict[int, str]]
94
+ id: Required[ReadOnly[str]]
95
+ index2label: NotRequired[ReadOnly[dict[int, str]]]
94
96
 
95
97
 
96
98
  @runtime_checkable
@@ -140,7 +142,7 @@ class AnnotatedDataset(Dataset[_T_co], Generic[_T_co], Protocol):
140
142
 
141
143
  ImageClassificationDatum: TypeAlias = tuple[Array, Array, dict[str, Any]]
142
144
  """
143
- A type definition for an image classification datum tuple.
145
+ Type alias for an image classification datum tuple.
144
146
 
145
147
  - :class:`Array` of shape (C, H, W) - Image data in channel, height, width format.
146
148
  - :class:`Array` of shape (N,) - Class label as one-hot encoded ground-truth or prediction confidences.
@@ -150,7 +152,7 @@ A type definition for an image classification datum tuple.
150
152
 
151
153
  ImageClassificationDataset: TypeAlias = AnnotatedDataset[ImageClassificationDatum]
152
154
  """
153
- A type definition for an :class:`AnnotatedDataset` of :class:`ImageClassificationDatum` elements.
155
+ Type alias for an :class:`AnnotatedDataset` of :class:`ImageClassificationDatum` elements.
154
156
  """
155
157
 
156
158
  # ========== OBJECT DETECTION DATASETS ==========
@@ -159,7 +161,7 @@ A type definition for an :class:`AnnotatedDataset` of :class:`ImageClassificatio
159
161
  @runtime_checkable
160
162
  class ObjectDetectionTarget(Protocol):
161
163
  """
162
- A protocol for targets in an Object Detection dataset.
164
+ Protocol for targets in an Object Detection dataset.
163
165
 
164
166
  Attributes
165
167
  ----------
@@ -180,7 +182,7 @@ class ObjectDetectionTarget(Protocol):
180
182
 
181
183
  ObjectDetectionDatum: TypeAlias = tuple[Array, ObjectDetectionTarget, dict[str, Any]]
182
184
  """
183
- A type definition for an object detection datum tuple.
185
+ Type alias for an object detection datum tuple.
184
186
 
185
187
  - :class:`Array` of shape (C, H, W) - Image data in channel, height, width format.
186
188
  - :class:`ObjectDetectionTarget` - Object detection target information for the image.
@@ -190,7 +192,7 @@ A type definition for an object detection datum tuple.
190
192
 
191
193
  ObjectDetectionDataset: TypeAlias = AnnotatedDataset[ObjectDetectionDatum]
192
194
  """
193
- A type definition for an :class:`AnnotatedDataset` of :class:`ObjectDetectionDatum` elements.
195
+ Type alias for an :class:`AnnotatedDataset` of :class:`ObjectDetectionDatum` elements.
194
196
  """
195
197
 
196
198
 
@@ -200,7 +202,7 @@ A type definition for an :class:`AnnotatedDataset` of :class:`ObjectDetectionDat
200
202
  @runtime_checkable
201
203
  class SegmentationTarget(Protocol):
202
204
  """
203
- A protocol for targets in a Segmentation dataset.
205
+ Protocol for targets in a Segmentation dataset.
204
206
 
205
207
  Attributes
206
208
  ----------
@@ -221,7 +223,7 @@ class SegmentationTarget(Protocol):
221
223
 
222
224
  SegmentationDatum: TypeAlias = tuple[Array, SegmentationTarget, dict[str, Any]]
223
225
  """
224
- A type definition for an image classification datum tuple.
226
+ Type alias for an image classification datum tuple.
225
227
 
226
228
  - :class:`Array` of shape (C, H, W) - Image data in channel, height, width format.
227
229
  - :class:`SegmentationTarget` - Segmentation target information for the image.
@@ -230,5 +232,34 @@ A type definition for an image classification datum tuple.
230
232
 
231
233
  SegmentationDataset: TypeAlias = AnnotatedDataset[SegmentationDatum]
232
234
  """
233
- A type definition for an :class:`AnnotatedDataset` of :class:`SegmentationDatum` elements.
235
+ Type alias for an :class:`AnnotatedDataset` of :class:`SegmentationDatum` elements.
234
236
  """
237
+
238
+
239
+ @runtime_checkable
240
+ class Transform(Generic[T], Protocol):
241
+ """
242
+ Protocol defining a transform function.
243
+
244
+ Requires a `__call__` method that returns transformed data.
245
+
246
+ Example
247
+ -------
248
+ >>> from typing import Any
249
+ >>> from numpy.typing import NDArray
250
+
251
+ >>> class MyTransform:
252
+ ... def __init__(self, divisor: float) -> None:
253
+ ... self.divisor = divisor
254
+ ...
255
+ ... def __call__(self, data: NDArray[Any], /) -> NDArray[Any]:
256
+ ... return data / self.divisor
257
+
258
+ >>> my_transform = MyTransform(divisor=255.0)
259
+ >>> isinstance(my_transform, Transform)
260
+ True
261
+ >>> my_transform(np.array([1, 2, 3]))
262
+ array([0.004, 0.008, 0.012])
263
+ """
264
+
265
+ def __call__(self, data: T, /) -> T: ...
dataeval/utils/_array.py CHANGED
@@ -13,7 +13,7 @@ import torch
13
13
  from numpy.typing import NDArray
14
14
 
15
15
  from dataeval._log import LogMessage
16
- from dataeval.typing import ArrayLike
16
+ from dataeval.typing import Array, ArrayLike
17
17
 
18
18
  _logger = logging.getLogger(__name__)
19
19
 
@@ -167,3 +167,28 @@ def flatten(array: ArrayLike) -> NDArray[Any]:
167
167
  """
168
168
  nparr = as_numpy(array)
169
169
  return nparr.reshape((nparr.shape[0], -1))
170
+
171
+
172
+ _TArray = TypeVar("_TArray", bound=Array)
173
+
174
+
175
+ def channels_first_to_last(array: _TArray) -> _TArray:
176
+ """
177
+ Converts array from channels first to channels last format
178
+
179
+ Parameters
180
+ ----------
181
+ array : ArrayLike
182
+ Input array
183
+
184
+ Returns
185
+ -------
186
+ ArrayLike
187
+ Converted array
188
+ """
189
+ if isinstance(array, np.ndarray):
190
+ return np.transpose(array, (1, 2, 0))
191
+ elif isinstance(array, torch.Tensor):
192
+ return torch.permute(array, (1, 2, 0))
193
+ else:
194
+ raise TypeError(f"Unsupported array type {type(array)} for conversion.")
dataeval/utils/_mst.py CHANGED
@@ -10,10 +10,9 @@ from scipy.sparse.csgraph import minimum_spanning_tree as mst
10
10
  from scipy.spatial.distance import pdist, squareform
11
11
  from sklearn.neighbors import NearestNeighbors
12
12
 
13
+ from dataeval.config import EPSILON
13
14
  from dataeval.utils._array import flatten
14
15
 
15
- EPSILON = 1e-5
16
-
17
16
 
18
17
  def minimum_spanning_tree(X: NDArray[Any]) -> Any:
19
18
  """