dataeval 0.88.0__py3-none-any.whl → 0.89.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dataeval/_version.py +2 -2
- dataeval/data/_embeddings.py +2 -2
- dataeval/data/_metadata.py +2 -1
- dataeval/detectors/drift/_base.py +152 -27
- dataeval/detectors/drift/_cvm.py +44 -25
- dataeval/detectors/drift/_ks.py +56 -28
- dataeval/detectors/drift/_mmd.py +44 -18
- dataeval/detectors/drift/_uncertainty.py +119 -45
- dataeval/outputs/_drift.py +67 -29
- dataeval/outputs/_workflows.py +81 -17
- dataeval/typing.py +23 -4
- dataeval/workflows/sufficiency.py +1 -2
- {dataeval-0.88.0.dist-info → dataeval-0.89.0.dist-info}/METADATA +1 -1
- {dataeval-0.88.0.dist-info → dataeval-0.89.0.dist-info}/RECORD +16 -16
- {dataeval-0.88.0.dist-info → dataeval-0.89.0.dist-info}/WHEEL +0 -0
- {dataeval-0.88.0.dist-info → dataeval-0.89.0.dist-info}/licenses/LICENSE +0 -0
@@ -31,24 +31,42 @@ def classifier_uncertainty(
|
|
31
31
|
preds: Array,
|
32
32
|
preds_type: Literal["probs", "logits"] = "probs",
|
33
33
|
) -> torch.Tensor:
|
34
|
-
"""
|
35
|
-
|
34
|
+
"""Convert model predictions to uncertainty scores using entropy.
|
35
|
+
|
36
|
+
Computes prediction uncertainty as the entropy of the predicted class
|
37
|
+
probability distribution. Higher entropy indicates greater model uncertainty,
|
38
|
+
with maximum uncertainty at uniform distributions and minimum at confident
|
39
|
+
single-class predictions.
|
36
40
|
|
37
41
|
Parameters
|
38
42
|
----------
|
39
|
-
|
40
|
-
|
41
|
-
|
42
|
-
|
43
|
-
|
44
|
-
|
45
|
-
|
46
|
-
|
43
|
+
preds : Array
|
44
|
+
Model predictions for a batch of instances. For "probs" type, should
|
45
|
+
contain class probabilities that sum to 1 across the last dimension.
|
46
|
+
For "logits" type, contains raw model outputs before softmax.
|
47
|
+
preds_type : "probs" or "logits", default "probs"
|
48
|
+
Type of prediction values. "probs" expects probabilities in [0,1] that
|
49
|
+
sum to 1. "logits" expects raw outputs in [-inf,inf] and applies softmax.
|
50
|
+
Default "probs" assumes model outputs normalized probabilities.
|
47
51
|
|
48
52
|
Returns
|
49
53
|
-------
|
50
|
-
|
51
|
-
|
54
|
+
torch.Tensor
|
55
|
+
Uncertainty scores for each instance with shape (n_samples, 1).
|
56
|
+
Values are always >= 0, with higher values indicating greater uncertainty.
|
57
|
+
|
58
|
+
Raises
|
59
|
+
------
|
60
|
+
ValueError
|
61
|
+
If preds_type is "probs" but probabilities don't sum to 1 within tolerance.
|
62
|
+
NotImplementedError
|
63
|
+
If preds_type is not "probs" or "logits".
|
64
|
+
|
65
|
+
Notes
|
66
|
+
-----
|
67
|
+
Uncertainty is computed as Shannon entropy: -sum(p * log(p)) where p are
|
68
|
+
the predicted class probabilities. This provides a principled measure of
|
69
|
+
model confidence that is widely used in uncertainty quantification.
|
52
70
|
"""
|
53
71
|
preds_np = as_numpy(preds)
|
54
72
|
if preds_type == "probs":
|
@@ -65,53 +83,98 @@ def classifier_uncertainty(
|
|
65
83
|
|
66
84
|
|
67
85
|
class DriftUncertainty(BaseDrift):
|
68
|
-
"""
|
69
|
-
Test for a change in the number of instances falling into regions on which \
|
70
|
-
the model is uncertain.
|
86
|
+
"""Drift detector using model prediction uncertainty.
|
71
87
|
|
72
|
-
|
88
|
+
Detects drift by monitoring changes in the distribution of model prediction
|
89
|
+
uncertainties (entropy) rather than input features directly. Uses
|
90
|
+
:term:`Kolmogorov-Smirnov (K-S) Test` to compare uncertainty distributions
|
91
|
+
between reference and test data.
|
92
|
+
|
93
|
+
This approach is particularly effective for detecting drift that affects model
|
94
|
+
confidence even when input features remain statistically similar, such as
|
95
|
+
out-of-domain samples or adversarial examples.
|
73
96
|
|
74
97
|
Parameters
|
75
98
|
----------
|
76
|
-
data : Array
|
77
|
-
|
78
|
-
|
79
|
-
:term:`Classification` model outputting class probabilities (or logits)
|
99
|
+
data : Embeddings or Array
|
100
|
+
Reference dataset used as baseline distribution for drift detection.
|
101
|
+
Should represent the expected "normal" data distribution.
|
80
102
|
p_val : float, default 0.05
|
81
|
-
|
103
|
+
Significance threshold for statistical tests, between 0 and 1.
|
104
|
+
For FDR correction, this represents the acceptable false discovery rate.
|
105
|
+
Default 0.05 provides 95% confidence level for drift detection.
|
82
106
|
update_strategy : UpdateStrategy or None, default None
|
83
|
-
|
84
|
-
|
85
|
-
or via reservoir sampling with ReservoirSamplingUpdateStrategy.
|
107
|
+
Strategy for updating reference data when new data arrives.
|
108
|
+
When None, reference data remains fixed throughout detection.
|
86
109
|
correction : "bonferroni" or "fdr", default "bonferroni"
|
87
|
-
|
88
|
-
|
110
|
+
Multiple testing correction method for multivariate drift detection.
|
111
|
+
"bonferroni" provides conservative family-wise error control by
|
112
|
+
dividing significance threshold by number of features.
|
113
|
+
"fdr" uses Benjamini-Hochberg procedure for less conservative control.
|
114
|
+
Default "bonferroni" minimizes false positive drift detections.
|
89
115
|
preds_type : "probs" or "logits", default "probs"
|
90
|
-
|
91
|
-
|
116
|
+
Format of model prediction outputs. "probs" expects normalized
|
117
|
+
probabilities summing to 1. "logits" expects raw model outputs
|
118
|
+
and applies softmax normalization internally.
|
119
|
+
Default "probs" assumes standard classification model outputs.
|
92
120
|
batch_size : int, default 32
|
93
|
-
Batch size
|
94
|
-
|
121
|
+
Batch size for model inference during uncertainty computation.
|
122
|
+
Larger batches improve GPU utilization but require more memory.
|
123
|
+
Default 32 balances efficiency and memory usage.
|
95
124
|
transforms : Transform, Sequence[Transform] or None, default None
|
96
|
-
|
125
|
+
Data transformations applied before model inference. Should match
|
126
|
+
preprocessing used during model training for consistent predictions.
|
127
|
+
When None, uses raw input data without preprocessing.
|
97
128
|
device : DeviceLike or None, default None
|
98
|
-
|
99
|
-
|
129
|
+
Hardware device for computation. When None, automatically selects
|
130
|
+
DataEval's configured device, falling back to PyTorch's default.
|
131
|
+
|
132
|
+
Attributes
|
133
|
+
----------
|
134
|
+
model : torch.nn.Module
|
135
|
+
Classification model used for uncertainty computation.
|
136
|
+
device : torch.device
|
137
|
+
Hardware device used for model inference.
|
138
|
+
batch_size : int
|
139
|
+
Batch size for model predictions.
|
140
|
+
preds_type : {"probs", "logits"}
|
141
|
+
Format of model prediction outputs.
|
100
142
|
|
101
143
|
Example
|
102
144
|
-------
|
103
145
|
>>> model = ClassificationModel()
|
104
|
-
>>>
|
146
|
+
>>> drift_detector = DriftUncertainty(x_ref, model=model, batch_size=16)
|
105
147
|
|
106
148
|
Verify reference images have not drifted
|
107
149
|
|
108
|
-
>>>
|
109
|
-
|
150
|
+
>>> result = drift_detector.predict(x_test)
|
151
|
+
>>> print(f"Drift detected: {result.drifted}")
|
152
|
+
Drift detected: True
|
110
153
|
|
111
|
-
|
154
|
+
>>> print(f"Mean uncertainty change: {result.distance:.4f}")
|
155
|
+
Mean uncertainty change: 0.8160
|
112
156
|
|
113
|
-
|
114
|
-
|
157
|
+
With data preprocessing
|
158
|
+
|
159
|
+
>>> import torchvision.transforms.v2 as T
|
160
|
+
>>> transforms = T.Compose([T.ToDtype(torch.float32)])
|
161
|
+
>>> drift_detector = DriftUncertainty(x_ref, model=model, batch_size=16, transforms=transforms)
|
162
|
+
|
163
|
+
Notes
|
164
|
+
-----
|
165
|
+
Uncertainty-based drift detection is complementary to feature-based methods.
|
166
|
+
It can detect semantic drift (changes in data meaning) that may not be
|
167
|
+
apparent in raw feature statistics, making it valuable for monitoring
|
168
|
+
model performance in production environments.
|
169
|
+
|
170
|
+
The method assumes that model uncertainty is a reliable indicator of
|
171
|
+
data quality. This works best with well-calibrated models trained on
|
172
|
+
representative data. Poorly calibrated models may produce misleading
|
173
|
+
uncertainty estimates.
|
174
|
+
|
175
|
+
For optimal performance, ensure the model and transforms match those used
|
176
|
+
during training, and that the reference data represents the expected
|
177
|
+
operational distribution where the model performs reliably.
|
115
178
|
"""
|
116
179
|
|
117
180
|
def __init__(
|
@@ -142,27 +205,38 @@ class DriftUncertainty(BaseDrift):
|
|
142
205
|
)
|
143
206
|
|
144
207
|
def _transform(self, x: torch.Tensor) -> torch.Tensor:
|
208
|
+
"""Apply preprocessing transforms to input data."""
|
145
209
|
for transform in self._transforms:
|
146
210
|
x = transform(x)
|
147
211
|
return x
|
148
212
|
|
149
213
|
def _preprocess(self, x: Array) -> torch.Tensor:
|
214
|
+
"""Convert input data to uncertainty scores via model predictions."""
|
150
215
|
preds = predict_batch(x, self.model, self.device, self.batch_size, self._transform)
|
151
216
|
return classifier_uncertainty(preds, self.preds_type)
|
152
217
|
|
153
218
|
def predict(self, x: Array) -> DriftOutput:
|
154
|
-
"""
|
155
|
-
|
219
|
+
"""Predict whether model uncertainty distribution has drifted.
|
220
|
+
|
221
|
+
Computes prediction uncertainties for the input data and tests
|
222
|
+
whether their distribution significantly differs from the reference
|
223
|
+
uncertainty distribution using Kolmogorov-Smirnov test.
|
156
224
|
|
157
225
|
Parameters
|
158
226
|
----------
|
159
227
|
x : Array
|
160
|
-
Batch of instances.
|
228
|
+
Batch of instances to test for uncertainty drift.
|
161
229
|
|
162
230
|
Returns
|
163
231
|
-------
|
164
|
-
|
165
|
-
|
166
|
-
statistics.
|
232
|
+
DriftOutput
|
233
|
+
Drift detection results including overall prediction, p-values,
|
234
|
+
test statistics, and feature-level analysis of uncertainty values.
|
235
|
+
|
236
|
+
Notes
|
237
|
+
-----
|
238
|
+
The returned DriftOutput treats uncertainty values as "features" for
|
239
|
+
consistency with the underlying KS test implementation, even though
|
240
|
+
uncertainty-based drift typically involves univariate analysis.
|
167
241
|
"""
|
168
242
|
return self._detector.predict(self._preprocess(x).cpu().numpy())
|
dataeval/outputs/_drift.py
CHANGED
@@ -18,8 +18,28 @@ from dataeval.outputs._base import Output
|
|
18
18
|
|
19
19
|
@dataclass(frozen=True)
|
20
20
|
class DriftBaseOutput(Output):
|
21
|
-
"""
|
22
|
-
|
21
|
+
"""Base output class for drift detector classes.
|
22
|
+
|
23
|
+
Provides common fields returned by all drift detection methods, containing
|
24
|
+
instance-level drift predictions and summary statistics. Subclasses extend
|
25
|
+
this with detector-specific additional fields.
|
26
|
+
|
27
|
+
Attributes
|
28
|
+
----------
|
29
|
+
drifted : bool
|
30
|
+
Whether drift was detected in the analyzed data. True indicates
|
31
|
+
significant drift from reference distribution.
|
32
|
+
threshold : float
|
33
|
+
Significance threshold used for drift detection, typically between 0 and 1.
|
34
|
+
For multivariate methods, this is the corrected threshold after
|
35
|
+
Bonferroni or FDR correction.
|
36
|
+
p_val : float
|
37
|
+
Instance-level p-value from statistical test, between 0 and 1.
|
38
|
+
For univariate methods, this is the mean p-value across all features.
|
39
|
+
distance : float
|
40
|
+
Instance-level test statistic or distance metric, always >= 0.
|
41
|
+
For univariate methods, this is the mean distance across all features.
|
42
|
+
Higher values indicate greater deviation from reference distribution.
|
23
43
|
"""
|
24
44
|
|
25
45
|
drifted: bool
|
@@ -31,58 +51,76 @@ class DriftBaseOutput(Output):
|
|
31
51
|
@dataclass(frozen=True)
|
32
52
|
class DriftMMDOutput(DriftBaseOutput):
|
33
53
|
"""
|
34
|
-
Output class for :class:`.DriftMMD`
|
54
|
+
Output class for :class:`.DriftMMD` (Maximum Mean Discrepancy) drift detector.
|
55
|
+
|
56
|
+
Extends :class:`.DriftBaseOutput` with MMD-specific distance threshold information.
|
57
|
+
Used by MMD-based drift detectors that compare kernel embeddings between
|
58
|
+
reference and test distributions.
|
35
59
|
|
36
60
|
Attributes
|
37
61
|
----------
|
38
62
|
drifted : bool
|
39
|
-
|
63
|
+
Whether drift was detected based on MMD permutation test.
|
40
64
|
threshold : float
|
41
|
-
|
65
|
+
P-value threshold used for significance of the permutation test.
|
42
66
|
p_val : float
|
43
|
-
P-value obtained from the permutation test
|
67
|
+
P-value obtained from the MMD permutation test, between 0 and 1.
|
44
68
|
distance : float
|
45
|
-
|
69
|
+
Squared Maximum Mean Discrepancy between reference and test set.
|
70
|
+
Always >= 0, with higher values indicating greater distributional difference.
|
46
71
|
distance_threshold : float
|
47
|
-
|
72
|
+
Squared Maximum Mean Discrepancy threshold above which drift is flagged, always >= 0.
|
73
|
+
Determined from permutation test at specified significance level.
|
74
|
+
|
75
|
+
Notes
|
76
|
+
-----
|
77
|
+
MMD uses kernel methods to compare distributions in reproducing kernel
|
78
|
+
Hilbert spaces, making it effective for high-dimensional data like images.
|
48
79
|
"""
|
49
80
|
|
50
|
-
# drifted: bool
|
51
|
-
# threshold: float
|
52
|
-
# p_val: float
|
53
|
-
# distance: float
|
54
81
|
distance_threshold: float
|
55
82
|
|
56
83
|
|
57
84
|
@dataclass(frozen=True)
|
58
85
|
class DriftOutput(DriftBaseOutput):
|
59
|
-
"""
|
60
|
-
|
86
|
+
"""Output class for univariate drift detectors.
|
87
|
+
|
88
|
+
Extends :class:`.DriftBaseOutput` with feature-level (per-pixel) drift information.
|
89
|
+
Used by Kolmogorov-Smirnov, Cramér-von Mises, and uncertainty-based
|
90
|
+
drift detectors that analyze each feature independently.
|
61
91
|
|
62
92
|
Attributes
|
63
93
|
----------
|
64
94
|
drifted : bool
|
65
|
-
|
95
|
+
Overall drift prediction after multivariate correction.
|
66
96
|
threshold : float
|
67
|
-
|
97
|
+
Corrected threshold after Bonferroni or FDR correction for multiple testing.
|
68
98
|
p_val : float
|
69
|
-
|
99
|
+
Mean p-value across all features, between 0 and 1.
|
100
|
+
For descriptive purposes only; individual feature p-values are used
|
101
|
+
for drift detection decisions. Can appear high even when drifted=True
|
102
|
+
if only a subset of features show drift.
|
70
103
|
distance : float
|
71
|
-
|
72
|
-
feature_drift : NDArray
|
73
|
-
|
104
|
+
Mean test statistic across all features, always >= 0.
|
105
|
+
feature_drift : NDArray[bool]
|
106
|
+
Boolean array indicating which features (pixels) show drift.
|
107
|
+
Shape matches the number of features in the input data.
|
74
108
|
feature_threshold : float
|
75
|
-
|
76
|
-
|
77
|
-
|
78
|
-
|
79
|
-
|
109
|
+
Uncorrected p-value threshold used for individual feature testing.
|
110
|
+
Typically the original p_val before multivariate correction.
|
111
|
+
p_vals : NDArray[np.float32]
|
112
|
+
P-values for each feature, all values between 0 and 1.
|
113
|
+
Shape matches the number of features in the input data.
|
114
|
+
distances : NDArray[np.float32]
|
115
|
+
Test statistics for each feature, all values >= 0.
|
116
|
+
Shape matches the number of features in the input data.
|
117
|
+
|
118
|
+
Notes
|
119
|
+
-----
|
120
|
+
Feature-level analysis enables identification of specific pixels or regions
|
121
|
+
that contribute most to detected drift, useful for interpretability.
|
80
122
|
"""
|
81
123
|
|
82
|
-
# drifted: bool
|
83
|
-
# threshold: float
|
84
|
-
# p_val: float
|
85
|
-
# distance: float
|
86
124
|
feature_drift: NDArray[np.bool_]
|
87
125
|
feature_threshold: float
|
88
126
|
p_vals: NDArray[np.float32]
|
dataeval/outputs/_workflows.py
CHANGED
@@ -62,9 +62,12 @@ def project_steps(params: NDArray[Any], projection: NDArray[Any]) -> NDArray[Any
|
|
62
62
|
def plot_measure(
|
63
63
|
name: str,
|
64
64
|
steps: NDArray[Any],
|
65
|
-
|
65
|
+
averaged_measure: NDArray[Any],
|
66
|
+
measures: NDArray[Any] | None,
|
66
67
|
params: NDArray[Any],
|
67
68
|
projection: NDArray[Any],
|
69
|
+
error_bars: bool,
|
70
|
+
asymptote: bool,
|
68
71
|
) -> Figure:
|
69
72
|
import matplotlib.pyplot
|
70
73
|
|
@@ -73,21 +76,57 @@ def plot_measure(
|
|
73
76
|
fig.tight_layout()
|
74
77
|
|
75
78
|
ax = fig.add_subplot(111)
|
76
|
-
|
77
79
|
ax.set_title(f"{name} Sufficiency")
|
78
80
|
ax.set_ylabel(f"{name}")
|
79
81
|
ax.set_xlabel("Steps")
|
80
|
-
# Plot
|
81
|
-
|
82
|
+
# Plot asymptote
|
83
|
+
if asymptote:
|
84
|
+
bound = 1 - params[2]
|
85
|
+
ax.axhline(y=bound, color="r", label=f"Asymptote: {bound:.4g}", zorder=1)
|
86
|
+
# Calculate error bars
|
87
|
+
# Plot measure over each step with associated error
|
88
|
+
if error_bars:
|
89
|
+
if measures is None:
|
90
|
+
warnings.warn(
|
91
|
+
"Error bars cannot be plotted without full, unaveraged data",
|
92
|
+
UserWarning,
|
93
|
+
)
|
94
|
+
else:
|
95
|
+
error = np.std(measures, axis=0)
|
96
|
+
ax.errorbar(
|
97
|
+
steps,
|
98
|
+
averaged_measure,
|
99
|
+
yerr=error,
|
100
|
+
capsize=7,
|
101
|
+
capthick=1.5,
|
102
|
+
elinewidth=1.5,
|
103
|
+
fmt="o",
|
104
|
+
label=f"Model Results ({name})",
|
105
|
+
markersize=5,
|
106
|
+
color="black",
|
107
|
+
ecolor="orange",
|
108
|
+
zorder=3,
|
109
|
+
)
|
110
|
+
else:
|
111
|
+
ax.scatter(
|
112
|
+
steps,
|
113
|
+
averaged_measure,
|
114
|
+
label=f"Model Results ({name})",
|
115
|
+
zorder=3,
|
116
|
+
c="black",
|
117
|
+
)
|
82
118
|
# Plot extrapolation
|
83
119
|
ax.plot(
|
84
120
|
projection,
|
85
121
|
project_steps(params, projection),
|
86
122
|
linestyle="dashed",
|
87
123
|
label=f"Potential Model Results ({name})",
|
124
|
+
linewidth=2,
|
125
|
+
zorder=2,
|
88
126
|
)
|
127
|
+
ax.set_xscale("log")
|
89
128
|
|
90
|
-
ax.legend()
|
129
|
+
ax.legend(loc="best")
|
91
130
|
return fig
|
92
131
|
|
93
132
|
|
@@ -116,7 +155,9 @@ def f_inv_out(y_i: NDArray[Any], x: NDArray[Any]) -> NDArray[np.int64]:
|
|
116
155
|
"Number of samples could not be determined for target(s): "
|
117
156
|
f"""{
|
118
157
|
np.array2string(
|
119
|
-
1 - y_i[unachievable_targets],
|
158
|
+
1 - y_i[unachievable_targets],
|
159
|
+
separator=", ",
|
160
|
+
formatter={"float": lambda x: f"{x}"},
|
120
161
|
)
|
121
162
|
}""",
|
122
163
|
UserWarning,
|
@@ -190,7 +231,9 @@ def calc_params(p_i: NDArray[Any], n_i: NDArray[Any], niter: int) -> NDArray[np.
|
|
190
231
|
|
191
232
|
|
192
233
|
def get_curve_params(
|
193
|
-
averaged_measures: MutableMapping[str, NDArray[Any]],
|
234
|
+
averaged_measures: MutableMapping[str, NDArray[Any]],
|
235
|
+
ranges: NDArray[Any],
|
236
|
+
niter: int,
|
194
237
|
) -> Mapping[str, NDArray[np.float64]]:
|
195
238
|
"""Calculates and aggregates parameters for both single and multi-class metrics"""
|
196
239
|
output = {}
|
@@ -286,11 +329,16 @@ class SufficiencyOutput(Output):
|
|
286
329
|
output[name] = np.array(result)
|
287
330
|
else:
|
288
331
|
output[name] = project_steps(self.params[name], projection)
|
289
|
-
proj = SufficiencyOutput(projection,
|
332
|
+
proj = SufficiencyOutput(projection, {}, output, self.n_iter)
|
290
333
|
proj._params = self._params
|
291
334
|
return proj
|
292
335
|
|
293
|
-
def plot(
|
336
|
+
def plot(
|
337
|
+
self,
|
338
|
+
class_names: Sequence[str] | None = None,
|
339
|
+
error_bars: bool = False,
|
340
|
+
asymptote: bool = False,
|
341
|
+
) -> Sequence[Figure]:
|
294
342
|
"""
|
295
343
|
Plotting function for data :term:`sufficience<Sufficiency>` tasks.
|
296
344
|
|
@@ -298,6 +346,10 @@ class SufficiencyOutput(Output):
|
|
298
346
|
----------
|
299
347
|
class_names : Sequence[str] | None, default None
|
300
348
|
List of class names
|
349
|
+
error_bars : bool, default False
|
350
|
+
True if error bars should be plotted, False if not
|
351
|
+
asymptote : bool, default False
|
352
|
+
True if asymptote should be plotted, False if not
|
301
353
|
|
302
354
|
Returns
|
303
355
|
-------
|
@@ -320,25 +372,36 @@ class SufficiencyOutput(Output):
|
|
320
372
|
|
321
373
|
# Stores all plots
|
322
374
|
plots = []
|
323
|
-
|
324
375
|
# Create a plot for each measure on one figure
|
325
|
-
for name,
|
326
|
-
if
|
327
|
-
if class_names is not None and len(
|
376
|
+
for name, measures in self.averaged_measures.items():
|
377
|
+
if measures.ndim > 1:
|
378
|
+
if class_names is not None and len(measures) != len(class_names):
|
328
379
|
raise IndexError("Class name count does not align with measures")
|
329
|
-
for i,
|
380
|
+
for i, values in enumerate(measures):
|
330
381
|
class_name = str(i) if class_names is None else class_names[i]
|
331
382
|
fig = plot_measure(
|
332
383
|
f"{name}_{class_name}",
|
333
384
|
self.steps,
|
334
|
-
|
385
|
+
values,
|
386
|
+
self.measures[name][:, :, i] if len(self.measures) else None,
|
335
387
|
self.params[name][i],
|
336
388
|
extrapolated,
|
389
|
+
error_bars,
|
390
|
+
asymptote,
|
337
391
|
)
|
338
392
|
plots.append(fig)
|
339
393
|
|
340
394
|
else:
|
341
|
-
fig = plot_measure(
|
395
|
+
fig = plot_measure(
|
396
|
+
name,
|
397
|
+
self.steps,
|
398
|
+
measures,
|
399
|
+
self.measures.get(name),
|
400
|
+
self.params[name],
|
401
|
+
extrapolated,
|
402
|
+
error_bars,
|
403
|
+
asymptote,
|
404
|
+
)
|
342
405
|
plots.append(fig)
|
343
406
|
|
344
407
|
return plots
|
@@ -376,7 +439,8 @@ class SufficiencyOutput(Output):
|
|
376
439
|
projection[name] = np.zeros((len(measure), len(tarray)))
|
377
440
|
for i in range(len(measure)):
|
378
441
|
projection[name][i] = inv_project_steps(
|
379
|
-
self.params[name][i],
|
442
|
+
self.params[name][i],
|
443
|
+
tarray[i] if tarray.ndim == measure.ndim else tarray,
|
380
444
|
)
|
381
445
|
else:
|
382
446
|
projection[name] = inv_project_steps(self.params[name], tarray)
|
dataeval/typing.py
CHANGED
@@ -21,7 +21,7 @@ __all__ = [
|
|
21
21
|
]
|
22
22
|
|
23
23
|
|
24
|
-
from collections.abc import Iterator
|
24
|
+
from collections.abc import Iterator
|
25
25
|
from typing import (
|
26
26
|
Any,
|
27
27
|
Generic,
|
@@ -94,6 +94,7 @@ class Array(Protocol):
|
|
94
94
|
|
95
95
|
_T = TypeVar("_T")
|
96
96
|
_T_co = TypeVar("_T_co", covariant=True)
|
97
|
+
_T_cn = TypeVar("_T_cn", contravariant=True)
|
97
98
|
|
98
99
|
|
99
100
|
class DatasetMetadata(TypedDict, total=False):
|
@@ -128,6 +129,19 @@ class ModelMetadata(TypedDict, total=False):
|
|
128
129
|
index2label: NotRequired[ReadOnly[dict[int, str]]]
|
129
130
|
|
130
131
|
|
132
|
+
class DatumMetadata(TypedDict, total=False):
|
133
|
+
"""
|
134
|
+
Datum level metadata required for all `AnnotatedDataset` classes.
|
135
|
+
|
136
|
+
Attributes
|
137
|
+
----------
|
138
|
+
id : Required[str]
|
139
|
+
A unique identifier for the datum
|
140
|
+
"""
|
141
|
+
|
142
|
+
id: Required[ReadOnly[str]]
|
143
|
+
|
144
|
+
|
131
145
|
@runtime_checkable
|
132
146
|
class Dataset(Generic[_T_co], Protocol):
|
133
147
|
"""
|
@@ -173,7 +187,7 @@ class AnnotatedDataset(Dataset[_T_co], Generic[_T_co], Protocol):
|
|
173
187
|
# ========== IMAGE CLASSIFICATION DATASETS ==========
|
174
188
|
|
175
189
|
|
176
|
-
ImageClassificationDatum: TypeAlias = tuple[ArrayLike, ArrayLike,
|
190
|
+
ImageClassificationDatum: TypeAlias = tuple[ArrayLike, ArrayLike, DatumMetadata]
|
177
191
|
"""
|
178
192
|
Type alias for an image classification datum tuple.
|
179
193
|
|
@@ -213,7 +227,7 @@ class ObjectDetectionTarget(Protocol):
|
|
213
227
|
def scores(self) -> ArrayLike: ...
|
214
228
|
|
215
229
|
|
216
|
-
ObjectDetectionDatum: TypeAlias = tuple[ArrayLike, ObjectDetectionTarget,
|
230
|
+
ObjectDetectionDatum: TypeAlias = tuple[ArrayLike, ObjectDetectionTarget, DatumMetadata]
|
217
231
|
"""
|
218
232
|
Type alias for an object detection datum tuple.
|
219
233
|
|
@@ -254,7 +268,7 @@ class SegmentationTarget(Protocol):
|
|
254
268
|
def scores(self) -> ArrayLike: ...
|
255
269
|
|
256
270
|
|
257
|
-
SegmentationDatum: TypeAlias = tuple[ArrayLike, SegmentationTarget,
|
271
|
+
SegmentationDatum: TypeAlias = tuple[ArrayLike, SegmentationTarget, DatumMetadata]
|
258
272
|
"""
|
259
273
|
Type alias for an image classification datum tuple.
|
260
274
|
|
@@ -311,3 +325,8 @@ class Transform(Generic[_T], Protocol):
|
|
311
325
|
"""
|
312
326
|
|
313
327
|
def __call__(self, data: _T, /) -> _T: ...
|
328
|
+
|
329
|
+
|
330
|
+
@runtime_checkable
|
331
|
+
class Action(Generic[_T_cn, _T_co], Protocol):
|
332
|
+
def __call__(self, evaluator: _T_cn) -> _T_co: ...
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: dataeval
|
3
|
-
Version: 0.
|
3
|
+
Version: 0.89.0
|
4
4
|
Summary: DataEval provides a simple interface to characterize image data and its impact on model performance across classification and object-detection tasks
|
5
5
|
Project-URL: Homepage, https://dataeval.ai/
|
6
6
|
Project-URL: Repository, https://github.com/aria-ml/dataeval/
|