dataeval 0.81.0__py3-none-any.whl → 0.82.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dataeval/__init__.py +1 -1
- dataeval/config.py +68 -11
- dataeval/detectors/drift/__init__.py +2 -2
- dataeval/detectors/drift/_base.py +8 -64
- dataeval/detectors/drift/_mmd.py +12 -38
- dataeval/detectors/drift/_torch.py +7 -7
- dataeval/detectors/drift/_uncertainty.py +6 -5
- dataeval/detectors/drift/updates.py +20 -3
- dataeval/detectors/linters/__init__.py +3 -2
- dataeval/detectors/linters/duplicates.py +14 -46
- dataeval/detectors/linters/outliers.py +25 -159
- dataeval/detectors/ood/__init__.py +1 -1
- dataeval/detectors/ood/ae.py +6 -5
- dataeval/detectors/ood/base.py +2 -2
- dataeval/detectors/ood/metadata_ood_mi.py +4 -6
- dataeval/detectors/ood/mixin.py +3 -4
- dataeval/detectors/ood/vae.py +3 -2
- dataeval/metadata/__init__.py +2 -1
- dataeval/metadata/_distance.py +134 -0
- dataeval/metadata/_ood.py +30 -49
- dataeval/metadata/_utils.py +44 -0
- dataeval/metrics/bias/__init__.py +5 -4
- dataeval/metrics/bias/_balance.py +17 -149
- dataeval/metrics/bias/_coverage.py +4 -106
- dataeval/metrics/bias/_diversity.py +12 -107
- dataeval/metrics/bias/_parity.py +7 -71
- dataeval/metrics/estimators/__init__.py +5 -4
- dataeval/metrics/estimators/_ber.py +2 -20
- dataeval/metrics/estimators/_clusterer.py +1 -61
- dataeval/metrics/estimators/_divergence.py +2 -19
- dataeval/metrics/estimators/_uap.py +2 -16
- dataeval/metrics/stats/__init__.py +15 -12
- dataeval/metrics/stats/_base.py +41 -128
- dataeval/metrics/stats/_boxratiostats.py +13 -13
- dataeval/metrics/stats/_dimensionstats.py +17 -58
- dataeval/metrics/stats/_hashstats.py +19 -35
- dataeval/metrics/stats/_imagestats.py +94 -0
- dataeval/metrics/stats/_labelstats.py +42 -121
- dataeval/metrics/stats/_pixelstats.py +19 -51
- dataeval/metrics/stats/_visualstats.py +19 -51
- dataeval/outputs/__init__.py +57 -0
- dataeval/outputs/_base.py +182 -0
- dataeval/outputs/_bias.py +381 -0
- dataeval/outputs/_drift.py +83 -0
- dataeval/outputs/_estimators.py +114 -0
- dataeval/outputs/_linters.py +186 -0
- dataeval/outputs/_metadata.py +54 -0
- dataeval/{detectors/ood/output.py → outputs/_ood.py} +22 -22
- dataeval/outputs/_stats.py +393 -0
- dataeval/outputs/_utils.py +44 -0
- dataeval/outputs/_workflows.py +364 -0
- dataeval/typing.py +187 -7
- dataeval/utils/_method.py +1 -5
- dataeval/utils/_plot.py +2 -2
- dataeval/utils/data/__init__.py +5 -1
- dataeval/utils/data/_dataset.py +217 -0
- dataeval/utils/data/_embeddings.py +12 -14
- dataeval/utils/data/_images.py +30 -27
- dataeval/utils/data/_metadata.py +28 -11
- dataeval/utils/data/_selection.py +25 -22
- dataeval/utils/data/_split.py +5 -29
- dataeval/utils/data/_targets.py +14 -2
- dataeval/utils/data/datasets/_base.py +5 -5
- dataeval/utils/data/datasets/_cifar10.py +1 -1
- dataeval/utils/data/datasets/_milco.py +1 -1
- dataeval/utils/data/datasets/_mnist.py +1 -1
- dataeval/utils/data/datasets/_ships.py +1 -1
- dataeval/utils/data/{_types.py → datasets/_types.py} +10 -16
- dataeval/utils/data/datasets/_voc.py +1 -1
- dataeval/utils/data/selections/_classfilter.py +4 -5
- dataeval/utils/data/selections/_indices.py +2 -2
- dataeval/utils/data/selections/_limit.py +2 -2
- dataeval/utils/data/selections/_reverse.py +2 -2
- dataeval/utils/data/selections/_shuffle.py +2 -2
- dataeval/utils/torch/_internal.py +5 -5
- dataeval/utils/torch/trainer.py +8 -8
- dataeval/workflows/__init__.py +2 -1
- dataeval/workflows/sufficiency.py +6 -342
- {dataeval-0.81.0.dist-info → dataeval-0.82.1.dist-info}/METADATA +2 -2
- dataeval-0.82.1.dist-info/RECORD +105 -0
- dataeval/_output.py +0 -137
- dataeval/detectors/ood/metadata_ks_compare.py +0 -129
- dataeval/metrics/stats/_datasetstats.py +0 -198
- dataeval-0.81.0.dist-info/RECORD +0 -94
- {dataeval-0.81.0.dist-info → dataeval-0.82.1.dist-info}/LICENSE.txt +0 -0
- {dataeval-0.81.0.dist-info → dataeval-0.82.1.dist-info}/WHEEL +0 -0
@@ -0,0 +1,381 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
__all__ = []
|
4
|
+
|
5
|
+
import contextlib
|
6
|
+
from dataclasses import asdict, dataclass
|
7
|
+
from typing import Any, Literal, TypeVar, overload
|
8
|
+
|
9
|
+
import numpy as np
|
10
|
+
from numpy.typing import NDArray
|
11
|
+
|
12
|
+
with contextlib.suppress(ImportError):
|
13
|
+
import pandas as pd
|
14
|
+
from matplotlib.figure import Figure
|
15
|
+
|
16
|
+
from dataeval.outputs._base import Output
|
17
|
+
from dataeval.typing import ArrayLike
|
18
|
+
from dataeval.utils._array import to_numpy
|
19
|
+
from dataeval.utils._plot import heatmap
|
20
|
+
|
21
|
+
TData = TypeVar("TData", np.float64, NDArray[np.float64])
|
22
|
+
|
23
|
+
|
24
|
+
class ToDataFrameMixin:
|
25
|
+
score: Any
|
26
|
+
p_value: Any
|
27
|
+
|
28
|
+
def to_dataframe(self) -> pd.DataFrame:
|
29
|
+
"""
|
30
|
+
Exports the parity output results to a pandas DataFrame.
|
31
|
+
|
32
|
+
Returns
|
33
|
+
-------
|
34
|
+
pd.DataFrame
|
35
|
+
|
36
|
+
Notes
|
37
|
+
-----
|
38
|
+
This method requires `pandas <https://pandas.pydata.org/>`_ to be installed.
|
39
|
+
"""
|
40
|
+
import pandas as pd
|
41
|
+
|
42
|
+
return pd.DataFrame(
|
43
|
+
index=self.factor_names, # type: ignore - list[str] is documented as acceptable index type
|
44
|
+
data={
|
45
|
+
"score": self.score.round(2),
|
46
|
+
"p-value": self.p_value.round(2),
|
47
|
+
},
|
48
|
+
)
|
49
|
+
|
50
|
+
|
51
|
+
@dataclass(frozen=True)
|
52
|
+
class ParityOutput(ToDataFrameMixin, Output):
|
53
|
+
"""
|
54
|
+
Output class for :func:`.parity` :term:`bias<Bias>` metrics.
|
55
|
+
|
56
|
+
Attributes
|
57
|
+
----------
|
58
|
+
score : NDArray[np.float64]
|
59
|
+
chi-squared score(s) of the test
|
60
|
+
p_value : NDArray[np.float64]
|
61
|
+
p-value(s) of the test
|
62
|
+
factor_names : list[str]
|
63
|
+
Names of each metadata factor
|
64
|
+
insufficient_data: dict
|
65
|
+
Dictionary of metadata factors with less than 5 class occurrences per value
|
66
|
+
"""
|
67
|
+
|
68
|
+
score: NDArray[np.float64]
|
69
|
+
p_value: NDArray[np.float64]
|
70
|
+
factor_names: list[str]
|
71
|
+
insufficient_data: dict[str, dict[int, dict[str, int]]]
|
72
|
+
|
73
|
+
|
74
|
+
@dataclass(frozen=True)
|
75
|
+
class LabelParityOutput(ToDataFrameMixin, Output):
|
76
|
+
"""
|
77
|
+
Output class for :func:`.label_parity` :term:`bias<Bias>` metrics.
|
78
|
+
|
79
|
+
Attributes
|
80
|
+
----------
|
81
|
+
score : np.float64
|
82
|
+
chi-squared score(s) of the test
|
83
|
+
p_value : np.float64
|
84
|
+
p-value(s) of the test
|
85
|
+
"""
|
86
|
+
|
87
|
+
score: np.float64
|
88
|
+
p_value: np.float64
|
89
|
+
|
90
|
+
|
91
|
+
@dataclass(frozen=True)
|
92
|
+
class CoverageOutput(Output):
|
93
|
+
"""
|
94
|
+
Output class for :func:`.coverage` :term:`bias<Bias>` metric.
|
95
|
+
|
96
|
+
Attributes
|
97
|
+
----------
|
98
|
+
uncovered_indices : NDArray[np.intp]
|
99
|
+
Array of uncovered indices
|
100
|
+
critical_value_radii : NDArray[np.float64]
|
101
|
+
Array of critical value radii
|
102
|
+
coverage_radius : float
|
103
|
+
Radius for :term:`coverage<Coverage>`
|
104
|
+
"""
|
105
|
+
|
106
|
+
uncovered_indices: NDArray[np.intp]
|
107
|
+
critical_value_radii: NDArray[np.float64]
|
108
|
+
coverage_radius: float
|
109
|
+
|
110
|
+
def plot(self, images: ArrayLike, top_k: int = 6) -> Figure:
|
111
|
+
"""
|
112
|
+
Plot the top k images together for visualization.
|
113
|
+
|
114
|
+
Parameters
|
115
|
+
----------
|
116
|
+
images : ArrayLike
|
117
|
+
Original images (not embeddings) in (N, C, H, W) or (N, H, W) format
|
118
|
+
top_k : int, default 6
|
119
|
+
Number of images to plot (plotting assumes groups of 3)
|
120
|
+
|
121
|
+
Returns
|
122
|
+
-------
|
123
|
+
matplotlib.figure.Figure
|
124
|
+
|
125
|
+
Notes
|
126
|
+
-----
|
127
|
+
This method requires `matplotlib <https://matplotlib.org/>`_ to be installed.
|
128
|
+
"""
|
129
|
+
|
130
|
+
import matplotlib.pyplot as plt
|
131
|
+
|
132
|
+
# Determine which images to plot
|
133
|
+
highest_uncovered_indices = self.uncovered_indices[:top_k]
|
134
|
+
|
135
|
+
# Grab the images
|
136
|
+
selected_images = to_numpy(images)[highest_uncovered_indices]
|
137
|
+
|
138
|
+
# Plot the images
|
139
|
+
num_images = min(top_k, len(images))
|
140
|
+
|
141
|
+
ndim = selected_images.ndim
|
142
|
+
if ndim == 4:
|
143
|
+
selected_images = np.moveaxis(selected_images, 1, -1)
|
144
|
+
elif ndim == 3:
|
145
|
+
selected_images = np.repeat(selected_images[:, :, :, np.newaxis], 3, axis=-1)
|
146
|
+
else:
|
147
|
+
raise ValueError(
|
148
|
+
f"Expected a (N,C,H,W) or a (N, H, W) set of images, but got a {ndim}-dimensional set of images."
|
149
|
+
)
|
150
|
+
|
151
|
+
rows = int(np.ceil(num_images / 3))
|
152
|
+
fig, axs = plt.subplots(rows, 3, figsize=(9, 3 * rows))
|
153
|
+
|
154
|
+
if rows == 1:
|
155
|
+
for j in range(3):
|
156
|
+
if j >= len(selected_images):
|
157
|
+
continue
|
158
|
+
axs[j].imshow(selected_images[j])
|
159
|
+
axs[j].axis("off")
|
160
|
+
else:
|
161
|
+
for i in range(rows):
|
162
|
+
for j in range(3):
|
163
|
+
i_j = i * 3 + j
|
164
|
+
if i_j >= len(selected_images):
|
165
|
+
continue
|
166
|
+
axs[i, j].imshow(selected_images[i_j])
|
167
|
+
axs[i, j].axis("off")
|
168
|
+
|
169
|
+
fig.tight_layout()
|
170
|
+
return fig
|
171
|
+
|
172
|
+
|
173
|
+
@dataclass(frozen=True)
|
174
|
+
class BalanceOutput(Output):
|
175
|
+
"""
|
176
|
+
Output class for :func:`.balance` :term:`bias<Bias>` metric.
|
177
|
+
|
178
|
+
Attributes
|
179
|
+
----------
|
180
|
+
balance : NDArray[np.float64]
|
181
|
+
Estimate of mutual information between metadata factors and class label
|
182
|
+
factors : NDArray[np.float64]
|
183
|
+
Estimate of inter/intra-factor mutual information
|
184
|
+
classwise : NDArray[np.float64]
|
185
|
+
Estimate of mutual information between metadata factors and individual class labels
|
186
|
+
factor_names : list[str]
|
187
|
+
Names of each metadata factor
|
188
|
+
class_names : list[str]
|
189
|
+
List of the class labels present in the dataset
|
190
|
+
"""
|
191
|
+
|
192
|
+
balance: NDArray[np.float64]
|
193
|
+
factors: NDArray[np.float64]
|
194
|
+
classwise: NDArray[np.float64]
|
195
|
+
factor_names: list[str]
|
196
|
+
class_names: list[str]
|
197
|
+
|
198
|
+
@overload
|
199
|
+
def _by_factor_type(
|
200
|
+
self,
|
201
|
+
attr: Literal["factor_names"],
|
202
|
+
factor_type: Literal["discrete", "continuous", "both"],
|
203
|
+
) -> list[str]: ...
|
204
|
+
|
205
|
+
@overload
|
206
|
+
def _by_factor_type(
|
207
|
+
self,
|
208
|
+
attr: Literal["balance", "factors", "classwise"],
|
209
|
+
factor_type: Literal["discrete", "continuous", "both"],
|
210
|
+
) -> NDArray[np.float64]: ...
|
211
|
+
|
212
|
+
def _by_factor_type(
|
213
|
+
self,
|
214
|
+
attr: Literal["balance", "factors", "classwise", "factor_names"],
|
215
|
+
factor_type: Literal["discrete", "continuous", "both"],
|
216
|
+
) -> NDArray[np.float64] | list[str]:
|
217
|
+
# if not filtering by factor_type then just return the requested attribute without mask
|
218
|
+
if factor_type == "both":
|
219
|
+
return getattr(self, attr)
|
220
|
+
|
221
|
+
# create the mask for the selected factor_type
|
222
|
+
mask_lambda = (
|
223
|
+
(lambda x: "-continuous" not in x) if factor_type == "discrete" else (lambda x: "-discrete" not in x)
|
224
|
+
)
|
225
|
+
|
226
|
+
# return the masked attribute
|
227
|
+
if attr == "factor_names":
|
228
|
+
return [x.replace(f"-{factor_type}", "") for x in self.factor_names if mask_lambda(x)]
|
229
|
+
else:
|
230
|
+
factor_type_mask = np.asarray([mask_lambda(x) for x in self.factor_names])
|
231
|
+
if attr == "factors":
|
232
|
+
return self.factors[factor_type_mask[1:]][:, factor_type_mask[1:]]
|
233
|
+
elif attr == "balance":
|
234
|
+
return self.balance[factor_type_mask]
|
235
|
+
elif attr == "classwise":
|
236
|
+
return self.classwise[:, factor_type_mask]
|
237
|
+
|
238
|
+
def plot(
|
239
|
+
self,
|
240
|
+
row_labels: list[Any] | NDArray[Any] | None = None,
|
241
|
+
col_labels: list[Any] | NDArray[Any] | None = None,
|
242
|
+
plot_classwise: bool = False,
|
243
|
+
factor_type: Literal["discrete", "continuous", "both"] = "discrete",
|
244
|
+
) -> Figure:
|
245
|
+
"""
|
246
|
+
Plot a heatmap of balance information.
|
247
|
+
|
248
|
+
Parameters
|
249
|
+
----------
|
250
|
+
row_labels : ArrayLike or None, default None
|
251
|
+
List/Array containing the labels for rows in the histogram
|
252
|
+
col_labels : ArrayLike or None, default None
|
253
|
+
List/Array containing the labels for columns in the histogram
|
254
|
+
plot_classwise : bool, default False
|
255
|
+
Whether to plot per-class balance instead of global balance
|
256
|
+
factor_type : "discrete", "continuous", or "both", default "discrete"
|
257
|
+
Whether to plot discretized values, continuous values, or to include both
|
258
|
+
|
259
|
+
Returns
|
260
|
+
-------
|
261
|
+
matplotlib.figure.Figure
|
262
|
+
|
263
|
+
Notes
|
264
|
+
-----
|
265
|
+
This method requires `matplotlib <https://matplotlib.org/>`_ to be installed.
|
266
|
+
"""
|
267
|
+
if plot_classwise:
|
268
|
+
if row_labels is None:
|
269
|
+
row_labels = self.class_names
|
270
|
+
if col_labels is None:
|
271
|
+
col_labels = self._by_factor_type("factor_names", factor_type)
|
272
|
+
|
273
|
+
fig = heatmap(
|
274
|
+
self._by_factor_type("classwise", factor_type),
|
275
|
+
row_labels,
|
276
|
+
col_labels,
|
277
|
+
xlabel="Factors",
|
278
|
+
ylabel="Class",
|
279
|
+
cbarlabel="Normalized Mutual Information",
|
280
|
+
)
|
281
|
+
else:
|
282
|
+
# Combine balance and factors results
|
283
|
+
data = np.concatenate(
|
284
|
+
[
|
285
|
+
self._by_factor_type("balance", factor_type)[np.newaxis, 1:],
|
286
|
+
self._by_factor_type("factors", factor_type),
|
287
|
+
],
|
288
|
+
axis=0,
|
289
|
+
)
|
290
|
+
# Create a mask for the upper triangle of the symmetrical array, ignoring the diagonal
|
291
|
+
mask = np.triu(data + 1, k=0) < 1
|
292
|
+
# Finalize the data for the plot, last row is last factor x last factor so it gets dropped
|
293
|
+
heat_data = np.where(mask, np.nan, data)[:-1]
|
294
|
+
# Creating label array for heat map axes
|
295
|
+
heat_labels = self._by_factor_type("factor_names", factor_type)
|
296
|
+
|
297
|
+
if row_labels is None:
|
298
|
+
row_labels = heat_labels[:-1]
|
299
|
+
if col_labels is None:
|
300
|
+
col_labels = heat_labels[1:]
|
301
|
+
|
302
|
+
fig = heatmap(heat_data, row_labels, col_labels, cbarlabel="Normalized Mutual Information")
|
303
|
+
|
304
|
+
return fig
|
305
|
+
|
306
|
+
|
307
|
+
@dataclass(frozen=True)
|
308
|
+
class DiversityOutput(Output):
|
309
|
+
"""
|
310
|
+
Output class for :func:`.diversity` :term:`bias<Bias>` metric.
|
311
|
+
|
312
|
+
Attributes
|
313
|
+
----------
|
314
|
+
diversity_index : NDArray[np.double]
|
315
|
+
:term:`Diversity` index for classes and factors
|
316
|
+
classwise : NDArray[np.double]
|
317
|
+
Classwise diversity index [n_class x n_factor]
|
318
|
+
factor_names : list[str]
|
319
|
+
Names of each metadata factor
|
320
|
+
class_names : list[str]
|
321
|
+
Class labels for each value in the dataset
|
322
|
+
"""
|
323
|
+
|
324
|
+
diversity_index: NDArray[np.double]
|
325
|
+
classwise: NDArray[np.double]
|
326
|
+
factor_names: list[str]
|
327
|
+
class_names: list[str]
|
328
|
+
|
329
|
+
def plot(
|
330
|
+
self,
|
331
|
+
row_labels: ArrayLike | None = None,
|
332
|
+
col_labels: ArrayLike | None = None,
|
333
|
+
plot_classwise: bool = False,
|
334
|
+
) -> Figure:
|
335
|
+
"""
|
336
|
+
Plot a heatmap of diversity information.
|
337
|
+
|
338
|
+
Parameters
|
339
|
+
----------
|
340
|
+
row_labels : ArrayLike or None, default None
|
341
|
+
List/Array containing the labels for rows in the histogram
|
342
|
+
col_labels : ArrayLike or None, default None
|
343
|
+
List/Array containing the labels for columns in the histogram
|
344
|
+
plot_classwise : bool, default False
|
345
|
+
Whether to plot per-class balance instead of global balance
|
346
|
+
|
347
|
+
Returns
|
348
|
+
-------
|
349
|
+
matplotlib.figure.Figure
|
350
|
+
|
351
|
+
Notes
|
352
|
+
-----
|
353
|
+
This method requires `matplotlib <https://matplotlib.org/>`_ to be installed.
|
354
|
+
"""
|
355
|
+
if plot_classwise:
|
356
|
+
if row_labels is None:
|
357
|
+
row_labels = self.class_names
|
358
|
+
if col_labels is None:
|
359
|
+
col_labels = self.factor_names
|
360
|
+
|
361
|
+
fig = heatmap(
|
362
|
+
self.classwise,
|
363
|
+
row_labels,
|
364
|
+
col_labels,
|
365
|
+
xlabel="Factors",
|
366
|
+
ylabel="Class",
|
367
|
+
cbarlabel=f"Normalized {asdict(self.meta())['arguments']['method'].title()} Index",
|
368
|
+
)
|
369
|
+
|
370
|
+
else:
|
371
|
+
# Creating label array for heat map axes
|
372
|
+
import matplotlib.pyplot as plt
|
373
|
+
|
374
|
+
fig, ax = plt.subplots(figsize=(8, 8))
|
375
|
+
heat_labels = np.concatenate((["class"], self.factor_names))
|
376
|
+
ax.bar(heat_labels, self.diversity_index)
|
377
|
+
ax.set_xlabel("Factors")
|
378
|
+
plt.setp(ax.get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor")
|
379
|
+
fig.tight_layout()
|
380
|
+
|
381
|
+
return fig
|
@@ -0,0 +1,83 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
__all__ = []
|
4
|
+
|
5
|
+
from dataclasses import dataclass
|
6
|
+
|
7
|
+
import numpy as np
|
8
|
+
from numpy.typing import NDArray
|
9
|
+
|
10
|
+
from dataeval.outputs._base import Output
|
11
|
+
|
12
|
+
|
13
|
+
@dataclass(frozen=True)
|
14
|
+
class DriftBaseOutput(Output):
|
15
|
+
"""
|
16
|
+
Base output class for Drift Detector classes
|
17
|
+
"""
|
18
|
+
|
19
|
+
drifted: bool
|
20
|
+
threshold: float
|
21
|
+
p_val: float
|
22
|
+
distance: float
|
23
|
+
|
24
|
+
|
25
|
+
@dataclass(frozen=True)
|
26
|
+
class DriftMMDOutput(DriftBaseOutput):
|
27
|
+
"""
|
28
|
+
Output class for :class:`.DriftMMD` :term:`drift<Drift>` detector.
|
29
|
+
|
30
|
+
Attributes
|
31
|
+
----------
|
32
|
+
drifted : bool
|
33
|
+
Drift prediction for the images
|
34
|
+
threshold : float
|
35
|
+
:term:`P-Value` used for significance of the permutation test
|
36
|
+
p_val : float
|
37
|
+
P-value obtained from the permutation test
|
38
|
+
distance : float
|
39
|
+
MMD^2 between the reference and test set
|
40
|
+
distance_threshold : float
|
41
|
+
MMD^2 threshold above which drift is flagged
|
42
|
+
"""
|
43
|
+
|
44
|
+
# drifted: bool
|
45
|
+
# threshold: float
|
46
|
+
# p_val: float
|
47
|
+
# distance: float
|
48
|
+
distance_threshold: float
|
49
|
+
|
50
|
+
|
51
|
+
@dataclass(frozen=True)
|
52
|
+
class DriftOutput(DriftBaseOutput):
|
53
|
+
"""
|
54
|
+
Output class for :class:`.DriftCVM`, :class:`.DriftKS`, and :class:`.DriftUncertainty` drift detectors.
|
55
|
+
|
56
|
+
Attributes
|
57
|
+
----------
|
58
|
+
drifted : bool
|
59
|
+
:term:`Drift` prediction for the images
|
60
|
+
threshold : float
|
61
|
+
Threshold after multivariate correction if needed
|
62
|
+
p_val : float
|
63
|
+
Instance-level p-value
|
64
|
+
distance : float
|
65
|
+
Instance-level distance
|
66
|
+
feature_drift : NDArray
|
67
|
+
Feature-level array of images detected to have drifted
|
68
|
+
feature_threshold : float
|
69
|
+
Feature-level threshold to determine drift
|
70
|
+
p_vals : NDArray
|
71
|
+
Feature-level p-values
|
72
|
+
distances : NDArray
|
73
|
+
Feature-level distances
|
74
|
+
"""
|
75
|
+
|
76
|
+
# drifted: bool
|
77
|
+
# threshold: float
|
78
|
+
# p_val: float
|
79
|
+
# distance: float
|
80
|
+
feature_drift: NDArray[np.bool_]
|
81
|
+
feature_threshold: float
|
82
|
+
p_vals: NDArray[np.float32]
|
83
|
+
distances: NDArray[np.float32]
|
@@ -0,0 +1,114 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
__all__ = []
|
4
|
+
|
5
|
+
from dataclasses import dataclass
|
6
|
+
|
7
|
+
import numpy as np
|
8
|
+
from numpy.typing import NDArray
|
9
|
+
|
10
|
+
from dataeval.outputs._base import Output
|
11
|
+
|
12
|
+
|
13
|
+
@dataclass(frozen=True)
|
14
|
+
class BEROutput(Output):
|
15
|
+
"""
|
16
|
+
Output class for :func:`.ber` estimator metric.
|
17
|
+
|
18
|
+
Attributes
|
19
|
+
----------
|
20
|
+
ber : float
|
21
|
+
The upper bounds of the :term:`Bayes error rate<Bayes Error Rate (BER)>`
|
22
|
+
ber_lower : float
|
23
|
+
The lower bounds of the Bayes Error Rate
|
24
|
+
"""
|
25
|
+
|
26
|
+
ber: float
|
27
|
+
ber_lower: float
|
28
|
+
|
29
|
+
|
30
|
+
@dataclass(frozen=True)
|
31
|
+
class ClustererOutput(Output):
|
32
|
+
"""
|
33
|
+
Output class for :func:`.clusterer`.
|
34
|
+
|
35
|
+
Attributes
|
36
|
+
----------
|
37
|
+
clusters : NDArray[int]
|
38
|
+
Assigned clusters
|
39
|
+
mst : NDArray[int]
|
40
|
+
The minimum spanning tree of the data
|
41
|
+
linkage_tree : NDArray[float]
|
42
|
+
The linkage array of the data
|
43
|
+
condensed_tree : NDArray[float]
|
44
|
+
The condensed tree of the data
|
45
|
+
membership_strengths : NDArray[float]
|
46
|
+
The strength of the data point belonging to the assigned cluster
|
47
|
+
"""
|
48
|
+
|
49
|
+
clusters: NDArray[np.int_]
|
50
|
+
mst: NDArray[np.double]
|
51
|
+
linkage_tree: NDArray[np.double]
|
52
|
+
condensed_tree: NDArray[np.double]
|
53
|
+
membership_strengths: NDArray[np.double]
|
54
|
+
|
55
|
+
def find_outliers(self) -> NDArray[np.int_]:
|
56
|
+
"""
|
57
|
+
Retrieves Outliers based on when the sample was added to the cluster
|
58
|
+
and how far it was from the cluster when it was added
|
59
|
+
|
60
|
+
Returns
|
61
|
+
-------
|
62
|
+
NDArray[int]
|
63
|
+
A numpy array of the outlier indices
|
64
|
+
"""
|
65
|
+
return np.nonzero(self.clusters == -1)[0]
|
66
|
+
|
67
|
+
def find_duplicates(self) -> tuple[list[list[int]], list[list[int]]]:
|
68
|
+
"""
|
69
|
+
Finds duplicate and near duplicate data based on cluster average distance
|
70
|
+
|
71
|
+
Returns
|
72
|
+
-------
|
73
|
+
Tuple[List[List[int]], List[List[int]]]
|
74
|
+
The exact :term:`duplicates<Duplicates>` and near duplicates as lists of related indices
|
75
|
+
"""
|
76
|
+
# Delay load numba compiled functions
|
77
|
+
from dataeval.utils._clusterer import compare_links_to_cluster_std, sorted_union_find
|
78
|
+
|
79
|
+
exact_indices, near_indices = compare_links_to_cluster_std(self.mst, self.clusters)
|
80
|
+
exact_dupes = sorted_union_find(exact_indices)
|
81
|
+
near_dupes = sorted_union_find(near_indices)
|
82
|
+
|
83
|
+
return [[int(ii) for ii in il] for il in exact_dupes], [[int(ii) for ii in il] for il in near_dupes]
|
84
|
+
|
85
|
+
|
86
|
+
@dataclass(frozen=True)
|
87
|
+
class DivergenceOutput(Output):
|
88
|
+
"""
|
89
|
+
Output class for :func:`.divergence` estimator metric.
|
90
|
+
|
91
|
+
Attributes
|
92
|
+
----------
|
93
|
+
divergence : float
|
94
|
+
:term:`Divergence` value calculated between 2 datasets ranging between 0.0 and 1.0
|
95
|
+
errors : int
|
96
|
+
The number of differing edges between the datasets
|
97
|
+
"""
|
98
|
+
|
99
|
+
divergence: float
|
100
|
+
errors: int
|
101
|
+
|
102
|
+
|
103
|
+
@dataclass(frozen=True)
|
104
|
+
class UAPOutput(Output):
|
105
|
+
"""
|
106
|
+
Output class for :func:`.uap` estimator metric.
|
107
|
+
|
108
|
+
Attributes
|
109
|
+
----------
|
110
|
+
uap : float
|
111
|
+
The empirical mean precision estimate
|
112
|
+
"""
|
113
|
+
|
114
|
+
uap: float
|