dataeval 0.86.0__py3-none-any.whl → 0.86.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dataeval/__init__.py +1 -1
- dataeval/_log.py +1 -1
- dataeval/config.py +21 -4
- dataeval/data/_embeddings.py +2 -2
- dataeval/data/_images.py +2 -3
- dataeval/data/_metadata.py +48 -37
- dataeval/data/_selection.py +1 -2
- dataeval/data/_split.py +2 -3
- dataeval/data/_targets.py +17 -13
- dataeval/data/selections/_classfilter.py +2 -5
- dataeval/data/selections/_prioritize.py +6 -9
- dataeval/data/selections/_shuffle.py +3 -1
- dataeval/detectors/drift/_base.py +4 -5
- dataeval/detectors/drift/_mmd.py +3 -6
- dataeval/detectors/drift/_nml/_base.py +4 -2
- dataeval/detectors/drift/_nml/_chunk.py +11 -19
- dataeval/detectors/drift/_nml/_domainclassifier.py +8 -19
- dataeval/detectors/drift/_nml/_result.py +8 -9
- dataeval/detectors/drift/_nml/_thresholds.py +66 -77
- dataeval/detectors/linters/outliers.py +7 -7
- dataeval/metrics/bias/_parity.py +10 -13
- dataeval/metrics/estimators/_divergence.py +2 -4
- dataeval/metrics/stats/_base.py +103 -42
- dataeval/metrics/stats/_boxratiostats.py +21 -19
- dataeval/metrics/stats/_dimensionstats.py +14 -10
- dataeval/metrics/stats/_hashstats.py +1 -1
- dataeval/metrics/stats/_pixelstats.py +6 -6
- dataeval/metrics/stats/_visualstats.py +3 -3
- dataeval/outputs/_base.py +22 -7
- dataeval/outputs/_bias.py +26 -28
- dataeval/outputs/_drift.py +1 -9
- dataeval/outputs/_linters.py +11 -11
- dataeval/outputs/_stats.py +82 -23
- dataeval/outputs/_workflows.py +2 -2
- dataeval/utils/_array.py +6 -9
- dataeval/utils/_bin.py +1 -2
- dataeval/utils/_clusterer.py +7 -4
- dataeval/utils/_fast_mst.py +27 -13
- dataeval/utils/_image.py +65 -11
- dataeval/utils/_mst.py +1 -3
- dataeval/utils/_plot.py +15 -10
- dataeval/utils/data/_dataset.py +32 -20
- dataeval/utils/data/metadata.py +104 -82
- dataeval/utils/datasets/__init__.py +2 -0
- dataeval/utils/datasets/_antiuav.py +189 -0
- dataeval/utils/datasets/_base.py +11 -8
- dataeval/utils/datasets/_cifar10.py +104 -45
- dataeval/utils/datasets/_fileio.py +21 -47
- dataeval/utils/datasets/_milco.py +19 -11
- dataeval/utils/datasets/_mixin.py +2 -4
- dataeval/utils/datasets/_mnist.py +3 -4
- dataeval/utils/datasets/_ships.py +14 -7
- dataeval/utils/datasets/_voc.py +229 -42
- dataeval/utils/torch/models.py +5 -10
- dataeval/utils/torch/trainer.py +3 -3
- dataeval/workflows/sufficiency.py +2 -2
- {dataeval-0.86.0.dist-info → dataeval-0.86.1.dist-info}/METADATA +1 -1
- dataeval-0.86.1.dist-info/RECORD +114 -0
- dataeval/detectors/ood/vae.py +0 -74
- dataeval-0.86.0.dist-info/RECORD +0 -114
- {dataeval-0.86.0.dist-info → dataeval-0.86.1.dist-info}/LICENSE.txt +0 -0
- {dataeval-0.86.0.dist-info → dataeval-0.86.1.dist-info}/WHEEL +0 -0
dataeval/utils/_clusterer.py
CHANGED
@@ -4,6 +4,7 @@ __all__ = []
|
|
4
4
|
|
5
5
|
import warnings
|
6
6
|
from dataclasses import dataclass
|
7
|
+
from typing import Any
|
7
8
|
|
8
9
|
import numba
|
9
10
|
import numpy as np
|
@@ -30,7 +31,9 @@ from dataeval.utils._fast_mst import calculate_neighbor_distances, minimum_spann
|
|
30
31
|
|
31
32
|
|
32
33
|
@numba.njit(parallel=True, locals={"i": numba.types.int32})
|
33
|
-
def compare_links_to_cluster_std(
|
34
|
+
def compare_links_to_cluster_std(
|
35
|
+
mst: NDArray[np.float32], clusters: NDArray[np.intp]
|
36
|
+
) -> tuple[NDArray[np.int32], NDArray[np.int32]]:
|
34
37
|
cluster_ids = np.unique(clusters)
|
35
38
|
cluster_grouping = np.full(mst.shape[0], -1, dtype=np.int16)
|
36
39
|
|
@@ -79,7 +82,7 @@ def cluster(data: ArrayLike) -> ClusterData:
|
|
79
82
|
cluster_selection_epsilon = 0.0
|
80
83
|
# cluster_selection_method = "eom"
|
81
84
|
|
82
|
-
x = flatten(to_numpy(data))
|
85
|
+
x: NDArray[Any] = flatten(to_numpy(data))
|
83
86
|
samples, features = x.shape # Due to flatten(), we know shape has a length of 2
|
84
87
|
if samples < 2:
|
85
88
|
raise ValueError(f"Data should have at least 2 samples; got {samples}")
|
@@ -125,9 +128,9 @@ def cluster(data: ArrayLike) -> ClusterData:
|
|
125
128
|
return ClusterData(clusters, mst, linkage_tree, condensed_tree, membership_strengths, kneighbors, kdistances)
|
126
129
|
|
127
130
|
|
128
|
-
def sorted_union_find(index_groups):
|
131
|
+
def sorted_union_find(index_groups: NDArray[np.int32]) -> list[list[np.int32]]:
|
129
132
|
"""Merges and sorts groups of indices that share any common index"""
|
130
|
-
groups = [[np.int32(x) for x in range(0)] for y in range(0)]
|
133
|
+
groups: list[list[np.int32]] = [[np.int32(x) for x in range(0)] for y in range(0)]
|
131
134
|
uniques, inverse = np.unique(index_groups, return_inverse=True)
|
132
135
|
inverse = inverse.flatten()
|
133
136
|
disjoint_set = ds_rank_create(uniques.size)
|
dataeval/utils/_fast_mst.py
CHANGED
@@ -6,9 +6,11 @@
|
|
6
6
|
__all__ = []
|
7
7
|
|
8
8
|
import warnings
|
9
|
+
from typing import Any
|
9
10
|
|
10
11
|
import numba
|
11
12
|
import numpy as np
|
13
|
+
from numpy.typing import NDArray
|
12
14
|
from sklearn.neighbors import NearestNeighbors
|
13
15
|
|
14
16
|
with warnings.catch_warnings():
|
@@ -17,24 +19,26 @@ with warnings.catch_warnings():
|
|
17
19
|
|
18
20
|
|
19
21
|
@numba.njit()
|
20
|
-
def _ds_union_by_rank(disjoint_set, point, nbr):
|
22
|
+
def _ds_union_by_rank(disjoint_set: tuple[NDArray[np.int32], NDArray[np.int32]], point: int, nbr: int) -> int:
|
21
23
|
y = ds_find(disjoint_set, point)
|
22
24
|
x = ds_find(disjoint_set, nbr)
|
23
25
|
|
24
26
|
if x == y:
|
25
27
|
return 0
|
26
28
|
|
27
|
-
if disjoint_set
|
29
|
+
if disjoint_set[1][x] < disjoint_set[1][y]:
|
28
30
|
x, y = y, x
|
29
31
|
|
30
|
-
disjoint_set
|
31
|
-
if disjoint_set
|
32
|
-
disjoint_set
|
32
|
+
disjoint_set[0][y] = x
|
33
|
+
if disjoint_set[1][x] == disjoint_set[1][y]:
|
34
|
+
disjoint_set[1][x] += 1
|
33
35
|
return 1
|
34
36
|
|
35
37
|
|
36
38
|
@numba.njit(locals={"i": numba.types.uint32, "nbr": numba.types.uint32, "dist": numba.types.float32})
|
37
|
-
def _init_tree(
|
39
|
+
def _init_tree(
|
40
|
+
n_neighbors: NDArray[np.intp], n_distance: NDArray[np.float32]
|
41
|
+
) -> tuple[NDArray[np.float32], int, tuple[NDArray[np.int32], NDArray[np.int32]], NDArray[np.uint32]]:
|
38
42
|
# Initial graph to hold tree connections
|
39
43
|
tree = np.zeros((n_neighbors.size - 1, 3), dtype=np.float32)
|
40
44
|
disjoint_set = ds_rank_create(n_neighbors.size)
|
@@ -56,7 +60,13 @@ def _init_tree(n_neighbors, n_distance):
|
|
56
60
|
|
57
61
|
|
58
62
|
@numba.njit(locals={"i": numba.types.uint32, "nbr": numba.types.uint32})
|
59
|
-
def _update_tree_by_distance(
|
63
|
+
def _update_tree_by_distance(
|
64
|
+
tree: NDArray[np.float32],
|
65
|
+
int_tree: int,
|
66
|
+
disjoint_set: tuple[NDArray[np.int32], NDArray[np.int32]],
|
67
|
+
n_neighbors: NDArray[np.uint32],
|
68
|
+
n_distance: NDArray[np.float32],
|
69
|
+
) -> tuple[NDArray[np.float32], int, tuple[NDArray[np.int32], NDArray[np.int32]], NDArray[np.uint32]]:
|
60
70
|
cluster_points = np.empty(n_neighbors.size, dtype=np.uint32)
|
61
71
|
sort_dist = np.argsort(n_distance)
|
62
72
|
dist_sorted = n_distance[sort_dist]
|
@@ -80,9 +90,9 @@ def _update_tree_by_distance(tree, int_tree, disjoint_set, n_neighbors, n_distan
|
|
80
90
|
|
81
91
|
|
82
92
|
@numba.njit(locals={"i": numba.types.uint32})
|
83
|
-
def _cluster_edges(tracker, last_idx, cluster_distances):
|
93
|
+
def _cluster_edges(tracker: NDArray[Any], last_idx: int, cluster_distances: NDArray[Any]) -> list[NDArray[np.intp]]:
|
84
94
|
cluster_ids = np.unique(tracker)
|
85
|
-
edge_points = []
|
95
|
+
edge_points: list[NDArray[np.intp]] = []
|
86
96
|
for idx in range(cluster_ids.size):
|
87
97
|
cluster_points = np.nonzero(tracker == cluster_ids[idx])[0]
|
88
98
|
cluster_size = cluster_points.size
|
@@ -102,14 +112,16 @@ def _cluster_edges(tracker, last_idx, cluster_distances):
|
|
102
112
|
return edge_points
|
103
113
|
|
104
114
|
|
105
|
-
def _compute_nn(dataA, dataB, k):
|
115
|
+
def _compute_nn(dataA: NDArray[Any], dataB: NDArray[Any], k: int) -> tuple[NDArray[np.int32], NDArray[np.float32]]:
|
106
116
|
distances, neighbors = NearestNeighbors(n_neighbors=k + 1, algorithm="brute").fit(dataA).kneighbors(dataB)
|
107
117
|
neighbors = np.array(neighbors[:, 1 : k + 1], dtype=np.int32)
|
108
118
|
distances = np.array(distances[:, 1 : k + 1], dtype=np.float32)
|
109
119
|
return neighbors, distances
|
110
120
|
|
111
121
|
|
112
|
-
def _calculate_cluster_neighbors(
|
122
|
+
def _calculate_cluster_neighbors(
|
123
|
+
data: NDArray[Any], groups: list[NDArray[np.intp]], point_array: NDArray[Any]
|
124
|
+
) -> tuple[NDArray[np.uint32], NDArray[np.float32]]:
|
113
125
|
"""Rerun nearest neighbor based on clusters"""
|
114
126
|
cluster_neighbors = np.zeros(point_array.size, dtype=np.uint32)
|
115
127
|
cluster_nbr_distances = np.full(point_array.size, np.inf, dtype=np.float32)
|
@@ -126,7 +138,9 @@ def _calculate_cluster_neighbors(data, groups, point_array):
|
|
126
138
|
return cluster_neighbors, cluster_nbr_distances
|
127
139
|
|
128
140
|
|
129
|
-
def minimum_spanning_tree(
|
141
|
+
def minimum_spanning_tree(
|
142
|
+
data: NDArray[Any], neighbors: NDArray[np.int32], distances: NDArray[np.float32]
|
143
|
+
) -> NDArray[np.float32]:
|
130
144
|
# Transpose arrays to get number of samples along a row
|
131
145
|
k_neighbors = neighbors.T.astype(np.uint32).copy()
|
132
146
|
k_distances = distances.T.astype(np.float32).copy()
|
@@ -168,7 +182,7 @@ def minimum_spanning_tree(data, neighbors, distances):
|
|
168
182
|
return tree
|
169
183
|
|
170
184
|
|
171
|
-
def calculate_neighbor_distances(data: np.ndarray, k: int = 10):
|
185
|
+
def calculate_neighbor_distances(data: np.ndarray, k: int = 10) -> tuple[NDArray[np.int32], NDArray[np.float32]]:
|
172
186
|
# Have the potential to add in other distance calculations - supported calculations:
|
173
187
|
# https://github.com/lmcinnes/pynndescent/blob/master/pynndescent/pynndescent_.py#L524
|
174
188
|
try:
|
dataeval/utils/_image.py
CHANGED
@@ -12,6 +12,9 @@ from scipy.signal import convolve2d
|
|
12
12
|
EDGE_KERNEL = np.array([[-1, -1, -1], [-1, 8, -1], [-1, -1, -1]], dtype=np.int8)
|
13
13
|
BIT_DEPTH = (1, 8, 12, 16, 32)
|
14
14
|
|
15
|
+
Box = tuple[int, int, int, int]
|
16
|
+
"""Bounding box as tuple of integers in x0, y0, x1, y1 format."""
|
17
|
+
|
15
18
|
|
16
19
|
@dataclass
|
17
20
|
class BitDepth:
|
@@ -25,12 +28,11 @@ def get_bitdepth(image: NDArray[Any]) -> BitDepth:
|
|
25
28
|
Approximates the bit depth of the image using the
|
26
29
|
min and max pixel values.
|
27
30
|
"""
|
28
|
-
pmin, pmax = np.
|
31
|
+
pmin, pmax = np.nanmin(image), np.nanmax(image)
|
29
32
|
if pmin < 0:
|
30
33
|
return BitDepth(0, pmin, pmax)
|
31
|
-
|
32
|
-
|
33
|
-
return BitDepth(depth, 0, 2**depth - 1)
|
34
|
+
depth = ([x for x in BIT_DEPTH if 2**x > pmax] or [max(BIT_DEPTH)])[0]
|
35
|
+
return BitDepth(depth, 0, 2**depth - 1)
|
34
36
|
|
35
37
|
|
36
38
|
def rescale(image: NDArray[Any], depth: int = 1) -> NDArray[Any]:
|
@@ -40,9 +42,8 @@ def rescale(image: NDArray[Any], depth: int = 1) -> NDArray[Any]:
|
|
40
42
|
bitdepth = get_bitdepth(image)
|
41
43
|
if bitdepth.depth == depth:
|
42
44
|
return image
|
43
|
-
|
44
|
-
|
45
|
-
return normalized * (2**depth - 1)
|
45
|
+
normalized = (image + bitdepth.pmin) / (bitdepth.pmax - bitdepth.pmin)
|
46
|
+
return normalized * (2**depth - 1)
|
46
47
|
|
47
48
|
|
48
49
|
def normalize_image_shape(image: NDArray[Any]) -> NDArray[Any]:
|
@@ -52,13 +53,12 @@ def normalize_image_shape(image: NDArray[Any]) -> NDArray[Any]:
|
|
52
53
|
ndim = image.ndim
|
53
54
|
if ndim == 2:
|
54
55
|
return np.expand_dims(image, axis=0)
|
55
|
-
|
56
|
+
if ndim == 3:
|
56
57
|
return image
|
57
|
-
|
58
|
+
if ndim > 3:
|
58
59
|
# Slice all but the last 3 dimensions
|
59
60
|
return image[(0,) * (ndim - 3)]
|
60
|
-
|
61
|
-
raise ValueError("Images must have 2 or more dimensions.")
|
61
|
+
raise ValueError("Images must have 2 or more dimensions.")
|
62
62
|
|
63
63
|
|
64
64
|
def edge_filter(image: NDArray[Any], offset: float = 0.5) -> NDArray[np.uint8]:
|
@@ -71,3 +71,57 @@ def edge_filter(image: NDArray[Any], offset: float = 0.5) -> NDArray[np.uint8]:
|
|
71
71
|
edges = convolve2d(image, EDGE_KERNEL, mode="same", boundary="symm") + offset
|
72
72
|
np.clip(edges, 0, 255, edges)
|
73
73
|
return edges
|
74
|
+
|
75
|
+
|
76
|
+
def clip_box(image: NDArray[Any], box: Box) -> Box:
|
77
|
+
"""
|
78
|
+
Clip the box to inside the provided image dimensions.
|
79
|
+
"""
|
80
|
+
x0, y0, x1, y1 = box
|
81
|
+
h, w = image.shape[-2:]
|
82
|
+
|
83
|
+
return max(0, x0), max(0, y0), min(w, x1), min(h, y1)
|
84
|
+
|
85
|
+
|
86
|
+
def is_valid_box(box: Box) -> bool:
|
87
|
+
"""
|
88
|
+
Check if the box dimensions provided are a valid image.
|
89
|
+
"""
|
90
|
+
return box[2] > box[0] and box[3] > box[1]
|
91
|
+
|
92
|
+
|
93
|
+
def clip_and_pad(image: NDArray[Any], box: Box) -> NDArray[Any]:
|
94
|
+
"""
|
95
|
+
Extract a region from an image based on a bounding box, clipping to image boundaries
|
96
|
+
and padding out-of-bounds areas with np.nan.
|
97
|
+
|
98
|
+
Parameters:
|
99
|
+
-----------
|
100
|
+
image : NDArray[Any]
|
101
|
+
Input image array in format C, H, W (channels first)
|
102
|
+
box : Box
|
103
|
+
Bounding box coordinates as (x0, y0, x1, y1) where (x0, y0) is top-left and (x1, y1) is bottom-right
|
104
|
+
|
105
|
+
Returns:
|
106
|
+
--------
|
107
|
+
NDArray[Any]
|
108
|
+
The extracted region with out-of-bounds areas padded with np.nan
|
109
|
+
"""
|
110
|
+
|
111
|
+
# Create output array filled with NaN with a minimum size of 1x1
|
112
|
+
bw, bh = max(1, box[2] - box[0]), max(1, box[3] - box[1])
|
113
|
+
|
114
|
+
output = np.full((image.shape[-3] if image.ndim > 2 else 1, bh, bw), np.nan)
|
115
|
+
|
116
|
+
# Calculate source box
|
117
|
+
sbox = clip_box(image, box)
|
118
|
+
|
119
|
+
# Calculate destination box
|
120
|
+
x0, y0 = sbox[0] - box[0], sbox[1] - box[1]
|
121
|
+
x1, y1 = x0 + (sbox[2] - sbox[0]), y0 + (sbox[3] - sbox[1])
|
122
|
+
|
123
|
+
# Copy the source if valid from the image to the output
|
124
|
+
if is_valid_box(sbox):
|
125
|
+
output[:, y0:y1, x0:x1] = image[:, sbox[1] : sbox[3], sbox[0] : sbox[2]]
|
126
|
+
|
127
|
+
return output
|
dataeval/utils/_mst.py
CHANGED
dataeval/utils/_plot.py
CHANGED
@@ -3,6 +3,7 @@ from __future__ import annotations
|
|
3
3
|
__all__ = []
|
4
4
|
|
5
5
|
import contextlib
|
6
|
+
import math
|
6
7
|
from typing import Any
|
7
8
|
|
8
9
|
import numpy as np
|
@@ -160,11 +161,9 @@ def histogram_plot(
|
|
160
161
|
import matplotlib.pyplot as plt
|
161
162
|
|
162
163
|
num_metrics = len(data_dict)
|
163
|
-
|
164
|
-
|
165
|
-
|
166
|
-
else:
|
167
|
-
fig, axs = plt.subplots(1, num_metrics, figsize=(4 * num_metrics, 4))
|
164
|
+
rows = math.ceil(num_metrics / 3)
|
165
|
+
cols = min(num_metrics, 3)
|
166
|
+
fig, axs = plt.subplots(rows, 3, figsize=(cols * 3 + 1, rows * 3))
|
168
167
|
|
169
168
|
for ax, metric in zip(
|
170
169
|
axs.flat,
|
@@ -178,6 +177,10 @@ def histogram_plot(
|
|
178
177
|
ax.set_ylabel(ylabel)
|
179
178
|
ax.set_xlabel(xlabel)
|
180
179
|
|
180
|
+
for ax in axs.flat[num_metrics:]:
|
181
|
+
ax.axis("off")
|
182
|
+
ax.set_visible(False)
|
183
|
+
|
181
184
|
fig.tight_layout()
|
182
185
|
return fig
|
183
186
|
|
@@ -216,11 +219,9 @@ def channel_histogram_plot(
|
|
216
219
|
label_kwargs = {"label": [f"Channel {i}" for i in range(max_channels)]}
|
217
220
|
|
218
221
|
num_metrics = len(data_keys)
|
219
|
-
|
220
|
-
|
221
|
-
|
222
|
-
else:
|
223
|
-
fig, axs = plt.subplots(1, num_metrics, figsize=(4 * num_metrics, 4))
|
222
|
+
rows = math.ceil(num_metrics / 3)
|
223
|
+
cols = min(num_metrics, 3)
|
224
|
+
fig, axs = plt.subplots(rows, 3, figsize=(cols * 3 + 1, rows * 3))
|
224
225
|
|
225
226
|
for ax, metric in zip(
|
226
227
|
axs.flat,
|
@@ -245,5 +246,9 @@ def channel_histogram_plot(
|
|
245
246
|
ax.set_ylabel(ylabel)
|
246
247
|
ax.set_xlabel(xlabel)
|
247
248
|
|
249
|
+
for ax in axs.flat[num_metrics:]:
|
250
|
+
ax.axis("off")
|
251
|
+
ax.set_visible(False)
|
252
|
+
|
248
253
|
fig.tight_layout()
|
249
254
|
return fig
|
dataeval/utils/data/_dataset.py
CHANGED
@@ -2,7 +2,7 @@ from __future__ import annotations
|
|
2
2
|
|
3
3
|
__all__ = []
|
4
4
|
|
5
|
-
from typing import Any, Generic, Iterable, Literal, Sequence, TypeVar
|
5
|
+
from typing import Any, Generic, Iterable, Literal, Sequence, SupportsFloat, SupportsInt, TypeVar, cast
|
6
6
|
|
7
7
|
from dataeval.typing import (
|
8
8
|
Array,
|
@@ -17,8 +17,8 @@ from dataeval.utils._array import as_numpy
|
|
17
17
|
def _validate_data(
|
18
18
|
datum_type: Literal["ic", "od"],
|
19
19
|
images: Array | Sequence[Array],
|
20
|
-
labels: Sequence[int] | Sequence[Sequence[int]],
|
21
|
-
bboxes: Sequence[Sequence[Sequence[float]]] | None,
|
20
|
+
labels: Array | Sequence[int] | Sequence[Array] | Sequence[Sequence[int]],
|
21
|
+
bboxes: Array | Sequence[Array] | Sequence[Sequence[Array]] | Sequence[Sequence[Sequence[float]]] | None,
|
22
22
|
metadata: Sequence[dict[str, Any]] | None,
|
23
23
|
) -> None:
|
24
24
|
# Validate inputs
|
@@ -34,16 +34,21 @@ def _validate_data(
|
|
34
34
|
raise ValueError(f"Number of metadata ({len(metadata)}) does not match number of images ({dataset_len}).")
|
35
35
|
|
36
36
|
if datum_type == "ic":
|
37
|
-
if not isinstance(labels, Sequence) or not isinstance(labels[0], int):
|
37
|
+
if not isinstance(labels, (Sequence, Array)) or not isinstance(labels[0], (int, SupportsInt)):
|
38
38
|
raise TypeError("Labels must be a sequence of integers for image classification.")
|
39
39
|
elif datum_type == "od":
|
40
|
-
if
|
40
|
+
if (
|
41
|
+
not isinstance(labels, (Sequence, Array))
|
42
|
+
or not isinstance(labels[0], (Sequence, Array))
|
43
|
+
or not isinstance(cast(Sequence[Any], labels[0])[0], (int, SupportsInt))
|
44
|
+
):
|
41
45
|
raise TypeError("Labels must be a sequence of sequences of integers for object detection.")
|
42
46
|
if (
|
43
47
|
bboxes is None
|
44
48
|
or not isinstance(bboxes, (Sequence, Array))
|
45
49
|
or not isinstance(bboxes[0], (Sequence, Array))
|
46
50
|
or not isinstance(bboxes[0][0], (Sequence, Array))
|
51
|
+
or not isinstance(bboxes[0][0][0], (float, SupportsFloat))
|
47
52
|
or not len(bboxes[0][0]) == 4
|
48
53
|
):
|
49
54
|
raise TypeError("Boxes must be a sequence of sequences of (x0, y0, x1, y1) for object detection.")
|
@@ -52,11 +57,10 @@ def _validate_data(
|
|
52
57
|
|
53
58
|
|
54
59
|
def _find_max(arr: ArrayLike) -> Any:
|
55
|
-
if isinstance(arr, (Iterable, Sequence, Array)):
|
60
|
+
if not isinstance(arr, (bytes, str)) and isinstance(arr, (Iterable, Sequence, Array)):
|
56
61
|
if isinstance(arr[0], (Iterable, Sequence, Array)):
|
57
62
|
return max([_find_max(x) for x in arr]) # type: ignore
|
58
|
-
|
59
|
-
return max(arr)
|
63
|
+
return max(arr)
|
60
64
|
return arr
|
61
65
|
|
62
66
|
|
@@ -92,12 +96,14 @@ class CustomImageClassificationDataset(BaseAnnotatedDataset[Sequence[int]], Imag
|
|
92
96
|
def __init__(
|
93
97
|
self,
|
94
98
|
images: Array | Sequence[Array],
|
95
|
-
labels: Sequence[int],
|
99
|
+
labels: Array | Sequence[int],
|
96
100
|
metadata: Sequence[dict[str, Any]] | None,
|
97
101
|
classes: Sequence[str] | None,
|
98
102
|
name: str | None = None,
|
99
103
|
) -> None:
|
100
|
-
super().__init__(
|
104
|
+
super().__init__(
|
105
|
+
"ic", images, as_numpy(labels).tolist() if isinstance(labels, Array) else labels, metadata, classes
|
106
|
+
)
|
101
107
|
if name is not None:
|
102
108
|
self.__name__ = name
|
103
109
|
self.__class__.__name__ = name
|
@@ -135,18 +141,24 @@ class CustomObjectDetectionDataset(BaseAnnotatedDataset[Sequence[Sequence[int]]]
|
|
135
141
|
def __init__(
|
136
142
|
self,
|
137
143
|
images: Array | Sequence[Array],
|
138
|
-
labels: Sequence[Sequence[int]],
|
139
|
-
bboxes: Sequence[Sequence[Sequence[float]]],
|
144
|
+
labels: Array | Sequence[Array] | Sequence[Sequence[int]],
|
145
|
+
bboxes: Array | Sequence[Array] | Sequence[Sequence[Array]] | Sequence[Sequence[Sequence[float]]],
|
140
146
|
metadata: Sequence[dict[str, Any]] | None,
|
141
147
|
classes: Sequence[str] | None,
|
142
148
|
name: str | None = None,
|
143
149
|
) -> None:
|
144
|
-
super().__init__(
|
150
|
+
super().__init__(
|
151
|
+
"od",
|
152
|
+
images,
|
153
|
+
[as_numpy(label).tolist() if isinstance(label, Array) else label for label in labels],
|
154
|
+
metadata,
|
155
|
+
classes,
|
156
|
+
)
|
145
157
|
if name is not None:
|
146
158
|
self.__name__ = name
|
147
159
|
self.__class__.__name__ = name
|
148
160
|
self.__class__.__qualname__ = name
|
149
|
-
self._bboxes = bboxes
|
161
|
+
self._bboxes = [[as_numpy(box).tolist() if isinstance(box, Array) else box for box in bbox] for bbox in bboxes]
|
150
162
|
|
151
163
|
@property
|
152
164
|
def metadata(self) -> DatasetMetadata:
|
@@ -162,7 +174,7 @@ class CustomObjectDetectionDataset(BaseAnnotatedDataset[Sequence[Sequence[int]]]
|
|
162
174
|
|
163
175
|
def to_image_classification_dataset(
|
164
176
|
images: Array | Sequence[Array],
|
165
|
-
labels: Sequence[int],
|
177
|
+
labels: Array | Sequence[int],
|
166
178
|
metadata: Sequence[dict[str, Any]] | None,
|
167
179
|
classes: Sequence[str] | None,
|
168
180
|
name: str | None = None,
|
@@ -174,7 +186,7 @@ def to_image_classification_dataset(
|
|
174
186
|
----------
|
175
187
|
images : Array | Sequence[Array]
|
176
188
|
The images to use in the dataset.
|
177
|
-
labels : Sequence[int]
|
189
|
+
labels : Array | Sequence[int]
|
178
190
|
The labels to use in the dataset.
|
179
191
|
metadata : Sequence[dict[str, Any]] | None
|
180
192
|
The metadata to use in the dataset.
|
@@ -191,8 +203,8 @@ def to_image_classification_dataset(
|
|
191
203
|
|
192
204
|
def to_object_detection_dataset(
|
193
205
|
images: Array | Sequence[Array],
|
194
|
-
labels: Sequence[Sequence[int]],
|
195
|
-
bboxes: Sequence[Sequence[Sequence[float]]],
|
206
|
+
labels: Array | Sequence[Array] | Sequence[Sequence[int]],
|
207
|
+
bboxes: Array | Sequence[Array] | Sequence[Sequence[Array]] | Sequence[Sequence[Sequence[float]]],
|
196
208
|
metadata: Sequence[dict[str, Any]] | None,
|
197
209
|
classes: Sequence[str] | None,
|
198
210
|
name: str | None = None,
|
@@ -204,9 +216,9 @@ def to_object_detection_dataset(
|
|
204
216
|
----------
|
205
217
|
images : Array | Sequence[Array]
|
206
218
|
The images to use in the dataset.
|
207
|
-
labels : Sequence[Sequence[int]]
|
219
|
+
labels : Array | Sequence[Array] | Sequence[Sequence[int]]
|
208
220
|
The labels to use in the dataset.
|
209
|
-
bboxes : Sequence[Sequence[Sequence[float]]]
|
221
|
+
bboxes : Array | Sequence[Array] | Sequence[Sequence[Array]] | Sequence[Sequence[Sequence[float]]]
|
210
222
|
The bounding boxes (x0,y0,x1,y0) to use in the dataset.
|
211
223
|
metadata : Sequence[dict[str, Any]] | None
|
212
224
|
The metadata to use in the dataset.
|