dataeval 0.86.0__py3-none-any.whl → 0.86.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (62) hide show
  1. dataeval/__init__.py +1 -1
  2. dataeval/_log.py +1 -1
  3. dataeval/config.py +21 -4
  4. dataeval/data/_embeddings.py +2 -2
  5. dataeval/data/_images.py +2 -3
  6. dataeval/data/_metadata.py +48 -37
  7. dataeval/data/_selection.py +1 -2
  8. dataeval/data/_split.py +2 -3
  9. dataeval/data/_targets.py +17 -13
  10. dataeval/data/selections/_classfilter.py +2 -5
  11. dataeval/data/selections/_prioritize.py +6 -9
  12. dataeval/data/selections/_shuffle.py +3 -1
  13. dataeval/detectors/drift/_base.py +4 -5
  14. dataeval/detectors/drift/_mmd.py +3 -6
  15. dataeval/detectors/drift/_nml/_base.py +4 -2
  16. dataeval/detectors/drift/_nml/_chunk.py +11 -19
  17. dataeval/detectors/drift/_nml/_domainclassifier.py +8 -19
  18. dataeval/detectors/drift/_nml/_result.py +8 -9
  19. dataeval/detectors/drift/_nml/_thresholds.py +66 -77
  20. dataeval/detectors/linters/outliers.py +7 -7
  21. dataeval/metrics/bias/_parity.py +10 -13
  22. dataeval/metrics/estimators/_divergence.py +2 -4
  23. dataeval/metrics/stats/_base.py +103 -42
  24. dataeval/metrics/stats/_boxratiostats.py +21 -19
  25. dataeval/metrics/stats/_dimensionstats.py +14 -10
  26. dataeval/metrics/stats/_hashstats.py +1 -1
  27. dataeval/metrics/stats/_pixelstats.py +6 -6
  28. dataeval/metrics/stats/_visualstats.py +3 -3
  29. dataeval/outputs/_base.py +22 -7
  30. dataeval/outputs/_bias.py +26 -28
  31. dataeval/outputs/_drift.py +1 -9
  32. dataeval/outputs/_linters.py +11 -11
  33. dataeval/outputs/_stats.py +82 -23
  34. dataeval/outputs/_workflows.py +2 -2
  35. dataeval/utils/_array.py +6 -9
  36. dataeval/utils/_bin.py +1 -2
  37. dataeval/utils/_clusterer.py +7 -4
  38. dataeval/utils/_fast_mst.py +27 -13
  39. dataeval/utils/_image.py +65 -11
  40. dataeval/utils/_mst.py +1 -3
  41. dataeval/utils/_plot.py +15 -10
  42. dataeval/utils/data/_dataset.py +32 -20
  43. dataeval/utils/data/metadata.py +104 -82
  44. dataeval/utils/datasets/__init__.py +2 -0
  45. dataeval/utils/datasets/_antiuav.py +189 -0
  46. dataeval/utils/datasets/_base.py +11 -8
  47. dataeval/utils/datasets/_cifar10.py +104 -45
  48. dataeval/utils/datasets/_fileio.py +21 -47
  49. dataeval/utils/datasets/_milco.py +19 -11
  50. dataeval/utils/datasets/_mixin.py +2 -4
  51. dataeval/utils/datasets/_mnist.py +3 -4
  52. dataeval/utils/datasets/_ships.py +14 -7
  53. dataeval/utils/datasets/_voc.py +229 -42
  54. dataeval/utils/torch/models.py +5 -10
  55. dataeval/utils/torch/trainer.py +3 -3
  56. dataeval/workflows/sufficiency.py +2 -2
  57. {dataeval-0.86.0.dist-info → dataeval-0.86.1.dist-info}/METADATA +1 -1
  58. dataeval-0.86.1.dist-info/RECORD +114 -0
  59. dataeval/detectors/ood/vae.py +0 -74
  60. dataeval-0.86.0.dist-info/RECORD +0 -114
  61. {dataeval-0.86.0.dist-info → dataeval-0.86.1.dist-info}/LICENSE.txt +0 -0
  62. {dataeval-0.86.0.dist-info → dataeval-0.86.1.dist-info}/WHEEL +0 -0
@@ -4,6 +4,7 @@ __all__ = []
4
4
 
5
5
  import warnings
6
6
  from dataclasses import dataclass
7
+ from typing import Any
7
8
 
8
9
  import numba
9
10
  import numpy as np
@@ -30,7 +31,9 @@ from dataeval.utils._fast_mst import calculate_neighbor_distances, minimum_spann
30
31
 
31
32
 
32
33
  @numba.njit(parallel=True, locals={"i": numba.types.int32})
33
- def compare_links_to_cluster_std(mst, clusters):
34
+ def compare_links_to_cluster_std(
35
+ mst: NDArray[np.float32], clusters: NDArray[np.intp]
36
+ ) -> tuple[NDArray[np.int32], NDArray[np.int32]]:
34
37
  cluster_ids = np.unique(clusters)
35
38
  cluster_grouping = np.full(mst.shape[0], -1, dtype=np.int16)
36
39
 
@@ -79,7 +82,7 @@ def cluster(data: ArrayLike) -> ClusterData:
79
82
  cluster_selection_epsilon = 0.0
80
83
  # cluster_selection_method = "eom"
81
84
 
82
- x = flatten(to_numpy(data))
85
+ x: NDArray[Any] = flatten(to_numpy(data))
83
86
  samples, features = x.shape # Due to flatten(), we know shape has a length of 2
84
87
  if samples < 2:
85
88
  raise ValueError(f"Data should have at least 2 samples; got {samples}")
@@ -125,9 +128,9 @@ def cluster(data: ArrayLike) -> ClusterData:
125
128
  return ClusterData(clusters, mst, linkage_tree, condensed_tree, membership_strengths, kneighbors, kdistances)
126
129
 
127
130
 
128
- def sorted_union_find(index_groups):
131
+ def sorted_union_find(index_groups: NDArray[np.int32]) -> list[list[np.int32]]:
129
132
  """Merges and sorts groups of indices that share any common index"""
130
- groups = [[np.int32(x) for x in range(0)] for y in range(0)]
133
+ groups: list[list[np.int32]] = [[np.int32(x) for x in range(0)] for y in range(0)]
131
134
  uniques, inverse = np.unique(index_groups, return_inverse=True)
132
135
  inverse = inverse.flatten()
133
136
  disjoint_set = ds_rank_create(uniques.size)
@@ -6,9 +6,11 @@
6
6
  __all__ = []
7
7
 
8
8
  import warnings
9
+ from typing import Any
9
10
 
10
11
  import numba
11
12
  import numpy as np
13
+ from numpy.typing import NDArray
12
14
  from sklearn.neighbors import NearestNeighbors
13
15
 
14
16
  with warnings.catch_warnings():
@@ -17,24 +19,26 @@ with warnings.catch_warnings():
17
19
 
18
20
 
19
21
  @numba.njit()
20
- def _ds_union_by_rank(disjoint_set, point, nbr):
22
+ def _ds_union_by_rank(disjoint_set: tuple[NDArray[np.int32], NDArray[np.int32]], point: int, nbr: int) -> int:
21
23
  y = ds_find(disjoint_set, point)
22
24
  x = ds_find(disjoint_set, nbr)
23
25
 
24
26
  if x == y:
25
27
  return 0
26
28
 
27
- if disjoint_set.rank[x] < disjoint_set.rank[y]:
29
+ if disjoint_set[1][x] < disjoint_set[1][y]:
28
30
  x, y = y, x
29
31
 
30
- disjoint_set.parent[y] = x
31
- if disjoint_set.rank[x] == disjoint_set.rank[y]:
32
- disjoint_set.rank[x] += 1
32
+ disjoint_set[0][y] = x
33
+ if disjoint_set[1][x] == disjoint_set[1][y]:
34
+ disjoint_set[1][x] += 1
33
35
  return 1
34
36
 
35
37
 
36
38
  @numba.njit(locals={"i": numba.types.uint32, "nbr": numba.types.uint32, "dist": numba.types.float32})
37
- def _init_tree(n_neighbors, n_distance):
39
+ def _init_tree(
40
+ n_neighbors: NDArray[np.intp], n_distance: NDArray[np.float32]
41
+ ) -> tuple[NDArray[np.float32], int, tuple[NDArray[np.int32], NDArray[np.int32]], NDArray[np.uint32]]:
38
42
  # Initial graph to hold tree connections
39
43
  tree = np.zeros((n_neighbors.size - 1, 3), dtype=np.float32)
40
44
  disjoint_set = ds_rank_create(n_neighbors.size)
@@ -56,7 +60,13 @@ def _init_tree(n_neighbors, n_distance):
56
60
 
57
61
 
58
62
  @numba.njit(locals={"i": numba.types.uint32, "nbr": numba.types.uint32})
59
- def _update_tree_by_distance(tree, int_tree, disjoint_set, n_neighbors, n_distance):
63
+ def _update_tree_by_distance(
64
+ tree: NDArray[np.float32],
65
+ int_tree: int,
66
+ disjoint_set: tuple[NDArray[np.int32], NDArray[np.int32]],
67
+ n_neighbors: NDArray[np.uint32],
68
+ n_distance: NDArray[np.float32],
69
+ ) -> tuple[NDArray[np.float32], int, tuple[NDArray[np.int32], NDArray[np.int32]], NDArray[np.uint32]]:
60
70
  cluster_points = np.empty(n_neighbors.size, dtype=np.uint32)
61
71
  sort_dist = np.argsort(n_distance)
62
72
  dist_sorted = n_distance[sort_dist]
@@ -80,9 +90,9 @@ def _update_tree_by_distance(tree, int_tree, disjoint_set, n_neighbors, n_distan
80
90
 
81
91
 
82
92
  @numba.njit(locals={"i": numba.types.uint32})
83
- def _cluster_edges(tracker, last_idx, cluster_distances):
93
+ def _cluster_edges(tracker: NDArray[Any], last_idx: int, cluster_distances: NDArray[Any]) -> list[NDArray[np.intp]]:
84
94
  cluster_ids = np.unique(tracker)
85
- edge_points = []
95
+ edge_points: list[NDArray[np.intp]] = []
86
96
  for idx in range(cluster_ids.size):
87
97
  cluster_points = np.nonzero(tracker == cluster_ids[idx])[0]
88
98
  cluster_size = cluster_points.size
@@ -102,14 +112,16 @@ def _cluster_edges(tracker, last_idx, cluster_distances):
102
112
  return edge_points
103
113
 
104
114
 
105
- def _compute_nn(dataA, dataB, k):
115
+ def _compute_nn(dataA: NDArray[Any], dataB: NDArray[Any], k: int) -> tuple[NDArray[np.int32], NDArray[np.float32]]:
106
116
  distances, neighbors = NearestNeighbors(n_neighbors=k + 1, algorithm="brute").fit(dataA).kneighbors(dataB)
107
117
  neighbors = np.array(neighbors[:, 1 : k + 1], dtype=np.int32)
108
118
  distances = np.array(distances[:, 1 : k + 1], dtype=np.float32)
109
119
  return neighbors, distances
110
120
 
111
121
 
112
- def _calculate_cluster_neighbors(data, groups, point_array):
122
+ def _calculate_cluster_neighbors(
123
+ data: NDArray[Any], groups: list[NDArray[np.intp]], point_array: NDArray[Any]
124
+ ) -> tuple[NDArray[np.uint32], NDArray[np.float32]]:
113
125
  """Rerun nearest neighbor based on clusters"""
114
126
  cluster_neighbors = np.zeros(point_array.size, dtype=np.uint32)
115
127
  cluster_nbr_distances = np.full(point_array.size, np.inf, dtype=np.float32)
@@ -126,7 +138,9 @@ def _calculate_cluster_neighbors(data, groups, point_array):
126
138
  return cluster_neighbors, cluster_nbr_distances
127
139
 
128
140
 
129
- def minimum_spanning_tree(data, neighbors, distances):
141
+ def minimum_spanning_tree(
142
+ data: NDArray[Any], neighbors: NDArray[np.int32], distances: NDArray[np.float32]
143
+ ) -> NDArray[np.float32]:
130
144
  # Transpose arrays to get number of samples along a row
131
145
  k_neighbors = neighbors.T.astype(np.uint32).copy()
132
146
  k_distances = distances.T.astype(np.float32).copy()
@@ -168,7 +182,7 @@ def minimum_spanning_tree(data, neighbors, distances):
168
182
  return tree
169
183
 
170
184
 
171
- def calculate_neighbor_distances(data: np.ndarray, k: int = 10):
185
+ def calculate_neighbor_distances(data: np.ndarray, k: int = 10) -> tuple[NDArray[np.int32], NDArray[np.float32]]:
172
186
  # Have the potential to add in other distance calculations - supported calculations:
173
187
  # https://github.com/lmcinnes/pynndescent/blob/master/pynndescent/pynndescent_.py#L524
174
188
  try:
dataeval/utils/_image.py CHANGED
@@ -12,6 +12,9 @@ from scipy.signal import convolve2d
12
12
  EDGE_KERNEL = np.array([[-1, -1, -1], [-1, 8, -1], [-1, -1, -1]], dtype=np.int8)
13
13
  BIT_DEPTH = (1, 8, 12, 16, 32)
14
14
 
15
+ Box = tuple[int, int, int, int]
16
+ """Bounding box as tuple of integers in x0, y0, x1, y1 format."""
17
+
15
18
 
16
19
  @dataclass
17
20
  class BitDepth:
@@ -25,12 +28,11 @@ def get_bitdepth(image: NDArray[Any]) -> BitDepth:
25
28
  Approximates the bit depth of the image using the
26
29
  min and max pixel values.
27
30
  """
28
- pmin, pmax = np.min(image), np.max(image)
31
+ pmin, pmax = np.nanmin(image), np.nanmax(image)
29
32
  if pmin < 0:
30
33
  return BitDepth(0, pmin, pmax)
31
- else:
32
- depth = ([x for x in BIT_DEPTH if 2**x > pmax] or [max(BIT_DEPTH)])[0]
33
- return BitDepth(depth, 0, 2**depth - 1)
34
+ depth = ([x for x in BIT_DEPTH if 2**x > pmax] or [max(BIT_DEPTH)])[0]
35
+ return BitDepth(depth, 0, 2**depth - 1)
34
36
 
35
37
 
36
38
  def rescale(image: NDArray[Any], depth: int = 1) -> NDArray[Any]:
@@ -40,9 +42,8 @@ def rescale(image: NDArray[Any], depth: int = 1) -> NDArray[Any]:
40
42
  bitdepth = get_bitdepth(image)
41
43
  if bitdepth.depth == depth:
42
44
  return image
43
- else:
44
- normalized = (image + bitdepth.pmin) / (bitdepth.pmax - bitdepth.pmin)
45
- return normalized * (2**depth - 1)
45
+ normalized = (image + bitdepth.pmin) / (bitdepth.pmax - bitdepth.pmin)
46
+ return normalized * (2**depth - 1)
46
47
 
47
48
 
48
49
  def normalize_image_shape(image: NDArray[Any]) -> NDArray[Any]:
@@ -52,13 +53,12 @@ def normalize_image_shape(image: NDArray[Any]) -> NDArray[Any]:
52
53
  ndim = image.ndim
53
54
  if ndim == 2:
54
55
  return np.expand_dims(image, axis=0)
55
- elif ndim == 3:
56
+ if ndim == 3:
56
57
  return image
57
- elif ndim > 3:
58
+ if ndim > 3:
58
59
  # Slice all but the last 3 dimensions
59
60
  return image[(0,) * (ndim - 3)]
60
- else:
61
- raise ValueError("Images must have 2 or more dimensions.")
61
+ raise ValueError("Images must have 2 or more dimensions.")
62
62
 
63
63
 
64
64
  def edge_filter(image: NDArray[Any], offset: float = 0.5) -> NDArray[np.uint8]:
@@ -71,3 +71,57 @@ def edge_filter(image: NDArray[Any], offset: float = 0.5) -> NDArray[np.uint8]:
71
71
  edges = convolve2d(image, EDGE_KERNEL, mode="same", boundary="symm") + offset
72
72
  np.clip(edges, 0, 255, edges)
73
73
  return edges
74
+
75
+
76
+ def clip_box(image: NDArray[Any], box: Box) -> Box:
77
+ """
78
+ Clip the box to inside the provided image dimensions.
79
+ """
80
+ x0, y0, x1, y1 = box
81
+ h, w = image.shape[-2:]
82
+
83
+ return max(0, x0), max(0, y0), min(w, x1), min(h, y1)
84
+
85
+
86
+ def is_valid_box(box: Box) -> bool:
87
+ """
88
+ Check if the box dimensions provided are a valid image.
89
+ """
90
+ return box[2] > box[0] and box[3] > box[1]
91
+
92
+
93
+ def clip_and_pad(image: NDArray[Any], box: Box) -> NDArray[Any]:
94
+ """
95
+ Extract a region from an image based on a bounding box, clipping to image boundaries
96
+ and padding out-of-bounds areas with np.nan.
97
+
98
+ Parameters:
99
+ -----------
100
+ image : NDArray[Any]
101
+ Input image array in format C, H, W (channels first)
102
+ box : Box
103
+ Bounding box coordinates as (x0, y0, x1, y1) where (x0, y0) is top-left and (x1, y1) is bottom-right
104
+
105
+ Returns:
106
+ --------
107
+ NDArray[Any]
108
+ The extracted region with out-of-bounds areas padded with np.nan
109
+ """
110
+
111
+ # Create output array filled with NaN with a minimum size of 1x1
112
+ bw, bh = max(1, box[2] - box[0]), max(1, box[3] - box[1])
113
+
114
+ output = np.full((image.shape[-3] if image.ndim > 2 else 1, bh, bw), np.nan)
115
+
116
+ # Calculate source box
117
+ sbox = clip_box(image, box)
118
+
119
+ # Calculate destination box
120
+ x0, y0 = sbox[0] - box[0], sbox[1] - box[1]
121
+ x1, y1 = x0 + (sbox[2] - sbox[0]), y0 + (sbox[3] - sbox[1])
122
+
123
+ # Copy the source if valid from the image to the output
124
+ if is_valid_box(sbox):
125
+ output[:, y0:y1, x0:x1] = image[:, sbox[1] : sbox[3], sbox[0] : sbox[2]]
126
+
127
+ return output
dataeval/utils/_mst.py CHANGED
@@ -83,6 +83,4 @@ def compute_neighbors(
83
83
 
84
84
  nbrs = NearestNeighbors(n_neighbors=k + 1, algorithm=algorithm).fit(B)
85
85
  nns = nbrs.kneighbors(A)[1]
86
- nns = nns[:, 1:].squeeze()
87
-
88
- return nns
86
+ return nns[:, 1:].squeeze()
dataeval/utils/_plot.py CHANGED
@@ -3,6 +3,7 @@ from __future__ import annotations
3
3
  __all__ = []
4
4
 
5
5
  import contextlib
6
+ import math
6
7
  from typing import Any
7
8
 
8
9
  import numpy as np
@@ -160,11 +161,9 @@ def histogram_plot(
160
161
  import matplotlib.pyplot as plt
161
162
 
162
163
  num_metrics = len(data_dict)
163
- if num_metrics > 2:
164
- rows = int(len(data_dict) / 3)
165
- fig, axs = plt.subplots(rows, 3, figsize=(10, rows * 2.5))
166
- else:
167
- fig, axs = plt.subplots(1, num_metrics, figsize=(4 * num_metrics, 4))
164
+ rows = math.ceil(num_metrics / 3)
165
+ cols = min(num_metrics, 3)
166
+ fig, axs = plt.subplots(rows, 3, figsize=(cols * 3 + 1, rows * 3))
168
167
 
169
168
  for ax, metric in zip(
170
169
  axs.flat,
@@ -178,6 +177,10 @@ def histogram_plot(
178
177
  ax.set_ylabel(ylabel)
179
178
  ax.set_xlabel(xlabel)
180
179
 
180
+ for ax in axs.flat[num_metrics:]:
181
+ ax.axis("off")
182
+ ax.set_visible(False)
183
+
181
184
  fig.tight_layout()
182
185
  return fig
183
186
 
@@ -216,11 +219,9 @@ def channel_histogram_plot(
216
219
  label_kwargs = {"label": [f"Channel {i}" for i in range(max_channels)]}
217
220
 
218
221
  num_metrics = len(data_keys)
219
- if num_metrics > 2:
220
- rows = int(len(data_keys) / 3)
221
- fig, axs = plt.subplots(rows, 3, figsize=(10, rows * 2.5))
222
- else:
223
- fig, axs = plt.subplots(1, num_metrics, figsize=(4 * num_metrics, 4))
222
+ rows = math.ceil(num_metrics / 3)
223
+ cols = min(num_metrics, 3)
224
+ fig, axs = plt.subplots(rows, 3, figsize=(cols * 3 + 1, rows * 3))
224
225
 
225
226
  for ax, metric in zip(
226
227
  axs.flat,
@@ -245,5 +246,9 @@ def channel_histogram_plot(
245
246
  ax.set_ylabel(ylabel)
246
247
  ax.set_xlabel(xlabel)
247
248
 
249
+ for ax in axs.flat[num_metrics:]:
250
+ ax.axis("off")
251
+ ax.set_visible(False)
252
+
248
253
  fig.tight_layout()
249
254
  return fig
@@ -2,7 +2,7 @@ from __future__ import annotations
2
2
 
3
3
  __all__ = []
4
4
 
5
- from typing import Any, Generic, Iterable, Literal, Sequence, TypeVar
5
+ from typing import Any, Generic, Iterable, Literal, Sequence, SupportsFloat, SupportsInt, TypeVar, cast
6
6
 
7
7
  from dataeval.typing import (
8
8
  Array,
@@ -17,8 +17,8 @@ from dataeval.utils._array import as_numpy
17
17
  def _validate_data(
18
18
  datum_type: Literal["ic", "od"],
19
19
  images: Array | Sequence[Array],
20
- labels: Sequence[int] | Sequence[Sequence[int]],
21
- bboxes: Sequence[Sequence[Sequence[float]]] | None,
20
+ labels: Array | Sequence[int] | Sequence[Array] | Sequence[Sequence[int]],
21
+ bboxes: Array | Sequence[Array] | Sequence[Sequence[Array]] | Sequence[Sequence[Sequence[float]]] | None,
22
22
  metadata: Sequence[dict[str, Any]] | None,
23
23
  ) -> None:
24
24
  # Validate inputs
@@ -34,16 +34,21 @@ def _validate_data(
34
34
  raise ValueError(f"Number of metadata ({len(metadata)}) does not match number of images ({dataset_len}).")
35
35
 
36
36
  if datum_type == "ic":
37
- if not isinstance(labels, Sequence) or not isinstance(labels[0], int):
37
+ if not isinstance(labels, (Sequence, Array)) or not isinstance(labels[0], (int, SupportsInt)):
38
38
  raise TypeError("Labels must be a sequence of integers for image classification.")
39
39
  elif datum_type == "od":
40
- if not isinstance(labels, Sequence) or not isinstance(labels[0], Sequence) or not isinstance(labels[0][0], int):
40
+ if (
41
+ not isinstance(labels, (Sequence, Array))
42
+ or not isinstance(labels[0], (Sequence, Array))
43
+ or not isinstance(cast(Sequence[Any], labels[0])[0], (int, SupportsInt))
44
+ ):
41
45
  raise TypeError("Labels must be a sequence of sequences of integers for object detection.")
42
46
  if (
43
47
  bboxes is None
44
48
  or not isinstance(bboxes, (Sequence, Array))
45
49
  or not isinstance(bboxes[0], (Sequence, Array))
46
50
  or not isinstance(bboxes[0][0], (Sequence, Array))
51
+ or not isinstance(bboxes[0][0][0], (float, SupportsFloat))
47
52
  or not len(bboxes[0][0]) == 4
48
53
  ):
49
54
  raise TypeError("Boxes must be a sequence of sequences of (x0, y0, x1, y1) for object detection.")
@@ -52,11 +57,10 @@ def _validate_data(
52
57
 
53
58
 
54
59
  def _find_max(arr: ArrayLike) -> Any:
55
- if isinstance(arr, (Iterable, Sequence, Array)):
60
+ if not isinstance(arr, (bytes, str)) and isinstance(arr, (Iterable, Sequence, Array)):
56
61
  if isinstance(arr[0], (Iterable, Sequence, Array)):
57
62
  return max([_find_max(x) for x in arr]) # type: ignore
58
- else:
59
- return max(arr)
63
+ return max(arr)
60
64
  return arr
61
65
 
62
66
 
@@ -92,12 +96,14 @@ class CustomImageClassificationDataset(BaseAnnotatedDataset[Sequence[int]], Imag
92
96
  def __init__(
93
97
  self,
94
98
  images: Array | Sequence[Array],
95
- labels: Sequence[int],
99
+ labels: Array | Sequence[int],
96
100
  metadata: Sequence[dict[str, Any]] | None,
97
101
  classes: Sequence[str] | None,
98
102
  name: str | None = None,
99
103
  ) -> None:
100
- super().__init__("ic", images, labels, metadata, classes)
104
+ super().__init__(
105
+ "ic", images, as_numpy(labels).tolist() if isinstance(labels, Array) else labels, metadata, classes
106
+ )
101
107
  if name is not None:
102
108
  self.__name__ = name
103
109
  self.__class__.__name__ = name
@@ -135,18 +141,24 @@ class CustomObjectDetectionDataset(BaseAnnotatedDataset[Sequence[Sequence[int]]]
135
141
  def __init__(
136
142
  self,
137
143
  images: Array | Sequence[Array],
138
- labels: Sequence[Sequence[int]],
139
- bboxes: Sequence[Sequence[Sequence[float]]],
144
+ labels: Array | Sequence[Array] | Sequence[Sequence[int]],
145
+ bboxes: Array | Sequence[Array] | Sequence[Sequence[Array]] | Sequence[Sequence[Sequence[float]]],
140
146
  metadata: Sequence[dict[str, Any]] | None,
141
147
  classes: Sequence[str] | None,
142
148
  name: str | None = None,
143
149
  ) -> None:
144
- super().__init__("od", images, labels, metadata, classes)
150
+ super().__init__(
151
+ "od",
152
+ images,
153
+ [as_numpy(label).tolist() if isinstance(label, Array) else label for label in labels],
154
+ metadata,
155
+ classes,
156
+ )
145
157
  if name is not None:
146
158
  self.__name__ = name
147
159
  self.__class__.__name__ = name
148
160
  self.__class__.__qualname__ = name
149
- self._bboxes = bboxes
161
+ self._bboxes = [[as_numpy(box).tolist() if isinstance(box, Array) else box for box in bbox] for bbox in bboxes]
150
162
 
151
163
  @property
152
164
  def metadata(self) -> DatasetMetadata:
@@ -162,7 +174,7 @@ class CustomObjectDetectionDataset(BaseAnnotatedDataset[Sequence[Sequence[int]]]
162
174
 
163
175
  def to_image_classification_dataset(
164
176
  images: Array | Sequence[Array],
165
- labels: Sequence[int],
177
+ labels: Array | Sequence[int],
166
178
  metadata: Sequence[dict[str, Any]] | None,
167
179
  classes: Sequence[str] | None,
168
180
  name: str | None = None,
@@ -174,7 +186,7 @@ def to_image_classification_dataset(
174
186
  ----------
175
187
  images : Array | Sequence[Array]
176
188
  The images to use in the dataset.
177
- labels : Sequence[int]
189
+ labels : Array | Sequence[int]
178
190
  The labels to use in the dataset.
179
191
  metadata : Sequence[dict[str, Any]] | None
180
192
  The metadata to use in the dataset.
@@ -191,8 +203,8 @@ def to_image_classification_dataset(
191
203
 
192
204
  def to_object_detection_dataset(
193
205
  images: Array | Sequence[Array],
194
- labels: Sequence[Sequence[int]],
195
- bboxes: Sequence[Sequence[Sequence[float]]],
206
+ labels: Array | Sequence[Array] | Sequence[Sequence[int]],
207
+ bboxes: Array | Sequence[Array] | Sequence[Sequence[Array]] | Sequence[Sequence[Sequence[float]]],
196
208
  metadata: Sequence[dict[str, Any]] | None,
197
209
  classes: Sequence[str] | None,
198
210
  name: str | None = None,
@@ -204,9 +216,9 @@ def to_object_detection_dataset(
204
216
  ----------
205
217
  images : Array | Sequence[Array]
206
218
  The images to use in the dataset.
207
- labels : Sequence[Sequence[int]]
219
+ labels : Array | Sequence[Array] | Sequence[Sequence[int]]
208
220
  The labels to use in the dataset.
209
- bboxes : Sequence[Sequence[Sequence[float]]]
221
+ bboxes : Array | Sequence[Array] | Sequence[Sequence[Array]] | Sequence[Sequence[Sequence[float]]]
210
222
  The bounding boxes (x0,y0,x1,y0) to use in the dataset.
211
223
  metadata : Sequence[dict[str, Any]] | None
212
224
  The metadata to use in the dataset.