dataeval 0.87.0__py3-none-any.whl → 0.88.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (56) hide show
  1. dataeval/_log.py +1 -1
  2. dataeval/_version.py +2 -2
  3. dataeval/data/_embeddings.py +78 -35
  4. dataeval/data/_images.py +41 -8
  5. dataeval/data/_metadata.py +294 -41
  6. dataeval/data/_selection.py +22 -7
  7. dataeval/data/_split.py +2 -1
  8. dataeval/data/selections/_classfilter.py +4 -3
  9. dataeval/data/selections/_indices.py +2 -1
  10. dataeval/data/selections/_shuffle.py +3 -2
  11. dataeval/detectors/drift/_base.py +2 -1
  12. dataeval/detectors/drift/_mmd.py +2 -1
  13. dataeval/detectors/drift/_nml/_base.py +1 -1
  14. dataeval/detectors/drift/_nml/_chunk.py +2 -1
  15. dataeval/detectors/drift/_nml/_result.py +3 -2
  16. dataeval/detectors/drift/_nml/_thresholds.py +6 -5
  17. dataeval/detectors/drift/_uncertainty.py +2 -1
  18. dataeval/detectors/linters/duplicates.py +2 -1
  19. dataeval/detectors/linters/outliers.py +4 -3
  20. dataeval/detectors/ood/ae.py +1 -1
  21. dataeval/detectors/ood/base.py +2 -1
  22. dataeval/detectors/ood/mixin.py +2 -1
  23. dataeval/metadata/_utils.py +1 -1
  24. dataeval/metrics/bias/_balance.py +1 -1
  25. dataeval/metrics/stats/_base.py +3 -29
  26. dataeval/metrics/stats/_boxratiostats.py +2 -1
  27. dataeval/metrics/stats/_dimensionstats.py +2 -1
  28. dataeval/metrics/stats/_hashstats.py +2 -1
  29. dataeval/metrics/stats/_pixelstats.py +2 -1
  30. dataeval/metrics/stats/_visualstats.py +2 -1
  31. dataeval/outputs/_base.py +2 -3
  32. dataeval/outputs/_bias.py +2 -1
  33. dataeval/outputs/_estimators.py +1 -1
  34. dataeval/outputs/_linters.py +3 -3
  35. dataeval/outputs/_stats.py +3 -3
  36. dataeval/outputs/_utils.py +1 -1
  37. dataeval/outputs/_workflows.py +29 -24
  38. dataeval/typing.py +11 -9
  39. dataeval/utils/_array.py +3 -2
  40. dataeval/utils/_bin.py +2 -1
  41. dataeval/utils/_method.py +2 -3
  42. dataeval/utils/_multiprocessing.py +34 -0
  43. dataeval/utils/_plot.py +2 -1
  44. dataeval/utils/data/__init__.py +4 -5
  45. dataeval/utils/data/{metadata.py → _merge.py} +3 -2
  46. dataeval/utils/data/_validate.py +2 -1
  47. dataeval/utils/data/collate.py +2 -1
  48. dataeval/utils/torch/_internal.py +2 -1
  49. dataeval/utils/torch/trainer.py +1 -1
  50. dataeval/workflows/sufficiency.py +13 -9
  51. {dataeval-0.87.0.dist-info → dataeval-0.88.0.dist-info}/METADATA +4 -5
  52. dataeval-0.88.0.dist-info/RECORD +105 -0
  53. dataeval/utils/data/_dataset.py +0 -253
  54. dataeval-0.87.0.dist-info/RECORD +0 -105
  55. {dataeval-0.87.0.dist-info → dataeval-0.88.0.dist-info}/WHEEL +0 -0
  56. {dataeval-0.87.0.dist-info → dataeval-0.88.0.dist-info}/licenses/LICENSE +0 -0
@@ -9,8 +9,8 @@ from __future__ import annotations
9
9
 
10
10
  import logging
11
11
  from abc import ABC, abstractmethod
12
+ from collections.abc import Sequence
12
13
  from logging import Logger
13
- from typing import Sequence
14
14
 
15
15
  import pandas as pd
16
16
  from typing_extensions import Self
@@ -13,7 +13,8 @@ import copy
13
13
  import logging
14
14
  import warnings
15
15
  from abc import ABC, abstractmethod
16
- from typing import Any, Generic, Literal, Sequence, TypeVar, cast
16
+ from collections.abc import Sequence
17
+ from typing import Any, Generic, Literal, TypeVar, cast
17
18
 
18
19
  import pandas as pd
19
20
  from pandas import Index, Period
@@ -11,7 +11,8 @@ from __future__ import annotations
11
11
 
12
12
  import copy
13
13
  from abc import ABC, abstractmethod
14
- from typing import NamedTuple, Sequence
14
+ from collections.abc import Sequence
15
+ from typing import NamedTuple
15
16
 
16
17
  import pandas as pd
17
18
  from typing_extensions import Self
@@ -52,7 +53,7 @@ class AbstractResult(GenericOutput[pd.DataFrame]):
52
53
 
53
54
  def filter(self, period: str = "all", metrics: str | Sequence[str] | None = None) -> Self:
54
55
  """Returns filtered result metric data."""
55
- if metrics and not isinstance(metrics, (str, Sequence)):
56
+ if metrics and not isinstance(metrics, str | Sequence):
56
57
  raise ValueError("metrics value provided is not a valid metric or sequence of metrics")
57
58
  if isinstance(metrics, str):
58
59
  metrics = [metrics]
@@ -9,7 +9,8 @@ from __future__ import annotations
9
9
 
10
10
  import logging
11
11
  from abc import ABC, abstractmethod
12
- from typing import Any, Callable, ClassVar
12
+ from collections.abc import Callable
13
+ from typing import Any, ClassVar
13
14
 
14
15
  import numpy as np
15
16
 
@@ -169,10 +170,10 @@ class ConstantThreshold(Threshold, threshold_type="constant"):
169
170
 
170
171
  @staticmethod
171
172
  def _validate_inputs(lower: float | int | None = None, upper: float | int | None = None) -> None:
172
- if lower is not None and not isinstance(lower, (float, int)) or isinstance(lower, bool):
173
+ if lower is not None and not isinstance(lower, float | int) or isinstance(lower, bool):
173
174
  raise ValueError(f"expected type of 'lower' to be 'float', 'int' or None but got '{type(lower).__name__}'")
174
175
 
175
- if upper is not None and not isinstance(upper, (float, int)) or isinstance(upper, bool):
176
+ if upper is not None and not isinstance(upper, float | int) or isinstance(upper, bool):
176
177
  raise ValueError(f"expected type of 'upper' to be 'float', 'int' or None but got '{type(upper).__name__}'")
177
178
 
178
179
  # explicit None check is required due to special interpretation of the value 0.0 as False
@@ -244,7 +245,7 @@ class StandardDeviationThreshold(Threshold, threshold_type="standard_deviation")
244
245
  ) -> None:
245
246
  if (
246
247
  std_lower_multiplier is not None
247
- and not isinstance(std_lower_multiplier, (float, int))
248
+ and not isinstance(std_lower_multiplier, float | int)
248
249
  or isinstance(std_lower_multiplier, bool)
249
250
  ):
250
251
  raise ValueError(
@@ -257,7 +258,7 @@ class StandardDeviationThreshold(Threshold, threshold_type="standard_deviation")
257
258
 
258
259
  if (
259
260
  std_upper_multiplier is not None
260
- and not isinstance(std_upper_multiplier, (float, int))
261
+ and not isinstance(std_upper_multiplier, float | int)
261
262
  or isinstance(std_upper_multiplier, bool)
262
263
  ):
263
264
  raise ValueError(
@@ -10,7 +10,8 @@ from __future__ import annotations
10
10
 
11
11
  __all__ = []
12
12
 
13
- from typing import Literal, Sequence, cast
13
+ from collections.abc import Sequence
14
+ from typing import Literal, cast
14
15
 
15
16
  import numpy as np
16
17
  import torch
@@ -2,7 +2,8 @@ from __future__ import annotations
2
2
 
3
3
  __all__ = []
4
4
 
5
- from typing import Any, Sequence, overload
5
+ from collections.abc import Sequence
6
+ from typing import Any, overload
6
7
 
7
8
  from dataeval.data._images import Images
8
9
  from dataeval.metrics.stats import hashstats
@@ -2,7 +2,8 @@ from __future__ import annotations
2
2
 
3
3
  __all__ = []
4
4
 
5
- from typing import Any, Literal, Sequence, overload
5
+ from collections.abc import Sequence
6
+ from typing import Any, Literal, overload
6
7
 
7
8
  import numpy as np
8
9
  from numpy.typing import NDArray
@@ -201,7 +202,7 @@ class Outliers:
201
202
  >>> results.issues[1]
202
203
  {}
203
204
  """
204
- if isinstance(stats, (ImageStatsOutput, DimensionStatsOutput, PixelStatsOutput, VisualStatsOutput)):
205
+ if isinstance(stats, ImageStatsOutput | DimensionStatsOutput | PixelStatsOutput | VisualStatsOutput):
205
206
  return OutliersOutput(self._get_outliers(stats.data()))
206
207
 
207
208
  if not isinstance(stats, Sequence):
@@ -212,7 +213,7 @@ class Outliers:
212
213
  stats_map: dict[type, list[int]] = {}
213
214
  for i, stats_output in enumerate(stats):
214
215
  if not isinstance(
215
- stats_output, (ImageStatsOutput, DimensionStatsOutput, PixelStatsOutput, VisualStatsOutput)
216
+ stats_output, ImageStatsOutput | DimensionStatsOutput | PixelStatsOutput | VisualStatsOutput
216
217
  ):
217
218
  raise TypeError(
218
219
  "Invalid stats output type; only use output from dimensionstats, pixelstats or visualstats."
@@ -12,7 +12,7 @@ from __future__ import annotations
12
12
 
13
13
  __all__ = []
14
14
 
15
- from typing import Callable
15
+ from collections.abc import Callable
16
16
 
17
17
  import numpy as np
18
18
  import torch
@@ -11,7 +11,8 @@ from __future__ import annotations
11
11
  __all__ = []
12
12
 
13
13
  from abc import ABC, abstractmethod
14
- from typing import Any, Callable, cast
14
+ from collections.abc import Callable
15
+ from typing import Any, cast
15
16
 
16
17
  import numpy as np
17
18
  import torch
@@ -3,7 +3,8 @@ from __future__ import annotations
3
3
  __all__ = []
4
4
 
5
5
  from abc import ABC, abstractmethod
6
- from typing import Callable, Generic, Literal, TypeVar
6
+ from collections.abc import Callable
7
+ from typing import Generic, Literal, TypeVar
7
8
 
8
9
  import numpy as np
9
10
  from numpy.typing import NDArray
@@ -1,6 +1,6 @@
1
1
  __all__ = []
2
2
 
3
- from typing import Sequence
3
+ from collections.abc import Sequence
4
4
 
5
5
  from numpy.typing import NDArray
6
6
 
@@ -16,7 +16,7 @@ from dataeval.utils._bin import get_counts
16
16
 
17
17
 
18
18
  def _validate_num_neighbors(num_neighbors: int) -> int:
19
- if not isinstance(num_neighbors, (int, float)):
19
+ if not isinstance(num_neighbors, int | float):
20
20
  raise TypeError(
21
21
  f"Variable {num_neighbors} is not real-valued numeric type."
22
22
  "num_neighbors should be an int, greater than 0 and less than"
@@ -6,11 +6,11 @@ import math
6
6
  import re
7
7
  import warnings
8
8
  from collections import ChainMap
9
+ from collections.abc import Callable, Iterable, Iterator, Sequence
9
10
  from copy import deepcopy
10
11
  from dataclasses import dataclass
11
12
  from functools import partial
12
- from multiprocessing import Pool
13
- from typing import Any, Callable, Generic, Iterable, Iterator, Sequence, TypeVar
13
+ from typing import Any, Generic, TypeVar
14
14
 
15
15
  import numpy as np
16
16
  from numpy.typing import NDArray
@@ -21,14 +21,12 @@ from dataeval.outputs._stats import BASE_ATTRS, BaseStatsOutput, SourceIndex
21
21
  from dataeval.typing import Array, ArrayLike, Dataset, ObjectDetectionTarget
22
22
  from dataeval.utils._array import as_numpy, to_numpy
23
23
  from dataeval.utils._image import clip_and_pad, clip_box, is_valid_box, normalize_image_shape, rescale
24
+ from dataeval.utils._multiprocessing import PoolWrapper
24
25
 
25
26
  DTYPE_REGEX = re.compile(r"NDArray\[np\.(.*?)\]")
26
27
 
27
28
  TStatsOutput = TypeVar("TStatsOutput", bound=BaseStatsOutput, covariant=True)
28
29
 
29
- _S = TypeVar("_S")
30
- _T = TypeVar("_T")
31
-
32
30
 
33
31
  @dataclass
34
32
  class BoundingBox:
@@ -67,30 +65,6 @@ class BoundingBox:
67
65
  return x0_int, y0_int, x1_int, y1_int
68
66
 
69
67
 
70
- class PoolWrapper:
71
- """
72
- Wraps `multiprocessing.Pool` to allow for easy switching between
73
- multiprocessing and single-threaded execution.
74
-
75
- This helps with debugging and profiling, as well as usage with Jupyter notebooks
76
- in VS Code, which does not support subprocess debugging.
77
- """
78
-
79
- def __init__(self, processes: int | None) -> None:
80
- self.pool = Pool(processes) if processes is None or processes > 1 else None
81
-
82
- def imap(self, func: Callable[[_S], _T], iterable: Iterable[_S]) -> Iterator[_T]:
83
- return map(func, iterable) if self.pool is None else self.pool.imap(func, iterable)
84
-
85
- def __enter__(self, *args: Any, **kwargs: Any) -> PoolWrapper:
86
- return self
87
-
88
- def __exit__(self, *args: Any) -> None:
89
- if self.pool is not None:
90
- self.pool.close()
91
- self.pool.join()
92
-
93
-
94
68
  class StatsProcessor(Generic[TStatsOutput]):
95
69
  output_class: type[TStatsOutput]
96
70
  cache_keys: set[str] = set()
@@ -3,7 +3,8 @@ from __future__ import annotations
3
3
  __all__ = []
4
4
 
5
5
  import copy
6
- from typing import Any, Callable, Generic, TypeVar, cast
6
+ from collections.abc import Callable
7
+ from typing import Any, Generic, TypeVar, cast
7
8
 
8
9
  import numpy as np
9
10
  from numpy.typing import NDArray
@@ -2,7 +2,8 @@ from __future__ import annotations
2
2
 
3
3
  __all__ = []
4
4
 
5
- from typing import Any, Callable
5
+ from collections.abc import Callable
6
+ from typing import Any
6
7
 
7
8
  import numpy as np
8
9
 
@@ -4,7 +4,8 @@ import warnings
4
4
 
5
5
  __all__ = []
6
6
 
7
- from typing import Any, Callable
7
+ from collections.abc import Callable
8
+ from typing import Any
8
9
 
9
10
  import numpy as np
10
11
  import xxhash as xxh
@@ -2,7 +2,8 @@ from __future__ import annotations
2
2
 
3
3
  __all__ = []
4
4
 
5
- from typing import Any, Callable
5
+ from collections.abc import Callable
6
+ from typing import Any
6
7
 
7
8
  import numpy as np
8
9
  from scipy.stats import entropy, kurtosis, skew
@@ -2,7 +2,8 @@ from __future__ import annotations
2
2
 
3
3
  __all__ = []
4
4
 
5
- from typing import Any, Callable
5
+ from collections.abc import Callable
6
+ from typing import Any
6
7
 
7
8
  import numpy as np
8
9
 
dataeval/outputs/_base.py CHANGED
@@ -4,14 +4,13 @@ __all__ = []
4
4
 
5
5
  import inspect
6
6
  import logging
7
- from collections.abc import Collection, Mapping, Sequence
7
+ from collections.abc import Callable, Collection, Iterator, Mapping, Sequence
8
8
  from dataclasses import dataclass
9
9
  from datetime import datetime, timezone
10
10
  from functools import partial, wraps
11
- from typing import Any, Callable, Generic, Iterator, TypeVar, overload
11
+ from typing import Any, Generic, ParamSpec, TypeVar, overload
12
12
 
13
13
  import numpy as np
14
- from typing_extensions import ParamSpec
15
14
 
16
15
  from dataeval import __version__
17
16
 
dataeval/outputs/_bias.py CHANGED
@@ -3,8 +3,9 @@ from __future__ import annotations
3
3
  __all__ = []
4
4
 
5
5
  import contextlib
6
+ from collections.abc import Mapping, Sequence
6
7
  from dataclasses import asdict, dataclass
7
- from typing import Any, Mapping, Sequence, TypeVar
8
+ from typing import Any, TypeVar
8
9
 
9
10
  import numpy as np
10
11
  import pandas as pd
@@ -2,8 +2,8 @@ from __future__ import annotations
2
2
 
3
3
  __all__ = []
4
4
 
5
+ from collections.abc import Sequence
5
6
  from dataclasses import dataclass
6
- from typing import Sequence
7
7
 
8
8
  import numpy as np
9
9
  from numpy.typing import NDArray
@@ -2,11 +2,11 @@ from __future__ import annotations
2
2
 
3
3
  __all__ = []
4
4
 
5
+ from collections.abc import Mapping, Sequence
5
6
  from dataclasses import dataclass
6
- from typing import Generic, Mapping, Sequence, TypeVar, Union
7
+ from typing import Generic, TypeAlias, TypeVar
7
8
 
8
9
  import pandas as pd
9
- from typing_extensions import TypeAlias
10
10
 
11
11
  from dataeval.outputs._base import Output
12
12
  from dataeval.outputs._stats import DimensionStatsOutput, LabelStatsOutput, PixelStatsOutput, VisualStatsOutput
@@ -16,7 +16,7 @@ DatasetDuplicateGroupMap: TypeAlias = Mapping[int, DuplicateGroup]
16
16
  TIndexCollection = TypeVar("TIndexCollection", DuplicateGroup, DatasetDuplicateGroupMap)
17
17
 
18
18
  IndexIssueMap: TypeAlias = Mapping[int, Mapping[str, float]]
19
- OutlierStatsOutput: TypeAlias = Union[DimensionStatsOutput, PixelStatsOutput, VisualStatsOutput]
19
+ OutlierStatsOutput: TypeAlias = DimensionStatsOutput | PixelStatsOutput | VisualStatsOutput
20
20
  TIndexIssueMap = TypeVar("TIndexIssueMap", IndexIssueMap, Sequence[IndexIssueMap])
21
21
 
22
22
 
@@ -2,13 +2,13 @@ from __future__ import annotations
2
2
 
3
3
  __all__ = []
4
4
 
5
+ from collections.abc import Iterable, Mapping, Sequence
5
6
  from dataclasses import dataclass
6
- from typing import TYPE_CHECKING, Any, Iterable, Mapping, NamedTuple, Optional, Sequence, Union
7
+ from typing import TYPE_CHECKING, Any, NamedTuple, TypeAlias
7
8
 
8
9
  import numpy as np
9
10
  import polars as pl
10
11
  from numpy.typing import NDArray
11
- from typing_extensions import TypeAlias
12
12
 
13
13
  from dataeval.outputs._base import Output
14
14
  from dataeval.utils._plot import channel_histogram_plot, histogram_plot
@@ -16,7 +16,7 @@ from dataeval.utils._plot import channel_histogram_plot, histogram_plot
16
16
  if TYPE_CHECKING:
17
17
  from matplotlib.figure import Figure
18
18
 
19
- OptionalRange: TypeAlias = Optional[Union[int, Iterable[int]]]
19
+ OptionalRange: TypeAlias = int | Iterable[int] | None
20
20
 
21
21
  SOURCE_INDEX = "source_index"
22
22
  OBJECT_COUNT = "object_count"
@@ -2,8 +2,8 @@ from __future__ import annotations
2
2
 
3
3
  __all__ = []
4
4
 
5
+ from collections.abc import Sequence
5
6
  from dataclasses import dataclass
6
- from typing import Sequence
7
7
 
8
8
  import numpy as np
9
9
  from numpy.typing import NDArray
@@ -4,8 +4,9 @@ __all__ = []
4
4
 
5
5
  import contextlib
6
6
  import warnings
7
- from dataclasses import dataclass
8
- from typing import Any, Iterable, Mapping, Sequence, cast
7
+ from collections.abc import Iterable, Mapping, MutableMapping, Sequence
8
+ from dataclasses import dataclass, field
9
+ from typing import Any, cast
9
10
 
10
11
  import numpy as np
11
12
  from numpy.typing import NDArray
@@ -76,10 +77,8 @@ def plot_measure(
76
77
  ax.set_title(f"{name} Sufficiency")
77
78
  ax.set_ylabel(f"{name}")
78
79
  ax.set_xlabel("Steps")
79
-
80
80
  # Plot measure over each step
81
81
  ax.scatter(steps, measure, label=f"Model Results ({name})", s=15, c="black")
82
-
83
82
  # Plot extrapolation
84
83
  ax.plot(
85
84
  projection,
@@ -145,7 +144,7 @@ def inv_project_steps(params: NDArray[Any], targets: NDArray[Any]) -> NDArray[np
145
144
  return np.ceil(steps)
146
145
 
147
146
 
148
- def calc_params(p_i: NDArray[Any], n_i: NDArray[Any], niter: int) -> NDArray[Any]:
147
+ def calc_params(p_i: NDArray[Any], n_i: NDArray[Any], niter: int) -> NDArray[np.float64]:
149
148
  """
150
149
  Retrieves the inverse power curve coefficients for the line of best fit.
151
150
  Global minimization is done via basin hopping. More info on this algorithm
@@ -191,11 +190,11 @@ def calc_params(p_i: NDArray[Any], n_i: NDArray[Any], niter: int) -> NDArray[Any
191
190
 
192
191
 
193
192
  def get_curve_params(
194
- measures: Mapping[str, NDArray[Any]], ranges: NDArray[Any], niter: int
195
- ) -> Mapping[str, NDArray[Any]]:
193
+ averaged_measures: MutableMapping[str, NDArray[Any]], ranges: NDArray[Any], niter: int
194
+ ) -> Mapping[str, NDArray[np.float64]]:
196
195
  """Calculates and aggregates parameters for both single and multi-class metrics"""
197
196
  output = {}
198
- for name, measure in measures.items():
197
+ for name, measure in averaged_measures.items():
199
198
  measure = cast(np.ndarray, measure)
200
199
  if measure.ndim > 1:
201
200
  result = []
@@ -216,19 +215,25 @@ class SufficiencyOutput(Output):
216
215
  ----------
217
216
  steps : NDArray
218
217
  Array of sample sizes
219
- measures : Dict[str, NDArray]
220
- Average of values observed for each sample size step for each measure
218
+ measures : dict[str, NDArray]
219
+ 3D array [runs, substep, classes] of values for all runs observed for each sample size step for each measure
220
+ averaged_measures : dict[str, NDArray]
221
+ Average of values for all runs observed for each sample size step for each measure
221
222
  n_iter : int, default 1000
222
223
  Number of iterations to perform in the basin-hopping curve-fit process
223
224
  """
224
225
 
225
226
  steps: NDArray[np.uint32]
226
- measures: Mapping[str, NDArray[np.float64]]
227
+ measures: Mapping[str, NDArray[Any]]
228
+ averaged_measures: MutableMapping[str, NDArray[Any]] = field(default_factory=lambda: {})
227
229
  n_iter: int = 1000
228
230
 
229
231
  def __post_init__(self) -> None:
232
+ if len(self.averaged_measures) == 0:
233
+ for metric, values in self.measures.items():
234
+ self.averaged_measures[metric] = np.asarray(np.mean(values, axis=0)).T
230
235
  c = len(self.steps)
231
- for m, v in self.measures.items():
236
+ for m, v in self.averaged_measures.items():
232
237
  c_v = v.shape[1] if v.ndim > 1 else len(v)
233
238
  if c != c_v:
234
239
  raise ValueError(f"{m} does not contain the expected number ({c}) of data points.")
@@ -239,7 +244,7 @@ class SufficiencyOutput(Output):
239
244
  if self._params is None:
240
245
  self._params = {}
241
246
  if self.n_iter not in self._params:
242
- self._params[self.n_iter] = get_curve_params(self.measures, self.steps, self.n_iter)
247
+ self._params[self.n_iter] = get_curve_params(self.averaged_measures, self.steps, self.n_iter)
243
248
  return self._params[self.n_iter]
244
249
 
245
250
  @set_metadata
@@ -272,16 +277,16 @@ class SufficiencyOutput(Output):
272
277
  raise ValueError("'projection' must consist of numerical values")
273
278
 
274
279
  output = {}
275
- for name, measures in self.measures.items():
276
- if measures.ndim > 1:
280
+ for name, averaged_measures in self.averaged_measures.items():
281
+ if averaged_measures.ndim > 1:
277
282
  result = []
278
- for i in range(len(measures)):
283
+ for i in range(len(averaged_measures)):
279
284
  projected = project_steps(self.params[name][i], projection)
280
285
  result.append(projected)
281
286
  output[name] = np.array(result)
282
287
  else:
283
288
  output[name] = project_steps(self.params[name], projection)
284
- proj = SufficiencyOutput(projection, output, self.n_iter)
289
+ proj = SufficiencyOutput(projection, measures=self.measures, averaged_measures=output, n_iter=self.n_iter)
285
290
  proj._params = self._params
286
291
  return proj
287
292
 
@@ -317,11 +322,11 @@ class SufficiencyOutput(Output):
317
322
  plots = []
318
323
 
319
324
  # Create a plot for each measure on one figure
320
- for name, measures in self.measures.items():
321
- if measures.ndim > 1:
322
- if class_names is not None and len(measures) != len(class_names):
325
+ for name, averaged_measures in self.averaged_measures.items():
326
+ if averaged_measures.ndim > 1:
327
+ if class_names is not None and len(averaged_measures) != len(class_names):
323
328
  raise IndexError("Class name count does not align with measures")
324
- for i, measure in enumerate(measures):
329
+ for i, measure in enumerate(averaged_measures):
325
330
  class_name = str(i) if class_names is None else class_names[i]
326
331
  fig = plot_measure(
327
332
  f"{name}_{class_name}",
@@ -333,7 +338,7 @@ class SufficiencyOutput(Output):
333
338
  plots.append(fig)
334
339
 
335
340
  else:
336
- fig = plot_measure(name, self.steps, measures, self.params[name], extrapolated)
341
+ fig = plot_measure(name, self.steps, averaged_measures, self.params[name], extrapolated)
337
342
  plots.append(fig)
338
343
 
339
344
  return plots
@@ -363,10 +368,10 @@ class SufficiencyOutput(Output):
363
368
 
364
369
  for name, target in targets.items():
365
370
  tarray = as_numpy(target)
366
- if name not in self.measures:
371
+ if name not in self.averaged_measures:
367
372
  continue
368
373
 
369
- measure = self.measures[name]
374
+ measure = self.averaged_measures[name]
370
375
  if measure.ndim > 1:
371
376
  projection[name] = np.zeros((len(measure), len(tarray)))
372
377
  for i in range(len(measure)):
dataeval/typing.py CHANGED
@@ -21,19 +21,21 @@ __all__ = [
21
21
  ]
22
22
 
23
23
 
24
- import sys
25
- from typing import Any, Generic, Iterator, Mapping, Protocol, TypedDict, TypeVar, Union, runtime_checkable
24
+ from collections.abc import Iterator, Mapping
25
+ from typing import (
26
+ Any,
27
+ Generic,
28
+ Protocol,
29
+ TypeAlias,
30
+ TypedDict,
31
+ TypeVar,
32
+ runtime_checkable,
33
+ )
26
34
 
27
35
  import numpy.typing
28
36
  import torch
29
37
  from typing_extensions import NotRequired, ReadOnly, Required
30
38
 
31
- if sys.version_info >= (3, 10):
32
- from typing import TypeAlias
33
- else:
34
- from typing_extensions import TypeAlias
35
-
36
-
37
39
  ArrayLike: TypeAlias = numpy.typing.ArrayLike
38
40
  """
39
41
  Type alias for a `Union` representing objects that can be coerced into an array.
@@ -44,7 +46,7 @@ See Also
44
46
  """
45
47
 
46
48
 
47
- DeviceLike: TypeAlias = Union[int, str, tuple[str, int], torch.device]
49
+ DeviceLike: TypeAlias = int | str | tuple[str, int] | torch.device
48
50
  """
49
51
  Type alias for a `Union` representing types that specify a torch.device.
50
52
 
dataeval/utils/_array.py CHANGED
@@ -4,9 +4,10 @@ __all__ = []
4
4
 
5
5
  import logging
6
6
  import warnings
7
+ from collections.abc import Iterable, Iterator
7
8
  from importlib import import_module
8
9
  from types import ModuleType
9
- from typing import Any, Iterable, Iterator, Literal, TypeVar, overload
10
+ from typing import Any, Literal, TypeVar, overload
10
11
 
11
12
  import numpy as np
12
13
  import torch
@@ -79,7 +80,7 @@ def rescale_array(array: NDArray[_np_dtype]) -> NDArray[_np_dtype]: ...
79
80
  def rescale_array(array: torch.Tensor) -> torch.Tensor: ...
80
81
  def rescale_array(array: Array | NDArray[_np_dtype] | torch.Tensor) -> Array | NDArray[_np_dtype] | torch.Tensor:
81
82
  """Rescale an array to the range [0, 1]"""
82
- if isinstance(array, (np.ndarray, torch.Tensor)):
83
+ if isinstance(array, np.ndarray | torch.Tensor):
83
84
  arr_min = array.min()
84
85
  arr_max = array.max()
85
86
  return (array - arr_min) / (arr_max - arr_min)
dataeval/utils/_bin.py CHANGED
@@ -3,7 +3,8 @@ from __future__ import annotations
3
3
  __all__ = []
4
4
 
5
5
  import warnings
6
- from typing import Any, Iterable
6
+ from collections.abc import Iterable
7
+ from typing import Any
7
8
 
8
9
  import numpy as np
9
10
  from numpy.typing import NDArray
dataeval/utils/_method.py CHANGED
@@ -1,8 +1,7 @@
1
1
  from __future__ import annotations
2
2
 
3
- from typing import Callable, TypeVar
4
-
5
- from typing_extensions import ParamSpec
3
+ from collections.abc import Callable
4
+ from typing import ParamSpec, TypeVar
6
5
 
7
6
  P = ParamSpec("P")
8
7
  R = TypeVar("R")
@@ -0,0 +1,34 @@
1
+ from __future__ import annotations
2
+
3
+ __all__ = []
4
+
5
+ from collections.abc import Callable, Iterable, Iterator
6
+ from multiprocessing import Pool
7
+ from typing import Any, TypeVar
8
+
9
+ _S = TypeVar("_S")
10
+ _T = TypeVar("_T")
11
+
12
+
13
+ class PoolWrapper:
14
+ """
15
+ Wraps `multiprocessing.Pool` to allow for easy switching between
16
+ multiprocessing and single-threaded execution.
17
+
18
+ This helps with debugging and profiling, as well as usage with Jupyter notebooks
19
+ in VS Code, which does not support subprocess debugging.
20
+ """
21
+
22
+ def __init__(self, processes: int | None) -> None:
23
+ self.pool = Pool(processes) if processes is None or processes > 1 else None
24
+
25
+ def imap(self, func: Callable[[_S], _T], iterable: Iterable[_S]) -> Iterator[_T]:
26
+ return map(func, iterable) if self.pool is None else self.pool.imap(func, iterable)
27
+
28
+ def __enter__(self, *args: Any, **kwargs: Any) -> PoolWrapper:
29
+ return self
30
+
31
+ def __exit__(self, *args: Any) -> None:
32
+ if self.pool is not None:
33
+ self.pool.close()
34
+ self.pool.join()