dataeval 0.76.1__py3-none-any.whl → 0.81.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dataeval/__init__.py +3 -3
- dataeval/{output.py → _output.py} +14 -0
- dataeval/config.py +77 -0
- dataeval/detectors/__init__.py +1 -1
- dataeval/detectors/drift/__init__.py +6 -6
- dataeval/detectors/drift/{base.py → _base.py} +41 -30
- dataeval/detectors/drift/{cvm.py → _cvm.py} +21 -28
- dataeval/detectors/drift/{ks.py → _ks.py} +20 -26
- dataeval/detectors/drift/{mmd.py → _mmd.py} +33 -19
- dataeval/detectors/drift/{torch.py → _torch.py} +2 -1
- dataeval/detectors/drift/{uncertainty.py → _uncertainty.py} +23 -7
- dataeval/detectors/drift/updates.py +1 -1
- dataeval/detectors/linters/__init__.py +0 -3
- dataeval/detectors/linters/duplicates.py +17 -8
- dataeval/detectors/linters/outliers.py +23 -14
- dataeval/detectors/ood/ae.py +29 -8
- dataeval/detectors/ood/base.py +5 -4
- dataeval/detectors/ood/metadata_ks_compare.py +1 -1
- dataeval/detectors/ood/mixin.py +20 -5
- dataeval/detectors/ood/output.py +1 -1
- dataeval/detectors/ood/vae.py +73 -0
- dataeval/metadata/__init__.py +5 -0
- dataeval/metadata/_ood.py +238 -0
- dataeval/metrics/__init__.py +1 -1
- dataeval/metrics/bias/__init__.py +5 -4
- dataeval/metrics/bias/{balance.py → _balance.py} +67 -17
- dataeval/metrics/bias/{coverage.py → _coverage.py} +41 -35
- dataeval/metrics/bias/{diversity.py → _diversity.py} +17 -12
- dataeval/metrics/bias/{parity.py → _parity.py} +89 -61
- dataeval/metrics/estimators/__init__.py +14 -4
- dataeval/metrics/estimators/{ber.py → _ber.py} +42 -11
- dataeval/metrics/estimators/_clusterer.py +104 -0
- dataeval/metrics/estimators/{divergence.py → _divergence.py} +18 -13
- dataeval/metrics/estimators/{uap.py → _uap.py} +4 -4
- dataeval/metrics/stats/__init__.py +7 -7
- dataeval/metrics/stats/{base.py → _base.py} +52 -16
- dataeval/metrics/stats/{boxratiostats.py → _boxratiostats.py} +6 -9
- dataeval/metrics/stats/{datasetstats.py → _datasetstats.py} +10 -14
- dataeval/metrics/stats/{dimensionstats.py → _dimensionstats.py} +6 -5
- dataeval/metrics/stats/{hashstats.py → _hashstats.py} +6 -6
- dataeval/metrics/stats/{labelstats.py → _labelstats.py} +4 -4
- dataeval/metrics/stats/{pixelstats.py → _pixelstats.py} +5 -4
- dataeval/metrics/stats/{visualstats.py → _visualstats.py} +9 -8
- dataeval/typing.py +54 -0
- dataeval/utils/__init__.py +2 -2
- dataeval/utils/_array.py +169 -0
- dataeval/utils/_bin.py +199 -0
- dataeval/utils/_clusterer.py +144 -0
- dataeval/utils/_fast_mst.py +189 -0
- dataeval/utils/{image.py → _image.py} +6 -4
- dataeval/utils/_method.py +18 -0
- dataeval/utils/{shared.py → _mst.py} +3 -65
- dataeval/utils/{plot.py → _plot.py} +4 -4
- dataeval/utils/data/__init__.py +22 -0
- dataeval/utils/data/_embeddings.py +105 -0
- dataeval/utils/data/_images.py +65 -0
- dataeval/utils/data/_metadata.py +352 -0
- dataeval/utils/data/_selection.py +119 -0
- dataeval/utils/{dataset/split.py → data/_split.py} +13 -14
- dataeval/utils/data/_targets.py +73 -0
- dataeval/utils/data/_types.py +58 -0
- dataeval/utils/data/collate.py +103 -0
- dataeval/utils/data/datasets/__init__.py +17 -0
- dataeval/utils/data/datasets/_base.py +254 -0
- dataeval/utils/data/datasets/_cifar10.py +134 -0
- dataeval/utils/data/datasets/_fileio.py +168 -0
- dataeval/utils/data/datasets/_milco.py +153 -0
- dataeval/utils/data/datasets/_mixin.py +56 -0
- dataeval/utils/data/datasets/_mnist.py +183 -0
- dataeval/utils/data/datasets/_ships.py +123 -0
- dataeval/utils/data/datasets/_voc.py +352 -0
- dataeval/utils/data/selections/__init__.py +15 -0
- dataeval/utils/data/selections/_classfilter.py +60 -0
- dataeval/utils/data/selections/_indices.py +26 -0
- dataeval/utils/data/selections/_limit.py +26 -0
- dataeval/utils/data/selections/_reverse.py +18 -0
- dataeval/utils/data/selections/_shuffle.py +29 -0
- dataeval/utils/metadata.py +51 -376
- dataeval/utils/torch/{gmm.py → _gmm.py} +4 -2
- dataeval/utils/torch/{internal.py → _internal.py} +21 -51
- dataeval/utils/torch/models.py +43 -2
- dataeval/workflows/sufficiency.py +10 -9
- {dataeval-0.76.1.dist-info → dataeval-0.81.0.dist-info}/METADATA +4 -1
- dataeval-0.81.0.dist-info/RECORD +94 -0
- dataeval/detectors/linters/clusterer.py +0 -512
- dataeval/detectors/linters/merged_stats.py +0 -49
- dataeval/detectors/ood/metadata_least_likely.py +0 -119
- dataeval/interop.py +0 -69
- dataeval/utils/dataset/__init__.py +0 -7
- dataeval/utils/dataset/datasets.py +0 -412
- dataeval/utils/dataset/read.py +0 -63
- dataeval-0.76.1.dist-info/RECORD +0 -67
- /dataeval/{log.py → _log.py} +0 -0
- /dataeval/utils/torch/{blocks.py → _blocks.py} +0 -0
- {dataeval-0.76.1.dist-info → dataeval-0.81.0.dist-info}/LICENSE.txt +0 -0
- {dataeval-0.76.1.dist-info → dataeval-0.81.0.dist-info}/WHEEL +0 -0
@@ -0,0 +1,189 @@
|
|
1
|
+
# Adapted from fast_hdbscan python module
|
2
|
+
# Original Authors: Leland McInnes <https://github.com/TutteInstitute/fast_hdbscan>
|
3
|
+
# Adapted for DataEval by Ryan Wood
|
4
|
+
# License: BSD 2-Clause
|
5
|
+
|
6
|
+
__all__ = []
|
7
|
+
|
8
|
+
import warnings
|
9
|
+
|
10
|
+
import numba
|
11
|
+
import numpy as np
|
12
|
+
from sklearn.neighbors import NearestNeighbors
|
13
|
+
|
14
|
+
with warnings.catch_warnings():
|
15
|
+
warnings.simplefilter("ignore", category=FutureWarning)
|
16
|
+
from fast_hdbscan.disjoint_set import ds_find, ds_rank_create
|
17
|
+
|
18
|
+
|
19
|
+
@numba.njit()
|
20
|
+
def _ds_union_by_rank(disjoint_set, point, nbr):
|
21
|
+
y = ds_find(disjoint_set, point)
|
22
|
+
x = ds_find(disjoint_set, nbr)
|
23
|
+
|
24
|
+
if x == y:
|
25
|
+
return 0
|
26
|
+
|
27
|
+
if disjoint_set.rank[x] < disjoint_set.rank[y]:
|
28
|
+
x, y = y, x
|
29
|
+
|
30
|
+
disjoint_set.parent[y] = x
|
31
|
+
if disjoint_set.rank[x] == disjoint_set.rank[y]:
|
32
|
+
disjoint_set.rank[x] += 1
|
33
|
+
return 1
|
34
|
+
|
35
|
+
|
36
|
+
@numba.njit(locals={"i": numba.types.uint32, "nbr": numba.types.uint32, "dist": numba.types.float32})
|
37
|
+
def _init_tree(n_neighbors, n_distance):
|
38
|
+
# Initial graph to hold tree connections
|
39
|
+
tree = np.zeros((n_neighbors.size - 1, 3), dtype=np.float32)
|
40
|
+
disjoint_set = ds_rank_create(n_neighbors.size)
|
41
|
+
cluster_points = np.empty(n_neighbors.size, dtype=np.uint32)
|
42
|
+
|
43
|
+
int_tree = 0
|
44
|
+
for i in range(n_neighbors.size):
|
45
|
+
nbr = n_neighbors[i]
|
46
|
+
connect = _ds_union_by_rank(disjoint_set, i, nbr)
|
47
|
+
if connect == 1:
|
48
|
+
dist = n_distance[i]
|
49
|
+
tree[int_tree] = (np.float32(i), np.float32(nbr), dist)
|
50
|
+
int_tree += 1
|
51
|
+
|
52
|
+
for i in range(cluster_points.size):
|
53
|
+
cluster_points[i] = ds_find(disjoint_set, i)
|
54
|
+
|
55
|
+
return tree, int_tree, disjoint_set, cluster_points
|
56
|
+
|
57
|
+
|
58
|
+
@numba.njit(locals={"i": numba.types.uint32, "nbr": numba.types.uint32})
|
59
|
+
def _update_tree_by_distance(tree, int_tree, disjoint_set, n_neighbors, n_distance):
|
60
|
+
cluster_points = np.empty(n_neighbors.size, dtype=np.uint32)
|
61
|
+
sort_dist = np.argsort(n_distance)
|
62
|
+
dist_sorted = n_distance[sort_dist]
|
63
|
+
nbrs_sorted = n_neighbors[sort_dist]
|
64
|
+
points = np.arange(n_neighbors.size)
|
65
|
+
point_sorted = points[sort_dist]
|
66
|
+
|
67
|
+
for i in range(n_neighbors.size):
|
68
|
+
point = point_sorted[i]
|
69
|
+
nbr = nbrs_sorted[i]
|
70
|
+
connect = _ds_union_by_rank(disjoint_set, point, nbr)
|
71
|
+
if connect == 1:
|
72
|
+
dist = dist_sorted[i]
|
73
|
+
tree[int_tree] = (np.float32(point), np.float32(nbr), dist)
|
74
|
+
int_tree += 1
|
75
|
+
|
76
|
+
for i in range(cluster_points.size):
|
77
|
+
cluster_points[i] = ds_find(disjoint_set, i)
|
78
|
+
|
79
|
+
return tree, int_tree, disjoint_set, cluster_points
|
80
|
+
|
81
|
+
|
82
|
+
@numba.njit(locals={"i": numba.types.uint32})
|
83
|
+
def _cluster_edges(tracker, last_idx, cluster_distances):
|
84
|
+
cluster_ids = np.unique(tracker)
|
85
|
+
edge_points = []
|
86
|
+
for idx in range(cluster_ids.size):
|
87
|
+
cluster_points = np.nonzero(tracker == cluster_ids[idx])[0]
|
88
|
+
cluster_size = cluster_points.size
|
89
|
+
cluster_mean = cluster_distances[: last_idx + 1, cluster_points].mean()
|
90
|
+
cluster_std = cluster_distances[: last_idx + 1, cluster_points].std()
|
91
|
+
threshold = cluster_mean + cluster_std
|
92
|
+
points_mean = np.empty_like(cluster_points, dtype=np.float32)
|
93
|
+
for i in range(cluster_size):
|
94
|
+
points_mean[i] = cluster_distances[: last_idx + 1, cluster_points[i]].mean()
|
95
|
+
pts_to_add = cluster_points[np.nonzero(points_mean > threshold)[0]]
|
96
|
+
threshold = int(cluster_size * 0.01) if np.floor(np.log10(cluster_size)) > 2 else int(cluster_size * 0.1)
|
97
|
+
threshold = max(10, threshold)
|
98
|
+
if pts_to_add.size > threshold:
|
99
|
+
edge_points.append(pts_to_add)
|
100
|
+
else:
|
101
|
+
edge_points.append(cluster_points)
|
102
|
+
return edge_points
|
103
|
+
|
104
|
+
|
105
|
+
def _compute_nn(dataA, dataB, k):
|
106
|
+
distances, neighbors = NearestNeighbors(n_neighbors=k + 1, algorithm="brute").fit(dataA).kneighbors(dataB)
|
107
|
+
neighbors = np.array(neighbors[:, 1 : k + 1], dtype=np.int32)
|
108
|
+
distances = np.array(distances[:, 1 : k + 1], dtype=np.float32)
|
109
|
+
return neighbors, distances
|
110
|
+
|
111
|
+
|
112
|
+
def _calculate_cluster_neighbors(data, groups, point_array):
|
113
|
+
"""Rerun nearest neighbor based on clusters"""
|
114
|
+
cluster_neighbors = np.zeros(point_array.size, dtype=np.uint32)
|
115
|
+
cluster_nbr_distances = np.full(point_array.size, np.inf, dtype=np.float32)
|
116
|
+
|
117
|
+
for i in range(len(groups)):
|
118
|
+
selectionA = groups[i]
|
119
|
+
groupA = data[selectionA]
|
120
|
+
selectionB = np.concatenate([arr for j, arr in enumerate(groups) if j != i])
|
121
|
+
groupB = data[selectionB]
|
122
|
+
new_neighbors, new_distances = _compute_nn(groupB, groupA, 2)
|
123
|
+
cluster_neighbors[selectionA] = selectionB[new_neighbors[:, 1]]
|
124
|
+
cluster_nbr_distances[selectionA] = new_distances[:, 1]
|
125
|
+
|
126
|
+
return cluster_neighbors, cluster_nbr_distances
|
127
|
+
|
128
|
+
|
129
|
+
def minimum_spanning_tree(data, neighbors, distances):
|
130
|
+
# Transpose arrays to get number of samples along a row
|
131
|
+
k_neighbors = neighbors.T.astype(np.uint32).copy()
|
132
|
+
k_distances = distances.T.astype(np.float32).copy()
|
133
|
+
|
134
|
+
# Create cluster merging tracker
|
135
|
+
merge_tracker = np.full((k_neighbors.shape[0] + 1, k_neighbors.shape[1]), -1, dtype=np.int32)
|
136
|
+
|
137
|
+
# Initialize tree
|
138
|
+
tree, int_tree, tree_disjoint_set, merge_tracker[0] = _init_tree(k_neighbors[0], k_distances[0])
|
139
|
+
|
140
|
+
# Loop through all of the neighbors, updating the tree
|
141
|
+
last_idx = 0
|
142
|
+
for i in range(1, k_neighbors.shape[0]):
|
143
|
+
tree, int_tree, tree_disjoint_set, merge_tracker[i] = _update_tree_by_distance(
|
144
|
+
tree, int_tree, tree_disjoint_set, k_neighbors[i], k_distances[i]
|
145
|
+
)
|
146
|
+
last_idx = i
|
147
|
+
if (merge_tracker[i] == merge_tracker[i - 1]).all():
|
148
|
+
last_idx -= 1
|
149
|
+
break
|
150
|
+
|
151
|
+
# Identify final clusters
|
152
|
+
cluster_ids = np.unique(merge_tracker[last_idx])
|
153
|
+
if cluster_ids.size > 1:
|
154
|
+
# Determining the edge points
|
155
|
+
edge_points = _cluster_edges(merge_tracker[last_idx], last_idx, k_distances)
|
156
|
+
|
157
|
+
# Run nearest neighbor again between clusters to reach single cluster
|
158
|
+
additional_neighbors, additional_distances = _calculate_cluster_neighbors(
|
159
|
+
data, edge_points, merge_tracker[last_idx]
|
160
|
+
)
|
161
|
+
|
162
|
+
# Update clusters
|
163
|
+
last_idx += 1
|
164
|
+
tree, int_tree, tree_disjoint_set, merge_tracker[last_idx] = _update_tree_by_distance(
|
165
|
+
tree, int_tree, tree_disjoint_set, additional_neighbors, additional_distances
|
166
|
+
)
|
167
|
+
|
168
|
+
return tree
|
169
|
+
|
170
|
+
|
171
|
+
def calculate_neighbor_distances(data: np.ndarray, k: int = 10):
|
172
|
+
# Have the potential to add in other distance calculations - supported calculations:
|
173
|
+
# https://github.com/lmcinnes/pynndescent/blob/master/pynndescent/pynndescent_.py#L524
|
174
|
+
try:
|
175
|
+
from pynndescent import NNDescent
|
176
|
+
|
177
|
+
max_descent = 30 if k <= 20 else k + 16
|
178
|
+
index = NNDescent(
|
179
|
+
data,
|
180
|
+
metric="euclidean",
|
181
|
+
n_neighbors=max_descent,
|
182
|
+
)
|
183
|
+
neighbors, distances = index.neighbor_graph
|
184
|
+
except ImportError:
|
185
|
+
distances, neighbors = NearestNeighbors(n_neighbors=k + 1, algorithm="brute").fit(data).kneighbors(data)
|
186
|
+
|
187
|
+
neighbors = np.array(neighbors[:, 1 : k + 1], dtype=np.int32)
|
188
|
+
distances = np.array(distances[:, 1 : k + 1], dtype=np.float32)
|
189
|
+
return neighbors, distances
|
@@ -2,17 +2,19 @@ from __future__ import annotations
|
|
2
2
|
|
3
3
|
__all__ = []
|
4
4
|
|
5
|
-
from
|
5
|
+
from dataclasses import dataclass
|
6
|
+
from typing import Any
|
6
7
|
|
7
8
|
import numpy as np
|
8
|
-
from numpy.typing import
|
9
|
+
from numpy.typing import NDArray
|
9
10
|
from scipy.signal import convolve2d
|
10
11
|
|
11
12
|
EDGE_KERNEL = np.array([[-1, -1, -1], [-1, 8, -1], [-1, -1, -1]], dtype=np.int8)
|
12
13
|
BIT_DEPTH = (1, 8, 12, 16, 32)
|
13
14
|
|
14
15
|
|
15
|
-
|
16
|
+
@dataclass
|
17
|
+
class BitDepth:
|
16
18
|
depth: int
|
17
19
|
pmin: float | int
|
18
20
|
pmax: float | int
|
@@ -59,7 +61,7 @@ def normalize_image_shape(image: NDArray[Any]) -> NDArray[Any]:
|
|
59
61
|
raise ValueError("Images must have 2 or more dimensions.")
|
60
62
|
|
61
63
|
|
62
|
-
def edge_filter(image:
|
64
|
+
def edge_filter(image: NDArray[Any], offset: float = 0.5) -> NDArray[np.uint8]:
|
63
65
|
"""
|
64
66
|
Returns the image filtered using a 3x3 edge detection kernel:
|
65
67
|
[[ -1, -1, -1 ],
|
@@ -0,0 +1,18 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
import sys
|
4
|
+
from typing import Callable, TypeVar
|
5
|
+
|
6
|
+
if sys.version_info >= (3, 10):
|
7
|
+
from typing import ParamSpec
|
8
|
+
else:
|
9
|
+
from typing_extensions import ParamSpec
|
10
|
+
|
11
|
+
P = ParamSpec("P")
|
12
|
+
R = TypeVar("R")
|
13
|
+
|
14
|
+
|
15
|
+
def get_method(method_map: dict[str, Callable[P, R]], method: str) -> Callable[P, R]:
|
16
|
+
if method not in method_map:
|
17
|
+
raise ValueError(f"Specified method {method} is not a valid method: {method_map}.")
|
18
|
+
return method_map[method]
|
@@ -2,53 +2,17 @@ from __future__ import annotations
|
|
2
2
|
|
3
3
|
__all__ = []
|
4
4
|
|
5
|
-
import
|
6
|
-
from typing import Any, Callable, Literal, TypeVar
|
5
|
+
from typing import Any, Literal
|
7
6
|
|
8
|
-
|
9
|
-
from numpy.typing import ArrayLike, NDArray
|
7
|
+
from numpy.typing import NDArray
|
10
8
|
from scipy.sparse import csr_matrix
|
11
9
|
from scipy.sparse.csgraph import minimum_spanning_tree as mst
|
12
10
|
from scipy.spatial.distance import pdist, squareform
|
13
11
|
from sklearn.neighbors import NearestNeighbors
|
14
12
|
|
15
|
-
|
16
|
-
from typing import ParamSpec
|
17
|
-
else:
|
18
|
-
from typing_extensions import ParamSpec
|
19
|
-
|
20
|
-
from dataeval.interop import as_numpy
|
13
|
+
from dataeval.utils._array import flatten
|
21
14
|
|
22
15
|
EPSILON = 1e-5
|
23
|
-
HASH_SIZE = 8
|
24
|
-
MAX_FACTOR = 4
|
25
|
-
|
26
|
-
|
27
|
-
P = ParamSpec("P")
|
28
|
-
R = TypeVar("R")
|
29
|
-
|
30
|
-
|
31
|
-
def get_method(method_map: dict[str, Callable[P, R]], method: str) -> Callable[P, R]:
|
32
|
-
if method not in method_map:
|
33
|
-
raise ValueError(f"Specified method {method} is not a valid method: {method_map}.")
|
34
|
-
return method_map[method]
|
35
|
-
|
36
|
-
|
37
|
-
def flatten(array: ArrayLike) -> NDArray[Any]:
|
38
|
-
"""
|
39
|
-
Flattens input array from (N, ... ) to (N, -1) where all samples N have all data in their last dimension
|
40
|
-
|
41
|
-
Parameters
|
42
|
-
----------
|
43
|
-
X : NDArray, shape - (N, ... )
|
44
|
-
Input array
|
45
|
-
|
46
|
-
Returns
|
47
|
-
-------
|
48
|
-
NDArray, shape - (N, -1)
|
49
|
-
"""
|
50
|
-
nparr = as_numpy(array)
|
51
|
-
return nparr.reshape((nparr.shape[0], -1))
|
52
16
|
|
53
17
|
|
54
18
|
def minimum_spanning_tree(X: NDArray[Any]) -> Any:
|
@@ -73,32 +37,6 @@ def minimum_spanning_tree(X: NDArray[Any]) -> Any:
|
|
73
37
|
return mst(eudist_csr)
|
74
38
|
|
75
39
|
|
76
|
-
def get_classes_counts(labels: NDArray[np.int_]) -> tuple[int, int]:
|
77
|
-
"""
|
78
|
-
Returns the classes and counts of from an array of labels
|
79
|
-
|
80
|
-
Parameters
|
81
|
-
----------
|
82
|
-
label : NDArray
|
83
|
-
Numpy labels array
|
84
|
-
|
85
|
-
Returns
|
86
|
-
-------
|
87
|
-
Classes and counts
|
88
|
-
|
89
|
-
Raises
|
90
|
-
------
|
91
|
-
ValueError
|
92
|
-
If the number of unique classes is less than 2
|
93
|
-
"""
|
94
|
-
classes, counts = np.unique(labels, return_counts=True)
|
95
|
-
M = len(classes)
|
96
|
-
if M < 2:
|
97
|
-
raise ValueError("Label vector contains less than 2 classes!")
|
98
|
-
N = int(np.sum(counts))
|
99
|
-
return M, N
|
100
|
-
|
101
|
-
|
102
40
|
def compute_neighbors(
|
103
41
|
A: NDArray[Any],
|
104
42
|
B: NDArray[Any],
|
@@ -6,9 +6,9 @@ import contextlib
|
|
6
6
|
from typing import Any
|
7
7
|
|
8
8
|
import numpy as np
|
9
|
-
from numpy.typing import ArrayLike
|
10
9
|
|
11
|
-
from dataeval.
|
10
|
+
from dataeval.typing import ArrayLike
|
11
|
+
from dataeval.utils._array import to_numpy
|
12
12
|
|
13
13
|
with contextlib.suppress(ImportError):
|
14
14
|
from matplotlib.figure import Figure
|
@@ -171,7 +171,7 @@ def histogram_plot(
|
|
171
171
|
data_dict,
|
172
172
|
):
|
173
173
|
# Plot the histogram for the chosen metric
|
174
|
-
ax.hist(data_dict[metric], bins=20, log=log)
|
174
|
+
ax.hist(data_dict[metric].astype(np.float64), bins=20, log=log)
|
175
175
|
|
176
176
|
# Add labels to the histogram
|
177
177
|
ax.set_title(metric)
|
@@ -229,7 +229,7 @@ def channel_histogram_plot(
|
|
229
229
|
# Plot the histogram for the chosen metric
|
230
230
|
data = data_dict[metric][ch_mask].reshape(-1, max_channels)
|
231
231
|
ax.hist(
|
232
|
-
data,
|
232
|
+
data.astype(np.float64),
|
233
233
|
bins=20,
|
234
234
|
density=True,
|
235
235
|
log=log,
|
@@ -0,0 +1,22 @@
|
|
1
|
+
"""Provides utility functions for interacting with Computer Vision datasets."""
|
2
|
+
|
3
|
+
__all__ = [
|
4
|
+
"collate",
|
5
|
+
"datasets",
|
6
|
+
"Embeddings",
|
7
|
+
"Images",
|
8
|
+
"Metadata",
|
9
|
+
"Select",
|
10
|
+
"SplitDatasetOutput",
|
11
|
+
"Targets",
|
12
|
+
"split_dataset",
|
13
|
+
]
|
14
|
+
|
15
|
+
from dataeval.utils.data._embeddings import Embeddings
|
16
|
+
from dataeval.utils.data._images import Images
|
17
|
+
from dataeval.utils.data._metadata import Metadata
|
18
|
+
from dataeval.utils.data._selection import Select
|
19
|
+
from dataeval.utils.data._split import SplitDatasetOutput, split_dataset
|
20
|
+
from dataeval.utils.data._targets import Targets
|
21
|
+
|
22
|
+
from . import collate, datasets
|
@@ -0,0 +1,105 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
__all__ = []
|
4
|
+
|
5
|
+
import math
|
6
|
+
from typing import Any, Iterator, Sequence
|
7
|
+
|
8
|
+
import torch
|
9
|
+
from torch.utils.data import DataLoader, Subset
|
10
|
+
from tqdm import tqdm
|
11
|
+
|
12
|
+
from dataeval.config import get_device
|
13
|
+
from dataeval.typing import TArray
|
14
|
+
from dataeval.utils.data._types import Dataset
|
15
|
+
from dataeval.utils.torch.models import SupportsEncode
|
16
|
+
|
17
|
+
|
18
|
+
class Embeddings:
|
19
|
+
"""
|
20
|
+
Collection of image embeddings from a dataset.
|
21
|
+
|
22
|
+
Embeddings are accessed by index or slice and are only loaded on-demand.
|
23
|
+
|
24
|
+
Parameters
|
25
|
+
----------
|
26
|
+
dataset : ImageClassificationDataset or ObjectDetectionDataset
|
27
|
+
Dataset to access original images from.
|
28
|
+
batch_size : int, optional
|
29
|
+
Batch size to use when encoding images.
|
30
|
+
model : torch.nn.Module, optional
|
31
|
+
Model to use for encoding images.
|
32
|
+
device : torch.device, optional
|
33
|
+
Device to use for encoding images.
|
34
|
+
verbose : bool, optional
|
35
|
+
Whether to print progress bar when encoding images.
|
36
|
+
"""
|
37
|
+
|
38
|
+
device: torch.device
|
39
|
+
batch_size: int
|
40
|
+
verbose: bool
|
41
|
+
|
42
|
+
def __init__(
|
43
|
+
self,
|
44
|
+
dataset: Dataset[TArray, Any],
|
45
|
+
batch_size: int,
|
46
|
+
indices: Sequence[int] | None = None,
|
47
|
+
model: torch.nn.Module | None = None,
|
48
|
+
device: torch.device | str | None = None,
|
49
|
+
verbose: bool = False,
|
50
|
+
) -> None:
|
51
|
+
self.device = get_device(device)
|
52
|
+
self.batch_size = batch_size
|
53
|
+
self.verbose = verbose
|
54
|
+
|
55
|
+
self._dataset = dataset
|
56
|
+
self._indices = indices if indices is not None else range(len(dataset))
|
57
|
+
model = torch.nn.Flatten() if model is None else model
|
58
|
+
self._model = model.to(self.device).eval()
|
59
|
+
self._encoder = model.encode if isinstance(model, SupportsEncode) else model
|
60
|
+
self._collate_fn = lambda datum: [torch.as_tensor(i) for i, _, _ in datum]
|
61
|
+
|
62
|
+
def to_tensor(self) -> torch.Tensor:
|
63
|
+
"""
|
64
|
+
Converts entire dataset to embeddings.
|
65
|
+
|
66
|
+
Warning
|
67
|
+
-------
|
68
|
+
Will process the entire dataset in batches and return
|
69
|
+
embeddings as a single Tensor in memory.
|
70
|
+
|
71
|
+
Returns
|
72
|
+
-------
|
73
|
+
torch.Tensor
|
74
|
+
"""
|
75
|
+
return self[:]
|
76
|
+
|
77
|
+
# Reduce overhead cost by not tracking tensor gradients
|
78
|
+
@torch.no_grad
|
79
|
+
def _batch(self, indices: Sequence[int]) -> Iterator[torch.Tensor]:
|
80
|
+
# manual batching
|
81
|
+
dataloader = DataLoader(Subset(self._dataset, indices), batch_size=self.batch_size, collate_fn=self._collate_fn)
|
82
|
+
for i, images in (
|
83
|
+
tqdm(enumerate(dataloader), total=math.ceil(len(indices) / self.batch_size), desc="Batch processing")
|
84
|
+
if self.verbose
|
85
|
+
else enumerate(dataloader)
|
86
|
+
):
|
87
|
+
embeddings = self._encoder(torch.stack(images).to(self.device))
|
88
|
+
yield embeddings
|
89
|
+
|
90
|
+
def __getitem__(self, key: int | slice | list[int]) -> torch.Tensor:
|
91
|
+
if isinstance(key, list):
|
92
|
+
return torch.vstack(list(self._batch(key))).to(self.device)
|
93
|
+
if isinstance(key, slice):
|
94
|
+
return torch.vstack(list(self._batch(range(len(self._dataset))[key]))).to(self.device)
|
95
|
+
elif isinstance(key, int):
|
96
|
+
return self._encoder(torch.as_tensor(self._dataset[key][0]).to(self.device))
|
97
|
+
raise TypeError("Invalid argument type.")
|
98
|
+
|
99
|
+
def __iter__(self) -> Iterator[torch.Tensor]:
|
100
|
+
# process in batches while yielding individual embeddings
|
101
|
+
for batch in self._batch(range(len(self._dataset))):
|
102
|
+
yield from batch
|
103
|
+
|
104
|
+
def __len__(self) -> int:
|
105
|
+
return len(self._dataset)
|
@@ -0,0 +1,65 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
__all__ = []
|
4
|
+
|
5
|
+
from typing import Any, Generic, Iterator, Sequence, overload
|
6
|
+
|
7
|
+
from dataeval.typing import TArray
|
8
|
+
from dataeval.utils.data._types import Dataset
|
9
|
+
|
10
|
+
|
11
|
+
class Images(Generic[TArray]):
|
12
|
+
"""
|
13
|
+
Collection of image data from a dataset.
|
14
|
+
|
15
|
+
Images are accessed by index or slice and are only loaded on-demand.
|
16
|
+
|
17
|
+
Parameters
|
18
|
+
----------
|
19
|
+
dataset : ImageClassificationDataset or ObjectDetectionDataset
|
20
|
+
Dataset to access images from.
|
21
|
+
"""
|
22
|
+
|
23
|
+
def __init__(
|
24
|
+
self,
|
25
|
+
dataset: Dataset[TArray, Any],
|
26
|
+
) -> None:
|
27
|
+
self._dataset = dataset
|
28
|
+
|
29
|
+
def to_list(self) -> Sequence[TArray]:
|
30
|
+
"""
|
31
|
+
Converts entire dataset to a sequence of images.
|
32
|
+
|
33
|
+
Warning
|
34
|
+
-------
|
35
|
+
Will load the entire dataset and return the images as a
|
36
|
+
single sequence of images in memory.
|
37
|
+
|
38
|
+
Returns
|
39
|
+
-------
|
40
|
+
list[TArray]
|
41
|
+
"""
|
42
|
+
return self[:]
|
43
|
+
|
44
|
+
@overload
|
45
|
+
def __getitem__(self, key: slice | list[int]) -> Sequence[TArray]: ...
|
46
|
+
|
47
|
+
@overload
|
48
|
+
def __getitem__(self, key: int) -> TArray: ...
|
49
|
+
|
50
|
+
def __getitem__(self, key: int | slice | list[int]) -> Sequence[TArray] | TArray:
|
51
|
+
if isinstance(key, list):
|
52
|
+
return [self._dataset[i][0] for i in key]
|
53
|
+
if isinstance(key, slice):
|
54
|
+
indices = list(range(len(self._dataset))[key])
|
55
|
+
return [self._dataset[i][0] for i in indices]
|
56
|
+
elif isinstance(key, int):
|
57
|
+
return self._dataset[key][0]
|
58
|
+
raise TypeError("Invalid argument type.")
|
59
|
+
|
60
|
+
def __iter__(self) -> Iterator[TArray]:
|
61
|
+
for i in range(len(self._dataset)):
|
62
|
+
yield self._dataset[i][0]
|
63
|
+
|
64
|
+
def __len__(self) -> int:
|
65
|
+
return len(self._dataset)
|