dataeval 0.76.1__py3-none-any.whl → 0.82.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dataeval/__init__.py +3 -3
- dataeval/config.py +77 -0
- dataeval/detectors/__init__.py +1 -1
- dataeval/detectors/drift/__init__.py +6 -6
- dataeval/detectors/drift/{base.py → _base.py} +40 -85
- dataeval/detectors/drift/{cvm.py → _cvm.py} +21 -28
- dataeval/detectors/drift/{ks.py → _ks.py} +20 -26
- dataeval/detectors/drift/{mmd.py → _mmd.py} +31 -43
- dataeval/detectors/drift/{torch.py → _torch.py} +2 -1
- dataeval/detectors/drift/{uncertainty.py → _uncertainty.py} +24 -7
- dataeval/detectors/drift/updates.py +20 -3
- dataeval/detectors/linters/__init__.py +3 -5
- dataeval/detectors/linters/duplicates.py +13 -36
- dataeval/detectors/linters/outliers.py +23 -148
- dataeval/detectors/ood/__init__.py +1 -1
- dataeval/detectors/ood/ae.py +30 -9
- dataeval/detectors/ood/base.py +5 -4
- dataeval/detectors/ood/mixin.py +21 -7
- dataeval/detectors/ood/vae.py +73 -0
- dataeval/metadata/__init__.py +6 -0
- dataeval/metadata/_distance.py +167 -0
- dataeval/metadata/_ood.py +217 -0
- dataeval/metadata/_utils.py +44 -0
- dataeval/metrics/__init__.py +1 -1
- dataeval/metrics/bias/__init__.py +6 -4
- dataeval/metrics/bias/{balance.py → _balance.py} +15 -101
- dataeval/metrics/bias/_coverage.py +98 -0
- dataeval/metrics/bias/{diversity.py → _diversity.py} +18 -111
- dataeval/metrics/bias/{parity.py → _parity.py} +39 -77
- dataeval/metrics/estimators/__init__.py +15 -4
- dataeval/metrics/estimators/{ber.py → _ber.py} +42 -29
- dataeval/metrics/estimators/_clusterer.py +44 -0
- dataeval/metrics/estimators/{divergence.py → _divergence.py} +18 -30
- dataeval/metrics/estimators/{uap.py → _uap.py} +4 -18
- dataeval/metrics/stats/__init__.py +16 -13
- dataeval/metrics/stats/{base.py → _base.py} +82 -133
- dataeval/metrics/stats/{boxratiostats.py → _boxratiostats.py} +15 -18
- dataeval/metrics/stats/_dimensionstats.py +75 -0
- dataeval/metrics/stats/{hashstats.py → _hashstats.py} +21 -37
- dataeval/metrics/stats/_imagestats.py +94 -0
- dataeval/metrics/stats/_labelstats.py +131 -0
- dataeval/metrics/stats/{pixelstats.py → _pixelstats.py} +19 -50
- dataeval/metrics/stats/{visualstats.py → _visualstats.py} +23 -54
- dataeval/outputs/__init__.py +53 -0
- dataeval/{output.py → outputs/_base.py} +55 -25
- dataeval/outputs/_bias.py +381 -0
- dataeval/outputs/_drift.py +83 -0
- dataeval/outputs/_estimators.py +114 -0
- dataeval/outputs/_linters.py +184 -0
- dataeval/{detectors/ood/output.py → outputs/_ood.py} +22 -22
- dataeval/outputs/_stats.py +387 -0
- dataeval/outputs/_utils.py +44 -0
- dataeval/outputs/_workflows.py +364 -0
- dataeval/typing.py +234 -0
- dataeval/utils/__init__.py +2 -2
- dataeval/utils/_array.py +169 -0
- dataeval/utils/_bin.py +199 -0
- dataeval/utils/_clusterer.py +144 -0
- dataeval/utils/_fast_mst.py +189 -0
- dataeval/utils/{image.py → _image.py} +6 -4
- dataeval/utils/_method.py +14 -0
- dataeval/utils/{shared.py → _mst.py} +3 -65
- dataeval/utils/{plot.py → _plot.py} +6 -6
- dataeval/utils/data/__init__.py +26 -0
- dataeval/utils/data/_dataset.py +217 -0
- dataeval/utils/data/_embeddings.py +104 -0
- dataeval/utils/data/_images.py +68 -0
- dataeval/utils/data/_metadata.py +360 -0
- dataeval/utils/data/_selection.py +126 -0
- dataeval/utils/{dataset/split.py → data/_split.py} +12 -38
- dataeval/utils/data/_targets.py +85 -0
- dataeval/utils/data/collate.py +103 -0
- dataeval/utils/data/datasets/__init__.py +17 -0
- dataeval/utils/data/datasets/_base.py +254 -0
- dataeval/utils/data/datasets/_cifar10.py +134 -0
- dataeval/utils/data/datasets/_fileio.py +168 -0
- dataeval/utils/data/datasets/_milco.py +153 -0
- dataeval/utils/data/datasets/_mixin.py +56 -0
- dataeval/utils/data/datasets/_mnist.py +183 -0
- dataeval/utils/data/datasets/_ships.py +123 -0
- dataeval/utils/data/datasets/_types.py +52 -0
- dataeval/utils/data/datasets/_voc.py +352 -0
- dataeval/utils/data/selections/__init__.py +15 -0
- dataeval/utils/data/selections/_classfilter.py +57 -0
- dataeval/utils/data/selections/_indices.py +26 -0
- dataeval/utils/data/selections/_limit.py +26 -0
- dataeval/utils/data/selections/_reverse.py +18 -0
- dataeval/utils/data/selections/_shuffle.py +29 -0
- dataeval/utils/metadata.py +51 -376
- dataeval/utils/torch/{gmm.py → _gmm.py} +4 -2
- dataeval/utils/torch/{internal.py → _internal.py} +21 -51
- dataeval/utils/torch/models.py +43 -2
- dataeval/workflows/__init__.py +2 -1
- dataeval/workflows/sufficiency.py +11 -346
- {dataeval-0.76.1.dist-info → dataeval-0.82.0.dist-info}/METADATA +5 -2
- dataeval-0.82.0.dist-info/RECORD +104 -0
- dataeval/detectors/linters/clusterer.py +0 -512
- dataeval/detectors/linters/merged_stats.py +0 -49
- dataeval/detectors/ood/metadata_ks_compare.py +0 -129
- dataeval/detectors/ood/metadata_least_likely.py +0 -119
- dataeval/interop.py +0 -69
- dataeval/metrics/bias/coverage.py +0 -194
- dataeval/metrics/stats/datasetstats.py +0 -202
- dataeval/metrics/stats/dimensionstats.py +0 -115
- dataeval/metrics/stats/labelstats.py +0 -210
- dataeval/utils/dataset/__init__.py +0 -7
- dataeval/utils/dataset/datasets.py +0 -412
- dataeval/utils/dataset/read.py +0 -63
- dataeval-0.76.1.dist-info/RECORD +0 -67
- /dataeval/{log.py → _log.py} +0 -0
- /dataeval/utils/torch/{blocks.py → _blocks.py} +0 -0
- {dataeval-0.76.1.dist-info → dataeval-0.82.0.dist-info}/LICENSE.txt +0 -0
- {dataeval-0.76.1.dist-info → dataeval-0.82.0.dist-info}/WHEEL +0 -0
@@ -0,0 +1,360 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
__all__ = []
|
4
|
+
|
5
|
+
import warnings
|
6
|
+
from typing import TYPE_CHECKING, Any, Literal, Mapping, Sequence
|
7
|
+
|
8
|
+
import numpy as np
|
9
|
+
from numpy.typing import NDArray
|
10
|
+
|
11
|
+
from dataeval.typing import (
|
12
|
+
AnnotatedDataset,
|
13
|
+
Array,
|
14
|
+
ObjectDetectionTarget,
|
15
|
+
)
|
16
|
+
from dataeval.utils._array import as_numpy, to_numpy
|
17
|
+
from dataeval.utils._bin import bin_data, digitize_data, is_continuous
|
18
|
+
from dataeval.utils.metadata import merge
|
19
|
+
|
20
|
+
if TYPE_CHECKING:
|
21
|
+
from dataeval.utils.data import Targets
|
22
|
+
else:
|
23
|
+
from dataeval.utils.data._targets import Targets
|
24
|
+
|
25
|
+
|
26
|
+
class Metadata:
|
27
|
+
"""
|
28
|
+
Class containing binned metadata.
|
29
|
+
|
30
|
+
Attributes
|
31
|
+
----------
|
32
|
+
discrete_factor_names : list[str]
|
33
|
+
List containing factor names for the original data that was discrete and
|
34
|
+
the binned continuous data
|
35
|
+
discrete_data : NDArray[np.int64]
|
36
|
+
Array containing values for the original data that was discrete and the
|
37
|
+
binned continuous data
|
38
|
+
continuous_factor_names : list[str]
|
39
|
+
List containing factor names for the original continuous data
|
40
|
+
continuous_data : NDArray[np.float64] | None
|
41
|
+
Array containing values for the original continuous data or None if there
|
42
|
+
was no continuous data
|
43
|
+
class_labels : NDArray[np.int]
|
44
|
+
Numerical class labels for the images/objects
|
45
|
+
class_names : list[str]
|
46
|
+
List of unique class names
|
47
|
+
total_num_factors : int
|
48
|
+
Sum of discrete_factor_names and continuous_factor_names plus 1 for class
|
49
|
+
image_indices : NDArray[np.intp]
|
50
|
+
Array of the image index that is mapped by the index of the factor
|
51
|
+
|
52
|
+
Parameters
|
53
|
+
----------
|
54
|
+
dataset : ImageClassificationDataset or ObjectDetectionDataset
|
55
|
+
Dataset to access original targets and metadata from.
|
56
|
+
continuous_factor_bins : Mapping[str, int | Sequence[float]] | None, default None
|
57
|
+
Mapping from continuous factor name to the number of bins or bin edges
|
58
|
+
auto_bin_method : Literal["uniform_width", "uniform_count", "clusters"], default "uniform_width"
|
59
|
+
Method for automatically determining the number of bins for continuous factors
|
60
|
+
exclude : Sequence[str] | None, default None
|
61
|
+
Filter metadata factors to exclude the specified factors, cannot be set with `include`
|
62
|
+
include : Sequence[str] | None, default None
|
63
|
+
Filter metadata factors to include the specified factors, cannot be set with `exclude`
|
64
|
+
"""
|
65
|
+
|
66
|
+
def __init__(
|
67
|
+
self,
|
68
|
+
dataset: AnnotatedDataset[tuple[Any, Any, dict[str, Any]]],
|
69
|
+
*,
|
70
|
+
continuous_factor_bins: Mapping[str, int | Sequence[float]] | None = None,
|
71
|
+
auto_bin_method: Literal["uniform_width", "uniform_count", "clusters"] = "uniform_width",
|
72
|
+
exclude: Sequence[str] | None = None,
|
73
|
+
include: Sequence[str] | None = None,
|
74
|
+
) -> None:
|
75
|
+
self._collated = False
|
76
|
+
self._merged = None
|
77
|
+
self._processed = False
|
78
|
+
|
79
|
+
self._dataset = dataset
|
80
|
+
self._continuous_factor_bins = dict(continuous_factor_bins) if continuous_factor_bins else {}
|
81
|
+
self._auto_bin_method = auto_bin_method
|
82
|
+
|
83
|
+
if exclude is not None and include is not None:
|
84
|
+
raise ValueError("Filters for `exclude` and `include` are mutually exclusive.")
|
85
|
+
|
86
|
+
self._exclude = set(exclude or ())
|
87
|
+
self._include = set(include or ())
|
88
|
+
|
89
|
+
@property
|
90
|
+
def targets(self) -> Targets:
|
91
|
+
self._collate()
|
92
|
+
return self._targets
|
93
|
+
|
94
|
+
@property
|
95
|
+
def raw(self) -> list[dict[str, Any]]:
|
96
|
+
self._collate()
|
97
|
+
return self._raw
|
98
|
+
|
99
|
+
@property
|
100
|
+
def exclude(self) -> set[str]:
|
101
|
+
return self._exclude
|
102
|
+
|
103
|
+
@exclude.setter
|
104
|
+
def exclude(self, value: Sequence[str]) -> None:
|
105
|
+
exclude = set(value)
|
106
|
+
if self._exclude != exclude:
|
107
|
+
self._exclude = exclude
|
108
|
+
self._include = set()
|
109
|
+
self._processed = False
|
110
|
+
|
111
|
+
@property
|
112
|
+
def include(self) -> set[str]:
|
113
|
+
return self._include
|
114
|
+
|
115
|
+
@include.setter
|
116
|
+
def include(self, value: Sequence[str]) -> None:
|
117
|
+
include = set(value)
|
118
|
+
if self._include != include:
|
119
|
+
self._include = include
|
120
|
+
self._exclude = set()
|
121
|
+
self._processed = False
|
122
|
+
|
123
|
+
@property
|
124
|
+
def continuous_factor_bins(self) -> Mapping[str, int | Sequence[float]]:
|
125
|
+
return self._continuous_factor_bins
|
126
|
+
|
127
|
+
@continuous_factor_bins.setter
|
128
|
+
def continuous_factor_bins(self, bins: Mapping[str, int | Sequence[float]]) -> None:
|
129
|
+
if self._continuous_factor_bins != bins:
|
130
|
+
self._continuous_factor_bins = dict(bins)
|
131
|
+
self._processed = False
|
132
|
+
|
133
|
+
@property
|
134
|
+
def auto_bin_method(self) -> str:
|
135
|
+
return self._auto_bin_method
|
136
|
+
|
137
|
+
@auto_bin_method.setter
|
138
|
+
def auto_bin_method(self, method: Literal["uniform_width", "uniform_count", "clusters"]) -> None:
|
139
|
+
if self._auto_bin_method != method:
|
140
|
+
self._auto_bin_method = method
|
141
|
+
self._processed = False
|
142
|
+
|
143
|
+
@property
|
144
|
+
def merged(self) -> dict[str, Any]:
|
145
|
+
self._merge()
|
146
|
+
return {} if self._merged is None else self._merged[0]
|
147
|
+
|
148
|
+
@property
|
149
|
+
def dropped_factors(self) -> dict[str, list[str]]:
|
150
|
+
self._merge()
|
151
|
+
return {} if self._merged is None else self._merged[1]
|
152
|
+
|
153
|
+
@property
|
154
|
+
def discrete_factor_names(self) -> list[str]:
|
155
|
+
self._process()
|
156
|
+
return self._discrete_factor_names
|
157
|
+
|
158
|
+
@property
|
159
|
+
def discrete_data(self) -> NDArray[np.int64]:
|
160
|
+
self._process()
|
161
|
+
return self._discrete_data
|
162
|
+
|
163
|
+
@property
|
164
|
+
def continuous_factor_names(self) -> list[str]:
|
165
|
+
self._process()
|
166
|
+
return self._continuous_factor_names
|
167
|
+
|
168
|
+
@property
|
169
|
+
def continuous_data(self) -> NDArray[np.float64]:
|
170
|
+
self._process()
|
171
|
+
return self._continuous_data
|
172
|
+
|
173
|
+
@property
|
174
|
+
def class_labels(self) -> NDArray[np.intp]:
|
175
|
+
self._collate()
|
176
|
+
return self._class_labels
|
177
|
+
|
178
|
+
@property
|
179
|
+
def class_names(self) -> list[str]:
|
180
|
+
self._collate()
|
181
|
+
return self._class_names
|
182
|
+
|
183
|
+
@property
|
184
|
+
def total_num_factors(self) -> int:
|
185
|
+
self._process()
|
186
|
+
return self._total_num_factors
|
187
|
+
|
188
|
+
@property
|
189
|
+
def image_indices(self) -> NDArray[np.intp]:
|
190
|
+
self._process()
|
191
|
+
return self._image_indices
|
192
|
+
|
193
|
+
def _collate(self, force: bool = False):
|
194
|
+
if self._collated and not force:
|
195
|
+
return
|
196
|
+
|
197
|
+
raw: list[dict[str, Any]] = []
|
198
|
+
|
199
|
+
labels = []
|
200
|
+
bboxes = []
|
201
|
+
scores = []
|
202
|
+
srcidx = []
|
203
|
+
is_od = None
|
204
|
+
for i in range(len(self._dataset)):
|
205
|
+
_, target, metadata = self._dataset[i]
|
206
|
+
|
207
|
+
raw.append(metadata)
|
208
|
+
|
209
|
+
if is_od_target := isinstance(target, ObjectDetectionTarget):
|
210
|
+
target_len = len(target.labels)
|
211
|
+
labels.extend(as_numpy(target.labels).tolist())
|
212
|
+
bboxes.extend(as_numpy(target.boxes).tolist())
|
213
|
+
scores.extend(as_numpy(target.scores).tolist())
|
214
|
+
srcidx.extend([i] * target_len)
|
215
|
+
elif isinstance(target, Array):
|
216
|
+
target_len = 1
|
217
|
+
labels.append(int(np.argmax(as_numpy(target))))
|
218
|
+
scores.append(target)
|
219
|
+
else:
|
220
|
+
raise TypeError("Encountered unsupported target type in dataset")
|
221
|
+
|
222
|
+
is_od = is_od_target if is_od is None else is_od
|
223
|
+
if is_od != is_od_target:
|
224
|
+
raise ValueError("Encountered unexpected target type in dataset")
|
225
|
+
|
226
|
+
labels = as_numpy(labels).astype(np.intp)
|
227
|
+
scores = as_numpy(scores).astype(np.float32)
|
228
|
+
bboxes = as_numpy(bboxes).astype(np.float32) if is_od else None
|
229
|
+
srcidx = as_numpy(srcidx).astype(np.intp) if is_od else None
|
230
|
+
|
231
|
+
self._targets = Targets(labels, scores, bboxes, srcidx)
|
232
|
+
self._raw = raw
|
233
|
+
|
234
|
+
index2label = self._dataset.metadata.get("index2label", {})
|
235
|
+
self._class_labels = self._targets.labels
|
236
|
+
self._class_names = [index2label.get(i, str(i)) for i in np.unique(self._class_labels)]
|
237
|
+
self._collated = True
|
238
|
+
|
239
|
+
def _merge(self, force: bool = False):
|
240
|
+
if self._merged is not None and not force:
|
241
|
+
return
|
242
|
+
|
243
|
+
targets_per_image = (
|
244
|
+
None if self.targets.source is None else np.unique(self.targets.source, return_counts=True)[1].tolist()
|
245
|
+
)
|
246
|
+
self._merged = merge(self.raw, return_dropped=True, ignore_lists=False, targets_per_image=targets_per_image)
|
247
|
+
|
248
|
+
def _validate(self) -> None:
|
249
|
+
# Check that metadata is a single, flattened dictionary with uniform array lengths
|
250
|
+
check_length = None
|
251
|
+
if self._targets.labels.ndim > 1:
|
252
|
+
raise ValueError(
|
253
|
+
f"Got class labels with {self._targets.labels.ndim}-dimensional "
|
254
|
+
f"shape {self._targets.labels.shape}, but expected a 1-dimensional array."
|
255
|
+
)
|
256
|
+
for v in self.merged.values():
|
257
|
+
if not isinstance(v, (list, tuple, np.ndarray)):
|
258
|
+
raise TypeError(
|
259
|
+
"Metadata dictionary needs to be a single dictionary whose values "
|
260
|
+
"are arraylike containing the metadata on a per image or per object basis."
|
261
|
+
)
|
262
|
+
else:
|
263
|
+
check_length = len(v) if check_length is None else check_length
|
264
|
+
if check_length != len(v):
|
265
|
+
raise ValueError(
|
266
|
+
"The lists/arrays in the metadata dict have varying lengths. "
|
267
|
+
"Metadata requires them to be uniform in length."
|
268
|
+
)
|
269
|
+
if len(self._class_labels) != check_length:
|
270
|
+
raise ValueError(
|
271
|
+
f"The length of the label array {len(self._class_labels)} is not the same as "
|
272
|
+
f"the length of the metadata arrays {check_length}."
|
273
|
+
)
|
274
|
+
|
275
|
+
def _process(self, force: bool = False) -> None:
|
276
|
+
if self._processed and not force:
|
277
|
+
return
|
278
|
+
|
279
|
+
# Trigger collate and merge if not yet done
|
280
|
+
self._collate()
|
281
|
+
self._merge()
|
282
|
+
|
283
|
+
# Validate the metadata dimensions
|
284
|
+
self._validate()
|
285
|
+
|
286
|
+
# Create image indices from targets
|
287
|
+
self._image_indices = np.arange(len(self.raw)) if self.targets.source is None else self.targets.source
|
288
|
+
|
289
|
+
# Include specified metadata keys
|
290
|
+
if self.include:
|
291
|
+
metadata = {i: self.merged[i] for i in self.include if i in self.merged}
|
292
|
+
continuous_factor_bins = (
|
293
|
+
{i: self.continuous_factor_bins[i] for i in self.include if i in self.continuous_factor_bins}
|
294
|
+
if self.continuous_factor_bins
|
295
|
+
else {}
|
296
|
+
)
|
297
|
+
else:
|
298
|
+
metadata = self.merged
|
299
|
+
continuous_factor_bins = dict(self.continuous_factor_bins) if self.continuous_factor_bins else {}
|
300
|
+
for k in self.exclude:
|
301
|
+
metadata.pop(k, None)
|
302
|
+
continuous_factor_bins.pop(k, None)
|
303
|
+
|
304
|
+
# Remove generated "_image_index" if present
|
305
|
+
if "_image_index" in metadata:
|
306
|
+
metadata.pop("_image_index", None)
|
307
|
+
|
308
|
+
# Bin according to user supplied bins
|
309
|
+
continuous_metadata = {}
|
310
|
+
discrete_metadata = {}
|
311
|
+
if continuous_factor_bins:
|
312
|
+
invalid_keys = set(continuous_factor_bins.keys()) - set(metadata.keys())
|
313
|
+
if invalid_keys:
|
314
|
+
raise KeyError(
|
315
|
+
f"The keys - {invalid_keys} - are present in the `continuous_factor_bins` dictionary "
|
316
|
+
"but are not keys in the `metadata` dictionary. Delete these keys from `continuous_factor_bins` "
|
317
|
+
"or add corresponding entries to the `metadata` dictionary."
|
318
|
+
)
|
319
|
+
for factor, bins in continuous_factor_bins.items():
|
320
|
+
discrete_metadata[factor] = digitize_data(metadata[factor], bins)
|
321
|
+
continuous_metadata[factor] = metadata[factor]
|
322
|
+
|
323
|
+
# Determine category of the rest of the keys
|
324
|
+
remaining_keys = set(metadata.keys()) - set(continuous_metadata.keys())
|
325
|
+
for key in remaining_keys:
|
326
|
+
data = to_numpy(metadata[key])
|
327
|
+
if np.issubdtype(data.dtype, np.number):
|
328
|
+
result = is_continuous(data, self._image_indices)
|
329
|
+
if result:
|
330
|
+
continuous_metadata[key] = data
|
331
|
+
unique_samples, ordinal_data = np.unique(data, return_inverse=True)
|
332
|
+
if unique_samples.size <= np.max([20, data.size * 0.01]):
|
333
|
+
discrete_metadata[key] = ordinal_data
|
334
|
+
else:
|
335
|
+
warnings.warn(
|
336
|
+
f"A user defined binning was not provided for {key}. "
|
337
|
+
f"Using the {self.auto_bin_method} method to discretize the data. "
|
338
|
+
"It is recommended that the user rerun and supply the desired "
|
339
|
+
"bins using the continuous_factor_bins parameter.",
|
340
|
+
UserWarning,
|
341
|
+
)
|
342
|
+
discrete_metadata[key] = bin_data(data, self.auto_bin_method)
|
343
|
+
else:
|
344
|
+
_, discrete_metadata[key] = np.unique(data, return_inverse=True)
|
345
|
+
|
346
|
+
# Split out the dictionaries into the keys and values
|
347
|
+
self._discrete_factor_names = list(discrete_metadata.keys())
|
348
|
+
self._discrete_data = (
|
349
|
+
np.stack(list(discrete_metadata.values()), axis=-1, dtype=np.int64)
|
350
|
+
if discrete_metadata
|
351
|
+
else np.array([], dtype=np.int64)
|
352
|
+
)
|
353
|
+
self._continuous_factor_names = list(continuous_metadata.keys())
|
354
|
+
self._continuous_data = (
|
355
|
+
np.stack(list(continuous_metadata.values()), axis=-1, dtype=np.float64)
|
356
|
+
if continuous_metadata
|
357
|
+
else np.array([], dtype=np.float64)
|
358
|
+
)
|
359
|
+
self._total_num_factors = len(self._discrete_factor_names + self._continuous_factor_names) + 1
|
360
|
+
self._processed = True
|
@@ -0,0 +1,126 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
__all__ = []
|
4
|
+
|
5
|
+
from enum import IntEnum
|
6
|
+
from typing import Any, Generic, Iterator, Sequence, TypeVar
|
7
|
+
|
8
|
+
from dataeval.typing import AnnotatedDataset, DatasetMetadata
|
9
|
+
|
10
|
+
_TDatum = TypeVar("_TDatum")
|
11
|
+
|
12
|
+
|
13
|
+
class SelectionStage(IntEnum):
|
14
|
+
STATE = 0
|
15
|
+
FILTER = 1
|
16
|
+
ORDER = 2
|
17
|
+
|
18
|
+
|
19
|
+
class Selection(Generic[_TDatum]):
|
20
|
+
stage: SelectionStage
|
21
|
+
|
22
|
+
def __call__(self, dataset: Select[_TDatum]) -> None: ...
|
23
|
+
|
24
|
+
def __str__(self) -> str:
|
25
|
+
return f"{self.__class__.__name__}({', '.join([f'{k}={v}' for k, v in self.__dict__.items()])})"
|
26
|
+
|
27
|
+
|
28
|
+
class Select(AnnotatedDataset[_TDatum]):
|
29
|
+
"""
|
30
|
+
Wraps a dataset and applies selection criteria to it.
|
31
|
+
|
32
|
+
Parameters
|
33
|
+
----------
|
34
|
+
dataset : Dataset
|
35
|
+
The dataset to wrap.
|
36
|
+
selections : Selection or list[Selection], optional
|
37
|
+
The selection criteria to apply to the dataset.
|
38
|
+
|
39
|
+
Examples
|
40
|
+
--------
|
41
|
+
>>> from dataeval.utils.data.selections import ClassFilter, Limit
|
42
|
+
|
43
|
+
>>> # Construct a sample dataset with size of 100 and class count of 10
|
44
|
+
>>> # Elements at index `idx` are returned as tuples:
|
45
|
+
>>> # - f"data_{idx}", one_hot_encoded(idx % class_count), {"id": idx}
|
46
|
+
>>> dataset = SampleDataset(size=100, class_count=10)
|
47
|
+
|
48
|
+
>>> # Apply a selection criteria to the dataset
|
49
|
+
>>> selections = [Limit(size=5), ClassFilter(classes=[0, 2])]
|
50
|
+
>>> selected_dataset = Select(dataset, selections=selections)
|
51
|
+
|
52
|
+
>>> # Iterate over the selected dataset
|
53
|
+
>>> for data, target, meta in selected_dataset:
|
54
|
+
... print(f"({data}, {np.argmax(target)}, {meta})")
|
55
|
+
(data_0, 0, {'id': 0})
|
56
|
+
(data_2, 2, {'id': 2})
|
57
|
+
(data_10, 0, {'id': 10})
|
58
|
+
(data_12, 2, {'id': 12})
|
59
|
+
(data_20, 0, {'id': 20})
|
60
|
+
"""
|
61
|
+
|
62
|
+
_dataset: AnnotatedDataset[_TDatum]
|
63
|
+
_selection: list[int]
|
64
|
+
_selections: Sequence[Selection[_TDatum]]
|
65
|
+
_size_limit: int
|
66
|
+
|
67
|
+
def __init__(
|
68
|
+
self,
|
69
|
+
dataset: AnnotatedDataset[_TDatum],
|
70
|
+
selections: Selection[_TDatum] | list[Selection[_TDatum]] | None = None,
|
71
|
+
) -> None:
|
72
|
+
self._dataset = dataset
|
73
|
+
self._size_limit = len(dataset)
|
74
|
+
self._selection = list(range(self._size_limit))
|
75
|
+
self._selections = self._sort_selections(selections)
|
76
|
+
self.__dict__.update(dataset.__dict__)
|
77
|
+
|
78
|
+
# Ensure metadata is populated correctly as DatasetMetadata TypedDict
|
79
|
+
_metadata = getattr(dataset, "metadata", {})
|
80
|
+
if "id" not in _metadata:
|
81
|
+
_metadata["id"] = dataset.__class__.__name__
|
82
|
+
self._metadata = DatasetMetadata(**_metadata)
|
83
|
+
|
84
|
+
if self._selections:
|
85
|
+
self._apply_selections()
|
86
|
+
|
87
|
+
@property
|
88
|
+
def metadata(self) -> DatasetMetadata:
|
89
|
+
return self._metadata
|
90
|
+
|
91
|
+
def __str__(self) -> str:
|
92
|
+
nt = "\n "
|
93
|
+
title = f"{self.__class__.__name__} Dataset"
|
94
|
+
sep = "-" * len(title)
|
95
|
+
selections = f"Selections: [{', '.join([str(s) for s in self._sort_selections(self._selections)])}]"
|
96
|
+
return f"{title}\n{sep}{nt}{selections}\n\n{self._dataset}"
|
97
|
+
|
98
|
+
def _sort_selections(self, selections: Selection[_TDatum] | Sequence[Selection[_TDatum]] | None) -> list[Selection]:
|
99
|
+
if not selections:
|
100
|
+
return []
|
101
|
+
|
102
|
+
selections = [selections] if isinstance(selections, Selection) else selections
|
103
|
+
grouped: dict[int, list[Selection]] = {}
|
104
|
+
for selection in selections:
|
105
|
+
grouped.setdefault(selection.stage, []).append(selection)
|
106
|
+
selection_list = [selection for category in sorted(grouped) for selection in grouped[category]]
|
107
|
+
return selection_list
|
108
|
+
|
109
|
+
def _apply_selections(self) -> None:
|
110
|
+
for selection in self._selections:
|
111
|
+
selection(self)
|
112
|
+
self._selection = self._selection[: self._size_limit]
|
113
|
+
|
114
|
+
def __getattr__(self, name: str, /) -> Any:
|
115
|
+
selfattr = getattr(self._dataset, name, None)
|
116
|
+
return selfattr if selfattr is not None else getattr(self._dataset, name)
|
117
|
+
|
118
|
+
def __getitem__(self, index: int) -> _TDatum:
|
119
|
+
return self._dataset[self._selection[index]]
|
120
|
+
|
121
|
+
def __iter__(self) -> Iterator[_TDatum]:
|
122
|
+
for i in range(len(self)):
|
123
|
+
yield self[i]
|
124
|
+
|
125
|
+
def __len__(self) -> int:
|
126
|
+
return len(self._selection)
|
@@ -3,8 +3,7 @@ from __future__ import annotations
|
|
3
3
|
__all__ = []
|
4
4
|
|
5
5
|
import warnings
|
6
|
-
from
|
7
|
-
from typing import Any, Iterator, NamedTuple, Protocol
|
6
|
+
from typing import Any, Iterator, Protocol
|
8
7
|
|
9
8
|
import numpy as np
|
10
9
|
from numpy.typing import NDArray
|
@@ -13,31 +12,8 @@ from sklearn.metrics import silhouette_score
|
|
13
12
|
from sklearn.model_selection import GroupKFold, KFold, StratifiedGroupKFold, StratifiedKFold
|
14
13
|
from sklearn.utils.multiclass import type_of_target
|
15
14
|
|
16
|
-
from dataeval.
|
17
|
-
|
18
|
-
|
19
|
-
class TrainValSplit(NamedTuple):
|
20
|
-
"""Tuple containing train and validation indices"""
|
21
|
-
|
22
|
-
train: NDArray[np.intp]
|
23
|
-
val: NDArray[np.intp]
|
24
|
-
|
25
|
-
|
26
|
-
@dataclass(frozen=True)
|
27
|
-
class SplitDatasetOutput(Output):
|
28
|
-
"""
|
29
|
-
Output class containing test indices and a list of TrainValSplits.
|
30
|
-
|
31
|
-
Attributes
|
32
|
-
----------
|
33
|
-
test: NDArray[np.intp]
|
34
|
-
Indices for the test set
|
35
|
-
folds: list[TrainValSplit]
|
36
|
-
List where each index contains the indices for the train and validation splits
|
37
|
-
"""
|
38
|
-
|
39
|
-
test: NDArray[np.intp]
|
40
|
-
folds: list[TrainValSplit]
|
15
|
+
from dataeval.outputs._base import set_metadata
|
16
|
+
from dataeval.outputs._utils import SplitDatasetOutput, TrainValSplit
|
41
17
|
|
42
18
|
|
43
19
|
class KFoldSplitter(Protocol):
|
@@ -274,8 +250,7 @@ def get_group_ids(metadata: dict[str, Any], group_names: list[str], num_samples:
|
|
274
250
|
for name, feature in features2group.items():
|
275
251
|
if len(feature) != num_samples:
|
276
252
|
raise ValueError(
|
277
|
-
f"Feature length does not match number of labels. "
|
278
|
-
f"Got {len(feature)} features and {num_samples} samples"
|
253
|
+
f"Feature length does not match number of labels. Got {len(feature)} features and {num_samples} samples"
|
279
254
|
)
|
280
255
|
|
281
256
|
if type_of_target(feature) == "continuous":
|
@@ -505,23 +480,22 @@ def split_dataset(
|
|
505
480
|
if is_groupable(possible_groups, group_partitions):
|
506
481
|
groups = possible_groups
|
507
482
|
|
508
|
-
test_indices: NDArray[np.intp]
|
509
483
|
index = np.arange(label_length)
|
510
484
|
|
511
|
-
|
485
|
+
tvs = (
|
512
486
|
single_split(index=index, labels=labels, split_frac=test_frac, groups=groups, stratified=stratify)
|
513
487
|
if test_frac
|
514
|
-
else (index, np.array([], dtype=np.intp))
|
488
|
+
else TrainValSplit(index, np.array([], dtype=np.intp))
|
515
489
|
)
|
516
490
|
|
517
|
-
tv_labels = labels[
|
518
|
-
tv_groups = groups[
|
491
|
+
tv_labels = labels[tvs.train]
|
492
|
+
tv_groups = groups[tvs.train] if groups is not None else None
|
519
493
|
|
520
494
|
if num_folds == 1:
|
521
|
-
tv_splits = [single_split(
|
495
|
+
tv_splits = [single_split(tvs.train, tv_labels, val_frac, tv_groups, stratify)]
|
522
496
|
else:
|
523
|
-
tv_splits = make_splits(
|
497
|
+
tv_splits = make_splits(tvs.train, tv_labels, num_folds, tv_groups, stratify)
|
524
498
|
|
525
|
-
folds: list[TrainValSplit] = [TrainValSplit(
|
499
|
+
folds: list[TrainValSplit] = [TrainValSplit(tvs.train[split.train], tvs.train[split.val]) for split in tv_splits]
|
526
500
|
|
527
|
-
return SplitDatasetOutput(
|
501
|
+
return SplitDatasetOutput(tvs.val, folds)
|
@@ -0,0 +1,85 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
from typing import Iterator
|
4
|
+
|
5
|
+
__all__ = []
|
6
|
+
|
7
|
+
from dataclasses import dataclass
|
8
|
+
|
9
|
+
import numpy as np
|
10
|
+
from numpy.typing import NDArray
|
11
|
+
|
12
|
+
|
13
|
+
def _len(arr: NDArray, dim: int) -> int:
|
14
|
+
return 0 if len(arr) == 0 else len(np.atleast_1d(arr) if dim == 1 else np.atleast_2d(arr))
|
15
|
+
|
16
|
+
|
17
|
+
@dataclass(frozen=True)
|
18
|
+
class Targets:
|
19
|
+
"""
|
20
|
+
Dataclass defining targets for image classification or object detection.
|
21
|
+
|
22
|
+
Attributes
|
23
|
+
----------
|
24
|
+
labels : NDArray[np.intp]
|
25
|
+
Labels (N,) for N images or objects
|
26
|
+
scores : NDArray[np.float32]
|
27
|
+
Probability scores (N,M) for N images of M classes or confidence score (N,) of objects
|
28
|
+
bboxes : NDArray[np.float32] | None
|
29
|
+
Bounding boxes (N,4) for N objects in (x0,y0,x1,y1) format
|
30
|
+
source : NDArray[np.intp] | None
|
31
|
+
Source image index (N,) for N objects
|
32
|
+
"""
|
33
|
+
|
34
|
+
labels: NDArray[np.intp]
|
35
|
+
scores: NDArray[np.float32]
|
36
|
+
bboxes: NDArray[np.float32] | None
|
37
|
+
source: NDArray[np.intp] | None
|
38
|
+
|
39
|
+
def __post_init__(self) -> None:
|
40
|
+
if (self.bboxes is None) != (self.source is None):
|
41
|
+
raise ValueError("Either both bboxes and source must be provided or neither.")
|
42
|
+
|
43
|
+
labels = _len(self.labels, 1)
|
44
|
+
scores = _len(self.scores, 2) if self.bboxes is None else _len(self.scores, 1)
|
45
|
+
bboxes = labels if self.bboxes is None else _len(self.bboxes, 2)
|
46
|
+
source = labels if self.source is None else _len(self.source, 1)
|
47
|
+
|
48
|
+
if labels != scores or labels != bboxes or labels != source:
|
49
|
+
raise ValueError(
|
50
|
+
"Labels, scores, bboxes and source must be the same length (if provided).\n"
|
51
|
+
+ f" labels: {self.labels.shape}\n"
|
52
|
+
+ f" scores: {self.scores.shape}\n"
|
53
|
+
+ f" bboxes: {None if self.bboxes is None else self.bboxes.shape}\n"
|
54
|
+
+ f" source: {None if self.source is None else self.source.shape}\n"
|
55
|
+
)
|
56
|
+
|
57
|
+
if self.bboxes is not None and len(self.bboxes) > 0 and self.bboxes.shape[-1] != 4:
|
58
|
+
raise ValueError("Bounding boxes must be in (x0,y0,x1,y1) format.")
|
59
|
+
|
60
|
+
def __len__(self) -> int:
|
61
|
+
if self.source is None:
|
62
|
+
return len(self.labels)
|
63
|
+
else:
|
64
|
+
return len(np.unique(self.source))
|
65
|
+
|
66
|
+
def __getitem__(self, idx: int, /) -> Targets:
|
67
|
+
if self.source is None or self.bboxes is None:
|
68
|
+
return Targets(
|
69
|
+
np.atleast_1d(self.labels[idx]),
|
70
|
+
np.atleast_2d(self.scores[idx]),
|
71
|
+
None,
|
72
|
+
None,
|
73
|
+
)
|
74
|
+
else:
|
75
|
+
mask = np.where(self.source == idx, True, False)
|
76
|
+
return Targets(
|
77
|
+
np.atleast_1d(self.labels[mask]),
|
78
|
+
np.atleast_1d(self.scores[mask]),
|
79
|
+
np.atleast_2d(self.bboxes[mask]),
|
80
|
+
np.atleast_1d(self.source[mask]),
|
81
|
+
)
|
82
|
+
|
83
|
+
def __iter__(self) -> Iterator[Targets]:
|
84
|
+
for i in range(len(self.labels)) if self.source is None else np.unique(self.source):
|
85
|
+
yield self[i]
|