dataeval 0.87.0__py3-none-any.whl → 0.88.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dataeval/_log.py +1 -1
- dataeval/_version.py +2 -2
- dataeval/data/_embeddings.py +78 -35
- dataeval/data/_images.py +41 -8
- dataeval/data/_metadata.py +294 -41
- dataeval/data/_selection.py +22 -7
- dataeval/data/_split.py +2 -1
- dataeval/data/selections/_classfilter.py +4 -3
- dataeval/data/selections/_indices.py +2 -1
- dataeval/data/selections/_shuffle.py +3 -2
- dataeval/detectors/drift/_base.py +2 -1
- dataeval/detectors/drift/_mmd.py +2 -1
- dataeval/detectors/drift/_nml/_base.py +1 -1
- dataeval/detectors/drift/_nml/_chunk.py +2 -1
- dataeval/detectors/drift/_nml/_result.py +3 -2
- dataeval/detectors/drift/_nml/_thresholds.py +6 -5
- dataeval/detectors/drift/_uncertainty.py +2 -1
- dataeval/detectors/linters/duplicates.py +2 -1
- dataeval/detectors/linters/outliers.py +4 -3
- dataeval/detectors/ood/ae.py +1 -1
- dataeval/detectors/ood/base.py +2 -1
- dataeval/detectors/ood/mixin.py +2 -1
- dataeval/metadata/_utils.py +1 -1
- dataeval/metrics/bias/_balance.py +1 -1
- dataeval/metrics/stats/_base.py +3 -29
- dataeval/metrics/stats/_boxratiostats.py +2 -1
- dataeval/metrics/stats/_dimensionstats.py +2 -1
- dataeval/metrics/stats/_hashstats.py +2 -1
- dataeval/metrics/stats/_pixelstats.py +2 -1
- dataeval/metrics/stats/_visualstats.py +2 -1
- dataeval/outputs/_base.py +2 -3
- dataeval/outputs/_bias.py +2 -1
- dataeval/outputs/_estimators.py +1 -1
- dataeval/outputs/_linters.py +3 -3
- dataeval/outputs/_stats.py +3 -3
- dataeval/outputs/_utils.py +1 -1
- dataeval/outputs/_workflows.py +29 -24
- dataeval/typing.py +11 -9
- dataeval/utils/_array.py +3 -2
- dataeval/utils/_bin.py +2 -1
- dataeval/utils/_method.py +2 -3
- dataeval/utils/_multiprocessing.py +34 -0
- dataeval/utils/_plot.py +2 -1
- dataeval/utils/data/__init__.py +4 -5
- dataeval/utils/data/{metadata.py → _merge.py} +3 -2
- dataeval/utils/data/_validate.py +2 -1
- dataeval/utils/data/collate.py +2 -1
- dataeval/utils/torch/_internal.py +2 -1
- dataeval/utils/torch/trainer.py +1 -1
- dataeval/workflows/sufficiency.py +13 -9
- {dataeval-0.87.0.dist-info → dataeval-0.88.0.dist-info}/METADATA +4 -5
- dataeval-0.88.0.dist-info/RECORD +105 -0
- dataeval/utils/data/_dataset.py +0 -253
- dataeval-0.87.0.dist-info/RECORD +0 -105
- {dataeval-0.87.0.dist-info → dataeval-0.88.0.dist-info}/WHEEL +0 -0
- {dataeval-0.87.0.dist-info → dataeval-0.88.0.dist-info}/licenses/LICENSE +0 -0
@@ -13,7 +13,8 @@ import copy
|
|
13
13
|
import logging
|
14
14
|
import warnings
|
15
15
|
from abc import ABC, abstractmethod
|
16
|
-
from
|
16
|
+
from collections.abc import Sequence
|
17
|
+
from typing import Any, Generic, Literal, TypeVar, cast
|
17
18
|
|
18
19
|
import pandas as pd
|
19
20
|
from pandas import Index, Period
|
@@ -11,7 +11,8 @@ from __future__ import annotations
|
|
11
11
|
|
12
12
|
import copy
|
13
13
|
from abc import ABC, abstractmethod
|
14
|
-
from
|
14
|
+
from collections.abc import Sequence
|
15
|
+
from typing import NamedTuple
|
15
16
|
|
16
17
|
import pandas as pd
|
17
18
|
from typing_extensions import Self
|
@@ -52,7 +53,7 @@ class AbstractResult(GenericOutput[pd.DataFrame]):
|
|
52
53
|
|
53
54
|
def filter(self, period: str = "all", metrics: str | Sequence[str] | None = None) -> Self:
|
54
55
|
"""Returns filtered result metric data."""
|
55
|
-
if metrics and not isinstance(metrics,
|
56
|
+
if metrics and not isinstance(metrics, str | Sequence):
|
56
57
|
raise ValueError("metrics value provided is not a valid metric or sequence of metrics")
|
57
58
|
if isinstance(metrics, str):
|
58
59
|
metrics = [metrics]
|
@@ -9,7 +9,8 @@ from __future__ import annotations
|
|
9
9
|
|
10
10
|
import logging
|
11
11
|
from abc import ABC, abstractmethod
|
12
|
-
from
|
12
|
+
from collections.abc import Callable
|
13
|
+
from typing import Any, ClassVar
|
13
14
|
|
14
15
|
import numpy as np
|
15
16
|
|
@@ -169,10 +170,10 @@ class ConstantThreshold(Threshold, threshold_type="constant"):
|
|
169
170
|
|
170
171
|
@staticmethod
|
171
172
|
def _validate_inputs(lower: float | int | None = None, upper: float | int | None = None) -> None:
|
172
|
-
if lower is not None and not isinstance(lower,
|
173
|
+
if lower is not None and not isinstance(lower, float | int) or isinstance(lower, bool):
|
173
174
|
raise ValueError(f"expected type of 'lower' to be 'float', 'int' or None but got '{type(lower).__name__}'")
|
174
175
|
|
175
|
-
if upper is not None and not isinstance(upper,
|
176
|
+
if upper is not None and not isinstance(upper, float | int) or isinstance(upper, bool):
|
176
177
|
raise ValueError(f"expected type of 'upper' to be 'float', 'int' or None but got '{type(upper).__name__}'")
|
177
178
|
|
178
179
|
# explicit None check is required due to special interpretation of the value 0.0 as False
|
@@ -244,7 +245,7 @@ class StandardDeviationThreshold(Threshold, threshold_type="standard_deviation")
|
|
244
245
|
) -> None:
|
245
246
|
if (
|
246
247
|
std_lower_multiplier is not None
|
247
|
-
and not isinstance(std_lower_multiplier,
|
248
|
+
and not isinstance(std_lower_multiplier, float | int)
|
248
249
|
or isinstance(std_lower_multiplier, bool)
|
249
250
|
):
|
250
251
|
raise ValueError(
|
@@ -257,7 +258,7 @@ class StandardDeviationThreshold(Threshold, threshold_type="standard_deviation")
|
|
257
258
|
|
258
259
|
if (
|
259
260
|
std_upper_multiplier is not None
|
260
|
-
and not isinstance(std_upper_multiplier,
|
261
|
+
and not isinstance(std_upper_multiplier, float | int)
|
261
262
|
or isinstance(std_upper_multiplier, bool)
|
262
263
|
):
|
263
264
|
raise ValueError(
|
@@ -2,7 +2,8 @@ from __future__ import annotations
|
|
2
2
|
|
3
3
|
__all__ = []
|
4
4
|
|
5
|
-
from
|
5
|
+
from collections.abc import Sequence
|
6
|
+
from typing import Any, Literal, overload
|
6
7
|
|
7
8
|
import numpy as np
|
8
9
|
from numpy.typing import NDArray
|
@@ -201,7 +202,7 @@ class Outliers:
|
|
201
202
|
>>> results.issues[1]
|
202
203
|
{}
|
203
204
|
"""
|
204
|
-
if isinstance(stats,
|
205
|
+
if isinstance(stats, ImageStatsOutput | DimensionStatsOutput | PixelStatsOutput | VisualStatsOutput):
|
205
206
|
return OutliersOutput(self._get_outliers(stats.data()))
|
206
207
|
|
207
208
|
if not isinstance(stats, Sequence):
|
@@ -212,7 +213,7 @@ class Outliers:
|
|
212
213
|
stats_map: dict[type, list[int]] = {}
|
213
214
|
for i, stats_output in enumerate(stats):
|
214
215
|
if not isinstance(
|
215
|
-
stats_output,
|
216
|
+
stats_output, ImageStatsOutput | DimensionStatsOutput | PixelStatsOutput | VisualStatsOutput
|
216
217
|
):
|
217
218
|
raise TypeError(
|
218
219
|
"Invalid stats output type; only use output from dimensionstats, pixelstats or visualstats."
|
dataeval/detectors/ood/ae.py
CHANGED
dataeval/detectors/ood/base.py
CHANGED
dataeval/detectors/ood/mixin.py
CHANGED
@@ -3,7 +3,8 @@ from __future__ import annotations
|
|
3
3
|
__all__ = []
|
4
4
|
|
5
5
|
from abc import ABC, abstractmethod
|
6
|
-
from
|
6
|
+
from collections.abc import Callable
|
7
|
+
from typing import Generic, Literal, TypeVar
|
7
8
|
|
8
9
|
import numpy as np
|
9
10
|
from numpy.typing import NDArray
|
dataeval/metadata/_utils.py
CHANGED
@@ -16,7 +16,7 @@ from dataeval.utils._bin import get_counts
|
|
16
16
|
|
17
17
|
|
18
18
|
def _validate_num_neighbors(num_neighbors: int) -> int:
|
19
|
-
if not isinstance(num_neighbors,
|
19
|
+
if not isinstance(num_neighbors, int | float):
|
20
20
|
raise TypeError(
|
21
21
|
f"Variable {num_neighbors} is not real-valued numeric type."
|
22
22
|
"num_neighbors should be an int, greater than 0 and less than"
|
dataeval/metrics/stats/_base.py
CHANGED
@@ -6,11 +6,11 @@ import math
|
|
6
6
|
import re
|
7
7
|
import warnings
|
8
8
|
from collections import ChainMap
|
9
|
+
from collections.abc import Callable, Iterable, Iterator, Sequence
|
9
10
|
from copy import deepcopy
|
10
11
|
from dataclasses import dataclass
|
11
12
|
from functools import partial
|
12
|
-
from
|
13
|
-
from typing import Any, Callable, Generic, Iterable, Iterator, Sequence, TypeVar
|
13
|
+
from typing import Any, Generic, TypeVar
|
14
14
|
|
15
15
|
import numpy as np
|
16
16
|
from numpy.typing import NDArray
|
@@ -21,14 +21,12 @@ from dataeval.outputs._stats import BASE_ATTRS, BaseStatsOutput, SourceIndex
|
|
21
21
|
from dataeval.typing import Array, ArrayLike, Dataset, ObjectDetectionTarget
|
22
22
|
from dataeval.utils._array import as_numpy, to_numpy
|
23
23
|
from dataeval.utils._image import clip_and_pad, clip_box, is_valid_box, normalize_image_shape, rescale
|
24
|
+
from dataeval.utils._multiprocessing import PoolWrapper
|
24
25
|
|
25
26
|
DTYPE_REGEX = re.compile(r"NDArray\[np\.(.*?)\]")
|
26
27
|
|
27
28
|
TStatsOutput = TypeVar("TStatsOutput", bound=BaseStatsOutput, covariant=True)
|
28
29
|
|
29
|
-
_S = TypeVar("_S")
|
30
|
-
_T = TypeVar("_T")
|
31
|
-
|
32
30
|
|
33
31
|
@dataclass
|
34
32
|
class BoundingBox:
|
@@ -67,30 +65,6 @@ class BoundingBox:
|
|
67
65
|
return x0_int, y0_int, x1_int, y1_int
|
68
66
|
|
69
67
|
|
70
|
-
class PoolWrapper:
|
71
|
-
"""
|
72
|
-
Wraps `multiprocessing.Pool` to allow for easy switching between
|
73
|
-
multiprocessing and single-threaded execution.
|
74
|
-
|
75
|
-
This helps with debugging and profiling, as well as usage with Jupyter notebooks
|
76
|
-
in VS Code, which does not support subprocess debugging.
|
77
|
-
"""
|
78
|
-
|
79
|
-
def __init__(self, processes: int | None) -> None:
|
80
|
-
self.pool = Pool(processes) if processes is None or processes > 1 else None
|
81
|
-
|
82
|
-
def imap(self, func: Callable[[_S], _T], iterable: Iterable[_S]) -> Iterator[_T]:
|
83
|
-
return map(func, iterable) if self.pool is None else self.pool.imap(func, iterable)
|
84
|
-
|
85
|
-
def __enter__(self, *args: Any, **kwargs: Any) -> PoolWrapper:
|
86
|
-
return self
|
87
|
-
|
88
|
-
def __exit__(self, *args: Any) -> None:
|
89
|
-
if self.pool is not None:
|
90
|
-
self.pool.close()
|
91
|
-
self.pool.join()
|
92
|
-
|
93
|
-
|
94
68
|
class StatsProcessor(Generic[TStatsOutput]):
|
95
69
|
output_class: type[TStatsOutput]
|
96
70
|
cache_keys: set[str] = set()
|
@@ -3,7 +3,8 @@ from __future__ import annotations
|
|
3
3
|
__all__ = []
|
4
4
|
|
5
5
|
import copy
|
6
|
-
from
|
6
|
+
from collections.abc import Callable
|
7
|
+
from typing import Any, Generic, TypeVar, cast
|
7
8
|
|
8
9
|
import numpy as np
|
9
10
|
from numpy.typing import NDArray
|
dataeval/outputs/_base.py
CHANGED
@@ -4,14 +4,13 @@ __all__ = []
|
|
4
4
|
|
5
5
|
import inspect
|
6
6
|
import logging
|
7
|
-
from collections.abc import Collection, Mapping, Sequence
|
7
|
+
from collections.abc import Callable, Collection, Iterator, Mapping, Sequence
|
8
8
|
from dataclasses import dataclass
|
9
9
|
from datetime import datetime, timezone
|
10
10
|
from functools import partial, wraps
|
11
|
-
from typing import Any,
|
11
|
+
from typing import Any, Generic, ParamSpec, TypeVar, overload
|
12
12
|
|
13
13
|
import numpy as np
|
14
|
-
from typing_extensions import ParamSpec
|
15
14
|
|
16
15
|
from dataeval import __version__
|
17
16
|
|
dataeval/outputs/_bias.py
CHANGED
@@ -3,8 +3,9 @@ from __future__ import annotations
|
|
3
3
|
__all__ = []
|
4
4
|
|
5
5
|
import contextlib
|
6
|
+
from collections.abc import Mapping, Sequence
|
6
7
|
from dataclasses import asdict, dataclass
|
7
|
-
from typing import Any,
|
8
|
+
from typing import Any, TypeVar
|
8
9
|
|
9
10
|
import numpy as np
|
10
11
|
import pandas as pd
|
dataeval/outputs/_estimators.py
CHANGED
dataeval/outputs/_linters.py
CHANGED
@@ -2,11 +2,11 @@ from __future__ import annotations
|
|
2
2
|
|
3
3
|
__all__ = []
|
4
4
|
|
5
|
+
from collections.abc import Mapping, Sequence
|
5
6
|
from dataclasses import dataclass
|
6
|
-
from typing import Generic,
|
7
|
+
from typing import Generic, TypeAlias, TypeVar
|
7
8
|
|
8
9
|
import pandas as pd
|
9
|
-
from typing_extensions import TypeAlias
|
10
10
|
|
11
11
|
from dataeval.outputs._base import Output
|
12
12
|
from dataeval.outputs._stats import DimensionStatsOutput, LabelStatsOutput, PixelStatsOutput, VisualStatsOutput
|
@@ -16,7 +16,7 @@ DatasetDuplicateGroupMap: TypeAlias = Mapping[int, DuplicateGroup]
|
|
16
16
|
TIndexCollection = TypeVar("TIndexCollection", DuplicateGroup, DatasetDuplicateGroupMap)
|
17
17
|
|
18
18
|
IndexIssueMap: TypeAlias = Mapping[int, Mapping[str, float]]
|
19
|
-
OutlierStatsOutput: TypeAlias =
|
19
|
+
OutlierStatsOutput: TypeAlias = DimensionStatsOutput | PixelStatsOutput | VisualStatsOutput
|
20
20
|
TIndexIssueMap = TypeVar("TIndexIssueMap", IndexIssueMap, Sequence[IndexIssueMap])
|
21
21
|
|
22
22
|
|
dataeval/outputs/_stats.py
CHANGED
@@ -2,13 +2,13 @@ from __future__ import annotations
|
|
2
2
|
|
3
3
|
__all__ = []
|
4
4
|
|
5
|
+
from collections.abc import Iterable, Mapping, Sequence
|
5
6
|
from dataclasses import dataclass
|
6
|
-
from typing import TYPE_CHECKING, Any,
|
7
|
+
from typing import TYPE_CHECKING, Any, NamedTuple, TypeAlias
|
7
8
|
|
8
9
|
import numpy as np
|
9
10
|
import polars as pl
|
10
11
|
from numpy.typing import NDArray
|
11
|
-
from typing_extensions import TypeAlias
|
12
12
|
|
13
13
|
from dataeval.outputs._base import Output
|
14
14
|
from dataeval.utils._plot import channel_histogram_plot, histogram_plot
|
@@ -16,7 +16,7 @@ from dataeval.utils._plot import channel_histogram_plot, histogram_plot
|
|
16
16
|
if TYPE_CHECKING:
|
17
17
|
from matplotlib.figure import Figure
|
18
18
|
|
19
|
-
OptionalRange: TypeAlias =
|
19
|
+
OptionalRange: TypeAlias = int | Iterable[int] | None
|
20
20
|
|
21
21
|
SOURCE_INDEX = "source_index"
|
22
22
|
OBJECT_COUNT = "object_count"
|
dataeval/outputs/_utils.py
CHANGED
dataeval/outputs/_workflows.py
CHANGED
@@ -4,8 +4,9 @@ __all__ = []
|
|
4
4
|
|
5
5
|
import contextlib
|
6
6
|
import warnings
|
7
|
-
from
|
8
|
-
from
|
7
|
+
from collections.abc import Iterable, Mapping, MutableMapping, Sequence
|
8
|
+
from dataclasses import dataclass, field
|
9
|
+
from typing import Any, cast
|
9
10
|
|
10
11
|
import numpy as np
|
11
12
|
from numpy.typing import NDArray
|
@@ -76,10 +77,8 @@ def plot_measure(
|
|
76
77
|
ax.set_title(f"{name} Sufficiency")
|
77
78
|
ax.set_ylabel(f"{name}")
|
78
79
|
ax.set_xlabel("Steps")
|
79
|
-
|
80
80
|
# Plot measure over each step
|
81
81
|
ax.scatter(steps, measure, label=f"Model Results ({name})", s=15, c="black")
|
82
|
-
|
83
82
|
# Plot extrapolation
|
84
83
|
ax.plot(
|
85
84
|
projection,
|
@@ -145,7 +144,7 @@ def inv_project_steps(params: NDArray[Any], targets: NDArray[Any]) -> NDArray[np
|
|
145
144
|
return np.ceil(steps)
|
146
145
|
|
147
146
|
|
148
|
-
def calc_params(p_i: NDArray[Any], n_i: NDArray[Any], niter: int) -> NDArray[
|
147
|
+
def calc_params(p_i: NDArray[Any], n_i: NDArray[Any], niter: int) -> NDArray[np.float64]:
|
149
148
|
"""
|
150
149
|
Retrieves the inverse power curve coefficients for the line of best fit.
|
151
150
|
Global minimization is done via basin hopping. More info on this algorithm
|
@@ -191,11 +190,11 @@ def calc_params(p_i: NDArray[Any], n_i: NDArray[Any], niter: int) -> NDArray[Any
|
|
191
190
|
|
192
191
|
|
193
192
|
def get_curve_params(
|
194
|
-
|
195
|
-
) -> Mapping[str, NDArray[
|
193
|
+
averaged_measures: MutableMapping[str, NDArray[Any]], ranges: NDArray[Any], niter: int
|
194
|
+
) -> Mapping[str, NDArray[np.float64]]:
|
196
195
|
"""Calculates and aggregates parameters for both single and multi-class metrics"""
|
197
196
|
output = {}
|
198
|
-
for name, measure in
|
197
|
+
for name, measure in averaged_measures.items():
|
199
198
|
measure = cast(np.ndarray, measure)
|
200
199
|
if measure.ndim > 1:
|
201
200
|
result = []
|
@@ -216,19 +215,25 @@ class SufficiencyOutput(Output):
|
|
216
215
|
----------
|
217
216
|
steps : NDArray
|
218
217
|
Array of sample sizes
|
219
|
-
measures :
|
220
|
-
|
218
|
+
measures : dict[str, NDArray]
|
219
|
+
3D array [runs, substep, classes] of values for all runs observed for each sample size step for each measure
|
220
|
+
averaged_measures : dict[str, NDArray]
|
221
|
+
Average of values for all runs observed for each sample size step for each measure
|
221
222
|
n_iter : int, default 1000
|
222
223
|
Number of iterations to perform in the basin-hopping curve-fit process
|
223
224
|
"""
|
224
225
|
|
225
226
|
steps: NDArray[np.uint32]
|
226
|
-
measures: Mapping[str, NDArray[
|
227
|
+
measures: Mapping[str, NDArray[Any]]
|
228
|
+
averaged_measures: MutableMapping[str, NDArray[Any]] = field(default_factory=lambda: {})
|
227
229
|
n_iter: int = 1000
|
228
230
|
|
229
231
|
def __post_init__(self) -> None:
|
232
|
+
if len(self.averaged_measures) == 0:
|
233
|
+
for metric, values in self.measures.items():
|
234
|
+
self.averaged_measures[metric] = np.asarray(np.mean(values, axis=0)).T
|
230
235
|
c = len(self.steps)
|
231
|
-
for m, v in self.
|
236
|
+
for m, v in self.averaged_measures.items():
|
232
237
|
c_v = v.shape[1] if v.ndim > 1 else len(v)
|
233
238
|
if c != c_v:
|
234
239
|
raise ValueError(f"{m} does not contain the expected number ({c}) of data points.")
|
@@ -239,7 +244,7 @@ class SufficiencyOutput(Output):
|
|
239
244
|
if self._params is None:
|
240
245
|
self._params = {}
|
241
246
|
if self.n_iter not in self._params:
|
242
|
-
self._params[self.n_iter] = get_curve_params(self.
|
247
|
+
self._params[self.n_iter] = get_curve_params(self.averaged_measures, self.steps, self.n_iter)
|
243
248
|
return self._params[self.n_iter]
|
244
249
|
|
245
250
|
@set_metadata
|
@@ -272,16 +277,16 @@ class SufficiencyOutput(Output):
|
|
272
277
|
raise ValueError("'projection' must consist of numerical values")
|
273
278
|
|
274
279
|
output = {}
|
275
|
-
for name,
|
276
|
-
if
|
280
|
+
for name, averaged_measures in self.averaged_measures.items():
|
281
|
+
if averaged_measures.ndim > 1:
|
277
282
|
result = []
|
278
|
-
for i in range(len(
|
283
|
+
for i in range(len(averaged_measures)):
|
279
284
|
projected = project_steps(self.params[name][i], projection)
|
280
285
|
result.append(projected)
|
281
286
|
output[name] = np.array(result)
|
282
287
|
else:
|
283
288
|
output[name] = project_steps(self.params[name], projection)
|
284
|
-
proj = SufficiencyOutput(projection, output, self.n_iter)
|
289
|
+
proj = SufficiencyOutput(projection, measures=self.measures, averaged_measures=output, n_iter=self.n_iter)
|
285
290
|
proj._params = self._params
|
286
291
|
return proj
|
287
292
|
|
@@ -317,11 +322,11 @@ class SufficiencyOutput(Output):
|
|
317
322
|
plots = []
|
318
323
|
|
319
324
|
# Create a plot for each measure on one figure
|
320
|
-
for name,
|
321
|
-
if
|
322
|
-
if class_names is not None and len(
|
325
|
+
for name, averaged_measures in self.averaged_measures.items():
|
326
|
+
if averaged_measures.ndim > 1:
|
327
|
+
if class_names is not None and len(averaged_measures) != len(class_names):
|
323
328
|
raise IndexError("Class name count does not align with measures")
|
324
|
-
for i, measure in enumerate(
|
329
|
+
for i, measure in enumerate(averaged_measures):
|
325
330
|
class_name = str(i) if class_names is None else class_names[i]
|
326
331
|
fig = plot_measure(
|
327
332
|
f"{name}_{class_name}",
|
@@ -333,7 +338,7 @@ class SufficiencyOutput(Output):
|
|
333
338
|
plots.append(fig)
|
334
339
|
|
335
340
|
else:
|
336
|
-
fig = plot_measure(name, self.steps,
|
341
|
+
fig = plot_measure(name, self.steps, averaged_measures, self.params[name], extrapolated)
|
337
342
|
plots.append(fig)
|
338
343
|
|
339
344
|
return plots
|
@@ -363,10 +368,10 @@ class SufficiencyOutput(Output):
|
|
363
368
|
|
364
369
|
for name, target in targets.items():
|
365
370
|
tarray = as_numpy(target)
|
366
|
-
if name not in self.
|
371
|
+
if name not in self.averaged_measures:
|
367
372
|
continue
|
368
373
|
|
369
|
-
measure = self.
|
374
|
+
measure = self.averaged_measures[name]
|
370
375
|
if measure.ndim > 1:
|
371
376
|
projection[name] = np.zeros((len(measure), len(tarray)))
|
372
377
|
for i in range(len(measure)):
|
dataeval/typing.py
CHANGED
@@ -21,19 +21,21 @@ __all__ = [
|
|
21
21
|
]
|
22
22
|
|
23
23
|
|
24
|
-
import
|
25
|
-
from typing import
|
24
|
+
from collections.abc import Iterator, Mapping
|
25
|
+
from typing import (
|
26
|
+
Any,
|
27
|
+
Generic,
|
28
|
+
Protocol,
|
29
|
+
TypeAlias,
|
30
|
+
TypedDict,
|
31
|
+
TypeVar,
|
32
|
+
runtime_checkable,
|
33
|
+
)
|
26
34
|
|
27
35
|
import numpy.typing
|
28
36
|
import torch
|
29
37
|
from typing_extensions import NotRequired, ReadOnly, Required
|
30
38
|
|
31
|
-
if sys.version_info >= (3, 10):
|
32
|
-
from typing import TypeAlias
|
33
|
-
else:
|
34
|
-
from typing_extensions import TypeAlias
|
35
|
-
|
36
|
-
|
37
39
|
ArrayLike: TypeAlias = numpy.typing.ArrayLike
|
38
40
|
"""
|
39
41
|
Type alias for a `Union` representing objects that can be coerced into an array.
|
@@ -44,7 +46,7 @@ See Also
|
|
44
46
|
"""
|
45
47
|
|
46
48
|
|
47
|
-
DeviceLike: TypeAlias =
|
49
|
+
DeviceLike: TypeAlias = int | str | tuple[str, int] | torch.device
|
48
50
|
"""
|
49
51
|
Type alias for a `Union` representing types that specify a torch.device.
|
50
52
|
|
dataeval/utils/_array.py
CHANGED
@@ -4,9 +4,10 @@ __all__ = []
|
|
4
4
|
|
5
5
|
import logging
|
6
6
|
import warnings
|
7
|
+
from collections.abc import Iterable, Iterator
|
7
8
|
from importlib import import_module
|
8
9
|
from types import ModuleType
|
9
|
-
from typing import Any,
|
10
|
+
from typing import Any, Literal, TypeVar, overload
|
10
11
|
|
11
12
|
import numpy as np
|
12
13
|
import torch
|
@@ -79,7 +80,7 @@ def rescale_array(array: NDArray[_np_dtype]) -> NDArray[_np_dtype]: ...
|
|
79
80
|
def rescale_array(array: torch.Tensor) -> torch.Tensor: ...
|
80
81
|
def rescale_array(array: Array | NDArray[_np_dtype] | torch.Tensor) -> Array | NDArray[_np_dtype] | torch.Tensor:
|
81
82
|
"""Rescale an array to the range [0, 1]"""
|
82
|
-
if isinstance(array,
|
83
|
+
if isinstance(array, np.ndarray | torch.Tensor):
|
83
84
|
arr_min = array.min()
|
84
85
|
arr_max = array.max()
|
85
86
|
return (array - arr_min) / (arr_max - arr_min)
|
dataeval/utils/_bin.py
CHANGED
dataeval/utils/_method.py
CHANGED
@@ -0,0 +1,34 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
__all__ = []
|
4
|
+
|
5
|
+
from collections.abc import Callable, Iterable, Iterator
|
6
|
+
from multiprocessing import Pool
|
7
|
+
from typing import Any, TypeVar
|
8
|
+
|
9
|
+
_S = TypeVar("_S")
|
10
|
+
_T = TypeVar("_T")
|
11
|
+
|
12
|
+
|
13
|
+
class PoolWrapper:
|
14
|
+
"""
|
15
|
+
Wraps `multiprocessing.Pool` to allow for easy switching between
|
16
|
+
multiprocessing and single-threaded execution.
|
17
|
+
|
18
|
+
This helps with debugging and profiling, as well as usage with Jupyter notebooks
|
19
|
+
in VS Code, which does not support subprocess debugging.
|
20
|
+
"""
|
21
|
+
|
22
|
+
def __init__(self, processes: int | None) -> None:
|
23
|
+
self.pool = Pool(processes) if processes is None or processes > 1 else None
|
24
|
+
|
25
|
+
def imap(self, func: Callable[[_S], _T], iterable: Iterable[_S]) -> Iterator[_T]:
|
26
|
+
return map(func, iterable) if self.pool is None else self.pool.imap(func, iterable)
|
27
|
+
|
28
|
+
def __enter__(self, *args: Any, **kwargs: Any) -> PoolWrapper:
|
29
|
+
return self
|
30
|
+
|
31
|
+
def __exit__(self, *args: Any) -> None:
|
32
|
+
if self.pool is not None:
|
33
|
+
self.pool.close()
|
34
|
+
self.pool.join()
|