dataeval 0.87.0__py3-none-any.whl → 0.88.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dataeval/_log.py +1 -1
- dataeval/_version.py +2 -2
- dataeval/data/_embeddings.py +78 -35
- dataeval/data/_images.py +41 -8
- dataeval/data/_metadata.py +294 -41
- dataeval/data/_selection.py +22 -7
- dataeval/data/_split.py +2 -1
- dataeval/data/selections/_classfilter.py +4 -3
- dataeval/data/selections/_indices.py +2 -1
- dataeval/data/selections/_shuffle.py +3 -2
- dataeval/detectors/drift/_base.py +2 -1
- dataeval/detectors/drift/_mmd.py +2 -1
- dataeval/detectors/drift/_nml/_base.py +1 -1
- dataeval/detectors/drift/_nml/_chunk.py +2 -1
- dataeval/detectors/drift/_nml/_result.py +3 -2
- dataeval/detectors/drift/_nml/_thresholds.py +6 -5
- dataeval/detectors/drift/_uncertainty.py +2 -1
- dataeval/detectors/linters/duplicates.py +2 -1
- dataeval/detectors/linters/outliers.py +4 -3
- dataeval/detectors/ood/ae.py +1 -1
- dataeval/detectors/ood/base.py +2 -1
- dataeval/detectors/ood/mixin.py +2 -1
- dataeval/metadata/_utils.py +1 -1
- dataeval/metrics/bias/_balance.py +1 -1
- dataeval/metrics/stats/_base.py +3 -29
- dataeval/metrics/stats/_boxratiostats.py +2 -1
- dataeval/metrics/stats/_dimensionstats.py +2 -1
- dataeval/metrics/stats/_hashstats.py +2 -1
- dataeval/metrics/stats/_pixelstats.py +2 -1
- dataeval/metrics/stats/_visualstats.py +2 -1
- dataeval/outputs/_base.py +2 -3
- dataeval/outputs/_bias.py +2 -1
- dataeval/outputs/_estimators.py +1 -1
- dataeval/outputs/_linters.py +3 -3
- dataeval/outputs/_stats.py +3 -3
- dataeval/outputs/_utils.py +1 -1
- dataeval/outputs/_workflows.py +85 -30
- dataeval/typing.py +11 -9
- dataeval/utils/_array.py +3 -2
- dataeval/utils/_bin.py +2 -1
- dataeval/utils/_method.py +2 -3
- dataeval/utils/_multiprocessing.py +34 -0
- dataeval/utils/_plot.py +2 -1
- dataeval/utils/data/__init__.py +4 -5
- dataeval/utils/data/{metadata.py → _merge.py} +3 -2
- dataeval/utils/data/_validate.py +2 -1
- dataeval/utils/data/collate.py +2 -1
- dataeval/utils/torch/_internal.py +2 -1
- dataeval/utils/torch/trainer.py +1 -1
- dataeval/workflows/sufficiency.py +12 -9
- {dataeval-0.87.0.dist-info → dataeval-0.88.1.dist-info}/METADATA +4 -5
- dataeval-0.88.1.dist-info/RECORD +105 -0
- dataeval/utils/data/_dataset.py +0 -253
- dataeval-0.87.0.dist-info/RECORD +0 -105
- {dataeval-0.87.0.dist-info → dataeval-0.88.1.dist-info}/WHEEL +0 -0
- {dataeval-0.87.0.dist-info → dataeval-0.88.1.dist-info}/licenses/LICENSE +0 -0
@@ -13,7 +13,8 @@ import copy
|
|
13
13
|
import logging
|
14
14
|
import warnings
|
15
15
|
from abc import ABC, abstractmethod
|
16
|
-
from
|
16
|
+
from collections.abc import Sequence
|
17
|
+
from typing import Any, Generic, Literal, TypeVar, cast
|
17
18
|
|
18
19
|
import pandas as pd
|
19
20
|
from pandas import Index, Period
|
@@ -11,7 +11,8 @@ from __future__ import annotations
|
|
11
11
|
|
12
12
|
import copy
|
13
13
|
from abc import ABC, abstractmethod
|
14
|
-
from
|
14
|
+
from collections.abc import Sequence
|
15
|
+
from typing import NamedTuple
|
15
16
|
|
16
17
|
import pandas as pd
|
17
18
|
from typing_extensions import Self
|
@@ -52,7 +53,7 @@ class AbstractResult(GenericOutput[pd.DataFrame]):
|
|
52
53
|
|
53
54
|
def filter(self, period: str = "all", metrics: str | Sequence[str] | None = None) -> Self:
|
54
55
|
"""Returns filtered result metric data."""
|
55
|
-
if metrics and not isinstance(metrics,
|
56
|
+
if metrics and not isinstance(metrics, str | Sequence):
|
56
57
|
raise ValueError("metrics value provided is not a valid metric or sequence of metrics")
|
57
58
|
if isinstance(metrics, str):
|
58
59
|
metrics = [metrics]
|
@@ -9,7 +9,8 @@ from __future__ import annotations
|
|
9
9
|
|
10
10
|
import logging
|
11
11
|
from abc import ABC, abstractmethod
|
12
|
-
from
|
12
|
+
from collections.abc import Callable
|
13
|
+
from typing import Any, ClassVar
|
13
14
|
|
14
15
|
import numpy as np
|
15
16
|
|
@@ -169,10 +170,10 @@ class ConstantThreshold(Threshold, threshold_type="constant"):
|
|
169
170
|
|
170
171
|
@staticmethod
|
171
172
|
def _validate_inputs(lower: float | int | None = None, upper: float | int | None = None) -> None:
|
172
|
-
if lower is not None and not isinstance(lower,
|
173
|
+
if lower is not None and not isinstance(lower, float | int) or isinstance(lower, bool):
|
173
174
|
raise ValueError(f"expected type of 'lower' to be 'float', 'int' or None but got '{type(lower).__name__}'")
|
174
175
|
|
175
|
-
if upper is not None and not isinstance(upper,
|
176
|
+
if upper is not None and not isinstance(upper, float | int) or isinstance(upper, bool):
|
176
177
|
raise ValueError(f"expected type of 'upper' to be 'float', 'int' or None but got '{type(upper).__name__}'")
|
177
178
|
|
178
179
|
# explicit None check is required due to special interpretation of the value 0.0 as False
|
@@ -244,7 +245,7 @@ class StandardDeviationThreshold(Threshold, threshold_type="standard_deviation")
|
|
244
245
|
) -> None:
|
245
246
|
if (
|
246
247
|
std_lower_multiplier is not None
|
247
|
-
and not isinstance(std_lower_multiplier,
|
248
|
+
and not isinstance(std_lower_multiplier, float | int)
|
248
249
|
or isinstance(std_lower_multiplier, bool)
|
249
250
|
):
|
250
251
|
raise ValueError(
|
@@ -257,7 +258,7 @@ class StandardDeviationThreshold(Threshold, threshold_type="standard_deviation")
|
|
257
258
|
|
258
259
|
if (
|
259
260
|
std_upper_multiplier is not None
|
260
|
-
and not isinstance(std_upper_multiplier,
|
261
|
+
and not isinstance(std_upper_multiplier, float | int)
|
261
262
|
or isinstance(std_upper_multiplier, bool)
|
262
263
|
):
|
263
264
|
raise ValueError(
|
@@ -2,7 +2,8 @@ from __future__ import annotations
|
|
2
2
|
|
3
3
|
__all__ = []
|
4
4
|
|
5
|
-
from
|
5
|
+
from collections.abc import Sequence
|
6
|
+
from typing import Any, Literal, overload
|
6
7
|
|
7
8
|
import numpy as np
|
8
9
|
from numpy.typing import NDArray
|
@@ -201,7 +202,7 @@ class Outliers:
|
|
201
202
|
>>> results.issues[1]
|
202
203
|
{}
|
203
204
|
"""
|
204
|
-
if isinstance(stats,
|
205
|
+
if isinstance(stats, ImageStatsOutput | DimensionStatsOutput | PixelStatsOutput | VisualStatsOutput):
|
205
206
|
return OutliersOutput(self._get_outliers(stats.data()))
|
206
207
|
|
207
208
|
if not isinstance(stats, Sequence):
|
@@ -212,7 +213,7 @@ class Outliers:
|
|
212
213
|
stats_map: dict[type, list[int]] = {}
|
213
214
|
for i, stats_output in enumerate(stats):
|
214
215
|
if not isinstance(
|
215
|
-
stats_output,
|
216
|
+
stats_output, ImageStatsOutput | DimensionStatsOutput | PixelStatsOutput | VisualStatsOutput
|
216
217
|
):
|
217
218
|
raise TypeError(
|
218
219
|
"Invalid stats output type; only use output from dimensionstats, pixelstats or visualstats."
|
dataeval/detectors/ood/ae.py
CHANGED
dataeval/detectors/ood/base.py
CHANGED
dataeval/detectors/ood/mixin.py
CHANGED
@@ -3,7 +3,8 @@ from __future__ import annotations
|
|
3
3
|
__all__ = []
|
4
4
|
|
5
5
|
from abc import ABC, abstractmethod
|
6
|
-
from
|
6
|
+
from collections.abc import Callable
|
7
|
+
from typing import Generic, Literal, TypeVar
|
7
8
|
|
8
9
|
import numpy as np
|
9
10
|
from numpy.typing import NDArray
|
dataeval/metadata/_utils.py
CHANGED
@@ -16,7 +16,7 @@ from dataeval.utils._bin import get_counts
|
|
16
16
|
|
17
17
|
|
18
18
|
def _validate_num_neighbors(num_neighbors: int) -> int:
|
19
|
-
if not isinstance(num_neighbors,
|
19
|
+
if not isinstance(num_neighbors, int | float):
|
20
20
|
raise TypeError(
|
21
21
|
f"Variable {num_neighbors} is not real-valued numeric type."
|
22
22
|
"num_neighbors should be an int, greater than 0 and less than"
|
dataeval/metrics/stats/_base.py
CHANGED
@@ -6,11 +6,11 @@ import math
|
|
6
6
|
import re
|
7
7
|
import warnings
|
8
8
|
from collections import ChainMap
|
9
|
+
from collections.abc import Callable, Iterable, Iterator, Sequence
|
9
10
|
from copy import deepcopy
|
10
11
|
from dataclasses import dataclass
|
11
12
|
from functools import partial
|
12
|
-
from
|
13
|
-
from typing import Any, Callable, Generic, Iterable, Iterator, Sequence, TypeVar
|
13
|
+
from typing import Any, Generic, TypeVar
|
14
14
|
|
15
15
|
import numpy as np
|
16
16
|
from numpy.typing import NDArray
|
@@ -21,14 +21,12 @@ from dataeval.outputs._stats import BASE_ATTRS, BaseStatsOutput, SourceIndex
|
|
21
21
|
from dataeval.typing import Array, ArrayLike, Dataset, ObjectDetectionTarget
|
22
22
|
from dataeval.utils._array import as_numpy, to_numpy
|
23
23
|
from dataeval.utils._image import clip_and_pad, clip_box, is_valid_box, normalize_image_shape, rescale
|
24
|
+
from dataeval.utils._multiprocessing import PoolWrapper
|
24
25
|
|
25
26
|
DTYPE_REGEX = re.compile(r"NDArray\[np\.(.*?)\]")
|
26
27
|
|
27
28
|
TStatsOutput = TypeVar("TStatsOutput", bound=BaseStatsOutput, covariant=True)
|
28
29
|
|
29
|
-
_S = TypeVar("_S")
|
30
|
-
_T = TypeVar("_T")
|
31
|
-
|
32
30
|
|
33
31
|
@dataclass
|
34
32
|
class BoundingBox:
|
@@ -67,30 +65,6 @@ class BoundingBox:
|
|
67
65
|
return x0_int, y0_int, x1_int, y1_int
|
68
66
|
|
69
67
|
|
70
|
-
class PoolWrapper:
|
71
|
-
"""
|
72
|
-
Wraps `multiprocessing.Pool` to allow for easy switching between
|
73
|
-
multiprocessing and single-threaded execution.
|
74
|
-
|
75
|
-
This helps with debugging and profiling, as well as usage with Jupyter notebooks
|
76
|
-
in VS Code, which does not support subprocess debugging.
|
77
|
-
"""
|
78
|
-
|
79
|
-
def __init__(self, processes: int | None) -> None:
|
80
|
-
self.pool = Pool(processes) if processes is None or processes > 1 else None
|
81
|
-
|
82
|
-
def imap(self, func: Callable[[_S], _T], iterable: Iterable[_S]) -> Iterator[_T]:
|
83
|
-
return map(func, iterable) if self.pool is None else self.pool.imap(func, iterable)
|
84
|
-
|
85
|
-
def __enter__(self, *args: Any, **kwargs: Any) -> PoolWrapper:
|
86
|
-
return self
|
87
|
-
|
88
|
-
def __exit__(self, *args: Any) -> None:
|
89
|
-
if self.pool is not None:
|
90
|
-
self.pool.close()
|
91
|
-
self.pool.join()
|
92
|
-
|
93
|
-
|
94
68
|
class StatsProcessor(Generic[TStatsOutput]):
|
95
69
|
output_class: type[TStatsOutput]
|
96
70
|
cache_keys: set[str] = set()
|
@@ -3,7 +3,8 @@ from __future__ import annotations
|
|
3
3
|
__all__ = []
|
4
4
|
|
5
5
|
import copy
|
6
|
-
from
|
6
|
+
from collections.abc import Callable
|
7
|
+
from typing import Any, Generic, TypeVar, cast
|
7
8
|
|
8
9
|
import numpy as np
|
9
10
|
from numpy.typing import NDArray
|
dataeval/outputs/_base.py
CHANGED
@@ -4,14 +4,13 @@ __all__ = []
|
|
4
4
|
|
5
5
|
import inspect
|
6
6
|
import logging
|
7
|
-
from collections.abc import Collection, Mapping, Sequence
|
7
|
+
from collections.abc import Callable, Collection, Iterator, Mapping, Sequence
|
8
8
|
from dataclasses import dataclass
|
9
9
|
from datetime import datetime, timezone
|
10
10
|
from functools import partial, wraps
|
11
|
-
from typing import Any,
|
11
|
+
from typing import Any, Generic, ParamSpec, TypeVar, overload
|
12
12
|
|
13
13
|
import numpy as np
|
14
|
-
from typing_extensions import ParamSpec
|
15
14
|
|
16
15
|
from dataeval import __version__
|
17
16
|
|
dataeval/outputs/_bias.py
CHANGED
@@ -3,8 +3,9 @@ from __future__ import annotations
|
|
3
3
|
__all__ = []
|
4
4
|
|
5
5
|
import contextlib
|
6
|
+
from collections.abc import Mapping, Sequence
|
6
7
|
from dataclasses import asdict, dataclass
|
7
|
-
from typing import Any,
|
8
|
+
from typing import Any, TypeVar
|
8
9
|
|
9
10
|
import numpy as np
|
10
11
|
import pandas as pd
|
dataeval/outputs/_estimators.py
CHANGED
dataeval/outputs/_linters.py
CHANGED
@@ -2,11 +2,11 @@ from __future__ import annotations
|
|
2
2
|
|
3
3
|
__all__ = []
|
4
4
|
|
5
|
+
from collections.abc import Mapping, Sequence
|
5
6
|
from dataclasses import dataclass
|
6
|
-
from typing import Generic,
|
7
|
+
from typing import Generic, TypeAlias, TypeVar
|
7
8
|
|
8
9
|
import pandas as pd
|
9
|
-
from typing_extensions import TypeAlias
|
10
10
|
|
11
11
|
from dataeval.outputs._base import Output
|
12
12
|
from dataeval.outputs._stats import DimensionStatsOutput, LabelStatsOutput, PixelStatsOutput, VisualStatsOutput
|
@@ -16,7 +16,7 @@ DatasetDuplicateGroupMap: TypeAlias = Mapping[int, DuplicateGroup]
|
|
16
16
|
TIndexCollection = TypeVar("TIndexCollection", DuplicateGroup, DatasetDuplicateGroupMap)
|
17
17
|
|
18
18
|
IndexIssueMap: TypeAlias = Mapping[int, Mapping[str, float]]
|
19
|
-
OutlierStatsOutput: TypeAlias =
|
19
|
+
OutlierStatsOutput: TypeAlias = DimensionStatsOutput | PixelStatsOutput | VisualStatsOutput
|
20
20
|
TIndexIssueMap = TypeVar("TIndexIssueMap", IndexIssueMap, Sequence[IndexIssueMap])
|
21
21
|
|
22
22
|
|
dataeval/outputs/_stats.py
CHANGED
@@ -2,13 +2,13 @@ from __future__ import annotations
|
|
2
2
|
|
3
3
|
__all__ = []
|
4
4
|
|
5
|
+
from collections.abc import Iterable, Mapping, Sequence
|
5
6
|
from dataclasses import dataclass
|
6
|
-
from typing import TYPE_CHECKING, Any,
|
7
|
+
from typing import TYPE_CHECKING, Any, NamedTuple, TypeAlias
|
7
8
|
|
8
9
|
import numpy as np
|
9
10
|
import polars as pl
|
10
11
|
from numpy.typing import NDArray
|
11
|
-
from typing_extensions import TypeAlias
|
12
12
|
|
13
13
|
from dataeval.outputs._base import Output
|
14
14
|
from dataeval.utils._plot import channel_histogram_plot, histogram_plot
|
@@ -16,7 +16,7 @@ from dataeval.utils._plot import channel_histogram_plot, histogram_plot
|
|
16
16
|
if TYPE_CHECKING:
|
17
17
|
from matplotlib.figure import Figure
|
18
18
|
|
19
|
-
OptionalRange: TypeAlias =
|
19
|
+
OptionalRange: TypeAlias = int | Iterable[int] | None
|
20
20
|
|
21
21
|
SOURCE_INDEX = "source_index"
|
22
22
|
OBJECT_COUNT = "object_count"
|
dataeval/outputs/_utils.py
CHANGED
dataeval/outputs/_workflows.py
CHANGED
@@ -4,8 +4,9 @@ __all__ = []
|
|
4
4
|
|
5
5
|
import contextlib
|
6
6
|
import warnings
|
7
|
-
from
|
8
|
-
from
|
7
|
+
from collections.abc import Iterable, Mapping, MutableMapping, Sequence
|
8
|
+
from dataclasses import dataclass, field
|
9
|
+
from typing import Any, cast
|
9
10
|
|
10
11
|
import numpy as np
|
11
12
|
from numpy.typing import NDArray
|
@@ -61,9 +62,12 @@ def project_steps(params: NDArray[Any], projection: NDArray[Any]) -> NDArray[Any
|
|
61
62
|
def plot_measure(
|
62
63
|
name: str,
|
63
64
|
steps: NDArray[Any],
|
64
|
-
|
65
|
+
averaged_measure: NDArray[Any],
|
66
|
+
measures: NDArray[Any] | None,
|
65
67
|
params: NDArray[Any],
|
66
68
|
projection: NDArray[Any],
|
69
|
+
error_bars: bool,
|
70
|
+
asymptote: bool,
|
67
71
|
) -> Figure:
|
68
72
|
import matplotlib.pyplot
|
69
73
|
|
@@ -72,23 +76,51 @@ def plot_measure(
|
|
72
76
|
fig.tight_layout()
|
73
77
|
|
74
78
|
ax = fig.add_subplot(111)
|
75
|
-
|
76
79
|
ax.set_title(f"{name} Sufficiency")
|
77
80
|
ax.set_ylabel(f"{name}")
|
78
81
|
ax.set_xlabel("Steps")
|
79
|
-
|
80
|
-
|
81
|
-
|
82
|
-
|
82
|
+
# Plot asymptote
|
83
|
+
if asymptote:
|
84
|
+
bound = 1 - params[2]
|
85
|
+
ax.axhline(y=bound, color="r", label=f"Asymptote: {bound:.4g}", zorder=1)
|
86
|
+
# Calculate error bars
|
87
|
+
# Plot measure over each step with associated error
|
88
|
+
if error_bars:
|
89
|
+
if measures is None:
|
90
|
+
warnings.warn(
|
91
|
+
"Error bars cannot be plotted without full, unaveraged data",
|
92
|
+
UserWarning,
|
93
|
+
)
|
94
|
+
else:
|
95
|
+
error = np.std(measures, axis=0)
|
96
|
+
ax.errorbar(
|
97
|
+
steps,
|
98
|
+
averaged_measure,
|
99
|
+
yerr=error,
|
100
|
+
capsize=7,
|
101
|
+
capthick=1.5,
|
102
|
+
elinewidth=1.5,
|
103
|
+
fmt="o",
|
104
|
+
label=f"Model Results ({name})",
|
105
|
+
markersize=5,
|
106
|
+
color="black",
|
107
|
+
ecolor="orange",
|
108
|
+
zorder=3,
|
109
|
+
)
|
110
|
+
else:
|
111
|
+
ax.scatter(steps, averaged_measure, label=f"Model Results ({name})", zorder=3, c="black")
|
83
112
|
# Plot extrapolation
|
84
113
|
ax.plot(
|
85
114
|
projection,
|
86
115
|
project_steps(params, projection),
|
87
116
|
linestyle="dashed",
|
88
117
|
label=f"Potential Model Results ({name})",
|
118
|
+
linewidth=2,
|
119
|
+
zorder=2,
|
89
120
|
)
|
121
|
+
ax.set_xscale("log")
|
90
122
|
|
91
|
-
ax.legend()
|
123
|
+
ax.legend(loc="best")
|
92
124
|
return fig
|
93
125
|
|
94
126
|
|
@@ -145,7 +177,7 @@ def inv_project_steps(params: NDArray[Any], targets: NDArray[Any]) -> NDArray[np
|
|
145
177
|
return np.ceil(steps)
|
146
178
|
|
147
179
|
|
148
|
-
def calc_params(p_i: NDArray[Any], n_i: NDArray[Any], niter: int) -> NDArray[
|
180
|
+
def calc_params(p_i: NDArray[Any], n_i: NDArray[Any], niter: int) -> NDArray[np.float64]:
|
149
181
|
"""
|
150
182
|
Retrieves the inverse power curve coefficients for the line of best fit.
|
151
183
|
Global minimization is done via basin hopping. More info on this algorithm
|
@@ -191,11 +223,11 @@ def calc_params(p_i: NDArray[Any], n_i: NDArray[Any], niter: int) -> NDArray[Any
|
|
191
223
|
|
192
224
|
|
193
225
|
def get_curve_params(
|
194
|
-
|
195
|
-
) -> Mapping[str, NDArray[
|
226
|
+
averaged_measures: MutableMapping[str, NDArray[Any]], ranges: NDArray[Any], niter: int
|
227
|
+
) -> Mapping[str, NDArray[np.float64]]:
|
196
228
|
"""Calculates and aggregates parameters for both single and multi-class metrics"""
|
197
229
|
output = {}
|
198
|
-
for name, measure in
|
230
|
+
for name, measure in averaged_measures.items():
|
199
231
|
measure = cast(np.ndarray, measure)
|
200
232
|
if measure.ndim > 1:
|
201
233
|
result = []
|
@@ -216,19 +248,25 @@ class SufficiencyOutput(Output):
|
|
216
248
|
----------
|
217
249
|
steps : NDArray
|
218
250
|
Array of sample sizes
|
219
|
-
measures :
|
220
|
-
|
251
|
+
measures : dict[str, NDArray]
|
252
|
+
3D array [runs, substep, classes] of values for all runs observed for each sample size step for each measure
|
253
|
+
averaged_measures : dict[str, NDArray]
|
254
|
+
Average of values for all runs observed for each sample size step for each measure
|
221
255
|
n_iter : int, default 1000
|
222
256
|
Number of iterations to perform in the basin-hopping curve-fit process
|
223
257
|
"""
|
224
258
|
|
225
259
|
steps: NDArray[np.uint32]
|
226
|
-
measures: Mapping[str, NDArray[
|
260
|
+
measures: Mapping[str, NDArray[Any]]
|
261
|
+
averaged_measures: MutableMapping[str, NDArray[Any]] = field(default_factory=lambda: {})
|
227
262
|
n_iter: int = 1000
|
228
263
|
|
229
264
|
def __post_init__(self) -> None:
|
265
|
+
if len(self.averaged_measures) == 0:
|
266
|
+
for metric, values in self.measures.items():
|
267
|
+
self.averaged_measures[metric] = np.asarray(np.mean(values, axis=0)).T
|
230
268
|
c = len(self.steps)
|
231
|
-
for m, v in self.
|
269
|
+
for m, v in self.averaged_measures.items():
|
232
270
|
c_v = v.shape[1] if v.ndim > 1 else len(v)
|
233
271
|
if c != c_v:
|
234
272
|
raise ValueError(f"{m} does not contain the expected number ({c}) of data points.")
|
@@ -239,7 +277,7 @@ class SufficiencyOutput(Output):
|
|
239
277
|
if self._params is None:
|
240
278
|
self._params = {}
|
241
279
|
if self.n_iter not in self._params:
|
242
|
-
self._params[self.n_iter] = get_curve_params(self.
|
280
|
+
self._params[self.n_iter] = get_curve_params(self.averaged_measures, self.steps, self.n_iter)
|
243
281
|
return self._params[self.n_iter]
|
244
282
|
|
245
283
|
@set_metadata
|
@@ -272,20 +310,22 @@ class SufficiencyOutput(Output):
|
|
272
310
|
raise ValueError("'projection' must consist of numerical values")
|
273
311
|
|
274
312
|
output = {}
|
275
|
-
for name,
|
276
|
-
if
|
313
|
+
for name, averaged_measures in self.averaged_measures.items():
|
314
|
+
if averaged_measures.ndim > 1:
|
277
315
|
result = []
|
278
|
-
for i in range(len(
|
316
|
+
for i in range(len(averaged_measures)):
|
279
317
|
projected = project_steps(self.params[name][i], projection)
|
280
318
|
result.append(projected)
|
281
319
|
output[name] = np.array(result)
|
282
320
|
else:
|
283
321
|
output[name] = project_steps(self.params[name], projection)
|
284
|
-
proj = SufficiencyOutput(projection, output, self.n_iter)
|
322
|
+
proj = SufficiencyOutput(projection, {}, output, self.n_iter)
|
285
323
|
proj._params = self._params
|
286
324
|
return proj
|
287
325
|
|
288
|
-
def plot(
|
326
|
+
def plot(
|
327
|
+
self, class_names: Sequence[str] | None = None, error_bars: bool = False, asymptote: bool = False
|
328
|
+
) -> Sequence[Figure]:
|
289
329
|
"""
|
290
330
|
Plotting function for data :term:`sufficience<Sufficiency>` tasks.
|
291
331
|
|
@@ -293,6 +333,10 @@ class SufficiencyOutput(Output):
|
|
293
333
|
----------
|
294
334
|
class_names : Sequence[str] | None, default None
|
295
335
|
List of class names
|
336
|
+
error_bars : bool, default False
|
337
|
+
True if error bars should be plotted, False if not
|
338
|
+
asymptote : bool, default False
|
339
|
+
True if asymptote should be plotted, False if not
|
296
340
|
|
297
341
|
Returns
|
298
342
|
-------
|
@@ -315,25 +359,36 @@ class SufficiencyOutput(Output):
|
|
315
359
|
|
316
360
|
# Stores all plots
|
317
361
|
plots = []
|
318
|
-
|
319
362
|
# Create a plot for each measure on one figure
|
320
|
-
for name, measures in self.
|
363
|
+
for name, measures in self.averaged_measures.items():
|
321
364
|
if measures.ndim > 1:
|
322
365
|
if class_names is not None and len(measures) != len(class_names):
|
323
366
|
raise IndexError("Class name count does not align with measures")
|
324
|
-
for i,
|
367
|
+
for i, values in enumerate(measures):
|
325
368
|
class_name = str(i) if class_names is None else class_names[i]
|
326
369
|
fig = plot_measure(
|
327
370
|
f"{name}_{class_name}",
|
328
371
|
self.steps,
|
329
|
-
|
372
|
+
values,
|
373
|
+
self.measures[name][:, :, i] if len(self.measures) else None,
|
330
374
|
self.params[name][i],
|
331
375
|
extrapolated,
|
376
|
+
error_bars,
|
377
|
+
asymptote,
|
332
378
|
)
|
333
379
|
plots.append(fig)
|
334
380
|
|
335
381
|
else:
|
336
|
-
fig = plot_measure(
|
382
|
+
fig = plot_measure(
|
383
|
+
name,
|
384
|
+
self.steps,
|
385
|
+
measures,
|
386
|
+
self.measures.get(name),
|
387
|
+
self.params[name],
|
388
|
+
extrapolated,
|
389
|
+
error_bars,
|
390
|
+
asymptote,
|
391
|
+
)
|
337
392
|
plots.append(fig)
|
338
393
|
|
339
394
|
return plots
|
@@ -363,10 +418,10 @@ class SufficiencyOutput(Output):
|
|
363
418
|
|
364
419
|
for name, target in targets.items():
|
365
420
|
tarray = as_numpy(target)
|
366
|
-
if name not in self.
|
421
|
+
if name not in self.averaged_measures:
|
367
422
|
continue
|
368
423
|
|
369
|
-
measure = self.
|
424
|
+
measure = self.averaged_measures[name]
|
370
425
|
if measure.ndim > 1:
|
371
426
|
projection[name] = np.zeros((len(measure), len(tarray)))
|
372
427
|
for i in range(len(measure)):
|