dataeval 0.88.0__py3-none-any.whl → 0.89.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -31,24 +31,42 @@ def classifier_uncertainty(
31
31
  preds: Array,
32
32
  preds_type: Literal["probs", "logits"] = "probs",
33
33
  ) -> torch.Tensor:
34
- """
35
- Evaluate model_fn on x and transform predictions to prediction uncertainties.
34
+ """Convert model predictions to uncertainty scores using entropy.
35
+
36
+ Computes prediction uncertainty as the entropy of the predicted class
37
+ probability distribution. Higher entropy indicates greater model uncertainty,
38
+ with maximum uncertainty at uniform distributions and minimum at confident
39
+ single-class predictions.
36
40
 
37
41
  Parameters
38
42
  ----------
39
- x : Array
40
- Batch of instances.
41
- model_fn : Callable
42
- Function that evaluates a :term:`classification<Classification>` model on x in a single call (contains
43
- batching logic if necessary).
44
- preds_type : "probs" | "logits", default "probs"
45
- Type of prediction output by the model. Options are 'probs' (in [0,1]) or
46
- 'logits' (in [-inf,inf]).
43
+ preds : Array
44
+ Model predictions for a batch of instances. For "probs" type, should
45
+ contain class probabilities that sum to 1 across the last dimension.
46
+ For "logits" type, contains raw model outputs before softmax.
47
+ preds_type : "probs" or "logits", default "probs"
48
+ Type of prediction values. "probs" expects probabilities in [0,1] that
49
+ sum to 1. "logits" expects raw outputs in [-inf,inf] and applies softmax.
50
+ Default "probs" assumes model outputs normalized probabilities.
47
51
 
48
52
  Returns
49
53
  -------
50
- NDArray
51
- A scalar indication of uncertainty of the model on each instance in x.
54
+ torch.Tensor
55
+ Uncertainty scores for each instance with shape (n_samples, 1).
56
+ Values are always >= 0, with higher values indicating greater uncertainty.
57
+
58
+ Raises
59
+ ------
60
+ ValueError
61
+ If preds_type is "probs" but probabilities don't sum to 1 within tolerance.
62
+ NotImplementedError
63
+ If preds_type is not "probs" or "logits".
64
+
65
+ Notes
66
+ -----
67
+ Uncertainty is computed as Shannon entropy: -sum(p * log(p)) where p are
68
+ the predicted class probabilities. This provides a principled measure of
69
+ model confidence that is widely used in uncertainty quantification.
52
70
  """
53
71
  preds_np = as_numpy(preds)
54
72
  if preds_type == "probs":
@@ -65,53 +83,98 @@ def classifier_uncertainty(
65
83
 
66
84
 
67
85
  class DriftUncertainty(BaseDrift):
68
- """
69
- Test for a change in the number of instances falling into regions on which \
70
- the model is uncertain.
86
+ """Drift detector using model prediction uncertainty.
71
87
 
72
- Performs a K-S test on prediction entropies.
88
+ Detects drift by monitoring changes in the distribution of model prediction
89
+ uncertainties (entropy) rather than input features directly. Uses
90
+ :term:`Kolmogorov-Smirnov (K-S) Test` to compare uncertainty distributions
91
+ between reference and test data.
92
+
93
+ This approach is particularly effective for detecting drift that affects model
94
+ confidence even when input features remain statistically similar, such as
95
+ out-of-domain samples or adversarial examples.
73
96
 
74
97
  Parameters
75
98
  ----------
76
- data : Array
77
- Data used as reference distribution.
78
- model : Callable
79
- :term:`Classification` model outputting class probabilities (or logits)
99
+ data : Embeddings or Array
100
+ Reference dataset used as baseline distribution for drift detection.
101
+ Should represent the expected "normal" data distribution.
80
102
  p_val : float, default 0.05
81
- :term:`P-Value` used for the significance of the test.
103
+ Significance threshold for statistical tests, between 0 and 1.
104
+ For FDR correction, this represents the acceptable false discovery rate.
105
+ Default 0.05 provides 95% confidence level for drift detection.
82
106
  update_strategy : UpdateStrategy or None, default None
83
- Reference data can optionally be updated using an UpdateStrategy class. Update
84
- using the last n instances seen by the detector with LastSeenUpdateStrategy
85
- or via reservoir sampling with ReservoirSamplingUpdateStrategy.
107
+ Strategy for updating reference data when new data arrives.
108
+ When None, reference data remains fixed throughout detection.
86
109
  correction : "bonferroni" or "fdr", default "bonferroni"
87
- Correction type for multivariate data. Either 'bonferroni' or 'fdr' (False
88
- Discovery Rate).
110
+ Multiple testing correction method for multivariate drift detection.
111
+ "bonferroni" provides conservative family-wise error control by
112
+ dividing significance threshold by number of features.
113
+ "fdr" uses Benjamini-Hochberg procedure for less conservative control.
114
+ Default "bonferroni" minimizes false positive drift detections.
89
115
  preds_type : "probs" or "logits", default "probs"
90
- Type of prediction output by the model. Options are 'probs' (in [0,1]) or
91
- 'logits' (in [-inf,inf]).
116
+ Format of model prediction outputs. "probs" expects normalized
117
+ probabilities summing to 1. "logits" expects raw model outputs
118
+ and applies softmax normalization internally.
119
+ Default "probs" assumes standard classification model outputs.
92
120
  batch_size : int, default 32
93
- Batch size used to evaluate model. Only relevant when backend has been
94
- specified for batch prediction.
121
+ Batch size for model inference during uncertainty computation.
122
+ Larger batches improve GPU utilization but require more memory.
123
+ Default 32 balances efficiency and memory usage.
95
124
  transforms : Transform, Sequence[Transform] or None, default None
96
- Transform(s) to apply to the data.
125
+ Data transformations applied before model inference. Should match
126
+ preprocessing used during model training for consistent predictions.
127
+ When None, uses raw input data without preprocessing.
97
128
  device : DeviceLike or None, default None
98
- Device type used. The default None tries to use the GPU and falls back on
99
- CPU if needed. Can be specified by passing either 'cuda' or 'cpu'.
129
+ Hardware device for computation. When None, automatically selects
130
+ DataEval's configured device, falling back to PyTorch's default.
131
+
132
+ Attributes
133
+ ----------
134
+ model : torch.nn.Module
135
+ Classification model used for uncertainty computation.
136
+ device : torch.device
137
+ Hardware device used for model inference.
138
+ batch_size : int
139
+ Batch size for model predictions.
140
+ preds_type : {"probs", "logits"}
141
+ Format of model prediction outputs.
100
142
 
101
143
  Example
102
144
  -------
103
145
  >>> model = ClassificationModel()
104
- >>> drift = DriftUncertainty(x_ref, model=model, batch_size=20)
146
+ >>> drift_detector = DriftUncertainty(x_ref, model=model, batch_size=16)
105
147
 
106
148
  Verify reference images have not drifted
107
149
 
108
- >>> drift.predict(x_ref.copy()).drifted
109
- False
150
+ >>> result = drift_detector.predict(x_test)
151
+ >>> print(f"Drift detected: {result.drifted}")
152
+ Drift detected: True
110
153
 
111
- Test incoming images for drift
154
+ >>> print(f"Mean uncertainty change: {result.distance:.4f}")
155
+ Mean uncertainty change: 0.8160
112
156
 
113
- >>> drift.predict(x_test).drifted
114
- True
157
+ With data preprocessing
158
+
159
+ >>> import torchvision.transforms.v2 as T
160
+ >>> transforms = T.Compose([T.ToDtype(torch.float32)])
161
+ >>> drift_detector = DriftUncertainty(x_ref, model=model, batch_size=16, transforms=transforms)
162
+
163
+ Notes
164
+ -----
165
+ Uncertainty-based drift detection is complementary to feature-based methods.
166
+ It can detect semantic drift (changes in data meaning) that may not be
167
+ apparent in raw feature statistics, making it valuable for monitoring
168
+ model performance in production environments.
169
+
170
+ The method assumes that model uncertainty is a reliable indicator of
171
+ data quality. This works best with well-calibrated models trained on
172
+ representative data. Poorly calibrated models may produce misleading
173
+ uncertainty estimates.
174
+
175
+ For optimal performance, ensure the model and transforms match those used
176
+ during training, and that the reference data represents the expected
177
+ operational distribution where the model performs reliably.
115
178
  """
116
179
 
117
180
  def __init__(
@@ -142,27 +205,38 @@ class DriftUncertainty(BaseDrift):
142
205
  )
143
206
 
144
207
  def _transform(self, x: torch.Tensor) -> torch.Tensor:
208
+ """Apply preprocessing transforms to input data."""
145
209
  for transform in self._transforms:
146
210
  x = transform(x)
147
211
  return x
148
212
 
149
213
  def _preprocess(self, x: Array) -> torch.Tensor:
214
+ """Convert input data to uncertainty scores via model predictions."""
150
215
  preds = predict_batch(x, self.model, self.device, self.batch_size, self._transform)
151
216
  return classifier_uncertainty(preds, self.preds_type)
152
217
 
153
218
  def predict(self, x: Array) -> DriftOutput:
154
- """
155
- Predict whether a batch of data has drifted from the reference data.
219
+ """Predict whether model uncertainty distribution has drifted.
220
+
221
+ Computes prediction uncertainties for the input data and tests
222
+ whether their distribution significantly differs from the reference
223
+ uncertainty distribution using Kolmogorov-Smirnov test.
156
224
 
157
225
  Parameters
158
226
  ----------
159
227
  x : Array
160
- Batch of instances.
228
+ Batch of instances to test for uncertainty drift.
161
229
 
162
230
  Returns
163
231
  -------
164
- DriftUnvariateOutput
165
- Dictionary containing the drift prediction, :term:`p-value<P-Value>`, and threshold
166
- statistics.
232
+ DriftOutput
233
+ Drift detection results including overall prediction, p-values,
234
+ test statistics, and feature-level analysis of uncertainty values.
235
+
236
+ Notes
237
+ -----
238
+ The returned DriftOutput treats uncertainty values as "features" for
239
+ consistency with the underlying KS test implementation, even though
240
+ uncertainty-based drift typically involves univariate analysis.
167
241
  """
168
242
  return self._detector.predict(self._preprocess(x).cpu().numpy())
@@ -18,8 +18,28 @@ from dataeval.outputs._base import Output
18
18
 
19
19
  @dataclass(frozen=True)
20
20
  class DriftBaseOutput(Output):
21
- """
22
- Base output class for Drift Detector classes
21
+ """Base output class for drift detector classes.
22
+
23
+ Provides common fields returned by all drift detection methods, containing
24
+ instance-level drift predictions and summary statistics. Subclasses extend
25
+ this with detector-specific additional fields.
26
+
27
+ Attributes
28
+ ----------
29
+ drifted : bool
30
+ Whether drift was detected in the analyzed data. True indicates
31
+ significant drift from reference distribution.
32
+ threshold : float
33
+ Significance threshold used for drift detection, typically between 0 and 1.
34
+ For multivariate methods, this is the corrected threshold after
35
+ Bonferroni or FDR correction.
36
+ p_val : float
37
+ Instance-level p-value from statistical test, between 0 and 1.
38
+ For univariate methods, this is the mean p-value across all features.
39
+ distance : float
40
+ Instance-level test statistic or distance metric, always >= 0.
41
+ For univariate methods, this is the mean distance across all features.
42
+ Higher values indicate greater deviation from reference distribution.
23
43
  """
24
44
 
25
45
  drifted: bool
@@ -31,58 +51,76 @@ class DriftBaseOutput(Output):
31
51
  @dataclass(frozen=True)
32
52
  class DriftMMDOutput(DriftBaseOutput):
33
53
  """
34
- Output class for :class:`.DriftMMD` :term:`drift<Drift>` detector.
54
+ Output class for :class:`.DriftMMD` (Maximum Mean Discrepancy) drift detector.
55
+
56
+ Extends :class:`.DriftBaseOutput` with MMD-specific distance threshold information.
57
+ Used by MMD-based drift detectors that compare kernel embeddings between
58
+ reference and test distributions.
35
59
 
36
60
  Attributes
37
61
  ----------
38
62
  drifted : bool
39
- Drift prediction for the images
63
+ Whether drift was detected based on MMD permutation test.
40
64
  threshold : float
41
- :term:`P-Value` used for significance of the permutation test
65
+ P-value threshold used for significance of the permutation test.
42
66
  p_val : float
43
- P-value obtained from the permutation test
67
+ P-value obtained from the MMD permutation test, between 0 and 1.
44
68
  distance : float
45
- MMD^2 between the reference and test set
69
+ Squared Maximum Mean Discrepancy between reference and test set.
70
+ Always >= 0, with higher values indicating greater distributional difference.
46
71
  distance_threshold : float
47
- MMD^2 threshold above which drift is flagged
72
+ Squared Maximum Mean Discrepancy threshold above which drift is flagged, always >= 0.
73
+ Determined from permutation test at specified significance level.
74
+
75
+ Notes
76
+ -----
77
+ MMD uses kernel methods to compare distributions in reproducing kernel
78
+ Hilbert spaces, making it effective for high-dimensional data like images.
48
79
  """
49
80
 
50
- # drifted: bool
51
- # threshold: float
52
- # p_val: float
53
- # distance: float
54
81
  distance_threshold: float
55
82
 
56
83
 
57
84
  @dataclass(frozen=True)
58
85
  class DriftOutput(DriftBaseOutput):
59
- """
60
- Output class for :class:`.DriftCVM`, :class:`.DriftKS`, and :class:`.DriftUncertainty` drift detectors.
86
+ """Output class for univariate drift detectors.
87
+
88
+ Extends :class:`.DriftBaseOutput` with feature-level (per-pixel) drift information.
89
+ Used by Kolmogorov-Smirnov, Cramér-von Mises, and uncertainty-based
90
+ drift detectors that analyze each feature independently.
61
91
 
62
92
  Attributes
63
93
  ----------
64
94
  drifted : bool
65
- :term:`Drift` prediction for the images
95
+ Overall drift prediction after multivariate correction.
66
96
  threshold : float
67
- Threshold after multivariate correction if needed
97
+ Corrected threshold after Bonferroni or FDR correction for multiple testing.
68
98
  p_val : float
69
- Instance-level p-value
99
+ Mean p-value across all features, between 0 and 1.
100
+ For descriptive purposes only; individual feature p-values are used
101
+ for drift detection decisions. Can appear high even when drifted=True
102
+ if only a subset of features show drift.
70
103
  distance : float
71
- Instance-level distance
72
- feature_drift : NDArray
73
- Feature-level array of images detected to have drifted
104
+ Mean test statistic across all features, always >= 0.
105
+ feature_drift : NDArray[bool]
106
+ Boolean array indicating which features (pixels) show drift.
107
+ Shape matches the number of features in the input data.
74
108
  feature_threshold : float
75
- Feature-level threshold to determine drift
76
- p_vals : NDArray
77
- Feature-level p-values
78
- distances : NDArray
79
- Feature-level distances
109
+ Uncorrected p-value threshold used for individual feature testing.
110
+ Typically the original p_val before multivariate correction.
111
+ p_vals : NDArray[np.float32]
112
+ P-values for each feature, all values between 0 and 1.
113
+ Shape matches the number of features in the input data.
114
+ distances : NDArray[np.float32]
115
+ Test statistics for each feature, all values >= 0.
116
+ Shape matches the number of features in the input data.
117
+
118
+ Notes
119
+ -----
120
+ Feature-level analysis enables identification of specific pixels or regions
121
+ that contribute most to detected drift, useful for interpretability.
80
122
  """
81
123
 
82
- # drifted: bool
83
- # threshold: float
84
- # p_val: float
85
- # distance: float
86
124
  feature_drift: NDArray[np.bool_]
87
125
  feature_threshold: float
88
126
  p_vals: NDArray[np.float32]
@@ -62,9 +62,12 @@ def project_steps(params: NDArray[Any], projection: NDArray[Any]) -> NDArray[Any
62
62
  def plot_measure(
63
63
  name: str,
64
64
  steps: NDArray[Any],
65
- measure: NDArray[Any],
65
+ averaged_measure: NDArray[Any],
66
+ measures: NDArray[Any] | None,
66
67
  params: NDArray[Any],
67
68
  projection: NDArray[Any],
69
+ error_bars: bool,
70
+ asymptote: bool,
68
71
  ) -> Figure:
69
72
  import matplotlib.pyplot
70
73
 
@@ -73,21 +76,57 @@ def plot_measure(
73
76
  fig.tight_layout()
74
77
 
75
78
  ax = fig.add_subplot(111)
76
-
77
79
  ax.set_title(f"{name} Sufficiency")
78
80
  ax.set_ylabel(f"{name}")
79
81
  ax.set_xlabel("Steps")
80
- # Plot measure over each step
81
- ax.scatter(steps, measure, label=f"Model Results ({name})", s=15, c="black")
82
+ # Plot asymptote
83
+ if asymptote:
84
+ bound = 1 - params[2]
85
+ ax.axhline(y=bound, color="r", label=f"Asymptote: {bound:.4g}", zorder=1)
86
+ # Calculate error bars
87
+ # Plot measure over each step with associated error
88
+ if error_bars:
89
+ if measures is None:
90
+ warnings.warn(
91
+ "Error bars cannot be plotted without full, unaveraged data",
92
+ UserWarning,
93
+ )
94
+ else:
95
+ error = np.std(measures, axis=0)
96
+ ax.errorbar(
97
+ steps,
98
+ averaged_measure,
99
+ yerr=error,
100
+ capsize=7,
101
+ capthick=1.5,
102
+ elinewidth=1.5,
103
+ fmt="o",
104
+ label=f"Model Results ({name})",
105
+ markersize=5,
106
+ color="black",
107
+ ecolor="orange",
108
+ zorder=3,
109
+ )
110
+ else:
111
+ ax.scatter(
112
+ steps,
113
+ averaged_measure,
114
+ label=f"Model Results ({name})",
115
+ zorder=3,
116
+ c="black",
117
+ )
82
118
  # Plot extrapolation
83
119
  ax.plot(
84
120
  projection,
85
121
  project_steps(params, projection),
86
122
  linestyle="dashed",
87
123
  label=f"Potential Model Results ({name})",
124
+ linewidth=2,
125
+ zorder=2,
88
126
  )
127
+ ax.set_xscale("log")
89
128
 
90
- ax.legend()
129
+ ax.legend(loc="best")
91
130
  return fig
92
131
 
93
132
 
@@ -116,7 +155,9 @@ def f_inv_out(y_i: NDArray[Any], x: NDArray[Any]) -> NDArray[np.int64]:
116
155
  "Number of samples could not be determined for target(s): "
117
156
  f"""{
118
157
  np.array2string(
119
- 1 - y_i[unachievable_targets], separator=", ", formatter={"float": lambda x: f"{x}"}
158
+ 1 - y_i[unachievable_targets],
159
+ separator=", ",
160
+ formatter={"float": lambda x: f"{x}"},
120
161
  )
121
162
  }""",
122
163
  UserWarning,
@@ -190,7 +231,9 @@ def calc_params(p_i: NDArray[Any], n_i: NDArray[Any], niter: int) -> NDArray[np.
190
231
 
191
232
 
192
233
  def get_curve_params(
193
- averaged_measures: MutableMapping[str, NDArray[Any]], ranges: NDArray[Any], niter: int
234
+ averaged_measures: MutableMapping[str, NDArray[Any]],
235
+ ranges: NDArray[Any],
236
+ niter: int,
194
237
  ) -> Mapping[str, NDArray[np.float64]]:
195
238
  """Calculates and aggregates parameters for both single and multi-class metrics"""
196
239
  output = {}
@@ -286,11 +329,16 @@ class SufficiencyOutput(Output):
286
329
  output[name] = np.array(result)
287
330
  else:
288
331
  output[name] = project_steps(self.params[name], projection)
289
- proj = SufficiencyOutput(projection, measures=self.measures, averaged_measures=output, n_iter=self.n_iter)
332
+ proj = SufficiencyOutput(projection, {}, output, self.n_iter)
290
333
  proj._params = self._params
291
334
  return proj
292
335
 
293
- def plot(self, class_names: Sequence[str] | None = None) -> Sequence[Figure]:
336
+ def plot(
337
+ self,
338
+ class_names: Sequence[str] | None = None,
339
+ error_bars: bool = False,
340
+ asymptote: bool = False,
341
+ ) -> Sequence[Figure]:
294
342
  """
295
343
  Plotting function for data :term:`sufficience<Sufficiency>` tasks.
296
344
 
@@ -298,6 +346,10 @@ class SufficiencyOutput(Output):
298
346
  ----------
299
347
  class_names : Sequence[str] | None, default None
300
348
  List of class names
349
+ error_bars : bool, default False
350
+ True if error bars should be plotted, False if not
351
+ asymptote : bool, default False
352
+ True if asymptote should be plotted, False if not
301
353
 
302
354
  Returns
303
355
  -------
@@ -320,25 +372,36 @@ class SufficiencyOutput(Output):
320
372
 
321
373
  # Stores all plots
322
374
  plots = []
323
-
324
375
  # Create a plot for each measure on one figure
325
- for name, averaged_measures in self.averaged_measures.items():
326
- if averaged_measures.ndim > 1:
327
- if class_names is not None and len(averaged_measures) != len(class_names):
376
+ for name, measures in self.averaged_measures.items():
377
+ if measures.ndim > 1:
378
+ if class_names is not None and len(measures) != len(class_names):
328
379
  raise IndexError("Class name count does not align with measures")
329
- for i, measure in enumerate(averaged_measures):
380
+ for i, values in enumerate(measures):
330
381
  class_name = str(i) if class_names is None else class_names[i]
331
382
  fig = plot_measure(
332
383
  f"{name}_{class_name}",
333
384
  self.steps,
334
- measure,
385
+ values,
386
+ self.measures[name][:, :, i] if len(self.measures) else None,
335
387
  self.params[name][i],
336
388
  extrapolated,
389
+ error_bars,
390
+ asymptote,
337
391
  )
338
392
  plots.append(fig)
339
393
 
340
394
  else:
341
- fig = plot_measure(name, self.steps, averaged_measures, self.params[name], extrapolated)
395
+ fig = plot_measure(
396
+ name,
397
+ self.steps,
398
+ measures,
399
+ self.measures.get(name),
400
+ self.params[name],
401
+ extrapolated,
402
+ error_bars,
403
+ asymptote,
404
+ )
342
405
  plots.append(fig)
343
406
 
344
407
  return plots
@@ -376,7 +439,8 @@ class SufficiencyOutput(Output):
376
439
  projection[name] = np.zeros((len(measure), len(tarray)))
377
440
  for i in range(len(measure)):
378
441
  projection[name][i] = inv_project_steps(
379
- self.params[name][i], tarray[i] if tarray.ndim == measure.ndim else tarray
442
+ self.params[name][i],
443
+ tarray[i] if tarray.ndim == measure.ndim else tarray,
380
444
  )
381
445
  else:
382
446
  projection[name] = inv_project_steps(self.params[name], tarray)
dataeval/typing.py CHANGED
@@ -21,7 +21,7 @@ __all__ = [
21
21
  ]
22
22
 
23
23
 
24
- from collections.abc import Iterator, Mapping
24
+ from collections.abc import Iterator
25
25
  from typing import (
26
26
  Any,
27
27
  Generic,
@@ -94,6 +94,7 @@ class Array(Protocol):
94
94
 
95
95
  _T = TypeVar("_T")
96
96
  _T_co = TypeVar("_T_co", covariant=True)
97
+ _T_cn = TypeVar("_T_cn", contravariant=True)
97
98
 
98
99
 
99
100
  class DatasetMetadata(TypedDict, total=False):
@@ -128,6 +129,19 @@ class ModelMetadata(TypedDict, total=False):
128
129
  index2label: NotRequired[ReadOnly[dict[int, str]]]
129
130
 
130
131
 
132
+ class DatumMetadata(TypedDict, total=False):
133
+ """
134
+ Datum level metadata required for all `AnnotatedDataset` classes.
135
+
136
+ Attributes
137
+ ----------
138
+ id : Required[str]
139
+ A unique identifier for the datum
140
+ """
141
+
142
+ id: Required[ReadOnly[str]]
143
+
144
+
131
145
  @runtime_checkable
132
146
  class Dataset(Generic[_T_co], Protocol):
133
147
  """
@@ -173,7 +187,7 @@ class AnnotatedDataset(Dataset[_T_co], Generic[_T_co], Protocol):
173
187
  # ========== IMAGE CLASSIFICATION DATASETS ==========
174
188
 
175
189
 
176
- ImageClassificationDatum: TypeAlias = tuple[ArrayLike, ArrayLike, Mapping[str, Any]]
190
+ ImageClassificationDatum: TypeAlias = tuple[ArrayLike, ArrayLike, DatumMetadata]
177
191
  """
178
192
  Type alias for an image classification datum tuple.
179
193
 
@@ -213,7 +227,7 @@ class ObjectDetectionTarget(Protocol):
213
227
  def scores(self) -> ArrayLike: ...
214
228
 
215
229
 
216
- ObjectDetectionDatum: TypeAlias = tuple[ArrayLike, ObjectDetectionTarget, Mapping[str, Any]]
230
+ ObjectDetectionDatum: TypeAlias = tuple[ArrayLike, ObjectDetectionTarget, DatumMetadata]
217
231
  """
218
232
  Type alias for an object detection datum tuple.
219
233
 
@@ -254,7 +268,7 @@ class SegmentationTarget(Protocol):
254
268
  def scores(self) -> ArrayLike: ...
255
269
 
256
270
 
257
- SegmentationDatum: TypeAlias = tuple[ArrayLike, SegmentationTarget, Mapping[str, Any]]
271
+ SegmentationDatum: TypeAlias = tuple[ArrayLike, SegmentationTarget, DatumMetadata]
258
272
  """
259
273
  Type alias for an image classification datum tuple.
260
274
 
@@ -311,3 +325,8 @@ class Transform(Generic[_T], Protocol):
311
325
  """
312
326
 
313
327
  def __call__(self, data: _T, /) -> _T: ...
328
+
329
+
330
+ @runtime_checkable
331
+ class Action(Generic[_T_cn, _T_co], Protocol):
332
+ def __call__(self, evaluator: _T_cn) -> _T_co: ...
@@ -280,5 +280,4 @@ class Sufficiency(Generic[T]):
280
280
  )
281
281
 
282
282
  measures[name][run, iteration] = value
283
- # The mean for each measure must be calculated before being returned
284
- return SufficiencyOutput(ranges, measures=measures)
283
+ return SufficiencyOutput(ranges, measures)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: dataeval
3
- Version: 0.88.0
3
+ Version: 0.89.0
4
4
  Summary: DataEval provides a simple interface to characterize image data and its impact on model performance across classification and object-detection tasks
5
5
  Project-URL: Homepage, https://dataeval.ai/
6
6
  Project-URL: Repository, https://github.com/aria-ml/dataeval/