dataeval 0.86.9__py3-none-any.whl → 0.88.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dataeval/__init__.py +1 -1
- dataeval/_log.py +1 -1
- dataeval/_version.py +2 -2
- dataeval/config.py +4 -19
- dataeval/data/_embeddings.py +78 -35
- dataeval/data/_images.py +41 -8
- dataeval/data/_metadata.py +348 -66
- dataeval/data/_selection.py +22 -7
- dataeval/data/_split.py +3 -2
- dataeval/data/selections/_classbalance.py +4 -3
- dataeval/data/selections/_classfilter.py +9 -8
- dataeval/data/selections/_indices.py +4 -3
- dataeval/data/selections/_prioritize.py +249 -29
- dataeval/data/selections/_reverse.py +1 -1
- dataeval/data/selections/_shuffle.py +5 -4
- dataeval/detectors/drift/_base.py +2 -1
- dataeval/detectors/drift/_mmd.py +2 -1
- dataeval/detectors/drift/_nml/_base.py +1 -1
- dataeval/detectors/drift/_nml/_chunk.py +2 -1
- dataeval/detectors/drift/_nml/_result.py +3 -2
- dataeval/detectors/drift/_nml/_thresholds.py +6 -5
- dataeval/detectors/drift/_uncertainty.py +2 -1
- dataeval/detectors/linters/duplicates.py +2 -1
- dataeval/detectors/linters/outliers.py +4 -3
- dataeval/detectors/ood/__init__.py +2 -1
- dataeval/detectors/ood/ae.py +1 -1
- dataeval/detectors/ood/base.py +39 -1
- dataeval/detectors/ood/knn.py +95 -0
- dataeval/detectors/ood/mixin.py +2 -1
- dataeval/metadata/_utils.py +1 -1
- dataeval/metrics/bias/_balance.py +29 -22
- dataeval/metrics/bias/_diversity.py +4 -4
- dataeval/metrics/bias/_parity.py +2 -2
- dataeval/metrics/stats/_base.py +3 -29
- dataeval/metrics/stats/_boxratiostats.py +2 -1
- dataeval/metrics/stats/_dimensionstats.py +2 -1
- dataeval/metrics/stats/_hashstats.py +21 -3
- dataeval/metrics/stats/_pixelstats.py +2 -1
- dataeval/metrics/stats/_visualstats.py +2 -1
- dataeval/outputs/_base.py +2 -3
- dataeval/outputs/_bias.py +2 -1
- dataeval/outputs/_estimators.py +1 -1
- dataeval/outputs/_linters.py +3 -3
- dataeval/outputs/_stats.py +3 -3
- dataeval/outputs/_utils.py +1 -1
- dataeval/outputs/_workflows.py +49 -31
- dataeval/typing.py +23 -9
- dataeval/utils/__init__.py +2 -2
- dataeval/utils/_array.py +3 -2
- dataeval/utils/_bin.py +9 -7
- dataeval/utils/_method.py +2 -3
- dataeval/utils/_multiprocessing.py +34 -0
- dataeval/utils/_plot.py +2 -1
- dataeval/utils/data/__init__.py +6 -5
- dataeval/utils/data/{metadata.py → _merge.py} +3 -2
- dataeval/utils/data/_validate.py +170 -0
- dataeval/utils/data/collate.py +2 -1
- dataeval/utils/torch/_internal.py +2 -1
- dataeval/utils/torch/trainer.py +1 -1
- dataeval/workflows/sufficiency.py +13 -9
- {dataeval-0.86.9.dist-info → dataeval-0.88.0.dist-info}/METADATA +8 -21
- dataeval-0.88.0.dist-info/RECORD +105 -0
- dataeval/utils/data/_dataset.py +0 -246
- dataeval/utils/datasets/__init__.py +0 -21
- dataeval/utils/datasets/_antiuav.py +0 -189
- dataeval/utils/datasets/_base.py +0 -266
- dataeval/utils/datasets/_cifar10.py +0 -201
- dataeval/utils/datasets/_fileio.py +0 -142
- dataeval/utils/datasets/_milco.py +0 -197
- dataeval/utils/datasets/_mixin.py +0 -54
- dataeval/utils/datasets/_mnist.py +0 -202
- dataeval/utils/datasets/_seadrone.py +0 -512
- dataeval/utils/datasets/_ships.py +0 -144
- dataeval/utils/datasets/_types.py +0 -48
- dataeval/utils/datasets/_voc.py +0 -583
- dataeval-0.86.9.dist-info/RECORD +0 -115
- {dataeval-0.86.9.dist-info → dataeval-0.88.0.dist-info}/WHEEL +0 -0
- /dataeval-0.86.9.dist-info/licenses/LICENSE.txt → /dataeval-0.88.0.dist-info/licenses/LICENSE +0 -0
dataeval/outputs/_linters.py
CHANGED
@@ -2,11 +2,11 @@ from __future__ import annotations
|
|
2
2
|
|
3
3
|
__all__ = []
|
4
4
|
|
5
|
+
from collections.abc import Mapping, Sequence
|
5
6
|
from dataclasses import dataclass
|
6
|
-
from typing import Generic,
|
7
|
+
from typing import Generic, TypeAlias, TypeVar
|
7
8
|
|
8
9
|
import pandas as pd
|
9
|
-
from typing_extensions import TypeAlias
|
10
10
|
|
11
11
|
from dataeval.outputs._base import Output
|
12
12
|
from dataeval.outputs._stats import DimensionStatsOutput, LabelStatsOutput, PixelStatsOutput, VisualStatsOutput
|
@@ -16,7 +16,7 @@ DatasetDuplicateGroupMap: TypeAlias = Mapping[int, DuplicateGroup]
|
|
16
16
|
TIndexCollection = TypeVar("TIndexCollection", DuplicateGroup, DatasetDuplicateGroupMap)
|
17
17
|
|
18
18
|
IndexIssueMap: TypeAlias = Mapping[int, Mapping[str, float]]
|
19
|
-
OutlierStatsOutput: TypeAlias =
|
19
|
+
OutlierStatsOutput: TypeAlias = DimensionStatsOutput | PixelStatsOutput | VisualStatsOutput
|
20
20
|
TIndexIssueMap = TypeVar("TIndexIssueMap", IndexIssueMap, Sequence[IndexIssueMap])
|
21
21
|
|
22
22
|
|
dataeval/outputs/_stats.py
CHANGED
@@ -2,13 +2,13 @@ from __future__ import annotations
|
|
2
2
|
|
3
3
|
__all__ = []
|
4
4
|
|
5
|
+
from collections.abc import Iterable, Mapping, Sequence
|
5
6
|
from dataclasses import dataclass
|
6
|
-
from typing import TYPE_CHECKING, Any,
|
7
|
+
from typing import TYPE_CHECKING, Any, NamedTuple, TypeAlias
|
7
8
|
|
8
9
|
import numpy as np
|
9
10
|
import polars as pl
|
10
11
|
from numpy.typing import NDArray
|
11
|
-
from typing_extensions import TypeAlias
|
12
12
|
|
13
13
|
from dataeval.outputs._base import Output
|
14
14
|
from dataeval.utils._plot import channel_histogram_plot, histogram_plot
|
@@ -16,7 +16,7 @@ from dataeval.utils._plot import channel_histogram_plot, histogram_plot
|
|
16
16
|
if TYPE_CHECKING:
|
17
17
|
from matplotlib.figure import Figure
|
18
18
|
|
19
|
-
OptionalRange: TypeAlias =
|
19
|
+
OptionalRange: TypeAlias = int | Iterable[int] | None
|
20
20
|
|
21
21
|
SOURCE_INDEX = "source_index"
|
22
22
|
OBJECT_COUNT = "object_count"
|
dataeval/outputs/_utils.py
CHANGED
dataeval/outputs/_workflows.py
CHANGED
@@ -4,8 +4,9 @@ __all__ = []
|
|
4
4
|
|
5
5
|
import contextlib
|
6
6
|
import warnings
|
7
|
-
from
|
8
|
-
from
|
7
|
+
from collections.abc import Iterable, Mapping, MutableMapping, Sequence
|
8
|
+
from dataclasses import dataclass, field
|
9
|
+
from typing import Any, cast
|
9
10
|
|
10
11
|
import numpy as np
|
11
12
|
from numpy.typing import NDArray
|
@@ -76,10 +77,8 @@ def plot_measure(
|
|
76
77
|
ax.set_title(f"{name} Sufficiency")
|
77
78
|
ax.set_ylabel(f"{name}")
|
78
79
|
ax.set_xlabel("Steps")
|
79
|
-
|
80
80
|
# Plot measure over each step
|
81
81
|
ax.scatter(steps, measure, label=f"Model Results ({name})", s=15, c="black")
|
82
|
-
|
83
82
|
# Plot extrapolation
|
84
83
|
ax.plot(
|
85
84
|
projection,
|
@@ -92,7 +91,7 @@ def plot_measure(
|
|
92
91
|
return fig
|
93
92
|
|
94
93
|
|
95
|
-
def f_inv_out(y_i: NDArray[Any], x: NDArray[Any]) -> NDArray[np.
|
94
|
+
def f_inv_out(y_i: NDArray[Any], x: NDArray[Any]) -> NDArray[np.int64]:
|
96
95
|
"""
|
97
96
|
Inverse function for f_out()
|
98
97
|
|
@@ -106,13 +105,27 @@ def f_inv_out(y_i: NDArray[Any], x: NDArray[Any]) -> NDArray[np.uint64]:
|
|
106
105
|
Returns
|
107
106
|
-------
|
108
107
|
NDArray
|
109
|
-
|
108
|
+
Sample size or -1 if unachievable for each data point
|
110
109
|
"""
|
111
|
-
|
112
|
-
|
110
|
+
with np.errstate(invalid="ignore"):
|
111
|
+
n_i = ((y_i - x[2]) / x[0]) ** (-1 / x[1])
|
112
|
+
unachievable_targets = np.isnan(n_i) | np.any(n_i > np.iinfo(np.int64).max)
|
113
|
+
if any(unachievable_targets):
|
114
|
+
with np.printoptions(suppress=True):
|
115
|
+
warnings.warn(
|
116
|
+
"Number of samples could not be determined for target(s): "
|
117
|
+
f"""{
|
118
|
+
np.array2string(
|
119
|
+
1 - y_i[unachievable_targets], separator=", ", formatter={"float": lambda x: f"{x}"}
|
120
|
+
)
|
121
|
+
}""",
|
122
|
+
UserWarning,
|
123
|
+
)
|
124
|
+
n_i[unachievable_targets] = -1
|
125
|
+
return np.asarray(n_i, dtype=np.int64)
|
113
126
|
|
114
127
|
|
115
|
-
def inv_project_steps(params: NDArray[Any], targets: NDArray[Any]) -> NDArray[np.
|
128
|
+
def inv_project_steps(params: NDArray[Any], targets: NDArray[Any]) -> NDArray[np.int64]:
|
116
129
|
"""Inverse function for project_steps()
|
117
130
|
|
118
131
|
Parameters
|
@@ -125,14 +138,13 @@ def inv_project_steps(params: NDArray[Any], targets: NDArray[Any]) -> NDArray[np
|
|
125
138
|
Returns
|
126
139
|
-------
|
127
140
|
NDArray
|
128
|
-
|
141
|
+
Samples required or -1 if unachievable for each target value
|
129
142
|
"""
|
130
143
|
steps = f_inv_out(1 - np.array(targets), params)
|
131
|
-
steps[np.isnan(steps)] = 0
|
132
144
|
return np.ceil(steps)
|
133
145
|
|
134
146
|
|
135
|
-
def calc_params(p_i: NDArray[Any], n_i: NDArray[Any], niter: int) -> NDArray[
|
147
|
+
def calc_params(p_i: NDArray[Any], n_i: NDArray[Any], niter: int) -> NDArray[np.float64]:
|
136
148
|
"""
|
137
149
|
Retrieves the inverse power curve coefficients for the line of best fit.
|
138
150
|
Global minimization is done via basin hopping. More info on this algorithm
|
@@ -178,11 +190,11 @@ def calc_params(p_i: NDArray[Any], n_i: NDArray[Any], niter: int) -> NDArray[Any
|
|
178
190
|
|
179
191
|
|
180
192
|
def get_curve_params(
|
181
|
-
|
182
|
-
) -> Mapping[str, NDArray[
|
193
|
+
averaged_measures: MutableMapping[str, NDArray[Any]], ranges: NDArray[Any], niter: int
|
194
|
+
) -> Mapping[str, NDArray[np.float64]]:
|
183
195
|
"""Calculates and aggregates parameters for both single and multi-class metrics"""
|
184
196
|
output = {}
|
185
|
-
for name, measure in
|
197
|
+
for name, measure in averaged_measures.items():
|
186
198
|
measure = cast(np.ndarray, measure)
|
187
199
|
if measure.ndim > 1:
|
188
200
|
result = []
|
@@ -203,19 +215,25 @@ class SufficiencyOutput(Output):
|
|
203
215
|
----------
|
204
216
|
steps : NDArray
|
205
217
|
Array of sample sizes
|
206
|
-
measures :
|
207
|
-
|
218
|
+
measures : dict[str, NDArray]
|
219
|
+
3D array [runs, substep, classes] of values for all runs observed for each sample size step for each measure
|
220
|
+
averaged_measures : dict[str, NDArray]
|
221
|
+
Average of values for all runs observed for each sample size step for each measure
|
208
222
|
n_iter : int, default 1000
|
209
223
|
Number of iterations to perform in the basin-hopping curve-fit process
|
210
224
|
"""
|
211
225
|
|
212
226
|
steps: NDArray[np.uint32]
|
213
|
-
measures: Mapping[str, NDArray[
|
227
|
+
measures: Mapping[str, NDArray[Any]]
|
228
|
+
averaged_measures: MutableMapping[str, NDArray[Any]] = field(default_factory=lambda: {})
|
214
229
|
n_iter: int = 1000
|
215
230
|
|
216
231
|
def __post_init__(self) -> None:
|
232
|
+
if len(self.averaged_measures) == 0:
|
233
|
+
for metric, values in self.measures.items():
|
234
|
+
self.averaged_measures[metric] = np.asarray(np.mean(values, axis=0)).T
|
217
235
|
c = len(self.steps)
|
218
|
-
for m, v in self.
|
236
|
+
for m, v in self.averaged_measures.items():
|
219
237
|
c_v = v.shape[1] if v.ndim > 1 else len(v)
|
220
238
|
if c != c_v:
|
221
239
|
raise ValueError(f"{m} does not contain the expected number ({c}) of data points.")
|
@@ -226,7 +244,7 @@ class SufficiencyOutput(Output):
|
|
226
244
|
if self._params is None:
|
227
245
|
self._params = {}
|
228
246
|
if self.n_iter not in self._params:
|
229
|
-
self._params[self.n_iter] = get_curve_params(self.
|
247
|
+
self._params[self.n_iter] = get_curve_params(self.averaged_measures, self.steps, self.n_iter)
|
230
248
|
return self._params[self.n_iter]
|
231
249
|
|
232
250
|
@set_metadata
|
@@ -259,16 +277,16 @@ class SufficiencyOutput(Output):
|
|
259
277
|
raise ValueError("'projection' must consist of numerical values")
|
260
278
|
|
261
279
|
output = {}
|
262
|
-
for name,
|
263
|
-
if
|
280
|
+
for name, averaged_measures in self.averaged_measures.items():
|
281
|
+
if averaged_measures.ndim > 1:
|
264
282
|
result = []
|
265
|
-
for i in range(len(
|
283
|
+
for i in range(len(averaged_measures)):
|
266
284
|
projected = project_steps(self.params[name][i], projection)
|
267
285
|
result.append(projected)
|
268
286
|
output[name] = np.array(result)
|
269
287
|
else:
|
270
288
|
output[name] = project_steps(self.params[name], projection)
|
271
|
-
proj = SufficiencyOutput(projection, output, self.n_iter)
|
289
|
+
proj = SufficiencyOutput(projection, measures=self.measures, averaged_measures=output, n_iter=self.n_iter)
|
272
290
|
proj._params = self._params
|
273
291
|
return proj
|
274
292
|
|
@@ -304,11 +322,11 @@ class SufficiencyOutput(Output):
|
|
304
322
|
plots = []
|
305
323
|
|
306
324
|
# Create a plot for each measure on one figure
|
307
|
-
for name,
|
308
|
-
if
|
309
|
-
if class_names is not None and len(
|
325
|
+
for name, averaged_measures in self.averaged_measures.items():
|
326
|
+
if averaged_measures.ndim > 1:
|
327
|
+
if class_names is not None and len(averaged_measures) != len(class_names):
|
310
328
|
raise IndexError("Class name count does not align with measures")
|
311
|
-
for i, measure in enumerate(
|
329
|
+
for i, measure in enumerate(averaged_measures):
|
312
330
|
class_name = str(i) if class_names is None else class_names[i]
|
313
331
|
fig = plot_measure(
|
314
332
|
f"{name}_{class_name}",
|
@@ -320,7 +338,7 @@ class SufficiencyOutput(Output):
|
|
320
338
|
plots.append(fig)
|
321
339
|
|
322
340
|
else:
|
323
|
-
fig = plot_measure(name, self.steps,
|
341
|
+
fig = plot_measure(name, self.steps, averaged_measures, self.params[name], extrapolated)
|
324
342
|
plots.append(fig)
|
325
343
|
|
326
344
|
return plots
|
@@ -350,10 +368,10 @@ class SufficiencyOutput(Output):
|
|
350
368
|
|
351
369
|
for name, target in targets.items():
|
352
370
|
tarray = as_numpy(target)
|
353
|
-
if name not in self.
|
371
|
+
if name not in self.averaged_measures:
|
354
372
|
continue
|
355
373
|
|
356
|
-
measure = self.
|
374
|
+
measure = self.averaged_measures[name]
|
357
375
|
if measure.ndim > 1:
|
358
376
|
projection[name] = np.zeros((len(measure), len(tarray)))
|
359
377
|
for i in range(len(measure)):
|
dataeval/typing.py
CHANGED
@@ -3,11 +3,12 @@ Common type protocols used for interoperability with DataEval.
|
|
3
3
|
"""
|
4
4
|
|
5
5
|
__all__ = [
|
6
|
+
"AnnotatedDataset",
|
6
7
|
"Array",
|
7
8
|
"ArrayLike",
|
8
9
|
"Dataset",
|
9
|
-
"AnnotatedDataset",
|
10
10
|
"DatasetMetadata",
|
11
|
+
"DeviceLike",
|
11
12
|
"ImageClassificationDatum",
|
12
13
|
"ImageClassificationDataset",
|
13
14
|
"ObjectDetectionTarget",
|
@@ -20,18 +21,21 @@ __all__ = [
|
|
20
21
|
]
|
21
22
|
|
22
23
|
|
23
|
-
import
|
24
|
-
from typing import
|
24
|
+
from collections.abc import Iterator, Mapping
|
25
|
+
from typing import (
|
26
|
+
Any,
|
27
|
+
Generic,
|
28
|
+
Protocol,
|
29
|
+
TypeAlias,
|
30
|
+
TypedDict,
|
31
|
+
TypeVar,
|
32
|
+
runtime_checkable,
|
33
|
+
)
|
25
34
|
|
26
35
|
import numpy.typing
|
36
|
+
import torch
|
27
37
|
from typing_extensions import NotRequired, ReadOnly, Required
|
28
38
|
|
29
|
-
if sys.version_info >= (3, 10):
|
30
|
-
from typing import TypeAlias
|
31
|
-
else:
|
32
|
-
from typing_extensions import TypeAlias
|
33
|
-
|
34
|
-
|
35
39
|
ArrayLike: TypeAlias = numpy.typing.ArrayLike
|
36
40
|
"""
|
37
41
|
Type alias for a `Union` representing objects that can be coerced into an array.
|
@@ -42,6 +46,16 @@ See Also
|
|
42
46
|
"""
|
43
47
|
|
44
48
|
|
49
|
+
DeviceLike: TypeAlias = int | str | tuple[str, int] | torch.device
|
50
|
+
"""
|
51
|
+
Type alias for a `Union` representing types that specify a torch.device.
|
52
|
+
|
53
|
+
See Also
|
54
|
+
--------
|
55
|
+
`torch.device <https://pytorch.org/docs/stable/tensor_attributes.html#torch.device>`_
|
56
|
+
"""
|
57
|
+
|
58
|
+
|
45
59
|
@runtime_checkable
|
46
60
|
class Array(Protocol):
|
47
61
|
"""
|
dataeval/utils/__init__.py
CHANGED
dataeval/utils/_array.py
CHANGED
@@ -4,9 +4,10 @@ __all__ = []
|
|
4
4
|
|
5
5
|
import logging
|
6
6
|
import warnings
|
7
|
+
from collections.abc import Iterable, Iterator
|
7
8
|
from importlib import import_module
|
8
9
|
from types import ModuleType
|
9
|
-
from typing import Any,
|
10
|
+
from typing import Any, Literal, TypeVar, overload
|
10
11
|
|
11
12
|
import numpy as np
|
12
13
|
import torch
|
@@ -79,7 +80,7 @@ def rescale_array(array: NDArray[_np_dtype]) -> NDArray[_np_dtype]: ...
|
|
79
80
|
def rescale_array(array: torch.Tensor) -> torch.Tensor: ...
|
80
81
|
def rescale_array(array: Array | NDArray[_np_dtype] | torch.Tensor) -> Array | NDArray[_np_dtype] | torch.Tensor:
|
81
82
|
"""Rescale an array to the range [0, 1]"""
|
82
|
-
if isinstance(array,
|
83
|
+
if isinstance(array, np.ndarray | torch.Tensor):
|
83
84
|
arr_min = array.min()
|
84
85
|
arr_max = array.max()
|
85
86
|
return (array - arr_min) / (arr_max - arr_min)
|
dataeval/utils/_bin.py
CHANGED
@@ -3,7 +3,8 @@ from __future__ import annotations
|
|
3
3
|
__all__ = []
|
4
4
|
|
5
5
|
import warnings
|
6
|
-
from
|
6
|
+
from collections.abc import Iterable
|
7
|
+
from typing import Any
|
7
8
|
|
8
9
|
import numpy as np
|
9
10
|
from numpy.typing import NDArray
|
@@ -94,7 +95,7 @@ def bin_data(data: NDArray[Any], bin_method: str) -> NDArray[np.int_]:
|
|
94
95
|
return np.digitize(data, bin_edges)
|
95
96
|
|
96
97
|
|
97
|
-
def is_continuous(data: NDArray[np.number[Any]], image_indices: NDArray[np.number[Any]]) -> bool:
|
98
|
+
def is_continuous(data: NDArray[np.number[Any]], image_indices: NDArray[np.number[Any]] | None = None) -> bool:
|
98
99
|
"""
|
99
100
|
Determines whether the data is continuous or discrete using the Wasserstein distance.
|
100
101
|
|
@@ -113,11 +114,12 @@ def is_continuous(data: NDArray[np.number[Any]], image_indices: NDArray[np.numbe
|
|
113
114
|
measured from a uniform distribution is greater or less than 0.054, respectively.
|
114
115
|
"""
|
115
116
|
# Check if the metadata is image specific
|
116
|
-
|
117
|
-
|
118
|
-
|
119
|
-
|
120
|
-
|
117
|
+
if image_indices is not None:
|
118
|
+
_, data_indices_unsorted = np.unique(data, return_index=True)
|
119
|
+
if data_indices_unsorted.size == image_indices.size:
|
120
|
+
data_indices = np.sort(data_indices_unsorted)
|
121
|
+
if (data_indices == image_indices).all():
|
122
|
+
data = data[data_indices]
|
121
123
|
|
122
124
|
n_examples = len(data)
|
123
125
|
|
dataeval/utils/_method.py
CHANGED
@@ -0,0 +1,34 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
__all__ = []
|
4
|
+
|
5
|
+
from collections.abc import Callable, Iterable, Iterator
|
6
|
+
from multiprocessing import Pool
|
7
|
+
from typing import Any, TypeVar
|
8
|
+
|
9
|
+
_S = TypeVar("_S")
|
10
|
+
_T = TypeVar("_T")
|
11
|
+
|
12
|
+
|
13
|
+
class PoolWrapper:
|
14
|
+
"""
|
15
|
+
Wraps `multiprocessing.Pool` to allow for easy switching between
|
16
|
+
multiprocessing and single-threaded execution.
|
17
|
+
|
18
|
+
This helps with debugging and profiling, as well as usage with Jupyter notebooks
|
19
|
+
in VS Code, which does not support subprocess debugging.
|
20
|
+
"""
|
21
|
+
|
22
|
+
def __init__(self, processes: int | None) -> None:
|
23
|
+
self.pool = Pool(processes) if processes is None or processes > 1 else None
|
24
|
+
|
25
|
+
def imap(self, func: Callable[[_S], _T], iterable: Iterable[_S]) -> Iterator[_T]:
|
26
|
+
return map(func, iterable) if self.pool is None else self.pool.imap(func, iterable)
|
27
|
+
|
28
|
+
def __enter__(self, *args: Any, **kwargs: Any) -> PoolWrapper:
|
29
|
+
return self
|
30
|
+
|
31
|
+
def __exit__(self, *args: Any) -> None:
|
32
|
+
if self.pool is not None:
|
33
|
+
self.pool.close()
|
34
|
+
self.pool.join()
|
dataeval/utils/_plot.py
CHANGED
dataeval/utils/data/__init__.py
CHANGED
@@ -1,11 +1,12 @@
|
|
1
1
|
"""Provides access to common Computer Vision datasets."""
|
2
2
|
|
3
|
-
from dataeval.utils.data import collate
|
4
|
-
from dataeval.utils.data.
|
3
|
+
from dataeval.utils.data import collate
|
4
|
+
from dataeval.utils.data._merge import flatten, merge
|
5
|
+
from dataeval.utils.data._validate import validate_dataset
|
5
6
|
|
6
7
|
__all__ = [
|
7
8
|
"collate",
|
8
|
-
"
|
9
|
-
"
|
10
|
-
"
|
9
|
+
"flatten",
|
10
|
+
"merge",
|
11
|
+
"validate_dataset",
|
11
12
|
]
|
@@ -7,8 +7,9 @@ from __future__ import annotations
|
|
7
7
|
__all__ = ["merge", "flatten"]
|
8
8
|
|
9
9
|
import warnings
|
10
|
+
from collections.abc import Iterable, Mapping, Sequence
|
10
11
|
from enum import Enum
|
11
|
-
from typing import Any,
|
12
|
+
from typing import Any, Literal, overload
|
12
13
|
|
13
14
|
import numpy as np
|
14
15
|
from numpy.typing import NDArray
|
@@ -132,7 +133,7 @@ def _flatten_dict_inner(
|
|
132
133
|
if isinstance(v, dict):
|
133
134
|
fd, size = _flatten_dict_inner(v, dropped, new_keys, size=size, nested=nested)
|
134
135
|
items.update(fd)
|
135
|
-
elif isinstance(v,
|
136
|
+
elif isinstance(v, list | tuple):
|
136
137
|
if nested:
|
137
138
|
dropped.setdefault(parent_keys + (k,), set()).add(DropReason.NESTED_LIST)
|
138
139
|
elif size is not None and size != len(v):
|
@@ -0,0 +1,170 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
__all__ = []
|
4
|
+
|
5
|
+
from collections.abc import Sequence, Sized
|
6
|
+
from typing import Any, Literal
|
7
|
+
|
8
|
+
from dataeval.config import EPSILON
|
9
|
+
from dataeval.typing import Array, ObjectDetectionTarget
|
10
|
+
from dataeval.utils._array import as_numpy
|
11
|
+
|
12
|
+
|
13
|
+
class ValidationMessages:
|
14
|
+
DATASET_SIZED = "Dataset must be sized."
|
15
|
+
DATASET_INDEXABLE = "Dataset must be indexable."
|
16
|
+
DATASET_NONEMPTY = "Dataset must be non-empty."
|
17
|
+
DATASET_METADATA = "Dataset must have a 'metadata' attribute."
|
18
|
+
DATASET_METADATA_TYPE = "Dataset metadata must be a dictionary."
|
19
|
+
DATASET_METADATA_FORMAT = "Dataset metadata must contain an 'id' key."
|
20
|
+
DATUM_TYPE = "Dataset datum must be a tuple."
|
21
|
+
DATUM_FORMAT = "Dataset datum must contain 3 elements: image, target, metadata."
|
22
|
+
DATUM_IMAGE_TYPE = "Images must be 3-dimensional arrays."
|
23
|
+
DATUM_IMAGE_FORMAT = "Images must be in CHW format."
|
24
|
+
DATUM_TARGET_IC_TYPE = "ImageClassificationDataset targets must be one-dimensional arrays."
|
25
|
+
DATUM_TARGET_IC_FORMAT = "ImageClassificationDataset targets must be one-hot encoded or pseudo-probabilities."
|
26
|
+
DATUM_TARGET_OD_TYPE = "ObjectDetectionDataset targets must be have 'boxes', 'labels' and 'scores'."
|
27
|
+
DATUM_TARGET_OD_LABELS_TYPE = "ObjectDetectionTarget labels must be one-dimensional (N,) arrays."
|
28
|
+
DATUM_TARGET_OD_BOXES_TYPE = "ObjectDetectionTarget boxes must be two-dimensional (N, 4) arrays in xxyy format."
|
29
|
+
DATUM_TARGET_OD_SCORES_TYPE = "ObjectDetectionTarget scores must be one (N,) or two-dimensional (N, M) arrays."
|
30
|
+
DATUM_TARGET_TYPE = "Target is not a valid ImageClassification or ObjectDetection target type."
|
31
|
+
DATUM_METADATA_TYPE = "Datum metadata must be a dictionary."
|
32
|
+
DATUM_METADATA_FORMAT = "Datum metadata must contain an 'id' key."
|
33
|
+
|
34
|
+
|
35
|
+
def _validate_dataset_type(dataset: Any) -> list[str]:
|
36
|
+
issues = []
|
37
|
+
is_sized = isinstance(dataset, Sized)
|
38
|
+
is_indexable = hasattr(dataset, "__getitem__")
|
39
|
+
if not is_sized:
|
40
|
+
issues.append(ValidationMessages.DATASET_SIZED)
|
41
|
+
if not is_indexable:
|
42
|
+
issues.append(ValidationMessages.DATASET_INDEXABLE)
|
43
|
+
if is_sized and len(dataset) == 0:
|
44
|
+
issues.append(ValidationMessages.DATASET_NONEMPTY)
|
45
|
+
return issues
|
46
|
+
|
47
|
+
|
48
|
+
def _validate_dataset_metadata(dataset: Any) -> list[str]:
|
49
|
+
issues = []
|
50
|
+
if not hasattr(dataset, "metadata"):
|
51
|
+
issues.append(ValidationMessages.DATASET_METADATA)
|
52
|
+
metadata = getattr(dataset, "metadata", None)
|
53
|
+
if not isinstance(metadata, dict):
|
54
|
+
issues.append(ValidationMessages.DATASET_METADATA_TYPE)
|
55
|
+
if not isinstance(metadata, dict) or "id" not in metadata:
|
56
|
+
issues.append(ValidationMessages.DATASET_METADATA_FORMAT)
|
57
|
+
return issues
|
58
|
+
|
59
|
+
|
60
|
+
def _validate_datum_type(datum: Any) -> list[str]:
|
61
|
+
issues = []
|
62
|
+
if not isinstance(datum, tuple):
|
63
|
+
issues.append(ValidationMessages.DATUM_TYPE)
|
64
|
+
if datum is None or isinstance(datum, Sized) and len(datum) != 3:
|
65
|
+
issues.append(ValidationMessages.DATUM_FORMAT)
|
66
|
+
return issues
|
67
|
+
|
68
|
+
|
69
|
+
def _validate_datum_image(image: Any) -> list[str]:
|
70
|
+
issues = []
|
71
|
+
if not isinstance(image, Array) or len(image.shape) != 3:
|
72
|
+
issues.append(ValidationMessages.DATUM_IMAGE_TYPE)
|
73
|
+
if (
|
74
|
+
not isinstance(image, Array)
|
75
|
+
or len(image.shape) == 3
|
76
|
+
and (image.shape[0] > image.shape[1] or image.shape[0] > image.shape[2])
|
77
|
+
):
|
78
|
+
issues.append(ValidationMessages.DATUM_IMAGE_FORMAT)
|
79
|
+
return issues
|
80
|
+
|
81
|
+
|
82
|
+
def _validate_datum_target_ic(target: Any) -> list[str]:
|
83
|
+
issues = []
|
84
|
+
if not isinstance(target, Array) or len(target.shape) != 1:
|
85
|
+
issues.append(ValidationMessages.DATUM_TARGET_IC_TYPE)
|
86
|
+
if target is None or sum(target) > 1 + EPSILON or sum(target) < 1 - EPSILON:
|
87
|
+
issues.append(ValidationMessages.DATUM_TARGET_IC_FORMAT)
|
88
|
+
return issues
|
89
|
+
|
90
|
+
|
91
|
+
def _validate_datum_target_od(target: Any) -> list[str]:
|
92
|
+
issues = []
|
93
|
+
if not isinstance(target, ObjectDetectionTarget):
|
94
|
+
issues.append(ValidationMessages.DATUM_TARGET_OD_TYPE)
|
95
|
+
od_target: ObjectDetectionTarget | None = target if isinstance(target, ObjectDetectionTarget) else None
|
96
|
+
if od_target is None or len(as_numpy(od_target.labels).shape) != 1:
|
97
|
+
issues.append(ValidationMessages.DATUM_TARGET_OD_LABELS_TYPE)
|
98
|
+
if (
|
99
|
+
od_target is None
|
100
|
+
or len(as_numpy(od_target.boxes).shape) != 2
|
101
|
+
or (len(as_numpy(od_target.boxes).shape) == 2 and as_numpy(od_target.boxes).shape[1] != 4)
|
102
|
+
):
|
103
|
+
issues.append(ValidationMessages.DATUM_TARGET_OD_BOXES_TYPE)
|
104
|
+
if od_target is None or len(as_numpy(od_target.scores).shape) not in (1, 2):
|
105
|
+
issues.append(ValidationMessages.DATUM_TARGET_OD_SCORES_TYPE)
|
106
|
+
return issues
|
107
|
+
|
108
|
+
|
109
|
+
def _detect_target_type(target: Any) -> Literal["ic", "od", "auto"]:
|
110
|
+
if isinstance(target, Array):
|
111
|
+
return "ic"
|
112
|
+
if isinstance(target, ObjectDetectionTarget):
|
113
|
+
return "od"
|
114
|
+
return "auto"
|
115
|
+
|
116
|
+
|
117
|
+
def _validate_datum_target(target: Any, target_type: Literal["ic", "od", "auto"]) -> list[str]:
|
118
|
+
issues = []
|
119
|
+
target_type = _detect_target_type(target) if target_type == "auto" else target_type
|
120
|
+
if target_type == "ic":
|
121
|
+
issues.extend(_validate_datum_target_ic(target))
|
122
|
+
elif target_type == "od":
|
123
|
+
issues.extend(_validate_datum_target_od(target))
|
124
|
+
else:
|
125
|
+
issues.append(ValidationMessages.DATUM_TARGET_TYPE)
|
126
|
+
return issues
|
127
|
+
|
128
|
+
|
129
|
+
def _validate_datum_metadata(metadata: Any) -> list[str]:
|
130
|
+
issues = []
|
131
|
+
if metadata is None or not isinstance(metadata, dict):
|
132
|
+
issues.append(ValidationMessages.DATUM_METADATA_TYPE)
|
133
|
+
if metadata is None or isinstance(metadata, dict) and "id" not in metadata:
|
134
|
+
issues.append(ValidationMessages.DATUM_METADATA_FORMAT)
|
135
|
+
return issues
|
136
|
+
|
137
|
+
|
138
|
+
def validate_dataset(dataset: Any, dataset_type: Literal["ic", "od", "auto"] = "auto") -> None:
|
139
|
+
"""
|
140
|
+
Validate a dataset for compliance with MAITE protocol.
|
141
|
+
|
142
|
+
Parameters
|
143
|
+
----------
|
144
|
+
dataset: Any
|
145
|
+
Dataset to validate.
|
146
|
+
dataset_type: "ic", "od", or "auto", default "auto"
|
147
|
+
Dataset type, if known.
|
148
|
+
|
149
|
+
Raises
|
150
|
+
------
|
151
|
+
ValueError
|
152
|
+
Raises exception if dataset is invalid with a list of validation issues.
|
153
|
+
"""
|
154
|
+
issues = []
|
155
|
+
issues.extend(_validate_dataset_type(dataset))
|
156
|
+
datum = None if issues else dataset[0] # type: ignore
|
157
|
+
issues.extend(_validate_dataset_metadata(dataset))
|
158
|
+
issues.extend(_validate_datum_type(datum))
|
159
|
+
|
160
|
+
is_seq = isinstance(datum, Sequence)
|
161
|
+
datum_len = len(datum) if is_seq else 0
|
162
|
+
image = datum[0] if is_seq and datum_len > 0 else None
|
163
|
+
target = datum[1] if is_seq and datum_len > 1 else None
|
164
|
+
metadata = datum[2] if is_seq and datum_len > 2 else None
|
165
|
+
issues.extend(_validate_datum_image(image))
|
166
|
+
issues.extend(_validate_datum_target(target, dataset_type))
|
167
|
+
issues.extend(_validate_datum_metadata(metadata))
|
168
|
+
|
169
|
+
if issues:
|
170
|
+
raise ValueError("Dataset validation issues found:\n - " + "\n - ".join(issues))
|
dataeval/utils/data/collate.py
CHANGED
@@ -6,7 +6,8 @@ from __future__ import annotations
|
|
6
6
|
|
7
7
|
__all__ = ["list_collate_fn", "numpy_collate_fn", "torch_collate_fn"]
|
8
8
|
|
9
|
-
from
|
9
|
+
from collections.abc import Iterable, Sequence
|
10
|
+
from typing import Any, TypeVar
|
10
11
|
|
11
12
|
import numpy as np
|
12
13
|
import torch
|