dataeval 0.76.1__py3-none-any.whl → 0.82.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dataeval/__init__.py +3 -3
- dataeval/config.py +77 -0
- dataeval/detectors/__init__.py +1 -1
- dataeval/detectors/drift/__init__.py +6 -6
- dataeval/detectors/drift/{base.py → _base.py} +40 -85
- dataeval/detectors/drift/{cvm.py → _cvm.py} +21 -28
- dataeval/detectors/drift/{ks.py → _ks.py} +20 -26
- dataeval/detectors/drift/{mmd.py → _mmd.py} +31 -43
- dataeval/detectors/drift/{torch.py → _torch.py} +2 -1
- dataeval/detectors/drift/{uncertainty.py → _uncertainty.py} +24 -7
- dataeval/detectors/drift/updates.py +20 -3
- dataeval/detectors/linters/__init__.py +3 -5
- dataeval/detectors/linters/duplicates.py +13 -36
- dataeval/detectors/linters/outliers.py +23 -148
- dataeval/detectors/ood/__init__.py +1 -1
- dataeval/detectors/ood/ae.py +30 -9
- dataeval/detectors/ood/base.py +5 -4
- dataeval/detectors/ood/mixin.py +21 -7
- dataeval/detectors/ood/vae.py +73 -0
- dataeval/metadata/__init__.py +6 -0
- dataeval/metadata/_distance.py +167 -0
- dataeval/metadata/_ood.py +217 -0
- dataeval/metadata/_utils.py +44 -0
- dataeval/metrics/__init__.py +1 -1
- dataeval/metrics/bias/__init__.py +6 -4
- dataeval/metrics/bias/{balance.py → _balance.py} +15 -101
- dataeval/metrics/bias/_coverage.py +98 -0
- dataeval/metrics/bias/{diversity.py → _diversity.py} +18 -111
- dataeval/metrics/bias/{parity.py → _parity.py} +39 -77
- dataeval/metrics/estimators/__init__.py +15 -4
- dataeval/metrics/estimators/{ber.py → _ber.py} +42 -29
- dataeval/metrics/estimators/_clusterer.py +44 -0
- dataeval/metrics/estimators/{divergence.py → _divergence.py} +18 -30
- dataeval/metrics/estimators/{uap.py → _uap.py} +4 -18
- dataeval/metrics/stats/__init__.py +16 -13
- dataeval/metrics/stats/{base.py → _base.py} +82 -133
- dataeval/metrics/stats/{boxratiostats.py → _boxratiostats.py} +15 -18
- dataeval/metrics/stats/_dimensionstats.py +75 -0
- dataeval/metrics/stats/{hashstats.py → _hashstats.py} +21 -37
- dataeval/metrics/stats/_imagestats.py +94 -0
- dataeval/metrics/stats/_labelstats.py +131 -0
- dataeval/metrics/stats/{pixelstats.py → _pixelstats.py} +19 -50
- dataeval/metrics/stats/{visualstats.py → _visualstats.py} +23 -54
- dataeval/outputs/__init__.py +53 -0
- dataeval/{output.py → outputs/_base.py} +55 -25
- dataeval/outputs/_bias.py +381 -0
- dataeval/outputs/_drift.py +83 -0
- dataeval/outputs/_estimators.py +114 -0
- dataeval/outputs/_linters.py +184 -0
- dataeval/{detectors/ood/output.py → outputs/_ood.py} +22 -22
- dataeval/outputs/_stats.py +387 -0
- dataeval/outputs/_utils.py +44 -0
- dataeval/outputs/_workflows.py +364 -0
- dataeval/typing.py +234 -0
- dataeval/utils/__init__.py +2 -2
- dataeval/utils/_array.py +169 -0
- dataeval/utils/_bin.py +199 -0
- dataeval/utils/_clusterer.py +144 -0
- dataeval/utils/_fast_mst.py +189 -0
- dataeval/utils/{image.py → _image.py} +6 -4
- dataeval/utils/_method.py +14 -0
- dataeval/utils/{shared.py → _mst.py} +3 -65
- dataeval/utils/{plot.py → _plot.py} +6 -6
- dataeval/utils/data/__init__.py +26 -0
- dataeval/utils/data/_dataset.py +217 -0
- dataeval/utils/data/_embeddings.py +104 -0
- dataeval/utils/data/_images.py +68 -0
- dataeval/utils/data/_metadata.py +360 -0
- dataeval/utils/data/_selection.py +126 -0
- dataeval/utils/{dataset/split.py → data/_split.py} +12 -38
- dataeval/utils/data/_targets.py +85 -0
- dataeval/utils/data/collate.py +103 -0
- dataeval/utils/data/datasets/__init__.py +17 -0
- dataeval/utils/data/datasets/_base.py +254 -0
- dataeval/utils/data/datasets/_cifar10.py +134 -0
- dataeval/utils/data/datasets/_fileio.py +168 -0
- dataeval/utils/data/datasets/_milco.py +153 -0
- dataeval/utils/data/datasets/_mixin.py +56 -0
- dataeval/utils/data/datasets/_mnist.py +183 -0
- dataeval/utils/data/datasets/_ships.py +123 -0
- dataeval/utils/data/datasets/_types.py +52 -0
- dataeval/utils/data/datasets/_voc.py +352 -0
- dataeval/utils/data/selections/__init__.py +15 -0
- dataeval/utils/data/selections/_classfilter.py +57 -0
- dataeval/utils/data/selections/_indices.py +26 -0
- dataeval/utils/data/selections/_limit.py +26 -0
- dataeval/utils/data/selections/_reverse.py +18 -0
- dataeval/utils/data/selections/_shuffle.py +29 -0
- dataeval/utils/metadata.py +51 -376
- dataeval/utils/torch/{gmm.py → _gmm.py} +4 -2
- dataeval/utils/torch/{internal.py → _internal.py} +21 -51
- dataeval/utils/torch/models.py +43 -2
- dataeval/workflows/__init__.py +2 -1
- dataeval/workflows/sufficiency.py +11 -346
- {dataeval-0.76.1.dist-info → dataeval-0.82.0.dist-info}/METADATA +5 -2
- dataeval-0.82.0.dist-info/RECORD +104 -0
- dataeval/detectors/linters/clusterer.py +0 -512
- dataeval/detectors/linters/merged_stats.py +0 -49
- dataeval/detectors/ood/metadata_ks_compare.py +0 -129
- dataeval/detectors/ood/metadata_least_likely.py +0 -119
- dataeval/interop.py +0 -69
- dataeval/metrics/bias/coverage.py +0 -194
- dataeval/metrics/stats/datasetstats.py +0 -202
- dataeval/metrics/stats/dimensionstats.py +0 -115
- dataeval/metrics/stats/labelstats.py +0 -210
- dataeval/utils/dataset/__init__.py +0 -7
- dataeval/utils/dataset/datasets.py +0 -412
- dataeval/utils/dataset/read.py +0 -63
- dataeval-0.76.1.dist-info/RECORD +0 -67
- /dataeval/{log.py → _log.py} +0 -0
- /dataeval/utils/torch/{blocks.py → _blocks.py} +0 -0
- {dataeval-0.76.1.dist-info → dataeval-0.82.0.dist-info}/LICENSE.txt +0 -0
- {dataeval-0.76.1.dist-info → dataeval-0.82.0.dist-info}/WHEEL +0 -0
@@ -2,17 +2,19 @@ from __future__ import annotations
|
|
2
2
|
|
3
3
|
__all__ = []
|
4
4
|
|
5
|
-
from
|
5
|
+
from dataclasses import dataclass
|
6
|
+
from typing import Any
|
6
7
|
|
7
8
|
import numpy as np
|
8
|
-
from numpy.typing import
|
9
|
+
from numpy.typing import NDArray
|
9
10
|
from scipy.signal import convolve2d
|
10
11
|
|
11
12
|
EDGE_KERNEL = np.array([[-1, -1, -1], [-1, 8, -1], [-1, -1, -1]], dtype=np.int8)
|
12
13
|
BIT_DEPTH = (1, 8, 12, 16, 32)
|
13
14
|
|
14
15
|
|
15
|
-
|
16
|
+
@dataclass
|
17
|
+
class BitDepth:
|
16
18
|
depth: int
|
17
19
|
pmin: float | int
|
18
20
|
pmax: float | int
|
@@ -59,7 +61,7 @@ def normalize_image_shape(image: NDArray[Any]) -> NDArray[Any]:
|
|
59
61
|
raise ValueError("Images must have 2 or more dimensions.")
|
60
62
|
|
61
63
|
|
62
|
-
def edge_filter(image:
|
64
|
+
def edge_filter(image: NDArray[Any], offset: float = 0.5) -> NDArray[np.uint8]:
|
63
65
|
"""
|
64
66
|
Returns the image filtered using a 3x3 edge detection kernel:
|
65
67
|
[[ -1, -1, -1 ],
|
@@ -0,0 +1,14 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
from typing import Callable, TypeVar
|
4
|
+
|
5
|
+
from typing_extensions import ParamSpec
|
6
|
+
|
7
|
+
P = ParamSpec("P")
|
8
|
+
R = TypeVar("R")
|
9
|
+
|
10
|
+
|
11
|
+
def get_method(method_map: dict[str, Callable[P, R]], method: str) -> Callable[P, R]:
|
12
|
+
if method not in method_map:
|
13
|
+
raise ValueError(f"Specified method {method} is not a valid method: {method_map}.")
|
14
|
+
return method_map[method]
|
@@ -2,53 +2,17 @@ from __future__ import annotations
|
|
2
2
|
|
3
3
|
__all__ = []
|
4
4
|
|
5
|
-
import
|
6
|
-
from typing import Any, Callable, Literal, TypeVar
|
5
|
+
from typing import Any, Literal
|
7
6
|
|
8
|
-
|
9
|
-
from numpy.typing import ArrayLike, NDArray
|
7
|
+
from numpy.typing import NDArray
|
10
8
|
from scipy.sparse import csr_matrix
|
11
9
|
from scipy.sparse.csgraph import minimum_spanning_tree as mst
|
12
10
|
from scipy.spatial.distance import pdist, squareform
|
13
11
|
from sklearn.neighbors import NearestNeighbors
|
14
12
|
|
15
|
-
|
16
|
-
from typing import ParamSpec
|
17
|
-
else:
|
18
|
-
from typing_extensions import ParamSpec
|
19
|
-
|
20
|
-
from dataeval.interop import as_numpy
|
13
|
+
from dataeval.utils._array import flatten
|
21
14
|
|
22
15
|
EPSILON = 1e-5
|
23
|
-
HASH_SIZE = 8
|
24
|
-
MAX_FACTOR = 4
|
25
|
-
|
26
|
-
|
27
|
-
P = ParamSpec("P")
|
28
|
-
R = TypeVar("R")
|
29
|
-
|
30
|
-
|
31
|
-
def get_method(method_map: dict[str, Callable[P, R]], method: str) -> Callable[P, R]:
|
32
|
-
if method not in method_map:
|
33
|
-
raise ValueError(f"Specified method {method} is not a valid method: {method_map}.")
|
34
|
-
return method_map[method]
|
35
|
-
|
36
|
-
|
37
|
-
def flatten(array: ArrayLike) -> NDArray[Any]:
|
38
|
-
"""
|
39
|
-
Flattens input array from (N, ... ) to (N, -1) where all samples N have all data in their last dimension
|
40
|
-
|
41
|
-
Parameters
|
42
|
-
----------
|
43
|
-
X : NDArray, shape - (N, ... )
|
44
|
-
Input array
|
45
|
-
|
46
|
-
Returns
|
47
|
-
-------
|
48
|
-
NDArray, shape - (N, -1)
|
49
|
-
"""
|
50
|
-
nparr = as_numpy(array)
|
51
|
-
return nparr.reshape((nparr.shape[0], -1))
|
52
16
|
|
53
17
|
|
54
18
|
def minimum_spanning_tree(X: NDArray[Any]) -> Any:
|
@@ -73,32 +37,6 @@ def minimum_spanning_tree(X: NDArray[Any]) -> Any:
|
|
73
37
|
return mst(eudist_csr)
|
74
38
|
|
75
39
|
|
76
|
-
def get_classes_counts(labels: NDArray[np.int_]) -> tuple[int, int]:
|
77
|
-
"""
|
78
|
-
Returns the classes and counts of from an array of labels
|
79
|
-
|
80
|
-
Parameters
|
81
|
-
----------
|
82
|
-
label : NDArray
|
83
|
-
Numpy labels array
|
84
|
-
|
85
|
-
Returns
|
86
|
-
-------
|
87
|
-
Classes and counts
|
88
|
-
|
89
|
-
Raises
|
90
|
-
------
|
91
|
-
ValueError
|
92
|
-
If the number of unique classes is less than 2
|
93
|
-
"""
|
94
|
-
classes, counts = np.unique(labels, return_counts=True)
|
95
|
-
M = len(classes)
|
96
|
-
if M < 2:
|
97
|
-
raise ValueError("Label vector contains less than 2 classes!")
|
98
|
-
N = int(np.sum(counts))
|
99
|
-
return M, N
|
100
|
-
|
101
|
-
|
102
40
|
def compute_neighbors(
|
103
41
|
A: NDArray[Any],
|
104
42
|
B: NDArray[Any],
|
@@ -6,9 +6,9 @@ import contextlib
|
|
6
6
|
from typing import Any
|
7
7
|
|
8
8
|
import numpy as np
|
9
|
-
from numpy.typing import ArrayLike
|
10
9
|
|
11
|
-
from dataeval.
|
10
|
+
from dataeval.typing import ArrayLike
|
11
|
+
from dataeval.utils._array import to_numpy
|
12
12
|
|
13
13
|
with contextlib.suppress(ImportError):
|
14
14
|
from matplotlib.figure import Figure
|
@@ -49,8 +49,8 @@ def heatmap(
|
|
49
49
|
from matplotlib.ticker import FuncFormatter
|
50
50
|
|
51
51
|
np_data = to_numpy(data)
|
52
|
-
rows =
|
53
|
-
cols =
|
52
|
+
rows: list[str] = [str(n) for n in to_numpy(row_labels)]
|
53
|
+
cols: list[str] = [str(n) for n in to_numpy(col_labels)]
|
54
54
|
|
55
55
|
fig, ax = plt.subplots(figsize=(10, 10))
|
56
56
|
|
@@ -171,7 +171,7 @@ def histogram_plot(
|
|
171
171
|
data_dict,
|
172
172
|
):
|
173
173
|
# Plot the histogram for the chosen metric
|
174
|
-
ax.hist(data_dict[metric], bins=20, log=log)
|
174
|
+
ax.hist(data_dict[metric].astype(np.float64), bins=20, log=log)
|
175
175
|
|
176
176
|
# Add labels to the histogram
|
177
177
|
ax.set_title(metric)
|
@@ -229,7 +229,7 @@ def channel_histogram_plot(
|
|
229
229
|
# Plot the histogram for the chosen metric
|
230
230
|
data = data_dict[metric][ch_mask].reshape(-1, max_channels)
|
231
231
|
ax.hist(
|
232
|
-
data,
|
232
|
+
data.astype(np.float64),
|
233
233
|
bins=20,
|
234
234
|
density=True,
|
235
235
|
log=log,
|
@@ -0,0 +1,26 @@
|
|
1
|
+
"""Provides utility functions for interacting with Computer Vision datasets."""
|
2
|
+
|
3
|
+
__all__ = [
|
4
|
+
"collate",
|
5
|
+
"datasets",
|
6
|
+
"Embeddings",
|
7
|
+
"Images",
|
8
|
+
"Metadata",
|
9
|
+
"Select",
|
10
|
+
"SplitDatasetOutput",
|
11
|
+
"Targets",
|
12
|
+
"split_dataset",
|
13
|
+
"to_image_classification_dataset",
|
14
|
+
"to_object_detection_dataset",
|
15
|
+
]
|
16
|
+
|
17
|
+
from dataeval.outputs._utils import SplitDatasetOutput
|
18
|
+
from dataeval.utils.data._dataset import to_image_classification_dataset, to_object_detection_dataset
|
19
|
+
from dataeval.utils.data._embeddings import Embeddings
|
20
|
+
from dataeval.utils.data._images import Images
|
21
|
+
from dataeval.utils.data._metadata import Metadata
|
22
|
+
from dataeval.utils.data._selection import Select
|
23
|
+
from dataeval.utils.data._split import split_dataset
|
24
|
+
from dataeval.utils.data._targets import Targets
|
25
|
+
|
26
|
+
from . import collate, datasets
|
@@ -0,0 +1,217 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
__all__ = []
|
4
|
+
|
5
|
+
from typing import Any, Generic, Iterable, Literal, Sequence, TypeVar
|
6
|
+
|
7
|
+
from dataeval.typing import (
|
8
|
+
Array,
|
9
|
+
ArrayLike,
|
10
|
+
DatasetMetadata,
|
11
|
+
ImageClassificationDataset,
|
12
|
+
ObjectDetectionDataset,
|
13
|
+
)
|
14
|
+
from dataeval.utils._array import as_numpy
|
15
|
+
|
16
|
+
|
17
|
+
def _validate_data(
|
18
|
+
datum_type: Literal["ic", "od"],
|
19
|
+
images: Array | Sequence[Array],
|
20
|
+
labels: Sequence[int] | Sequence[Sequence[int]],
|
21
|
+
bboxes: Sequence[Sequence[Sequence[float]]] | None,
|
22
|
+
metadata: Sequence[dict[str, Any]] | None,
|
23
|
+
) -> None:
|
24
|
+
# Validate inputs
|
25
|
+
dataset_len = len(images)
|
26
|
+
|
27
|
+
if not isinstance(images, (Sequence, Array)) or len(images[0].shape) != 3:
|
28
|
+
raise ValueError("Images must be a sequence or array of 3 dimensional arrays (H, W, C).")
|
29
|
+
if len(labels) != dataset_len:
|
30
|
+
raise ValueError(f"Number of labels ({len(labels)}) does not match number of images ({dataset_len}).")
|
31
|
+
if bboxes is not None and len(bboxes) != dataset_len:
|
32
|
+
raise ValueError(f"Number of bboxes ({len(bboxes)}) does not match number of images ({dataset_len}).")
|
33
|
+
if metadata is not None and len(metadata) != dataset_len:
|
34
|
+
raise ValueError(f"Number of metadata ({len(metadata)}) does not match number of images ({dataset_len}).")
|
35
|
+
|
36
|
+
if datum_type == "ic":
|
37
|
+
if not isinstance(labels, Sequence) or not isinstance(labels[0], int):
|
38
|
+
raise TypeError("Labels must be a sequence of integers for image classification.")
|
39
|
+
elif datum_type == "od":
|
40
|
+
if not isinstance(labels, Sequence) or not isinstance(labels[0], Sequence) or not isinstance(labels[0][0], int):
|
41
|
+
raise TypeError("Labels must be a sequence of sequences of integers for object detection.")
|
42
|
+
if (
|
43
|
+
bboxes is None
|
44
|
+
or not isinstance(bboxes, (Sequence, Array))
|
45
|
+
or not isinstance(bboxes[0], (Sequence, Array))
|
46
|
+
or not isinstance(bboxes[0][0], (Sequence, Array))
|
47
|
+
or not len(bboxes[0][0]) == 4
|
48
|
+
):
|
49
|
+
raise TypeError("Boxes must be a sequence of sequences of (x0, y0, x1, y1) for object detection.")
|
50
|
+
|
51
|
+
|
52
|
+
def _find_max(arr: ArrayLike) -> Any:
|
53
|
+
if isinstance(arr[0], (Iterable, Sequence, Array)):
|
54
|
+
return max([_find_max(x) for x in arr]) # type: ignore
|
55
|
+
else:
|
56
|
+
return max(arr)
|
57
|
+
|
58
|
+
|
59
|
+
_TLabels = TypeVar("_TLabels", Sequence[int], Sequence[Sequence[int]])
|
60
|
+
|
61
|
+
|
62
|
+
class BaseAnnotatedDataset(Generic[_TLabels]):
|
63
|
+
def __init__(
|
64
|
+
self,
|
65
|
+
datum_type: Literal["ic", "od"],
|
66
|
+
images: Array | Sequence[Array],
|
67
|
+
labels: _TLabels,
|
68
|
+
metadata: Sequence[dict[str, Any]] | None,
|
69
|
+
classes: Sequence[str] | None,
|
70
|
+
name: str | None = None,
|
71
|
+
) -> None:
|
72
|
+
self._classes = classes if classes is not None else [str(i) for i in range(_find_max(labels) + 1)]
|
73
|
+
self._index2label = dict(enumerate(self._classes))
|
74
|
+
self._images = images
|
75
|
+
self._labels = labels
|
76
|
+
self._metadata = metadata
|
77
|
+
self._id = name or f"{len(self._images)}_image_{len(self._index2label)}_class_{datum_type}_dataset"
|
78
|
+
|
79
|
+
@property
|
80
|
+
def metadata(self) -> DatasetMetadata:
|
81
|
+
return DatasetMetadata(id=self._id, index2label=self._index2label)
|
82
|
+
|
83
|
+
def __len__(self) -> int:
|
84
|
+
return len(self._images)
|
85
|
+
|
86
|
+
|
87
|
+
class CustomImageClassificationDataset(BaseAnnotatedDataset[Sequence[int]], ImageClassificationDataset):
|
88
|
+
def __init__(
|
89
|
+
self,
|
90
|
+
images: Array | Sequence[Array],
|
91
|
+
labels: Sequence[int],
|
92
|
+
metadata: Sequence[dict[str, Any]] | None,
|
93
|
+
classes: Sequence[str] | None,
|
94
|
+
name: str | None = None,
|
95
|
+
) -> None:
|
96
|
+
super().__init__("ic", images, labels, metadata, classes)
|
97
|
+
if name is not None:
|
98
|
+
self.__name__ = name
|
99
|
+
self.__class__.__name__ = name
|
100
|
+
self.__class__.__qualname__ = name
|
101
|
+
|
102
|
+
def __getitem__(self, idx: int, /) -> tuple[Array, Array, dict[str, Any]]:
|
103
|
+
one_hot = [0.0] * len(self._index2label)
|
104
|
+
one_hot[self._labels[idx]] = 1.0
|
105
|
+
return (
|
106
|
+
self._images[idx],
|
107
|
+
as_numpy(one_hot),
|
108
|
+
self._metadata[idx] if self._metadata is not None else {},
|
109
|
+
)
|
110
|
+
|
111
|
+
|
112
|
+
class CustomObjectDetectionDataset(BaseAnnotatedDataset[Sequence[Sequence[int]]], ObjectDetectionDataset):
|
113
|
+
class ObjectDetectionTarget:
|
114
|
+
def __init__(self, labels: Sequence[int], bboxes: Sequence[Sequence[float]]) -> None:
|
115
|
+
self._labels = labels
|
116
|
+
self._bboxes = bboxes
|
117
|
+
self._scores = [1.0] * len(labels)
|
118
|
+
|
119
|
+
@property
|
120
|
+
def labels(self) -> Sequence[int]:
|
121
|
+
return self._labels
|
122
|
+
|
123
|
+
@property
|
124
|
+
def boxes(self) -> Sequence[Sequence[float]]:
|
125
|
+
return self._bboxes
|
126
|
+
|
127
|
+
@property
|
128
|
+
def scores(self) -> Sequence[float]:
|
129
|
+
return self._scores
|
130
|
+
|
131
|
+
def __init__(
|
132
|
+
self,
|
133
|
+
images: Array | Sequence[Array],
|
134
|
+
labels: Sequence[Sequence[int]],
|
135
|
+
bboxes: Sequence[Sequence[Sequence[float]]],
|
136
|
+
metadata: Sequence[dict[str, Any]] | None,
|
137
|
+
classes: Sequence[str] | None,
|
138
|
+
name: str | None = None,
|
139
|
+
) -> None:
|
140
|
+
super().__init__("od", images, labels, metadata, classes)
|
141
|
+
if name is not None:
|
142
|
+
self.__name__ = name
|
143
|
+
self.__class__.__name__ = name
|
144
|
+
self.__class__.__qualname__ = name
|
145
|
+
self._bboxes = bboxes
|
146
|
+
|
147
|
+
@property
|
148
|
+
def metadata(self) -> DatasetMetadata:
|
149
|
+
return DatasetMetadata(id=self._id, index2label=self._index2label)
|
150
|
+
|
151
|
+
def __getitem__(self, idx: int, /) -> tuple[Array, ObjectDetectionTarget, dict[str, Any]]:
|
152
|
+
return (
|
153
|
+
self._images[idx],
|
154
|
+
self.ObjectDetectionTarget(self._labels[idx], self._bboxes[idx]),
|
155
|
+
self._metadata[idx] if self._metadata is not None else {},
|
156
|
+
)
|
157
|
+
|
158
|
+
|
159
|
+
def to_image_classification_dataset(
|
160
|
+
images: Array | Sequence[Array],
|
161
|
+
labels: Sequence[int],
|
162
|
+
metadata: Sequence[dict[str, Any]] | None,
|
163
|
+
classes: Sequence[str] | None,
|
164
|
+
name: str | None = None,
|
165
|
+
) -> ImageClassificationDataset:
|
166
|
+
"""
|
167
|
+
Helper function to create custom ImageClassificationDataset classes.
|
168
|
+
|
169
|
+
Parameters
|
170
|
+
----------
|
171
|
+
images : Array | Sequence[Array]
|
172
|
+
The images to use in the dataset.
|
173
|
+
labels : Sequence[int]
|
174
|
+
The labels to use in the dataset.
|
175
|
+
metadata : Sequence[dict[str, Any]] | None
|
176
|
+
The metadata to use in the dataset.
|
177
|
+
classes : Sequence[str] | None
|
178
|
+
The classes to use in the dataset.
|
179
|
+
|
180
|
+
Returns
|
181
|
+
-------
|
182
|
+
ImageClassificationDataset
|
183
|
+
"""
|
184
|
+
_validate_data("ic", images, labels, None, metadata)
|
185
|
+
return CustomImageClassificationDataset(images, labels, metadata, classes, name)
|
186
|
+
|
187
|
+
|
188
|
+
def to_object_detection_dataset(
|
189
|
+
images: Array | Sequence[Array],
|
190
|
+
labels: Sequence[Sequence[int]],
|
191
|
+
bboxes: Sequence[Sequence[Sequence[float]]],
|
192
|
+
metadata: Sequence[dict[str, Any]] | None,
|
193
|
+
classes: Sequence[str] | None,
|
194
|
+
name: str | None = None,
|
195
|
+
) -> ObjectDetectionDataset:
|
196
|
+
"""
|
197
|
+
Helper function to create custom ObjectDetectionDataset classes.
|
198
|
+
|
199
|
+
Parameters
|
200
|
+
----------
|
201
|
+
images : Array | Sequence[Array]
|
202
|
+
The images to use in the dataset.
|
203
|
+
labels : Sequence[Sequence[int]]
|
204
|
+
The labels to use in the dataset.
|
205
|
+
bboxes : Sequence[Sequence[Sequence[float]]]
|
206
|
+
The bounding boxes (x0,y0,x1,y0) to use in the dataset.
|
207
|
+
metadata : Sequence[dict[str, Any]] | None
|
208
|
+
The metadata to use in the dataset.
|
209
|
+
classes : Sequence[str] | None
|
210
|
+
The classes to use in the dataset.
|
211
|
+
|
212
|
+
Returns
|
213
|
+
-------
|
214
|
+
ObjectDetectionDataset
|
215
|
+
"""
|
216
|
+
_validate_data("od", images, labels, bboxes, metadata)
|
217
|
+
return CustomObjectDetectionDataset(images, labels, bboxes, metadata, classes, name)
|
@@ -0,0 +1,104 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
__all__ = []
|
4
|
+
|
5
|
+
import math
|
6
|
+
from typing import Any, Iterator, Sequence
|
7
|
+
|
8
|
+
import torch
|
9
|
+
from torch.utils.data import DataLoader, Subset
|
10
|
+
from tqdm import tqdm
|
11
|
+
|
12
|
+
from dataeval.config import get_device
|
13
|
+
from dataeval.typing import Array, Dataset
|
14
|
+
from dataeval.utils.torch.models import SupportsEncode
|
15
|
+
|
16
|
+
|
17
|
+
class Embeddings:
|
18
|
+
"""
|
19
|
+
Collection of image embeddings from a dataset.
|
20
|
+
|
21
|
+
Embeddings are accessed by index or slice and are only loaded on-demand.
|
22
|
+
|
23
|
+
Parameters
|
24
|
+
----------
|
25
|
+
dataset : ImageClassificationDataset or ObjectDetectionDataset
|
26
|
+
Dataset to access original images from.
|
27
|
+
batch_size : int, optional
|
28
|
+
Batch size to use when encoding images.
|
29
|
+
model : torch.nn.Module, optional
|
30
|
+
Model to use for encoding images.
|
31
|
+
device : torch.device, optional
|
32
|
+
Device to use for encoding images.
|
33
|
+
verbose : bool, optional
|
34
|
+
Whether to print progress bar when encoding images.
|
35
|
+
"""
|
36
|
+
|
37
|
+
device: torch.device
|
38
|
+
batch_size: int
|
39
|
+
verbose: bool
|
40
|
+
|
41
|
+
def __init__(
|
42
|
+
self,
|
43
|
+
dataset: Dataset[tuple[Array, Any, Any]],
|
44
|
+
batch_size: int,
|
45
|
+
indices: Sequence[int] | None = None,
|
46
|
+
model: torch.nn.Module | None = None,
|
47
|
+
device: torch.device | str | None = None,
|
48
|
+
verbose: bool = False,
|
49
|
+
) -> None:
|
50
|
+
self.device = get_device(device)
|
51
|
+
self.batch_size = batch_size
|
52
|
+
self.verbose = verbose
|
53
|
+
|
54
|
+
self._dataset = dataset
|
55
|
+
self._indices = indices if indices is not None else range(len(dataset))
|
56
|
+
model = torch.nn.Flatten() if model is None else model
|
57
|
+
self._model = model.to(self.device).eval()
|
58
|
+
self._encoder = model.encode if isinstance(model, SupportsEncode) else model
|
59
|
+
self._collate_fn = lambda datum: [torch.as_tensor(i) for i, _, _ in datum]
|
60
|
+
|
61
|
+
def to_tensor(self) -> torch.Tensor:
|
62
|
+
"""
|
63
|
+
Converts entire dataset to embeddings.
|
64
|
+
|
65
|
+
Warning
|
66
|
+
-------
|
67
|
+
Will process the entire dataset in batches and return
|
68
|
+
embeddings as a single Tensor in memory.
|
69
|
+
|
70
|
+
Returns
|
71
|
+
-------
|
72
|
+
torch.Tensor
|
73
|
+
"""
|
74
|
+
return self[:]
|
75
|
+
|
76
|
+
# Reduce overhead cost by not tracking tensor gradients
|
77
|
+
@torch.no_grad
|
78
|
+
def _batch(self, indices: Sequence[int]) -> Iterator[torch.Tensor]:
|
79
|
+
# manual batching
|
80
|
+
dataloader = DataLoader(Subset(self._dataset, indices), batch_size=self.batch_size, collate_fn=self._collate_fn) # type: ignore
|
81
|
+
for i, images in (
|
82
|
+
tqdm(enumerate(dataloader), total=math.ceil(len(indices) / self.batch_size), desc="Batch processing")
|
83
|
+
if self.verbose
|
84
|
+
else enumerate(dataloader)
|
85
|
+
):
|
86
|
+
embeddings = self._encoder(torch.stack(images).to(self.device))
|
87
|
+
yield embeddings
|
88
|
+
|
89
|
+
def __getitem__(self, key: int | slice | list[int], /) -> torch.Tensor:
|
90
|
+
if isinstance(key, list):
|
91
|
+
return torch.vstack(list(self._batch(key))).to(self.device)
|
92
|
+
if isinstance(key, slice):
|
93
|
+
return torch.vstack(list(self._batch(range(len(self._dataset))[key]))).to(self.device)
|
94
|
+
elif isinstance(key, int):
|
95
|
+
return self._encoder(torch.as_tensor(self._dataset[key][0]).to(self.device))
|
96
|
+
raise TypeError("Invalid argument type.")
|
97
|
+
|
98
|
+
def __iter__(self) -> Iterator[torch.Tensor]:
|
99
|
+
# process in batches while yielding individual embeddings
|
100
|
+
for batch in self._batch(range(len(self._dataset))):
|
101
|
+
yield from batch
|
102
|
+
|
103
|
+
def __len__(self) -> int:
|
104
|
+
return len(self._dataset)
|
@@ -0,0 +1,68 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
__all__ = []
|
4
|
+
|
5
|
+
from typing import Any, Generic, Iterator, Sequence, TypeVar, cast, overload
|
6
|
+
|
7
|
+
from dataeval.typing import Dataset
|
8
|
+
|
9
|
+
T = TypeVar("T")
|
10
|
+
|
11
|
+
|
12
|
+
class Images(Generic[T]):
|
13
|
+
"""
|
14
|
+
Collection of image data from a dataset.
|
15
|
+
|
16
|
+
Images are accessed by index or slice and are only loaded on-demand.
|
17
|
+
|
18
|
+
Parameters
|
19
|
+
----------
|
20
|
+
dataset : Dataset[tuple[T, ...]] or Dataset[T]
|
21
|
+
Dataset to access images from.
|
22
|
+
"""
|
23
|
+
|
24
|
+
def __init__(self, dataset: Dataset[tuple[T, Any, Any] | T]) -> None:
|
25
|
+
self._is_tuple_datum = isinstance(dataset[0], tuple)
|
26
|
+
self._dataset = dataset
|
27
|
+
|
28
|
+
def to_list(self) -> Sequence[T]:
|
29
|
+
"""
|
30
|
+
Converts entire dataset to a sequence of images.
|
31
|
+
|
32
|
+
Warning
|
33
|
+
-------
|
34
|
+
Will load the entire dataset and return the images as a
|
35
|
+
single sequence of images in memory.
|
36
|
+
|
37
|
+
Returns
|
38
|
+
-------
|
39
|
+
list[T]
|
40
|
+
"""
|
41
|
+
return self[:]
|
42
|
+
|
43
|
+
@overload
|
44
|
+
def __getitem__(self, key: int, /) -> T: ...
|
45
|
+
@overload
|
46
|
+
def __getitem__(self, key: slice, /) -> Sequence[T]: ...
|
47
|
+
|
48
|
+
def __getitem__(self, key: int | slice, /) -> Sequence[T] | T:
|
49
|
+
if self._is_tuple_datum:
|
50
|
+
dataset = cast(Dataset[tuple[T, Any, Any]], self._dataset)
|
51
|
+
if isinstance(key, slice):
|
52
|
+
return [dataset[k][0] for k in range(len(self._dataset))[key]]
|
53
|
+
elif isinstance(key, int):
|
54
|
+
return dataset[key][0]
|
55
|
+
else:
|
56
|
+
dataset = cast(Dataset[T], self._dataset)
|
57
|
+
if isinstance(key, slice):
|
58
|
+
return [dataset[k] for k in range(len(self._dataset))[key]]
|
59
|
+
elif isinstance(key, int):
|
60
|
+
return dataset[key]
|
61
|
+
raise TypeError(f"Key must be integers or slices, not {type(key)}")
|
62
|
+
|
63
|
+
def __iter__(self) -> Iterator[T]:
|
64
|
+
for i in range(len(self._dataset)):
|
65
|
+
yield self[i]
|
66
|
+
|
67
|
+
def __len__(self) -> int:
|
68
|
+
return len(self._dataset)
|