dataeval 0.82.1__py3-none-any.whl → 0.84.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dataeval/__init__.py +7 -2
- dataeval/config.py +13 -3
- dataeval/metadata/__init__.py +2 -2
- dataeval/metadata/_ood.py +144 -27
- dataeval/metrics/bias/__init__.py +11 -1
- dataeval/metrics/bias/_balance.py +3 -3
- dataeval/metrics/bias/_completeness.py +130 -0
- dataeval/metrics/estimators/_ber.py +2 -1
- dataeval/metrics/stats/_base.py +31 -36
- dataeval/metrics/stats/_dimensionstats.py +2 -2
- dataeval/metrics/stats/_hashstats.py +2 -2
- dataeval/metrics/stats/_imagestats.py +4 -4
- dataeval/metrics/stats/_labelstats.py +4 -45
- dataeval/metrics/stats/_pixelstats.py +2 -2
- dataeval/metrics/stats/_visualstats.py +2 -2
- dataeval/outputs/__init__.py +4 -2
- dataeval/outputs/_bias.py +31 -22
- dataeval/outputs/_metadata.py +7 -0
- dataeval/outputs/_stats.py +2 -3
- dataeval/typing.py +43 -12
- dataeval/utils/_array.py +26 -1
- dataeval/utils/_mst.py +1 -2
- dataeval/utils/data/_dataset.py +2 -0
- dataeval/utils/data/_embeddings.py +115 -32
- dataeval/utils/data/_images.py +38 -15
- dataeval/utils/data/_selection.py +7 -8
- dataeval/utils/data/_split.py +76 -129
- dataeval/utils/data/datasets/_base.py +4 -2
- dataeval/utils/data/datasets/_cifar10.py +17 -9
- dataeval/utils/data/datasets/_milco.py +18 -12
- dataeval/utils/data/datasets/_mnist.py +24 -8
- dataeval/utils/data/datasets/_ships.py +18 -8
- dataeval/utils/data/datasets/_types.py +1 -5
- dataeval/utils/data/datasets/_voc.py +47 -24
- dataeval/utils/data/selections/__init__.py +2 -0
- dataeval/utils/data/selections/_classfilter.py +1 -1
- dataeval/utils/data/selections/_prioritize.py +296 -0
- dataeval/utils/data/selections/_shuffle.py +13 -4
- dataeval/utils/metadata.py +1 -1
- dataeval/utils/torch/_gmm.py +3 -2
- {dataeval-0.82.1.dist-info → dataeval-0.84.0.dist-info}/METADATA +4 -4
- {dataeval-0.82.1.dist-info → dataeval-0.84.0.dist-info}/RECORD +44 -43
- dataeval/detectors/ood/metadata_ood_mi.py +0 -91
- {dataeval-0.82.1.dist-info → dataeval-0.84.0.dist-info}/LICENSE.txt +0 -0
- {dataeval-0.82.1.dist-info → dataeval-0.84.0.dist-info}/WHEEL +0 -0
dataeval/utils/data/_dataset.py
CHANGED
@@ -47,6 +47,8 @@ def _validate_data(
|
|
47
47
|
or not len(bboxes[0][0]) == 4
|
48
48
|
):
|
49
49
|
raise TypeError("Boxes must be a sequence of sequences of (x0, y0, x1, y1) for object detection.")
|
50
|
+
else:
|
51
|
+
raise ValueError(f"Unknown datum type '{datum_type}'. Must be 'ic' or 'od'.")
|
50
52
|
|
51
53
|
|
52
54
|
def _find_max(arr: ArrayLike) -> Any:
|
@@ -3,14 +3,14 @@ from __future__ import annotations
|
|
3
3
|
__all__ = []
|
4
4
|
|
5
5
|
import math
|
6
|
-
from typing import Any, Iterator, Sequence
|
6
|
+
from typing import Any, Iterator, Sequence, cast
|
7
7
|
|
8
8
|
import torch
|
9
9
|
from torch.utils.data import DataLoader, Subset
|
10
10
|
from tqdm import tqdm
|
11
11
|
|
12
12
|
from dataeval.config import DeviceLike, get_device
|
13
|
-
from dataeval.typing import Array, Dataset
|
13
|
+
from dataeval.typing import Array, Dataset, Transform
|
14
14
|
from dataeval.utils.torch.models import SupportsEncode
|
15
15
|
|
16
16
|
|
@@ -26,11 +26,15 @@ class Embeddings:
|
|
26
26
|
Dataset to access original images from.
|
27
27
|
batch_size : int
|
28
28
|
Batch size to use when encoding images.
|
29
|
+
transforms : Transform or Sequence[Transform] or None, default None
|
30
|
+
Transforms to apply to images before encoding.
|
29
31
|
model : torch.nn.Module or None, default None
|
30
32
|
Model to use for encoding images.
|
31
33
|
device : DeviceLike or None, default None
|
32
34
|
The hardware device to use if specified, otherwise uses the DataEval
|
33
35
|
default or torch default.
|
36
|
+
cache : bool, default False
|
37
|
+
Whether to cache the embeddings in memory.
|
34
38
|
verbose : bool, default False
|
35
39
|
Whether to print progress bar when encoding images.
|
36
40
|
"""
|
@@ -43,61 +47,140 @@ class Embeddings:
|
|
43
47
|
self,
|
44
48
|
dataset: Dataset[tuple[Array, Any, Any]],
|
45
49
|
batch_size: int,
|
50
|
+
transforms: Transform[torch.Tensor] | Sequence[Transform[torch.Tensor]] | None = None,
|
46
51
|
model: torch.nn.Module | None = None,
|
47
52
|
device: DeviceLike | None = None,
|
53
|
+
cache: bool = False,
|
48
54
|
verbose: bool = False,
|
49
55
|
) -> None:
|
50
56
|
self.device = get_device(device)
|
51
|
-
self.
|
57
|
+
self.cache = cache
|
58
|
+
self.batch_size = batch_size if batch_size > 0 else 1
|
52
59
|
self.verbose = verbose
|
53
60
|
|
54
61
|
self._dataset = dataset
|
62
|
+
self._length = len(dataset)
|
55
63
|
model = torch.nn.Flatten() if model is None else model
|
64
|
+
self._transforms = [transforms] if isinstance(transforms, Transform) else transforms
|
56
65
|
self._model = model.to(self.device).eval()
|
57
66
|
self._encoder = model.encode if isinstance(model, SupportsEncode) else model
|
58
67
|
self._collate_fn = lambda datum: [torch.as_tensor(i) for i, _, _ in datum]
|
68
|
+
self._cached_idx = set()
|
69
|
+
self._embeddings: torch.Tensor = torch.empty(())
|
70
|
+
self._shallow: bool = False
|
59
71
|
|
60
|
-
def to_tensor(self) -> torch.Tensor:
|
72
|
+
def to_tensor(self, indices: Sequence[int] | None = None) -> torch.Tensor:
|
61
73
|
"""
|
62
|
-
Converts
|
74
|
+
Converts dataset to embeddings.
|
63
75
|
|
64
|
-
|
65
|
-
|
66
|
-
|
67
|
-
|
76
|
+
Parameters
|
77
|
+
----------
|
78
|
+
indices : Sequence[int] or None, default None
|
79
|
+
The indices to convert to embeddings
|
68
80
|
|
69
81
|
Returns
|
70
82
|
-------
|
71
83
|
torch.Tensor
|
84
|
+
|
85
|
+
Warning
|
86
|
+
-------
|
87
|
+
Processing large quantities of data can be resource intensive.
|
88
|
+
"""
|
89
|
+
if indices is not None:
|
90
|
+
return torch.vstack(list(self._batch(indices))).to(self.device)
|
91
|
+
else:
|
92
|
+
return self[:]
|
93
|
+
|
94
|
+
@classmethod
|
95
|
+
def from_array(cls, array: Array, device: DeviceLike | None = None) -> Embeddings:
|
72
96
|
"""
|
73
|
-
|
97
|
+
Instantiates a shallow Embeddings object using an array.
|
74
98
|
|
75
|
-
|
76
|
-
|
99
|
+
Parameters
|
100
|
+
----------
|
101
|
+
array : Array
|
102
|
+
The array to convert to embeddings.
|
103
|
+
device : DeviceLike or None, default None
|
104
|
+
The hardware device to use if specified, otherwise uses the DataEval
|
105
|
+
default or torch default.
|
106
|
+
|
107
|
+
Returns
|
108
|
+
-------
|
109
|
+
Embeddings
|
110
|
+
|
111
|
+
Example
|
112
|
+
-------
|
113
|
+
>>> import numpy as np
|
114
|
+
>>> from dataeval.utils.data._embeddings import Embeddings
|
115
|
+
>>> array = np.random.randn(100, 3, 224, 224)
|
116
|
+
>>> embeddings = Embeddings.from_array(array)
|
117
|
+
>>> print(embeddings.to_tensor().shape)
|
118
|
+
torch.Size([100, 3, 224, 224])
|
119
|
+
"""
|
120
|
+
embeddings = Embeddings([], 0, None, None, device, True, False)
|
121
|
+
embeddings._length = len(array)
|
122
|
+
embeddings._cached_idx = set(range(len(array)))
|
123
|
+
embeddings._embeddings = torch.as_tensor(array).to(get_device(device))
|
124
|
+
embeddings._shallow = True
|
125
|
+
return embeddings
|
126
|
+
|
127
|
+
def _encode(self, images: list[torch.Tensor]) -> torch.Tensor:
|
128
|
+
if self._transforms:
|
129
|
+
images = [transform(image) for transform in self._transforms for image in images]
|
130
|
+
return self._encoder(torch.stack(images).to(self.device))
|
131
|
+
|
132
|
+
@torch.no_grad() # Reduce overhead cost by not tracking tensor gradients
|
77
133
|
def _batch(self, indices: Sequence[int]) -> Iterator[torch.Tensor]:
|
78
|
-
|
79
|
-
|
80
|
-
|
81
|
-
|
82
|
-
|
83
|
-
|
84
|
-
|
85
|
-
|
86
|
-
|
87
|
-
|
88
|
-
|
89
|
-
|
90
|
-
return
|
91
|
-
|
92
|
-
|
93
|
-
|
94
|
-
|
95
|
-
|
134
|
+
dataset = cast(torch.utils.data.Dataset[tuple[Array, Any, Any]], self._dataset)
|
135
|
+
total_batches = math.ceil(len(indices) / self.batch_size)
|
136
|
+
|
137
|
+
# If not caching, process all indices normally
|
138
|
+
if not self.cache:
|
139
|
+
for images in tqdm(
|
140
|
+
DataLoader(Subset(dataset, indices), self.batch_size, collate_fn=self._collate_fn),
|
141
|
+
total=total_batches,
|
142
|
+
desc="Batch embedding",
|
143
|
+
disable=not self.verbose,
|
144
|
+
):
|
145
|
+
yield self._encode(images)
|
146
|
+
return
|
147
|
+
|
148
|
+
# If caching, process each batch of indices at a time, preserving original order
|
149
|
+
for i in tqdm(range(0, len(indices), self.batch_size), desc="Batch embedding", disable=not self.verbose):
|
150
|
+
batch = indices[i : i + self.batch_size]
|
151
|
+
uncached = [idx for idx in batch if idx not in self._cached_idx]
|
152
|
+
|
153
|
+
if uncached:
|
154
|
+
# Process uncached indices as as single batch
|
155
|
+
for images in DataLoader(Subset(dataset, uncached), len(uncached), collate_fn=self._collate_fn):
|
156
|
+
embeddings = self._encode(images)
|
157
|
+
|
158
|
+
if not self._embeddings.shape:
|
159
|
+
full_shape = (len(self._dataset), *embeddings.shape[1:])
|
160
|
+
self._embeddings = torch.empty(full_shape, dtype=embeddings.dtype, device=self.device)
|
161
|
+
|
162
|
+
self._embeddings[uncached] = embeddings
|
163
|
+
self._cached_idx.update(uncached)
|
164
|
+
|
165
|
+
yield self._embeddings[batch]
|
166
|
+
|
167
|
+
def __getitem__(self, key: int | slice, /) -> torch.Tensor:
|
168
|
+
if not isinstance(key, slice) and not hasattr(key, "__int__"):
|
169
|
+
raise TypeError("Invalid argument type.")
|
170
|
+
|
171
|
+
if self._shallow:
|
172
|
+
if not self._embeddings.shape:
|
173
|
+
raise ValueError("Embeddings not initialized.")
|
174
|
+
return self._embeddings[key]
|
175
|
+
|
176
|
+
indices = list(range(len(self._dataset))[key]) if isinstance(key, slice) else [int(key)]
|
177
|
+
result = torch.vstack(list(self._batch(indices))).to(self.device)
|
178
|
+
return result.squeeze(0) if len(indices) == 1 else result
|
96
179
|
|
97
180
|
def __iter__(self) -> Iterator[torch.Tensor]:
|
98
181
|
# process in batches while yielding individual embeddings
|
99
|
-
for batch in self._batch(range(
|
182
|
+
for batch in self._batch(range(self._length)):
|
100
183
|
yield from batch
|
101
184
|
|
102
185
|
def __len__(self) -> int:
|
103
|
-
return
|
186
|
+
return self._length
|
dataeval/utils/data/_images.py
CHANGED
@@ -2,11 +2,15 @@ from __future__ import annotations
|
|
2
2
|
|
3
3
|
__all__ = []
|
4
4
|
|
5
|
-
from typing import Any, Generic, Iterator, Sequence, TypeVar, cast, overload
|
5
|
+
from typing import TYPE_CHECKING, Any, Generic, Iterator, Sequence, TypeVar, cast, overload
|
6
6
|
|
7
|
-
from dataeval.typing import Dataset
|
7
|
+
from dataeval.typing import Array, Dataset
|
8
|
+
from dataeval.utils._array import as_numpy, channels_first_to_last
|
8
9
|
|
9
|
-
|
10
|
+
if TYPE_CHECKING:
|
11
|
+
from matplotlib.figure import Figure
|
12
|
+
|
13
|
+
T = TypeVar("T", bound=Array)
|
10
14
|
|
11
15
|
|
12
16
|
class Images(Generic[T]):
|
@@ -21,7 +25,10 @@ class Images(Generic[T]):
|
|
21
25
|
Dataset to access images from.
|
22
26
|
"""
|
23
27
|
|
24
|
-
def __init__(
|
28
|
+
def __init__(
|
29
|
+
self,
|
30
|
+
dataset: Dataset[tuple[T, Any, Any] | T],
|
31
|
+
) -> None:
|
25
32
|
self._is_tuple_datum = isinstance(dataset[0], tuple)
|
26
33
|
self._dataset = dataset
|
27
34
|
|
@@ -40,25 +47,41 @@ class Images(Generic[T]):
|
|
40
47
|
"""
|
41
48
|
return self[:]
|
42
49
|
|
50
|
+
def plot(
|
51
|
+
self,
|
52
|
+
indices: Sequence[int],
|
53
|
+
images_per_row: int = 3,
|
54
|
+
figsize: tuple[int, int] = (10, 10),
|
55
|
+
) -> Figure:
|
56
|
+
import matplotlib.pyplot as plt
|
57
|
+
|
58
|
+
num_images = len(indices)
|
59
|
+
num_rows = (num_images + images_per_row - 1) // images_per_row
|
60
|
+
fig, axes = plt.subplots(num_rows, images_per_row, figsize=figsize)
|
61
|
+
for i, ax in enumerate(axes.flatten()):
|
62
|
+
image = channels_first_to_last(as_numpy(self[i]))
|
63
|
+
ax.imshow(image)
|
64
|
+
ax.axis("off")
|
65
|
+
plt.tight_layout()
|
66
|
+
return fig
|
67
|
+
|
43
68
|
@overload
|
44
69
|
def __getitem__(self, key: int, /) -> T: ...
|
45
70
|
@overload
|
46
71
|
def __getitem__(self, key: slice, /) -> Sequence[T]: ...
|
47
72
|
|
48
73
|
def __getitem__(self, key: int | slice, /) -> Sequence[T] | T:
|
74
|
+
if isinstance(key, slice):
|
75
|
+
return [self._get_image(k) for k in range(len(self._dataset))[key]]
|
76
|
+
elif hasattr(key, "__int__"):
|
77
|
+
return self._get_image(int(key))
|
78
|
+
raise TypeError(f"Key must be integers or slices, not {type(key)}")
|
79
|
+
|
80
|
+
def _get_image(self, index: int) -> T:
|
49
81
|
if self._is_tuple_datum:
|
50
|
-
|
51
|
-
if isinstance(key, slice):
|
52
|
-
return [dataset[k][0] for k in range(len(self._dataset))[key]]
|
53
|
-
elif isinstance(key, int):
|
54
|
-
return dataset[key][0]
|
82
|
+
return cast(Dataset[tuple[T, Any, Any]], self._dataset)[index][0]
|
55
83
|
else:
|
56
|
-
|
57
|
-
if isinstance(key, slice):
|
58
|
-
return [dataset[k] for k in range(len(self._dataset))[key]]
|
59
|
-
elif isinstance(key, int):
|
60
|
-
return dataset[key]
|
61
|
-
raise TypeError(f"Key must be integers or slices, not {type(key)}")
|
84
|
+
return cast(Dataset[T], self._dataset)[index]
|
62
85
|
|
63
86
|
def __iter__(self) -> Iterator[T]:
|
64
87
|
for i in range(len(self._dataset)):
|
@@ -7,7 +7,7 @@ from typing import Generic, Iterator, Sequence, TypeVar
|
|
7
7
|
|
8
8
|
from dataeval.typing import AnnotatedDataset, DatasetMetadata
|
9
9
|
|
10
|
-
_TDatum = TypeVar("_TDatum"
|
10
|
+
_TDatum = TypeVar("_TDatum")
|
11
11
|
|
12
12
|
|
13
13
|
class SelectionStage(IntEnum):
|
@@ -67,13 +67,13 @@ class Select(AnnotatedDataset[_TDatum]):
|
|
67
67
|
def __init__(
|
68
68
|
self,
|
69
69
|
dataset: AnnotatedDataset[_TDatum],
|
70
|
-
selections: Selection[_TDatum] |
|
70
|
+
selections: Selection[_TDatum] | Sequence[Selection[_TDatum]] | None = None,
|
71
71
|
) -> None:
|
72
72
|
self.__dict__.update(dataset.__dict__)
|
73
73
|
self._dataset = dataset
|
74
74
|
self._size_limit = len(dataset)
|
75
75
|
self._selection = list(range(self._size_limit))
|
76
|
-
self._selections = self.
|
76
|
+
self._selections = self._sort(selections)
|
77
77
|
|
78
78
|
# Ensure metadata is populated correctly as DatasetMetadata TypedDict
|
79
79
|
_metadata = getattr(dataset, "metadata", {})
|
@@ -81,8 +81,7 @@ class Select(AnnotatedDataset[_TDatum]):
|
|
81
81
|
_metadata["id"] = dataset.__class__.__name__
|
82
82
|
self._metadata = DatasetMetadata(**_metadata)
|
83
83
|
|
84
|
-
|
85
|
-
self._apply_selections()
|
84
|
+
self._select()
|
86
85
|
|
87
86
|
@property
|
88
87
|
def metadata(self) -> DatasetMetadata:
|
@@ -92,10 +91,10 @@ class Select(AnnotatedDataset[_TDatum]):
|
|
92
91
|
nt = "\n "
|
93
92
|
title = f"{self.__class__.__name__} Dataset"
|
94
93
|
sep = "-" * len(title)
|
95
|
-
selections = f"Selections: [{', '.join([str(s) for s in self.
|
94
|
+
selections = f"Selections: [{', '.join([str(s) for s in self._selections])}]"
|
96
95
|
return f"{title}\n{sep}{nt}{selections}{nt}Selected Size: {len(self)}\n\n{self._dataset}"
|
97
96
|
|
98
|
-
def
|
97
|
+
def _sort(self, selections: Selection[_TDatum] | Sequence[Selection[_TDatum]] | None) -> list[Selection]:
|
99
98
|
if not selections:
|
100
99
|
return []
|
101
100
|
|
@@ -106,7 +105,7 @@ class Select(AnnotatedDataset[_TDatum]):
|
|
106
105
|
selection_list = [selection for category in sorted(grouped) for selection in grouped[category]]
|
107
106
|
return selection_list
|
108
107
|
|
109
|
-
def
|
108
|
+
def _select(self) -> None:
|
110
109
|
for selection in self._selections:
|
111
110
|
selection(self)
|
112
111
|
self._selection = self._selection[: self._size_limit]
|