dataeval 0.81.0__py3-none-any.whl → 0.82.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (86) hide show
  1. dataeval/__init__.py +1 -1
  2. dataeval/config.py +68 -11
  3. dataeval/detectors/drift/__init__.py +2 -2
  4. dataeval/detectors/drift/_base.py +8 -64
  5. dataeval/detectors/drift/_mmd.py +12 -38
  6. dataeval/detectors/drift/_torch.py +7 -7
  7. dataeval/detectors/drift/_uncertainty.py +6 -5
  8. dataeval/detectors/drift/updates.py +20 -3
  9. dataeval/detectors/linters/__init__.py +3 -2
  10. dataeval/detectors/linters/duplicates.py +14 -46
  11. dataeval/detectors/linters/outliers.py +25 -159
  12. dataeval/detectors/ood/__init__.py +1 -1
  13. dataeval/detectors/ood/ae.py +6 -5
  14. dataeval/detectors/ood/base.py +2 -2
  15. dataeval/detectors/ood/metadata_ood_mi.py +4 -6
  16. dataeval/detectors/ood/mixin.py +3 -4
  17. dataeval/detectors/ood/vae.py +3 -2
  18. dataeval/metadata/__init__.py +2 -1
  19. dataeval/metadata/_distance.py +134 -0
  20. dataeval/metadata/_ood.py +30 -49
  21. dataeval/metadata/_utils.py +44 -0
  22. dataeval/metrics/bias/__init__.py +5 -4
  23. dataeval/metrics/bias/_balance.py +17 -149
  24. dataeval/metrics/bias/_coverage.py +4 -106
  25. dataeval/metrics/bias/_diversity.py +12 -107
  26. dataeval/metrics/bias/_parity.py +7 -71
  27. dataeval/metrics/estimators/__init__.py +5 -4
  28. dataeval/metrics/estimators/_ber.py +2 -20
  29. dataeval/metrics/estimators/_clusterer.py +1 -61
  30. dataeval/metrics/estimators/_divergence.py +2 -19
  31. dataeval/metrics/estimators/_uap.py +2 -16
  32. dataeval/metrics/stats/__init__.py +15 -12
  33. dataeval/metrics/stats/_base.py +41 -128
  34. dataeval/metrics/stats/_boxratiostats.py +13 -13
  35. dataeval/metrics/stats/_dimensionstats.py +17 -58
  36. dataeval/metrics/stats/_hashstats.py +19 -35
  37. dataeval/metrics/stats/_imagestats.py +94 -0
  38. dataeval/metrics/stats/_labelstats.py +42 -121
  39. dataeval/metrics/stats/_pixelstats.py +19 -51
  40. dataeval/metrics/stats/_visualstats.py +19 -51
  41. dataeval/outputs/__init__.py +57 -0
  42. dataeval/outputs/_base.py +182 -0
  43. dataeval/outputs/_bias.py +381 -0
  44. dataeval/outputs/_drift.py +83 -0
  45. dataeval/outputs/_estimators.py +114 -0
  46. dataeval/outputs/_linters.py +186 -0
  47. dataeval/outputs/_metadata.py +54 -0
  48. dataeval/{detectors/ood/output.py → outputs/_ood.py} +22 -22
  49. dataeval/outputs/_stats.py +393 -0
  50. dataeval/outputs/_utils.py +44 -0
  51. dataeval/outputs/_workflows.py +364 -0
  52. dataeval/typing.py +187 -7
  53. dataeval/utils/_method.py +1 -5
  54. dataeval/utils/_plot.py +2 -2
  55. dataeval/utils/data/__init__.py +5 -1
  56. dataeval/utils/data/_dataset.py +217 -0
  57. dataeval/utils/data/_embeddings.py +12 -14
  58. dataeval/utils/data/_images.py +30 -27
  59. dataeval/utils/data/_metadata.py +28 -11
  60. dataeval/utils/data/_selection.py +25 -22
  61. dataeval/utils/data/_split.py +5 -29
  62. dataeval/utils/data/_targets.py +14 -2
  63. dataeval/utils/data/datasets/_base.py +5 -5
  64. dataeval/utils/data/datasets/_cifar10.py +1 -1
  65. dataeval/utils/data/datasets/_milco.py +1 -1
  66. dataeval/utils/data/datasets/_mnist.py +1 -1
  67. dataeval/utils/data/datasets/_ships.py +1 -1
  68. dataeval/utils/data/{_types.py → datasets/_types.py} +10 -16
  69. dataeval/utils/data/datasets/_voc.py +1 -1
  70. dataeval/utils/data/selections/_classfilter.py +4 -5
  71. dataeval/utils/data/selections/_indices.py +2 -2
  72. dataeval/utils/data/selections/_limit.py +2 -2
  73. dataeval/utils/data/selections/_reverse.py +2 -2
  74. dataeval/utils/data/selections/_shuffle.py +2 -2
  75. dataeval/utils/torch/_internal.py +5 -5
  76. dataeval/utils/torch/trainer.py +8 -8
  77. dataeval/workflows/__init__.py +2 -1
  78. dataeval/workflows/sufficiency.py +6 -342
  79. {dataeval-0.81.0.dist-info → dataeval-0.82.1.dist-info}/METADATA +2 -2
  80. dataeval-0.82.1.dist-info/RECORD +105 -0
  81. dataeval/_output.py +0 -137
  82. dataeval/detectors/ood/metadata_ks_compare.py +0 -129
  83. dataeval/metrics/stats/_datasetstats.py +0 -198
  84. dataeval-0.81.0.dist-info/RECORD +0 -94
  85. {dataeval-0.81.0.dist-info → dataeval-0.82.1.dist-info}/LICENSE.txt +0 -0
  86. {dataeval-0.81.0.dist-info → dataeval-0.82.1.dist-info}/WHEEL +0 -0
@@ -9,9 +9,9 @@ import numpy as np
9
9
  from numpy.typing import NDArray
10
10
  from PIL import Image
11
11
 
12
- from dataeval.utils.data._types import Transform
13
12
  from dataeval.utils.data.datasets._base import BaseICDataset, DataLocation
14
13
  from dataeval.utils.data.datasets._mixin import BaseDatasetNumpyMixin
14
+ from dataeval.utils.data.datasets._types import Transform
15
15
 
16
16
  CIFARClassStringMap = Literal["airplane", "automobile", "bird", "cat", "deer", "dog", "frog", "horse", "ship", "truck"]
17
17
  TCIFARClassMap = TypeVar("TCIFARClassMap", CIFARClassStringMap, int, list[CIFARClassStringMap], list[int])
@@ -9,8 +9,8 @@ from typing import Any, Sequence
9
9
 
10
10
  from numpy.typing import NDArray
11
11
 
12
- from dataeval.utils.data._types import Transform
13
12
  from dataeval.utils.data.datasets._base import BaseODDataset, DataLocation
13
+ from dataeval.utils.data.datasets._types import Transform
14
14
 
15
15
 
16
16
  class MILCO(BaseODDataset[NDArray[Any]], BaseDatasetNumpyMixin):
@@ -8,9 +8,9 @@ from typing import Any, Literal, Sequence, TypeVar
8
8
  import numpy as np
9
9
  from numpy.typing import NDArray
10
10
 
11
- from dataeval.utils.data._types import Transform
12
11
  from dataeval.utils.data.datasets._base import BaseICDataset, DataLocation
13
12
  from dataeval.utils.data.datasets._mixin import BaseDatasetNumpyMixin
13
+ from dataeval.utils.data.datasets._types import Transform
14
14
 
15
15
  MNISTClassStringMap = Literal["zero", "one", "two", "three", "four", "five", "six", "seven", "eight", "nine"]
16
16
  TMNISTClassMap = TypeVar("TMNISTClassMap", MNISTClassStringMap, int, list[MNISTClassStringMap], list[int])
@@ -8,9 +8,9 @@ from typing import Any, Sequence
8
8
  import numpy as np
9
9
  from numpy.typing import NDArray
10
10
 
11
- from dataeval.utils.data._types import Transform
12
11
  from dataeval.utils.data.datasets._base import BaseICDataset, DataLocation
13
12
  from dataeval.utils.data.datasets._mixin import BaseDatasetNumpyMixin
13
+ from dataeval.utils.data.datasets._types import Transform
14
14
 
15
15
 
16
16
  class Ships(BaseICDataset[NDArray[Any]], BaseDatasetNumpyMixin):
@@ -2,20 +2,11 @@ from __future__ import annotations
2
2
 
3
3
  __all__ = []
4
4
 
5
- import sys
6
5
  from dataclasses import dataclass
7
6
  from typing import Any, Generic, Protocol, TypedDict, TypeVar
8
7
 
9
- if sys.version_info >= (3, 11):
10
- from typing import NotRequired, Required
11
- else:
12
- from typing_extensions import NotRequired, Required
13
-
14
- from torch.utils.data import Dataset as _Dataset
15
-
16
- _TArray = TypeVar("_TArray")
17
- _TData = TypeVar("_TData", covariant=True)
18
- _TTarget = TypeVar("_TTarget", covariant=True)
8
+ from torch.utils.data import Dataset
9
+ from typing_extensions import NotRequired, Required
19
10
 
20
11
 
21
12
  class DatasetMetadata(TypedDict):
@@ -24,14 +15,17 @@ class DatasetMetadata(TypedDict):
24
15
  split: NotRequired[str]
25
16
 
26
17
 
27
- class Dataset(_Dataset[tuple[_TData, _TTarget, dict[str, Any]]]):
18
+ _TDatum = TypeVar("_TDatum")
19
+ _TArray = TypeVar("_TArray")
20
+
21
+
22
+ class AnnotatedDataset(Dataset[_TDatum]):
28
23
  metadata: DatasetMetadata
29
24
 
30
- def __getitem__(self, index: Any) -> tuple[_TData, _TTarget, dict[str, Any]]: ...
31
25
  def __len__(self) -> int: ...
32
26
 
33
27
 
34
- class ImageClassificationDataset(Dataset[_TArray, _TArray]): ...
28
+ class ImageClassificationDataset(AnnotatedDataset[tuple[_TArray, _TArray, dict[str, Any]]]): ...
35
29
 
36
30
 
37
31
  @dataclass
@@ -41,7 +35,7 @@ class ObjectDetectionTarget(Generic[_TArray]):
41
35
  scores: _TArray
42
36
 
43
37
 
44
- class ObjectDetectionDataset(Dataset[_TArray, ObjectDetectionTarget[_TArray]]): ...
38
+ class ObjectDetectionDataset(AnnotatedDataset[tuple[_TArray, ObjectDetectionTarget[_TArray], dict[str, Any]]]): ...
45
39
 
46
40
 
47
41
  @dataclass
@@ -51,7 +45,7 @@ class SegmentationTarget(Generic[_TArray]):
51
45
  scores: _TArray
52
46
 
53
47
 
54
- class SegmentationDataset(Dataset[_TArray, SegmentationTarget[_TArray]]): ...
48
+ class SegmentationDataset(AnnotatedDataset[tuple[_TArray, SegmentationTarget[_TArray], dict[str, Any]]]): ...
55
49
 
56
50
 
57
51
  class Transform(Generic[_TArray], Protocol):
@@ -9,7 +9,6 @@ import torch
9
9
  from defusedxml.ElementTree import parse
10
10
  from numpy.typing import NDArray
11
11
 
12
- from dataeval.utils.data._types import ObjectDetectionTarget, SegmentationTarget, Transform
13
12
  from dataeval.utils.data.datasets._base import (
14
13
  BaseDataset,
15
14
  BaseODDataset,
@@ -17,6 +16,7 @@ from dataeval.utils.data.datasets._base import (
17
16
  DataLocation,
18
17
  )
19
18
  from dataeval.utils.data.datasets._mixin import BaseDatasetNumpyMixin, BaseDatasetTorchMixin
19
+ from dataeval.utils.data.datasets._types import ObjectDetectionTarget, SegmentationTarget, Transform
20
20
 
21
21
  _TArray = TypeVar("_TArray")
22
22
  _TTarget = TypeVar("_TTarget")
@@ -6,15 +6,14 @@ from typing import Sequence, TypeVar
6
6
 
7
7
  import numpy as np
8
8
 
9
- from dataeval.typing import Array
9
+ from dataeval.typing import Array, ImageClassificationDatum
10
10
  from dataeval.utils._array import as_numpy
11
11
  from dataeval.utils.data._selection import Select, Selection, SelectionStage
12
12
 
13
- _TData = TypeVar("_TData")
14
- _TTarget = TypeVar("_TTarget", bound=Array)
13
+ TImageClassificationDatum = TypeVar("TImageClassificationDatum", bound=ImageClassificationDatum, covariant=True)
15
14
 
16
15
 
17
- class ClassFilter(Selection[_TData, _TTarget]):
16
+ class ClassFilter(Selection[TImageClassificationDatum]):
18
17
  """
19
18
  Filter and balance the dataset by class.
20
19
 
@@ -37,7 +36,7 @@ class ClassFilter(Selection[_TData, _TTarget]):
37
36
  self.classes = classes
38
37
  self.balance = balance
39
38
 
40
- def __call__(self, dataset: Select[_TData, _TTarget]) -> None:
39
+ def __call__(self, dataset: Select[TImageClassificationDatum]) -> None:
41
40
  if self.classes is None and not self.balance:
42
41
  return
43
42
 
@@ -7,7 +7,7 @@ from typing import Any, Sequence
7
7
  from dataeval.utils.data._selection import Select, Selection, SelectionStage
8
8
 
9
9
 
10
- class Indices(Selection[Any, Any]):
10
+ class Indices(Selection[Any]):
11
11
  """
12
12
  Selects specific indices from the dataset.
13
13
 
@@ -22,5 +22,5 @@ class Indices(Selection[Any, Any]):
22
22
  def __init__(self, indices: Sequence[int]) -> None:
23
23
  self.indices = indices
24
24
 
25
- def __call__(self, dataset: Select[Any, Any]) -> None:
25
+ def __call__(self, dataset: Select[Any]) -> None:
26
26
  dataset._selection = [index for index in self.indices if index in dataset._selection]
@@ -7,7 +7,7 @@ from typing import Any
7
7
  from dataeval.utils.data._selection import Select, Selection, SelectionStage
8
8
 
9
9
 
10
- class Limit(Selection[Any, Any]):
10
+ class Limit(Selection[Any]):
11
11
  """
12
12
  Limit the size of the dataset.
13
13
 
@@ -22,5 +22,5 @@ class Limit(Selection[Any, Any]):
22
22
  def __init__(self, size: int) -> None:
23
23
  self.size = size
24
24
 
25
- def __call__(self, dataset: Select[Any, Any]) -> None:
25
+ def __call__(self, dataset: Select[Any]) -> None:
26
26
  dataset._size_limit = self.size
@@ -7,12 +7,12 @@ from typing import Any
7
7
  from dataeval.utils.data._selection import Select, Selection, SelectionStage
8
8
 
9
9
 
10
- class Reverse(Selection[Any, Any]):
10
+ class Reverse(Selection[Any]):
11
11
  """
12
12
  Reverse the selection order of the dataset.
13
13
  """
14
14
 
15
15
  stage = SelectionStage.ORDER
16
16
 
17
- def __call__(self, dataset: Select[Any, Any]) -> None:
17
+ def __call__(self, dataset: Select[Any]) -> None:
18
18
  dataset._selection.reverse()
@@ -9,7 +9,7 @@ import numpy as np
9
9
  from dataeval.utils.data._selection import Select, Selection, SelectionStage
10
10
 
11
11
 
12
- class Shuffle(Selection[Any, Any]):
12
+ class Shuffle(Selection[Any]):
13
13
  """
14
14
  Shuffle the dataset using a seed.
15
15
 
@@ -24,6 +24,6 @@ class Shuffle(Selection[Any, Any]):
24
24
  def __init__(self, seed: int):
25
25
  self.seed = seed
26
26
 
27
- def __call__(self, dataset: Select[Any, Any]) -> None:
27
+ def __call__(self, dataset: Select[Any]) -> None:
28
28
  rng = np.random.default_rng(self.seed)
29
29
  rng.shuffle(dataset._selection)
@@ -11,13 +11,13 @@ from numpy.typing import NDArray
11
11
  from torch.utils.data import DataLoader, TensorDataset
12
12
  from tqdm import tqdm
13
13
 
14
- from dataeval.config import get_device
14
+ from dataeval.config import DeviceLike, get_device
15
15
 
16
16
 
17
17
  def predict_batch(
18
18
  x: NDArray[Any] | torch.Tensor,
19
19
  model: Callable | torch.nn.Module | torch.nn.Sequential,
20
- device: torch.device | None = None,
20
+ device: DeviceLike | None = None,
21
21
  batch_size: int = int(1e10),
22
22
  preprocess_fn: Callable[[torch.Tensor], torch.Tensor] | None = None,
23
23
  dtype: type[np.generic] | torch.dtype = np.float32,
@@ -31,9 +31,9 @@ def predict_batch(
31
31
  Batch of instances.
32
32
  model : Callable | nn.Module | nn.Sequential
33
33
  PyTorch model.
34
- device : torch.device | None, default None
35
- Device type used. The default None tries to use the GPU and falls back on CPU.
36
- Can be specified by passing either torch.device('cuda') or torch.device('cpu').
34
+ device : DeviceLike or None, default None
35
+ The hardware device to use if specified, otherwise uses the DataEval
36
+ default or torch default.
37
37
  batch_size : int, default 1e10
38
38
  Batch size used during prediction.
39
39
  preprocess_fn : Callable | None, default None
@@ -2,6 +2,8 @@
2
2
 
3
3
  from __future__ import annotations
4
4
 
5
+ from dataeval.config import DeviceLike, get_device
6
+
5
7
  __all__ = ["AETrainer"]
6
8
 
7
9
  from typing import Any
@@ -25,9 +27,9 @@ class AETrainer:
25
27
  ----------
26
28
  model : nn.Module
27
29
  The model to be trained.
28
- device : str or torch.device, default "auto"
29
- The hardware device to use for training.
30
- If "auto", the device will be set to "cuda" if available, otherwise "cpu".
30
+ device : DeviceLike or None, default None
31
+ The hardware device to use if specified, otherwise uses the DataEval
32
+ default or torch default.
31
33
  batch_size : int, default 8
32
34
  The number of images to process in a batch.
33
35
  """
@@ -35,13 +37,11 @@ class AETrainer:
35
37
  def __init__(
36
38
  self,
37
39
  model: nn.Module,
38
- device: str | torch.device = "auto",
40
+ device: DeviceLike | None = None,
39
41
  batch_size: int = 8,
40
42
  ):
41
- if device == "auto":
42
- device = "cuda" if torch.cuda.is_available() else "cpu"
43
- self.device: torch.device = torch.device(device)
44
- self.model: nn.Module = model.to(device)
43
+ self.device: torch.device = get_device(device)
44
+ self.model: nn.Module = model.to(self.device)
45
45
  self.batch_size = batch_size
46
46
 
47
47
  def train(self, dataset: Dataset[Any], epochs: int = 25) -> list[float]:
@@ -4,4 +4,5 @@ Workflows perform a sequence of actions to analyze the dataset and make predicti
4
4
 
5
5
  __all__ = ["Sufficiency", "SufficiencyOutput"]
6
6
 
7
- from dataeval.workflows.sufficiency import Sufficiency, SufficiencyOutput
7
+ from dataeval.outputs._workflows import SufficiencyOutput
8
+ from dataeval.workflows.sufficiency import Sufficiency
@@ -2,261 +2,16 @@ from __future__ import annotations
2
2
 
3
3
  __all__ = []
4
4
 
5
- import contextlib
6
- import warnings
7
- from dataclasses import dataclass
8
- from typing import Any, Callable, Generic, Iterable, Mapping, Sequence, Sized, TypeVar, cast
5
+ from typing import Any, Callable, Generic, Iterable, Mapping, Sequence, Sized, TypeVar
9
6
 
10
7
  import numpy as np
11
8
  import torch
12
9
  import torch.nn as nn
13
- from numpy.typing import NDArray
14
- from scipy.optimize import basinhopping
15
10
  from torch.utils.data import Dataset
16
11
 
17
- from dataeval._output import Output, set_metadata
12
+ from dataeval.outputs import SufficiencyOutput
13
+ from dataeval.outputs._base import set_metadata
18
14
  from dataeval.typing import ArrayLike
19
- from dataeval.utils._array import as_numpy
20
-
21
- with contextlib.suppress(ImportError):
22
- from matplotlib.figure import Figure
23
-
24
-
25
- @dataclass(frozen=True)
26
- class SufficiencyOutput(Output):
27
- """
28
- Output class for :class:`.Sufficiency` workflow.
29
-
30
- Attributes
31
- ----------
32
- steps : NDArray
33
- Array of sample sizes
34
- params : Dict[str, NDArray]
35
- Inverse power curve coefficients for the line of best fit for each measure
36
- measures : Dict[str, NDArray]
37
- Average of values observed for each sample size step for each measure
38
- """
39
-
40
- steps: NDArray[np.uint32]
41
- params: dict[str, NDArray[np.float64]]
42
- measures: dict[str, NDArray[np.float64]]
43
-
44
- def __post_init__(self) -> None:
45
- c = len(self.steps)
46
- if set(self.params) != set(self.measures):
47
- raise ValueError("params and measures have a key mismatch")
48
- for m, v in self.measures.items():
49
- c_v = v.shape[1] if v.ndim > 1 else len(v)
50
- if c != c_v:
51
- raise ValueError(f"{m} does not contain the expected number ({c}) of data points.")
52
-
53
- @set_metadata
54
- def project(
55
- self,
56
- projection: int | Iterable[int],
57
- ) -> SufficiencyOutput:
58
- """Projects the measures for each value of X
59
-
60
- Parameters
61
- ----------
62
- projection : int | Iterable[int]
63
- Step or steps to project
64
-
65
- Returns
66
- -------
67
- SufficiencyOutput
68
- Dataclass containing the projected measures per projection
69
-
70
- Raises
71
- ------
72
- ValueError
73
- If the length of data points in the measures do not match
74
- If `projection` is not numerical
75
- """
76
- projection = np.asarray(list(projection) if isinstance(projection, Iterable) else [projection])
77
-
78
- if not np.issubdtype(projection.dtype, np.number):
79
- raise ValueError("'projection' must consist of numerical values")
80
-
81
- output = {}
82
- for name, measures in self.measures.items():
83
- if measures.ndim > 1:
84
- result = []
85
- for i in range(len(measures)):
86
- projected = project_steps(self.params[name][i], projection)
87
- result.append(projected)
88
- output[name] = np.array(result)
89
- else:
90
- output[name] = project_steps(self.params[name], projection)
91
- return SufficiencyOutput(projection, self.params, output)
92
-
93
- def plot(self, class_names: Sequence[str] | None = None) -> list[Figure]:
94
- """Plotting function for data :term:`sufficience<Sufficiency>` tasks
95
-
96
- Parameters
97
- ----------
98
- class_names : Sequence[str] | None, default None
99
- List of class names
100
-
101
- Returns
102
- -------
103
- list[Figure]
104
- List of Figures for each measure
105
-
106
- Raises
107
- ------
108
- ValueError
109
- If the length of data points in the measures do not match
110
- """
111
- # Extrapolation parameters
112
- last_X = self.steps[-1]
113
- geomshape = (0.01 * last_X, last_X * 4, len(self.steps))
114
- extrapolated = np.geomspace(*geomshape).astype(np.int64)
115
-
116
- # Stores all plots
117
- plots = []
118
-
119
- # Create a plot for each measure on one figure
120
- for name, measures in self.measures.items():
121
- if measures.ndim > 1:
122
- if class_names is not None and len(measures) != len(class_names):
123
- raise IndexError("Class name count does not align with measures")
124
- for i, measure in enumerate(measures):
125
- class_name = str(i) if class_names is None else class_names[i]
126
- fig = plot_measure(
127
- f"{name}_{class_name}",
128
- self.steps,
129
- measure,
130
- self.params[name][i],
131
- extrapolated,
132
- )
133
- plots.append(fig)
134
-
135
- else:
136
- fig = plot_measure(name, self.steps, measures, self.params[name], extrapolated)
137
- plots.append(fig)
138
-
139
- return plots
140
-
141
- def inv_project(self, targets: Mapping[str, ArrayLike]) -> dict[str, NDArray[np.float64]]:
142
- """
143
- Calculate training samples needed to achieve target model metric values.
144
-
145
- Parameters
146
- ----------
147
- targets : Mapping[str, ArrayLike]
148
- Mapping of target metric scores (from 0.0 to 1.0) that we want
149
- to achieve, where the key is the name of the metric.
150
-
151
- Returns
152
- -------
153
- dict[str, NDArray]
154
- List of the number of training samples needed to achieve each
155
- corresponding entry in targets
156
- """
157
-
158
- projection = {}
159
-
160
- for name, target in targets.items():
161
- tarray = as_numpy(target)
162
- if name not in self.measures:
163
- continue
164
-
165
- measure = self.measures[name]
166
- if measure.ndim > 1:
167
- projection[name] = np.zeros((len(measure), len(tarray)))
168
- for i in range(len(measure)):
169
- projection[name][i] = inv_project_steps(
170
- self.params[name][i], tarray[i] if tarray.ndim == measure.ndim else tarray
171
- )
172
- else:
173
- projection[name] = inv_project_steps(self.params[name], tarray)
174
-
175
- return projection
176
-
177
-
178
- def f_out(n_i: NDArray[Any], x: NDArray[Any]) -> NDArray[Any]:
179
- """
180
- Calculates the line of best fit based on its free parameters
181
-
182
- Parameters
183
- ----------
184
- n_i : NDArray
185
- Array of sample sizes
186
- x : NDArray
187
- Array of inverse power curve coefficients
188
-
189
- Returns
190
- -------
191
- NDArray
192
- Data points for the line of best fit
193
- """
194
- return x[0] * n_i ** (-x[1]) + x[2]
195
-
196
-
197
- def f_inv_out(y_i: NDArray[Any], x: NDArray[Any]) -> NDArray[np.uint64]:
198
- """
199
- Inverse function for f_out()
200
-
201
- Parameters
202
- ----------
203
- y_i : NDArray
204
- Data points for the line of best fit
205
- x : NDArray
206
- Array of inverse power curve coefficients
207
-
208
- Returns
209
- -------
210
- NDArray
211
- Array of sample sizes
212
- """
213
- n_i = ((y_i - x[2]) / x[0]) ** (-1 / x[1])
214
- return np.asarray(n_i, dtype=np.uint64)
215
-
216
-
217
- def calc_params(p_i: NDArray[Any], n_i: NDArray[Any], niter: int) -> NDArray[Any]:
218
- """
219
- Retrieves the inverse power curve coefficients for the line of best fit.
220
- Global minimization is done via basin hopping. More info on this algorithm
221
- can be found here: https://arxiv.org/abs/cond-mat/9803344 .
222
-
223
- Parameters
224
- ----------
225
- p_i : NDArray
226
- Array of corresponding losses
227
- n_i : NDArray
228
- Array of sample sizes
229
- niter : int
230
- Number of iterations to perform in the basin-hopping
231
- numerical process to curve-fit p_i
232
-
233
- Returns
234
- -------
235
- NDArray
236
- Array of parameters to recreate line of best fit
237
- """
238
-
239
- def is_valid(f_new, x_new, f_old, x_old):
240
- return f_new != np.nan
241
-
242
- def f(x):
243
- try:
244
- return np.sum(np.square(p_i - f_out(n_i, x)))
245
- except RuntimeWarning:
246
- return np.nan
247
-
248
- with warnings.catch_warnings():
249
- warnings.filterwarnings("error", category=RuntimeWarning)
250
- res = basinhopping(
251
- f,
252
- np.array([0.5, 0.5, 0.1]),
253
- niter=niter,
254
- stepsize=1.0,
255
- minimizer_kwargs={"method": "Powell"},
256
- accept_test=is_valid,
257
- niter_success=200,
258
- )
259
- return res.x
260
15
 
261
16
 
262
17
  def reset_parameters(model: nn.Module) -> nn.Module:
@@ -286,94 +41,6 @@ def validate_dataset_len(dataset: Dataset[Any]) -> int:
286
41
  return length
287
42
 
288
43
 
289
- def project_steps(params: NDArray[Any], projection: NDArray[Any]) -> NDArray[Any]:
290
- """Projects the measures for each value of X
291
-
292
- Parameters
293
- ----------
294
- params : NDArray
295
- Inverse power curve coefficients used to calculate projection
296
- projection : NDArray
297
- Steps to extrapolate
298
-
299
- Returns
300
- -------
301
- NDArray
302
- Extrapolated measure values at each projection step
303
-
304
- """
305
- return 1 - f_out(projection, params)
306
-
307
-
308
- def inv_project_steps(params: NDArray[Any], targets: NDArray[Any]) -> NDArray[np.uint64]:
309
- """Inverse function for project_steps()
310
-
311
- Parameters
312
- ----------
313
- params : NDArray
314
- Inverse power curve coefficients used to calculate projection
315
- targets : NDArray
316
- Desired measure values
317
-
318
- Returns
319
- -------
320
- NDArray
321
- Array of sample sizes, or 0 if overflow
322
- """
323
- steps = f_inv_out(1 - np.array(targets), params)
324
- steps[np.isnan(steps)] = 0
325
- return np.ceil(steps)
326
-
327
-
328
- def get_curve_params(measures: dict[str, NDArray[Any]], ranges: NDArray[Any], niter: int) -> dict[str, NDArray[Any]]:
329
- """Calculates and aggregates parameters for both single and multi-class metrics"""
330
- output = {}
331
- for name, measure in measures.items():
332
- measure = cast(np.ndarray, measure)
333
- if measure.ndim > 1:
334
- result = []
335
- for value in measure:
336
- result.append(calc_params(1 - value, ranges, niter))
337
- output[name] = np.array(result)
338
- else:
339
- output[name] = calc_params(1 - measure, ranges, niter)
340
- return output
341
-
342
-
343
- def plot_measure(
344
- name: str,
345
- steps: NDArray[Any],
346
- measure: NDArray[Any],
347
- params: NDArray[Any],
348
- projection: NDArray[Any],
349
- ) -> Figure:
350
- import matplotlib.pyplot
351
-
352
- fig = matplotlib.pyplot.figure()
353
- fig = cast(Figure, fig)
354
- fig.tight_layout()
355
-
356
- ax = fig.add_subplot(111)
357
-
358
- ax.set_title(f"{name} Sufficiency")
359
- ax.set_ylabel(f"{name}")
360
- ax.set_xlabel("Steps")
361
-
362
- # Plot measure over each step
363
- ax.scatter(steps, measure, label=f"Model Results ({name})", s=15, c="black")
364
-
365
- # Plot extrapolation
366
- ax.plot(
367
- projection,
368
- project_steps(params, projection),
369
- linestyle="dashed",
370
- label=f"Potential Model Results ({name})",
371
- )
372
-
373
- ax.legend()
374
- return fig
375
-
376
-
377
44
  T = TypeVar("T")
378
45
 
379
46
 
@@ -490,7 +157,7 @@ class Sufficiency(Generic[T]):
490
157
  self._eval_kwargs = {} if value is None else value
491
158
 
492
159
  @set_metadata(state=["runs", "substeps"])
493
- def evaluate(self, eval_at: int | Iterable[int] | None = None, niter: int = 1000) -> SufficiencyOutput:
160
+ def evaluate(self, eval_at: int | Iterable[int] | None = None) -> SufficiencyOutput:
494
161
  """
495
162
  Creates data indices, trains models, and returns plotting data
496
163
 
@@ -499,8 +166,6 @@ class Sufficiency(Generic[T]):
499
166
  eval_at : int | Iterable[int] | None, default None
500
167
  Specify this to collect accuracies over a specific set of dataset lengths, rather
501
168
  than letting :term:`sufficiency<Sufficiency>` internally create the lengths to evaluate at.
502
- niter : int, default 1000
503
- Iterations to perform when using the basin-hopping method to curve-fit measure(s).
504
169
 
505
170
  Returns
506
171
  -------
@@ -524,7 +189,7 @@ class Sufficiency(Generic[T]):
524
189
  ... substeps=5,
525
190
  ... )
526
191
  >>> suff.evaluate()
527
- SufficiencyOutput(steps=array([ 1, 3, 10, 31, 100], dtype=uint32), params={'test': array([ 0., 42., 0.])}, measures={'test': array([1., 1., 1., 1., 1.])})
192
+ SufficiencyOutput(steps=array([ 1, 3, 10, 31, 100], dtype=uint32), measures={'test': array([1., 1., 1., 1., 1.])}, n_iter=1000)
528
193
  """ # noqa: E501
529
194
  if eval_at is not None:
530
195
  ranges = np.asarray(list(eval_at) if isinstance(eval_at, Iterable) else [eval_at])
@@ -569,5 +234,4 @@ class Sufficiency(Generic[T]):
569
234
 
570
235
  # The mean for each measure must be calculated before being returned
571
236
  measures = {k: (v / self.runs).T for k, v in measures.items()}
572
- params_output = get_curve_params(measures, ranges, niter)
573
- return SufficiencyOutput(ranges, params_output, measures)
237
+ return SufficiencyOutput(ranges, measures)