dataeval 0.83.0__py3-none-any.whl → 0.84.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
dataeval/__init__.py CHANGED
@@ -8,7 +8,7 @@ shifts that impact performance of deployed models.
8
8
  from __future__ import annotations
9
9
 
10
10
  __all__ = ["config", "detectors", "log", "metrics", "typing", "utils", "workflows"]
11
- __version__ = "0.83.0"
11
+ __version__ = "0.84.0"
12
12
 
13
13
  import logging
14
14
 
dataeval/config.py CHANGED
@@ -45,13 +45,13 @@ def _todevice(device: DeviceLike) -> torch.device:
45
45
  return torch.device(*device) if isinstance(device, tuple) else torch.device(device)
46
46
 
47
47
 
48
- def set_device(device: DeviceLike) -> None:
48
+ def set_device(device: DeviceLike | None) -> None:
49
49
  """
50
50
  Sets the default device to use when executing against a PyTorch backend.
51
51
 
52
52
  Parameters
53
53
  ----------
54
- device : DeviceLike
54
+ device : DeviceLike or None
55
55
  The default device to use. See documentation for more information.
56
56
 
57
57
  See Also
@@ -59,7 +59,7 @@ def set_device(device: DeviceLike) -> None:
59
59
  `torch.device <https://pytorch.org/docs/stable/tensor_attributes.html#torch.device>`_
60
60
  """
61
61
  global _device
62
- _device = _todevice(device)
62
+ _device = None if device is None else _todevice(device)
63
63
 
64
64
 
65
65
  def get_device(override: DeviceLike | None = None) -> torch.device:
@@ -6,10 +6,12 @@ representation which may impact model performance.
6
6
  __all__ = [
7
7
  "BalanceOutput",
8
8
  "CoverageOutput",
9
+ "CompletenessOutput",
9
10
  "DiversityOutput",
10
11
  "LabelParityOutput",
11
12
  "ParityOutput",
12
13
  "balance",
14
+ "completeness",
13
15
  "coverage",
14
16
  "diversity",
15
17
  "label_parity",
@@ -17,7 +19,15 @@ __all__ = [
17
19
  ]
18
20
 
19
21
  from dataeval.metrics.bias._balance import balance
22
+ from dataeval.metrics.bias._completeness import completeness
20
23
  from dataeval.metrics.bias._coverage import coverage
21
24
  from dataeval.metrics.bias._diversity import diversity
22
25
  from dataeval.metrics.bias._parity import label_parity, parity
23
- from dataeval.outputs._bias import BalanceOutput, CoverageOutput, DiversityOutput, LabelParityOutput, ParityOutput
26
+ from dataeval.outputs._bias import (
27
+ BalanceOutput,
28
+ CompletenessOutput,
29
+ CoverageOutput,
30
+ DiversityOutput,
31
+ LabelParityOutput,
32
+ ParityOutput,
33
+ )
@@ -0,0 +1,130 @@
1
+ from __future__ import annotations
2
+
3
+ import itertools
4
+
5
+ __all__ = []
6
+
7
+
8
+ import numpy as np
9
+
10
+ from dataeval.config import EPSILON
11
+ from dataeval.outputs import CompletenessOutput
12
+ from dataeval.typing import ArrayLike
13
+ from dataeval.utils._array import ensure_embeddings
14
+
15
+
16
+ def completeness(embeddings: ArrayLike, quantiles: int) -> CompletenessOutput:
17
+ """
18
+ Calculate the fraction of boxes in a grid defined by quantiles that
19
+ contain at least one data point.
20
+ Also returns the center coordinates of each empty box.
21
+
22
+ Parameters
23
+ ----------
24
+ embeddings : ArrayLike
25
+ Embedded dataset (or other low-dimensional data) (nxp)
26
+ quantiles : int
27
+ number of quantile values to use for partitioning each dimension
28
+ e.g., 1 would create a grid of 2^p boxes, 2, 3^p etc..
29
+
30
+ Returns
31
+ -------
32
+ CompletenessOutput
33
+ - fraction_filled: float - Fraction of boxes that contain at least one
34
+ data point
35
+ - empty_box_centers: List[np.ndarray] - List of coordinates for centers of empty
36
+ boxes
37
+
38
+ Raises
39
+ ------
40
+ ValueError
41
+ If embeddings are too high-dimensional (>10)
42
+ ValueError
43
+ If there are too many quantiles (>2)
44
+ ValueError
45
+ If embedding is invalid shape
46
+
47
+ Example
48
+ -------
49
+ >>> embs = np.array([[1, 0], [0, 1], [1, 1]])
50
+ >>> quantiles = 1
51
+ >>> result = completeness(embs, quantiles)
52
+ >>> result.fraction_filled
53
+ 0.75
54
+
55
+ Reference
56
+ ---------
57
+ This implementation is based on https://arxiv.org/abs/2002.03147.
58
+
59
+ [1] Byun, Taejoon, and Sanjai Rayadurgam. “Manifold for Machine Learning Assurance.”
60
+ Proceedings of the ACM/IEEE 42nd International Conference on Software Engineering
61
+ """
62
+ # Ensure proper data format
63
+ embeddings = ensure_embeddings(embeddings, dtype=np.float64, unit_interval=False)
64
+
65
+ # Get data dimensions
66
+ n, p = embeddings.shape
67
+ if quantiles > 2 or quantiles <= 0:
68
+ raise ValueError(
69
+ f"Number of quantiles ({quantiles}) is greater than 2 or is nonpositive. \
70
+ The metric scales exponentially in this value. Please 1 or 2 quantiles."
71
+ )
72
+ if p > 10:
73
+ raise ValueError(
74
+ f"Dimension of embeddings ({p}) is greater than 10. \
75
+ The metric scales exponentially in this value. Please reduce the embedding dimension."
76
+ )
77
+ if n == 0 or p == 0:
78
+ raise ValueError("Your provided embeddings do not contain any data!")
79
+ # n+2 edges partition the embedding dimension (e.g. [0,0.5,1] for quantiles = 1)
80
+ quantile_vec = np.linspace(0, 1, quantiles + 2)
81
+
82
+ # Calculate the bin edges for each dimension based on quantiles
83
+ bin_edges = []
84
+ for dim in range(p):
85
+ # Calculate the quantile values for this feature
86
+ edges = np.array(np.quantile(embeddings[:, dim], quantile_vec))
87
+ # Make sure the last bin contains all the remaining points
88
+ edges[-1] += EPSILON
89
+ bin_edges.append(edges)
90
+ # Convert each data point into its corresponding grid cell indices
91
+ grid_indices = []
92
+ for dim in range(p):
93
+ # For each dimension, find which bin each data point belongs to
94
+ # Digitize is 1 indexed so we subtract 1
95
+ indices = np.digitize(embeddings[:, dim], bin_edges[dim]) - 1
96
+ grid_indices.append(indices)
97
+
98
+ # Make the rows the data point and the column the grid index
99
+ grid_coords = np.array(grid_indices).T
100
+
101
+ # Use set to find unique tuple of grid coordinates
102
+ occupied_cells = set(map(tuple, grid_coords))
103
+
104
+ # For the fraction
105
+ num_occupied_cells = len(occupied_cells)
106
+
107
+ # Calculate total possible cells in the grid
108
+ num_bins_per_dim = [len(edges) - 1 for edges in bin_edges]
109
+ total_possible_cells = np.prod(num_bins_per_dim)
110
+
111
+ # Generate all possible grid cells
112
+ all_cells = set(itertools.product(*[range(bins) for bins in num_bins_per_dim]))
113
+
114
+ # Find the empty cells (cells with no data points)
115
+ empty_cells = all_cells - occupied_cells
116
+
117
+ # Calculate center points of empty boxes
118
+ empty_box_centers = []
119
+ for cell in empty_cells:
120
+ center_coords = []
121
+ for dim, idx in enumerate(cell):
122
+ # Calculate center of the bin as midpoint between edges
123
+ center = (bin_edges[dim][idx] + bin_edges[dim][idx + 1]) / 2
124
+ center_coords.append(center)
125
+ empty_box_centers.append(np.array(center_coords))
126
+
127
+ # Calculate the fraction
128
+ fraction = float(num_occupied_cells / total_possible_cells)
129
+ empty_box_centers = np.array(empty_box_centers)
130
+ return CompletenessOutput(fraction, empty_box_centers)
@@ -2,6 +2,7 @@ from __future__ import annotations
2
2
 
3
3
  __all__ = []
4
4
 
5
+ import math
5
6
  import re
6
7
  import warnings
7
8
  from collections import ChainMap
@@ -9,7 +10,7 @@ from copy import deepcopy
9
10
  from dataclasses import dataclass
10
11
  from functools import partial
11
12
  from multiprocessing import Pool
12
- from typing import Any, Callable, Generic, Iterable, Sequence, TypeVar
13
+ from typing import Any, Callable, Generic, Iterable, Sequence, TypeVar, cast
13
14
 
14
15
  import numpy as np
15
16
  import tqdm
@@ -23,20 +24,7 @@ from dataeval.utils._image import normalize_image_shape, rescale
23
24
 
24
25
  DTYPE_REGEX = re.compile(r"NDArray\[np\.(.*?)\]")
25
26
 
26
-
27
- def normalize_box_shape(bounding_box: NDArray[Any]) -> NDArray[Any]:
28
- """
29
- Normalizes the bounding box shape into (N,4).
30
- """
31
- ndim = bounding_box.ndim
32
- if ndim == 1:
33
- return np.expand_dims(bounding_box, axis=0)
34
- elif ndim > 2:
35
- raise ValueError("Bounding boxes must have 2 dimensions: (# of boxes in an image, [X,Y,W,H]) -> (N,4)")
36
- else:
37
- return bounding_box
38
-
39
-
27
+ BoundingBox = tuple[float, float, float, float]
40
28
  TStatsOutput = TypeVar("TStatsOutput", bound=BaseStatsOutput, covariant=True)
41
29
 
42
30
 
@@ -46,11 +34,15 @@ class StatsProcessor(Generic[TStatsOutput]):
46
34
  image_function_map: dict[str, Callable[[StatsProcessor[TStatsOutput]], Any]] = {}
47
35
  channel_function_map: dict[str, Callable[[StatsProcessor[TStatsOutput]], Any]] = {}
48
36
 
49
- def __init__(self, image: NDArray[Any], box: NDArray[Any] | None, per_channel: bool) -> None:
37
+ def __init__(self, image: NDArray[Any], box: BoundingBox | None, per_channel: bool) -> None:
50
38
  self.raw = image
51
39
  self.width: int = image.shape[-1]
52
40
  self.height: int = image.shape[-2]
53
- self.box: NDArray[np.int64] = np.array([0, 0, self.width, self.height]) if box is None else box.astype(np.int64)
41
+ box = BoundingBox((0, 0, self.width, self.height)) if box is None else box
42
+ # Clip the bounding box to image
43
+ x0, y0 = (min(j, max(0, math.floor(box[i]))) for i, j in zip((0, 1), (self.width - 1, self.height - 1)))
44
+ x1, y1 = (min(j, max(1, math.ceil(box[i]))) for i, j in zip((2, 3), (self.width, self.height)))
45
+ self.box: NDArray[np.int64] = np.array([x0, y0, x1, y1], dtype=np.int64)
54
46
  self._per_channel = per_channel
55
47
  self._image = None
56
48
  self._shape = None
@@ -123,18 +115,16 @@ class StatsProcessorOutput:
123
115
  def process_stats(
124
116
  i: int,
125
117
  image: ArrayLike,
126
- target: Any,
127
- per_box: bool,
118
+ boxes: list[BoundingBox] | None,
128
119
  per_channel: bool,
129
120
  stats_processor_cls: Iterable[type[StatsProcessor[TStatsOutput]]],
130
121
  ) -> StatsProcessorOutput:
131
122
  image = to_numpy(image)
132
- boxes = to_numpy(target.boxes) if isinstance(target, ObjectDetectionTarget) else None
133
123
  results_list: list[dict[str, Any]] = []
134
124
  source_indices: list[SourceIndex] = []
135
125
  box_counts: list[int] = []
136
126
  warnings_list: list[str] = []
137
- for i_b, box in [(None, None)] if boxes is None else enumerate(normalize_box_shape(boxes)):
127
+ for i_b, box in [(None, None)] if boxes is None else enumerate(boxes):
138
128
  processor_list = [p(image, box, per_channel) for p in stats_processor_cls]
139
129
  if any(not p._is_valid_slice for p in processor_list) and i_b is not None and box is not None:
140
130
  warnings_list.append(f"Bounding box [{i}][{i_b}]: {box} is out of bounds of {image.shape}.")
@@ -148,12 +138,24 @@ def process_stats(
148
138
 
149
139
 
150
140
  def process_stats_unpack(
151
- args: tuple[int, ArrayLike, Any],
152
- per_box: bool,
141
+ args: tuple[int, Array, list[BoundingBox] | None],
153
142
  per_channel: bool,
154
143
  stats_processor_cls: Iterable[type[StatsProcessor[TStatsOutput]]],
155
144
  ) -> StatsProcessorOutput:
156
- return process_stats(*args, per_box=per_box, per_channel=per_channel, stats_processor_cls=stats_processor_cls)
145
+ return process_stats(*args, per_channel=per_channel, stats_processor_cls=stats_processor_cls)
146
+
147
+
148
+ def _enumerate(dataset: Dataset[Array] | Dataset[tuple[Array, Any, Any]], per_box: bool):
149
+ for i in range(len(dataset)):
150
+ d = dataset[i]
151
+ image = d[0] if isinstance(d, tuple) else d
152
+ if per_box and isinstance(d, tuple) and isinstance(d[1], ObjectDetectionTarget):
153
+ boxes = cast(Array, d[1].boxes)
154
+ target = [BoundingBox(float(box[i]) for i in range(4)) for box in boxes]
155
+ else:
156
+ target = None
157
+
158
+ yield i, image, target
157
159
 
158
160
 
159
161
  def run_stats(
@@ -202,17 +204,11 @@ def run_stats(
202
204
  warning_list = []
203
205
  stats_processor_cls = stats_processor_cls if isinstance(stats_processor_cls, Iterable) else [stats_processor_cls]
204
206
 
205
- def _enumerate(dataset: Dataset[Array] | Dataset[tuple[Array, Any, Any]], per_box: bool):
206
- for i in range(len(dataset)):
207
- d = dataset[i]
208
- yield i, d[0] if isinstance(d, tuple) else d, d[1] if isinstance(d, tuple) and per_box else None
209
-
210
207
  with Pool(processes=get_max_processes()) as p:
211
208
  for r in tqdm.tqdm(
212
209
  p.imap(
213
210
  partial(
214
211
  process_stats_unpack,
215
- per_box=per_box,
216
212
  per_channel=per_channel,
217
213
  stats_processor_cls=stats_processor_cls,
218
214
  ),
@@ -5,54 +5,16 @@ __all__ = []
5
5
  from collections import Counter, defaultdict
6
6
  from typing import Any, Mapping, TypeVar
7
7
 
8
- import numpy as np
9
-
10
8
  from dataeval.outputs import LabelStatsOutput
11
9
  from dataeval.outputs._base import set_metadata
12
- from dataeval.typing import AnnotatedDataset, ArrayLike
13
- from dataeval.utils._array import as_numpy
10
+ from dataeval.typing import AnnotatedDataset
14
11
  from dataeval.utils.data._metadata import Metadata
15
12
 
16
13
  TValue = TypeVar("TValue")
17
14
 
18
15
 
19
- def _ensure_2d(labels: ArrayLike) -> ArrayLike:
20
- if isinstance(labels, np.ndarray):
21
- return labels[:, None]
22
- else:
23
- return [[lbl] for lbl in labels] # type: ignore
24
-
25
-
26
- def _get_list_depth(lst):
27
- if isinstance(lst, list) and lst:
28
- return 1 + max(_get_list_depth(item) for item in lst)
29
- return 0
30
-
31
-
32
- def _check_labels_dimension(labels: ArrayLike) -> ArrayLike:
33
- # Check for nested lists beyond 2 levels
34
-
35
- if isinstance(labels, np.ndarray):
36
- if labels.ndim == 1:
37
- return _ensure_2d(labels)
38
- elif labels.ndim == 2:
39
- return labels
40
- else:
41
- raise ValueError("The label array must not have more than 2 dimensions.")
42
- elif isinstance(labels, list):
43
- depth = _get_list_depth(labels)
44
- if depth == 1:
45
- return _ensure_2d(labels)
46
- elif depth == 2:
47
- return labels
48
- else:
49
- raise ValueError("The label list must not be empty or have more than 2 levels of nesting.")
50
- else:
51
- raise TypeError("Labels must be either a NumPy array or a list.")
52
-
53
-
54
16
  def _sort_to_list(d: Mapping[int, TValue]) -> list[TValue]:
55
- return [v for _, v in sorted(d.items())]
17
+ return [t[1] for t in sorted(d.items())]
56
18
 
57
19
 
58
20
  @set_metadata
@@ -98,12 +60,9 @@ def labelstats(dataset: Metadata | AnnotatedDataset[Any]) -> LabelStatsOutput:
98
60
  label_per_image: list[int] = []
99
61
 
100
62
  index2label = dict(enumerate(dataset.class_names))
101
- labels = [target.labels.tolist() for target in dataset.targets]
102
-
103
- labels_2d = _check_labels_dimension(labels)
104
63
 
105
- for i, group in enumerate(labels_2d):
106
- group = as_numpy(group).tolist()
64
+ for i, target in enumerate(dataset.targets):
65
+ group = target.labels.tolist()
107
66
 
108
67
  # Count occurrences of each label in all sublists
109
68
  label_counts.update(group)
@@ -4,7 +4,7 @@ as well as runtime metadata for reproducibility and logging.
4
4
  """
5
5
 
6
6
  from ._base import ExecutionMetadata
7
- from ._bias import BalanceOutput, CoverageOutput, DiversityOutput, LabelParityOutput, ParityOutput
7
+ from ._bias import BalanceOutput, CompletenessOutput, CoverageOutput, DiversityOutput, LabelParityOutput, ParityOutput
8
8
  from ._drift import DriftMMDOutput, DriftOutput
9
9
  from ._estimators import BEROutput, ClustererOutput, DivergenceOutput, UAPOutput
10
10
  from ._linters import DuplicatesOutput, OutliersOutput
@@ -29,6 +29,7 @@ __all__ = [
29
29
  "ChannelStatsOutput",
30
30
  "ClustererOutput",
31
31
  "CoverageOutput",
32
+ "CompletenessOutput",
32
33
  "DimensionStatsOutput",
33
34
  "DivergenceOutput",
34
35
  "DiversityOutput",
dataeval/outputs/_bias.py CHANGED
@@ -14,9 +14,10 @@ with contextlib.suppress(ImportError):
14
14
  from matplotlib.figure import Figure
15
15
 
16
16
  from dataeval.outputs._base import Output
17
- from dataeval.typing import ArrayLike
18
- from dataeval.utils._array import to_numpy
17
+ from dataeval.typing import ArrayLike, Dataset
18
+ from dataeval.utils._array import as_numpy, channels_first_to_last
19
19
  from dataeval.utils._plot import heatmap
20
+ from dataeval.utils.data._images import Images
20
21
 
21
22
  TData = TypeVar("TData", np.float64, NDArray[np.float64])
22
23
 
@@ -107,13 +108,13 @@ class CoverageOutput(Output):
107
108
  critical_value_radii: NDArray[np.float64]
108
109
  coverage_radius: float
109
110
 
110
- def plot(self, images: ArrayLike, top_k: int = 6) -> Figure:
111
+ def plot(self, images: Images[Any] | Dataset[Any], top_k: int = 6) -> Figure:
111
112
  """
112
113
  Plot the top k images together for visualization.
113
114
 
114
115
  Parameters
115
116
  ----------
116
- images : ArrayLike
117
+ images : Images or Dataset
117
118
  Original images (not embeddings) in (N, C, H, W) or (N, H, W) format
118
119
  top_k : int, default 6
119
120
  Number of images to plot (plotting assumes groups of 3)
@@ -130,46 +131,54 @@ class CoverageOutput(Output):
130
131
  import matplotlib.pyplot as plt
131
132
 
132
133
  # Determine which images to plot
133
- highest_uncovered_indices = self.uncovered_indices[:top_k]
134
+ selected_indices = self.uncovered_indices[:top_k]
134
135
 
135
- # Grab the images
136
- selected_images = to_numpy(images)[highest_uncovered_indices]
136
+ images = Images(images) if isinstance(images, Dataset) else images
137
137
 
138
138
  # Plot the images
139
- num_images = min(top_k, len(images))
140
-
141
- ndim = selected_images.ndim
142
- if ndim == 4:
143
- selected_images = np.moveaxis(selected_images, 1, -1)
144
- elif ndim == 3:
145
- selected_images = np.repeat(selected_images[:, :, :, np.newaxis], 3, axis=-1)
146
- else:
147
- raise ValueError(
148
- f"Expected a (N,C,H,W) or a (N, H, W) set of images, but got a {ndim}-dimensional set of images."
149
- )
139
+ num_images = min(top_k, len(selected_indices))
150
140
 
151
141
  rows = int(np.ceil(num_images / 3))
152
142
  fig, axs = plt.subplots(rows, 3, figsize=(9, 3 * rows))
153
143
 
154
144
  if rows == 1:
155
145
  for j in range(3):
156
- if j >= len(selected_images):
146
+ if j >= len(selected_indices):
157
147
  continue
158
- axs[j].imshow(selected_images[j])
148
+ image = channels_first_to_last(as_numpy(images[selected_indices[j]]))
149
+ axs[j].imshow(image)
159
150
  axs[j].axis("off")
160
151
  else:
161
152
  for i in range(rows):
162
153
  for j in range(3):
163
154
  i_j = i * 3 + j
164
- if i_j >= len(selected_images):
155
+ if i_j >= len(selected_indices):
165
156
  continue
166
- axs[i, j].imshow(selected_images[i_j])
157
+ image = channels_first_to_last(as_numpy(images[selected_indices[i_j]]))
158
+ axs[i, j].imshow(image)
167
159
  axs[i, j].axis("off")
168
160
 
169
161
  fig.tight_layout()
170
162
  return fig
171
163
 
172
164
 
165
+ @dataclass(frozen=True)
166
+ class CompletenessOutput(Output):
167
+ """
168
+ Output from the completeness function.
169
+
170
+ Attributes
171
+ ----------
172
+ fraction_filled : float
173
+ Fraction of boxes that contain at least one data point
174
+ empty_box_centers : List[np.ndarray]
175
+ List of coordinates for centers of empty boxes
176
+ """
177
+
178
+ fraction_filled: float
179
+ empty_box_centers: NDArray[np.float64]
180
+
181
+
173
182
  @dataclass(frozen=True)
174
183
  class BalanceOutput(Output):
175
184
  """
@@ -4,7 +4,7 @@ __all__ = []
4
4
 
5
5
  import contextlib
6
6
  from dataclasses import dataclass
7
- from typing import Any, Iterable, Optional, Union
7
+ from typing import Any, Iterable, NamedTuple, Optional, Union
8
8
 
9
9
  import numpy as np
10
10
  from numpy.typing import NDArray
@@ -22,8 +22,7 @@ SOURCE_INDEX = "source_index"
22
22
  BOX_COUNT = "box_count"
23
23
 
24
24
 
25
- @dataclass(frozen=True)
26
- class SourceIndex:
25
+ class SourceIndex(NamedTuple):
27
26
  """
28
27
  The indices of the source image, box and channel.
29
28
 
dataeval/typing.py CHANGED
@@ -23,7 +23,7 @@ __all__ = [
23
23
  import sys
24
24
  from typing import Any, Generic, Iterator, Protocol, Sequence, TypedDict, TypeVar, Union, runtime_checkable
25
25
 
26
- from typing_extensions import NotRequired, Required
26
+ from typing_extensions import NotRequired, ReadOnly, Required
27
27
 
28
28
  if sys.version_info >= (3, 10):
29
29
  from typing import TypeAlias
@@ -91,8 +91,8 @@ class DatasetMetadata(TypedDict, total=False):
91
91
  A lookup table converting label value to class name
92
92
  """
93
93
 
94
- id: Required[str]
95
- index2label: NotRequired[dict[int, str]]
94
+ id: Required[ReadOnly[str]]
95
+ index2label: NotRequired[ReadOnly[dict[int, str]]]
96
96
 
97
97
 
98
98
  @runtime_checkable
dataeval/utils/_array.py CHANGED
@@ -13,7 +13,7 @@ import torch
13
13
  from numpy.typing import NDArray
14
14
 
15
15
  from dataeval._log import LogMessage
16
- from dataeval.typing import ArrayLike
16
+ from dataeval.typing import Array, ArrayLike
17
17
 
18
18
  _logger = logging.getLogger(__name__)
19
19
 
@@ -167,3 +167,28 @@ def flatten(array: ArrayLike) -> NDArray[Any]:
167
167
  """
168
168
  nparr = as_numpy(array)
169
169
  return nparr.reshape((nparr.shape[0], -1))
170
+
171
+
172
+ _TArray = TypeVar("_TArray", bound=Array)
173
+
174
+
175
+ def channels_first_to_last(array: _TArray) -> _TArray:
176
+ """
177
+ Converts array from channels first to channels last format
178
+
179
+ Parameters
180
+ ----------
181
+ array : ArrayLike
182
+ Input array
183
+
184
+ Returns
185
+ -------
186
+ ArrayLike
187
+ Converted array
188
+ """
189
+ if isinstance(array, np.ndarray):
190
+ return np.transpose(array, (1, 2, 0))
191
+ elif isinstance(array, torch.Tensor):
192
+ return torch.permute(array, (1, 2, 0))
193
+ else:
194
+ raise TypeError(f"Unsupported array type {type(array)} for conversion.")
@@ -47,6 +47,8 @@ def _validate_data(
47
47
  or not len(bboxes[0][0]) == 4
48
48
  ):
49
49
  raise TypeError("Boxes must be a sequence of sequences of (x0, y0, x1, y1) for object detection.")
50
+ else:
51
+ raise ValueError(f"Unknown datum type '{datum_type}'. Must be 'ic' or 'od'.")
50
52
 
51
53
 
52
54
  def _find_max(arr: ArrayLike) -> Any: