dataeval 0.87.0__py3-none-any.whl → 0.88.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (56) hide show
  1. dataeval/_log.py +1 -1
  2. dataeval/_version.py +2 -2
  3. dataeval/data/_embeddings.py +78 -35
  4. dataeval/data/_images.py +41 -8
  5. dataeval/data/_metadata.py +294 -41
  6. dataeval/data/_selection.py +22 -7
  7. dataeval/data/_split.py +2 -1
  8. dataeval/data/selections/_classfilter.py +4 -3
  9. dataeval/data/selections/_indices.py +2 -1
  10. dataeval/data/selections/_shuffle.py +3 -2
  11. dataeval/detectors/drift/_base.py +2 -1
  12. dataeval/detectors/drift/_mmd.py +2 -1
  13. dataeval/detectors/drift/_nml/_base.py +1 -1
  14. dataeval/detectors/drift/_nml/_chunk.py +2 -1
  15. dataeval/detectors/drift/_nml/_result.py +3 -2
  16. dataeval/detectors/drift/_nml/_thresholds.py +6 -5
  17. dataeval/detectors/drift/_uncertainty.py +2 -1
  18. dataeval/detectors/linters/duplicates.py +2 -1
  19. dataeval/detectors/linters/outliers.py +4 -3
  20. dataeval/detectors/ood/ae.py +1 -1
  21. dataeval/detectors/ood/base.py +2 -1
  22. dataeval/detectors/ood/mixin.py +2 -1
  23. dataeval/metadata/_utils.py +1 -1
  24. dataeval/metrics/bias/_balance.py +1 -1
  25. dataeval/metrics/stats/_base.py +3 -29
  26. dataeval/metrics/stats/_boxratiostats.py +2 -1
  27. dataeval/metrics/stats/_dimensionstats.py +2 -1
  28. dataeval/metrics/stats/_hashstats.py +2 -1
  29. dataeval/metrics/stats/_pixelstats.py +2 -1
  30. dataeval/metrics/stats/_visualstats.py +2 -1
  31. dataeval/outputs/_base.py +2 -3
  32. dataeval/outputs/_bias.py +2 -1
  33. dataeval/outputs/_estimators.py +1 -1
  34. dataeval/outputs/_linters.py +3 -3
  35. dataeval/outputs/_stats.py +3 -3
  36. dataeval/outputs/_utils.py +1 -1
  37. dataeval/outputs/_workflows.py +85 -30
  38. dataeval/typing.py +11 -9
  39. dataeval/utils/_array.py +3 -2
  40. dataeval/utils/_bin.py +2 -1
  41. dataeval/utils/_method.py +2 -3
  42. dataeval/utils/_multiprocessing.py +34 -0
  43. dataeval/utils/_plot.py +2 -1
  44. dataeval/utils/data/__init__.py +4 -5
  45. dataeval/utils/data/{metadata.py → _merge.py} +3 -2
  46. dataeval/utils/data/_validate.py +2 -1
  47. dataeval/utils/data/collate.py +2 -1
  48. dataeval/utils/torch/_internal.py +2 -1
  49. dataeval/utils/torch/trainer.py +1 -1
  50. dataeval/workflows/sufficiency.py +12 -9
  51. {dataeval-0.87.0.dist-info → dataeval-0.88.1.dist-info}/METADATA +4 -5
  52. dataeval-0.88.1.dist-info/RECORD +105 -0
  53. dataeval/utils/data/_dataset.py +0 -253
  54. dataeval-0.87.0.dist-info/RECORD +0 -105
  55. {dataeval-0.87.0.dist-info → dataeval-0.88.1.dist-info}/WHEEL +0 -0
  56. {dataeval-0.87.0.dist-info → dataeval-0.88.1.dist-info}/licenses/LICENSE +0 -0
@@ -9,8 +9,8 @@ from __future__ import annotations
9
9
 
10
10
  import logging
11
11
  from abc import ABC, abstractmethod
12
+ from collections.abc import Sequence
12
13
  from logging import Logger
13
- from typing import Sequence
14
14
 
15
15
  import pandas as pd
16
16
  from typing_extensions import Self
@@ -13,7 +13,8 @@ import copy
13
13
  import logging
14
14
  import warnings
15
15
  from abc import ABC, abstractmethod
16
- from typing import Any, Generic, Literal, Sequence, TypeVar, cast
16
+ from collections.abc import Sequence
17
+ from typing import Any, Generic, Literal, TypeVar, cast
17
18
 
18
19
  import pandas as pd
19
20
  from pandas import Index, Period
@@ -11,7 +11,8 @@ from __future__ import annotations
11
11
 
12
12
  import copy
13
13
  from abc import ABC, abstractmethod
14
- from typing import NamedTuple, Sequence
14
+ from collections.abc import Sequence
15
+ from typing import NamedTuple
15
16
 
16
17
  import pandas as pd
17
18
  from typing_extensions import Self
@@ -52,7 +53,7 @@ class AbstractResult(GenericOutput[pd.DataFrame]):
52
53
 
53
54
  def filter(self, period: str = "all", metrics: str | Sequence[str] | None = None) -> Self:
54
55
  """Returns filtered result metric data."""
55
- if metrics and not isinstance(metrics, (str, Sequence)):
56
+ if metrics and not isinstance(metrics, str | Sequence):
56
57
  raise ValueError("metrics value provided is not a valid metric or sequence of metrics")
57
58
  if isinstance(metrics, str):
58
59
  metrics = [metrics]
@@ -9,7 +9,8 @@ from __future__ import annotations
9
9
 
10
10
  import logging
11
11
  from abc import ABC, abstractmethod
12
- from typing import Any, Callable, ClassVar
12
+ from collections.abc import Callable
13
+ from typing import Any, ClassVar
13
14
 
14
15
  import numpy as np
15
16
 
@@ -169,10 +170,10 @@ class ConstantThreshold(Threshold, threshold_type="constant"):
169
170
 
170
171
  @staticmethod
171
172
  def _validate_inputs(lower: float | int | None = None, upper: float | int | None = None) -> None:
172
- if lower is not None and not isinstance(lower, (float, int)) or isinstance(lower, bool):
173
+ if lower is not None and not isinstance(lower, float | int) or isinstance(lower, bool):
173
174
  raise ValueError(f"expected type of 'lower' to be 'float', 'int' or None but got '{type(lower).__name__}'")
174
175
 
175
- if upper is not None and not isinstance(upper, (float, int)) or isinstance(upper, bool):
176
+ if upper is not None and not isinstance(upper, float | int) or isinstance(upper, bool):
176
177
  raise ValueError(f"expected type of 'upper' to be 'float', 'int' or None but got '{type(upper).__name__}'")
177
178
 
178
179
  # explicit None check is required due to special interpretation of the value 0.0 as False
@@ -244,7 +245,7 @@ class StandardDeviationThreshold(Threshold, threshold_type="standard_deviation")
244
245
  ) -> None:
245
246
  if (
246
247
  std_lower_multiplier is not None
247
- and not isinstance(std_lower_multiplier, (float, int))
248
+ and not isinstance(std_lower_multiplier, float | int)
248
249
  or isinstance(std_lower_multiplier, bool)
249
250
  ):
250
251
  raise ValueError(
@@ -257,7 +258,7 @@ class StandardDeviationThreshold(Threshold, threshold_type="standard_deviation")
257
258
 
258
259
  if (
259
260
  std_upper_multiplier is not None
260
- and not isinstance(std_upper_multiplier, (float, int))
261
+ and not isinstance(std_upper_multiplier, float | int)
261
262
  or isinstance(std_upper_multiplier, bool)
262
263
  ):
263
264
  raise ValueError(
@@ -10,7 +10,8 @@ from __future__ import annotations
10
10
 
11
11
  __all__ = []
12
12
 
13
- from typing import Literal, Sequence, cast
13
+ from collections.abc import Sequence
14
+ from typing import Literal, cast
14
15
 
15
16
  import numpy as np
16
17
  import torch
@@ -2,7 +2,8 @@ from __future__ import annotations
2
2
 
3
3
  __all__ = []
4
4
 
5
- from typing import Any, Sequence, overload
5
+ from collections.abc import Sequence
6
+ from typing import Any, overload
6
7
 
7
8
  from dataeval.data._images import Images
8
9
  from dataeval.metrics.stats import hashstats
@@ -2,7 +2,8 @@ from __future__ import annotations
2
2
 
3
3
  __all__ = []
4
4
 
5
- from typing import Any, Literal, Sequence, overload
5
+ from collections.abc import Sequence
6
+ from typing import Any, Literal, overload
6
7
 
7
8
  import numpy as np
8
9
  from numpy.typing import NDArray
@@ -201,7 +202,7 @@ class Outliers:
201
202
  >>> results.issues[1]
202
203
  {}
203
204
  """
204
- if isinstance(stats, (ImageStatsOutput, DimensionStatsOutput, PixelStatsOutput, VisualStatsOutput)):
205
+ if isinstance(stats, ImageStatsOutput | DimensionStatsOutput | PixelStatsOutput | VisualStatsOutput):
205
206
  return OutliersOutput(self._get_outliers(stats.data()))
206
207
 
207
208
  if not isinstance(stats, Sequence):
@@ -212,7 +213,7 @@ class Outliers:
212
213
  stats_map: dict[type, list[int]] = {}
213
214
  for i, stats_output in enumerate(stats):
214
215
  if not isinstance(
215
- stats_output, (ImageStatsOutput, DimensionStatsOutput, PixelStatsOutput, VisualStatsOutput)
216
+ stats_output, ImageStatsOutput | DimensionStatsOutput | PixelStatsOutput | VisualStatsOutput
216
217
  ):
217
218
  raise TypeError(
218
219
  "Invalid stats output type; only use output from dimensionstats, pixelstats or visualstats."
@@ -12,7 +12,7 @@ from __future__ import annotations
12
12
 
13
13
  __all__ = []
14
14
 
15
- from typing import Callable
15
+ from collections.abc import Callable
16
16
 
17
17
  import numpy as np
18
18
  import torch
@@ -11,7 +11,8 @@ from __future__ import annotations
11
11
  __all__ = []
12
12
 
13
13
  from abc import ABC, abstractmethod
14
- from typing import Any, Callable, cast
14
+ from collections.abc import Callable
15
+ from typing import Any, cast
15
16
 
16
17
  import numpy as np
17
18
  import torch
@@ -3,7 +3,8 @@ from __future__ import annotations
3
3
  __all__ = []
4
4
 
5
5
  from abc import ABC, abstractmethod
6
- from typing import Callable, Generic, Literal, TypeVar
6
+ from collections.abc import Callable
7
+ from typing import Generic, Literal, TypeVar
7
8
 
8
9
  import numpy as np
9
10
  from numpy.typing import NDArray
@@ -1,6 +1,6 @@
1
1
  __all__ = []
2
2
 
3
- from typing import Sequence
3
+ from collections.abc import Sequence
4
4
 
5
5
  from numpy.typing import NDArray
6
6
 
@@ -16,7 +16,7 @@ from dataeval.utils._bin import get_counts
16
16
 
17
17
 
18
18
  def _validate_num_neighbors(num_neighbors: int) -> int:
19
- if not isinstance(num_neighbors, (int, float)):
19
+ if not isinstance(num_neighbors, int | float):
20
20
  raise TypeError(
21
21
  f"Variable {num_neighbors} is not real-valued numeric type."
22
22
  "num_neighbors should be an int, greater than 0 and less than"
@@ -6,11 +6,11 @@ import math
6
6
  import re
7
7
  import warnings
8
8
  from collections import ChainMap
9
+ from collections.abc import Callable, Iterable, Iterator, Sequence
9
10
  from copy import deepcopy
10
11
  from dataclasses import dataclass
11
12
  from functools import partial
12
- from multiprocessing import Pool
13
- from typing import Any, Callable, Generic, Iterable, Iterator, Sequence, TypeVar
13
+ from typing import Any, Generic, TypeVar
14
14
 
15
15
  import numpy as np
16
16
  from numpy.typing import NDArray
@@ -21,14 +21,12 @@ from dataeval.outputs._stats import BASE_ATTRS, BaseStatsOutput, SourceIndex
21
21
  from dataeval.typing import Array, ArrayLike, Dataset, ObjectDetectionTarget
22
22
  from dataeval.utils._array import as_numpy, to_numpy
23
23
  from dataeval.utils._image import clip_and_pad, clip_box, is_valid_box, normalize_image_shape, rescale
24
+ from dataeval.utils._multiprocessing import PoolWrapper
24
25
 
25
26
  DTYPE_REGEX = re.compile(r"NDArray\[np\.(.*?)\]")
26
27
 
27
28
  TStatsOutput = TypeVar("TStatsOutput", bound=BaseStatsOutput, covariant=True)
28
29
 
29
- _S = TypeVar("_S")
30
- _T = TypeVar("_T")
31
-
32
30
 
33
31
  @dataclass
34
32
  class BoundingBox:
@@ -67,30 +65,6 @@ class BoundingBox:
67
65
  return x0_int, y0_int, x1_int, y1_int
68
66
 
69
67
 
70
- class PoolWrapper:
71
- """
72
- Wraps `multiprocessing.Pool` to allow for easy switching between
73
- multiprocessing and single-threaded execution.
74
-
75
- This helps with debugging and profiling, as well as usage with Jupyter notebooks
76
- in VS Code, which does not support subprocess debugging.
77
- """
78
-
79
- def __init__(self, processes: int | None) -> None:
80
- self.pool = Pool(processes) if processes is None or processes > 1 else None
81
-
82
- def imap(self, func: Callable[[_S], _T], iterable: Iterable[_S]) -> Iterator[_T]:
83
- return map(func, iterable) if self.pool is None else self.pool.imap(func, iterable)
84
-
85
- def __enter__(self, *args: Any, **kwargs: Any) -> PoolWrapper:
86
- return self
87
-
88
- def __exit__(self, *args: Any) -> None:
89
- if self.pool is not None:
90
- self.pool.close()
91
- self.pool.join()
92
-
93
-
94
68
  class StatsProcessor(Generic[TStatsOutput]):
95
69
  output_class: type[TStatsOutput]
96
70
  cache_keys: set[str] = set()
@@ -3,7 +3,8 @@ from __future__ import annotations
3
3
  __all__ = []
4
4
 
5
5
  import copy
6
- from typing import Any, Callable, Generic, TypeVar, cast
6
+ from collections.abc import Callable
7
+ from typing import Any, Generic, TypeVar, cast
7
8
 
8
9
  import numpy as np
9
10
  from numpy.typing import NDArray
@@ -2,7 +2,8 @@ from __future__ import annotations
2
2
 
3
3
  __all__ = []
4
4
 
5
- from typing import Any, Callable
5
+ from collections.abc import Callable
6
+ from typing import Any
6
7
 
7
8
  import numpy as np
8
9
 
@@ -4,7 +4,8 @@ import warnings
4
4
 
5
5
  __all__ = []
6
6
 
7
- from typing import Any, Callable
7
+ from collections.abc import Callable
8
+ from typing import Any
8
9
 
9
10
  import numpy as np
10
11
  import xxhash as xxh
@@ -2,7 +2,8 @@ from __future__ import annotations
2
2
 
3
3
  __all__ = []
4
4
 
5
- from typing import Any, Callable
5
+ from collections.abc import Callable
6
+ from typing import Any
6
7
 
7
8
  import numpy as np
8
9
  from scipy.stats import entropy, kurtosis, skew
@@ -2,7 +2,8 @@ from __future__ import annotations
2
2
 
3
3
  __all__ = []
4
4
 
5
- from typing import Any, Callable
5
+ from collections.abc import Callable
6
+ from typing import Any
6
7
 
7
8
  import numpy as np
8
9
 
dataeval/outputs/_base.py CHANGED
@@ -4,14 +4,13 @@ __all__ = []
4
4
 
5
5
  import inspect
6
6
  import logging
7
- from collections.abc import Collection, Mapping, Sequence
7
+ from collections.abc import Callable, Collection, Iterator, Mapping, Sequence
8
8
  from dataclasses import dataclass
9
9
  from datetime import datetime, timezone
10
10
  from functools import partial, wraps
11
- from typing import Any, Callable, Generic, Iterator, TypeVar, overload
11
+ from typing import Any, Generic, ParamSpec, TypeVar, overload
12
12
 
13
13
  import numpy as np
14
- from typing_extensions import ParamSpec
15
14
 
16
15
  from dataeval import __version__
17
16
 
dataeval/outputs/_bias.py CHANGED
@@ -3,8 +3,9 @@ from __future__ import annotations
3
3
  __all__ = []
4
4
 
5
5
  import contextlib
6
+ from collections.abc import Mapping, Sequence
6
7
  from dataclasses import asdict, dataclass
7
- from typing import Any, Mapping, Sequence, TypeVar
8
+ from typing import Any, TypeVar
8
9
 
9
10
  import numpy as np
10
11
  import pandas as pd
@@ -2,8 +2,8 @@ from __future__ import annotations
2
2
 
3
3
  __all__ = []
4
4
 
5
+ from collections.abc import Sequence
5
6
  from dataclasses import dataclass
6
- from typing import Sequence
7
7
 
8
8
  import numpy as np
9
9
  from numpy.typing import NDArray
@@ -2,11 +2,11 @@ from __future__ import annotations
2
2
 
3
3
  __all__ = []
4
4
 
5
+ from collections.abc import Mapping, Sequence
5
6
  from dataclasses import dataclass
6
- from typing import Generic, Mapping, Sequence, TypeVar, Union
7
+ from typing import Generic, TypeAlias, TypeVar
7
8
 
8
9
  import pandas as pd
9
- from typing_extensions import TypeAlias
10
10
 
11
11
  from dataeval.outputs._base import Output
12
12
  from dataeval.outputs._stats import DimensionStatsOutput, LabelStatsOutput, PixelStatsOutput, VisualStatsOutput
@@ -16,7 +16,7 @@ DatasetDuplicateGroupMap: TypeAlias = Mapping[int, DuplicateGroup]
16
16
  TIndexCollection = TypeVar("TIndexCollection", DuplicateGroup, DatasetDuplicateGroupMap)
17
17
 
18
18
  IndexIssueMap: TypeAlias = Mapping[int, Mapping[str, float]]
19
- OutlierStatsOutput: TypeAlias = Union[DimensionStatsOutput, PixelStatsOutput, VisualStatsOutput]
19
+ OutlierStatsOutput: TypeAlias = DimensionStatsOutput | PixelStatsOutput | VisualStatsOutput
20
20
  TIndexIssueMap = TypeVar("TIndexIssueMap", IndexIssueMap, Sequence[IndexIssueMap])
21
21
 
22
22
 
@@ -2,13 +2,13 @@ from __future__ import annotations
2
2
 
3
3
  __all__ = []
4
4
 
5
+ from collections.abc import Iterable, Mapping, Sequence
5
6
  from dataclasses import dataclass
6
- from typing import TYPE_CHECKING, Any, Iterable, Mapping, NamedTuple, Optional, Sequence, Union
7
+ from typing import TYPE_CHECKING, Any, NamedTuple, TypeAlias
7
8
 
8
9
  import numpy as np
9
10
  import polars as pl
10
11
  from numpy.typing import NDArray
11
- from typing_extensions import TypeAlias
12
12
 
13
13
  from dataeval.outputs._base import Output
14
14
  from dataeval.utils._plot import channel_histogram_plot, histogram_plot
@@ -16,7 +16,7 @@ from dataeval.utils._plot import channel_histogram_plot, histogram_plot
16
16
  if TYPE_CHECKING:
17
17
  from matplotlib.figure import Figure
18
18
 
19
- OptionalRange: TypeAlias = Optional[Union[int, Iterable[int]]]
19
+ OptionalRange: TypeAlias = int | Iterable[int] | None
20
20
 
21
21
  SOURCE_INDEX = "source_index"
22
22
  OBJECT_COUNT = "object_count"
@@ -2,8 +2,8 @@ from __future__ import annotations
2
2
 
3
3
  __all__ = []
4
4
 
5
+ from collections.abc import Sequence
5
6
  from dataclasses import dataclass
6
- from typing import Sequence
7
7
 
8
8
  import numpy as np
9
9
  from numpy.typing import NDArray
@@ -4,8 +4,9 @@ __all__ = []
4
4
 
5
5
  import contextlib
6
6
  import warnings
7
- from dataclasses import dataclass
8
- from typing import Any, Iterable, Mapping, Sequence, cast
7
+ from collections.abc import Iterable, Mapping, MutableMapping, Sequence
8
+ from dataclasses import dataclass, field
9
+ from typing import Any, cast
9
10
 
10
11
  import numpy as np
11
12
  from numpy.typing import NDArray
@@ -61,9 +62,12 @@ def project_steps(params: NDArray[Any], projection: NDArray[Any]) -> NDArray[Any
61
62
  def plot_measure(
62
63
  name: str,
63
64
  steps: NDArray[Any],
64
- measure: NDArray[Any],
65
+ averaged_measure: NDArray[Any],
66
+ measures: NDArray[Any] | None,
65
67
  params: NDArray[Any],
66
68
  projection: NDArray[Any],
69
+ error_bars: bool,
70
+ asymptote: bool,
67
71
  ) -> Figure:
68
72
  import matplotlib.pyplot
69
73
 
@@ -72,23 +76,51 @@ def plot_measure(
72
76
  fig.tight_layout()
73
77
 
74
78
  ax = fig.add_subplot(111)
75
-
76
79
  ax.set_title(f"{name} Sufficiency")
77
80
  ax.set_ylabel(f"{name}")
78
81
  ax.set_xlabel("Steps")
79
-
80
- # Plot measure over each step
81
- ax.scatter(steps, measure, label=f"Model Results ({name})", s=15, c="black")
82
-
82
+ # Plot asymptote
83
+ if asymptote:
84
+ bound = 1 - params[2]
85
+ ax.axhline(y=bound, color="r", label=f"Asymptote: {bound:.4g}", zorder=1)
86
+ # Calculate error bars
87
+ # Plot measure over each step with associated error
88
+ if error_bars:
89
+ if measures is None:
90
+ warnings.warn(
91
+ "Error bars cannot be plotted without full, unaveraged data",
92
+ UserWarning,
93
+ )
94
+ else:
95
+ error = np.std(measures, axis=0)
96
+ ax.errorbar(
97
+ steps,
98
+ averaged_measure,
99
+ yerr=error,
100
+ capsize=7,
101
+ capthick=1.5,
102
+ elinewidth=1.5,
103
+ fmt="o",
104
+ label=f"Model Results ({name})",
105
+ markersize=5,
106
+ color="black",
107
+ ecolor="orange",
108
+ zorder=3,
109
+ )
110
+ else:
111
+ ax.scatter(steps, averaged_measure, label=f"Model Results ({name})", zorder=3, c="black")
83
112
  # Plot extrapolation
84
113
  ax.plot(
85
114
  projection,
86
115
  project_steps(params, projection),
87
116
  linestyle="dashed",
88
117
  label=f"Potential Model Results ({name})",
118
+ linewidth=2,
119
+ zorder=2,
89
120
  )
121
+ ax.set_xscale("log")
90
122
 
91
- ax.legend()
123
+ ax.legend(loc="best")
92
124
  return fig
93
125
 
94
126
 
@@ -145,7 +177,7 @@ def inv_project_steps(params: NDArray[Any], targets: NDArray[Any]) -> NDArray[np
145
177
  return np.ceil(steps)
146
178
 
147
179
 
148
- def calc_params(p_i: NDArray[Any], n_i: NDArray[Any], niter: int) -> NDArray[Any]:
180
+ def calc_params(p_i: NDArray[Any], n_i: NDArray[Any], niter: int) -> NDArray[np.float64]:
149
181
  """
150
182
  Retrieves the inverse power curve coefficients for the line of best fit.
151
183
  Global minimization is done via basin hopping. More info on this algorithm
@@ -191,11 +223,11 @@ def calc_params(p_i: NDArray[Any], n_i: NDArray[Any], niter: int) -> NDArray[Any
191
223
 
192
224
 
193
225
  def get_curve_params(
194
- measures: Mapping[str, NDArray[Any]], ranges: NDArray[Any], niter: int
195
- ) -> Mapping[str, NDArray[Any]]:
226
+ averaged_measures: MutableMapping[str, NDArray[Any]], ranges: NDArray[Any], niter: int
227
+ ) -> Mapping[str, NDArray[np.float64]]:
196
228
  """Calculates and aggregates parameters for both single and multi-class metrics"""
197
229
  output = {}
198
- for name, measure in measures.items():
230
+ for name, measure in averaged_measures.items():
199
231
  measure = cast(np.ndarray, measure)
200
232
  if measure.ndim > 1:
201
233
  result = []
@@ -216,19 +248,25 @@ class SufficiencyOutput(Output):
216
248
  ----------
217
249
  steps : NDArray
218
250
  Array of sample sizes
219
- measures : Dict[str, NDArray]
220
- Average of values observed for each sample size step for each measure
251
+ measures : dict[str, NDArray]
252
+ 3D array [runs, substep, classes] of values for all runs observed for each sample size step for each measure
253
+ averaged_measures : dict[str, NDArray]
254
+ Average of values for all runs observed for each sample size step for each measure
221
255
  n_iter : int, default 1000
222
256
  Number of iterations to perform in the basin-hopping curve-fit process
223
257
  """
224
258
 
225
259
  steps: NDArray[np.uint32]
226
- measures: Mapping[str, NDArray[np.float64]]
260
+ measures: Mapping[str, NDArray[Any]]
261
+ averaged_measures: MutableMapping[str, NDArray[Any]] = field(default_factory=lambda: {})
227
262
  n_iter: int = 1000
228
263
 
229
264
  def __post_init__(self) -> None:
265
+ if len(self.averaged_measures) == 0:
266
+ for metric, values in self.measures.items():
267
+ self.averaged_measures[metric] = np.asarray(np.mean(values, axis=0)).T
230
268
  c = len(self.steps)
231
- for m, v in self.measures.items():
269
+ for m, v in self.averaged_measures.items():
232
270
  c_v = v.shape[1] if v.ndim > 1 else len(v)
233
271
  if c != c_v:
234
272
  raise ValueError(f"{m} does not contain the expected number ({c}) of data points.")
@@ -239,7 +277,7 @@ class SufficiencyOutput(Output):
239
277
  if self._params is None:
240
278
  self._params = {}
241
279
  if self.n_iter not in self._params:
242
- self._params[self.n_iter] = get_curve_params(self.measures, self.steps, self.n_iter)
280
+ self._params[self.n_iter] = get_curve_params(self.averaged_measures, self.steps, self.n_iter)
243
281
  return self._params[self.n_iter]
244
282
 
245
283
  @set_metadata
@@ -272,20 +310,22 @@ class SufficiencyOutput(Output):
272
310
  raise ValueError("'projection' must consist of numerical values")
273
311
 
274
312
  output = {}
275
- for name, measures in self.measures.items():
276
- if measures.ndim > 1:
313
+ for name, averaged_measures in self.averaged_measures.items():
314
+ if averaged_measures.ndim > 1:
277
315
  result = []
278
- for i in range(len(measures)):
316
+ for i in range(len(averaged_measures)):
279
317
  projected = project_steps(self.params[name][i], projection)
280
318
  result.append(projected)
281
319
  output[name] = np.array(result)
282
320
  else:
283
321
  output[name] = project_steps(self.params[name], projection)
284
- proj = SufficiencyOutput(projection, output, self.n_iter)
322
+ proj = SufficiencyOutput(projection, {}, output, self.n_iter)
285
323
  proj._params = self._params
286
324
  return proj
287
325
 
288
- def plot(self, class_names: Sequence[str] | None = None) -> Sequence[Figure]:
326
+ def plot(
327
+ self, class_names: Sequence[str] | None = None, error_bars: bool = False, asymptote: bool = False
328
+ ) -> Sequence[Figure]:
289
329
  """
290
330
  Plotting function for data :term:`sufficience<Sufficiency>` tasks.
291
331
 
@@ -293,6 +333,10 @@ class SufficiencyOutput(Output):
293
333
  ----------
294
334
  class_names : Sequence[str] | None, default None
295
335
  List of class names
336
+ error_bars : bool, default False
337
+ True if error bars should be plotted, False if not
338
+ asymptote : bool, default False
339
+ True if asymptote should be plotted, False if not
296
340
 
297
341
  Returns
298
342
  -------
@@ -315,25 +359,36 @@ class SufficiencyOutput(Output):
315
359
 
316
360
  # Stores all plots
317
361
  plots = []
318
-
319
362
  # Create a plot for each measure on one figure
320
- for name, measures in self.measures.items():
363
+ for name, measures in self.averaged_measures.items():
321
364
  if measures.ndim > 1:
322
365
  if class_names is not None and len(measures) != len(class_names):
323
366
  raise IndexError("Class name count does not align with measures")
324
- for i, measure in enumerate(measures):
367
+ for i, values in enumerate(measures):
325
368
  class_name = str(i) if class_names is None else class_names[i]
326
369
  fig = plot_measure(
327
370
  f"{name}_{class_name}",
328
371
  self.steps,
329
- measure,
372
+ values,
373
+ self.measures[name][:, :, i] if len(self.measures) else None,
330
374
  self.params[name][i],
331
375
  extrapolated,
376
+ error_bars,
377
+ asymptote,
332
378
  )
333
379
  plots.append(fig)
334
380
 
335
381
  else:
336
- fig = plot_measure(name, self.steps, measures, self.params[name], extrapolated)
382
+ fig = plot_measure(
383
+ name,
384
+ self.steps,
385
+ measures,
386
+ self.measures.get(name),
387
+ self.params[name],
388
+ extrapolated,
389
+ error_bars,
390
+ asymptote,
391
+ )
337
392
  plots.append(fig)
338
393
 
339
394
  return plots
@@ -363,10 +418,10 @@ class SufficiencyOutput(Output):
363
418
 
364
419
  for name, target in targets.items():
365
420
  tarray = as_numpy(target)
366
- if name not in self.measures:
421
+ if name not in self.averaged_measures:
367
422
  continue
368
423
 
369
- measure = self.measures[name]
424
+ measure = self.averaged_measures[name]
370
425
  if measure.ndim > 1:
371
426
  projection[name] = np.zeros((len(measure), len(tarray)))
372
427
  for i in range(len(measure)):