dataeval 0.84.0__py3-none-any.whl → 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dataeval/__init__.py +1 -1
- dataeval/data/__init__.py +19 -0
- dataeval/data/_embeddings.py +345 -0
- dataeval/{utils/data → data}/_images.py +2 -2
- dataeval/{utils/data → data}/_metadata.py +8 -7
- dataeval/{utils/data → data}/_selection.py +22 -9
- dataeval/{utils/data → data}/_split.py +1 -1
- dataeval/data/selections/__init__.py +19 -0
- dataeval/data/selections/_classbalance.py +37 -0
- dataeval/data/selections/_classfilter.py +109 -0
- dataeval/{utils/data → data}/selections/_indices.py +1 -1
- dataeval/{utils/data → data}/selections/_limit.py +1 -1
- dataeval/{utils/data → data}/selections/_prioritize.py +3 -3
- dataeval/{utils/data → data}/selections/_reverse.py +1 -1
- dataeval/{utils/data → data}/selections/_shuffle.py +3 -3
- dataeval/detectors/drift/__init__.py +2 -2
- dataeval/detectors/drift/_base.py +55 -203
- dataeval/detectors/drift/_cvm.py +19 -30
- dataeval/detectors/drift/_ks.py +18 -30
- dataeval/detectors/drift/_mmd.py +189 -53
- dataeval/detectors/drift/_uncertainty.py +52 -56
- dataeval/detectors/drift/updates.py +13 -12
- dataeval/detectors/linters/duplicates.py +6 -4
- dataeval/detectors/linters/outliers.py +3 -3
- dataeval/detectors/ood/ae.py +1 -1
- dataeval/metadata/_distance.py +1 -1
- dataeval/metadata/_ood.py +4 -4
- dataeval/metrics/bias/_balance.py +1 -1
- dataeval/metrics/bias/_diversity.py +1 -1
- dataeval/metrics/bias/_parity.py +1 -1
- dataeval/metrics/stats/_base.py +7 -7
- dataeval/metrics/stats/_dimensionstats.py +2 -2
- dataeval/metrics/stats/_hashstats.py +2 -2
- dataeval/metrics/stats/_imagestats.py +4 -4
- dataeval/metrics/stats/_labelstats.py +2 -2
- dataeval/metrics/stats/_pixelstats.py +2 -2
- dataeval/metrics/stats/_visualstats.py +2 -2
- dataeval/outputs/_bias.py +1 -1
- dataeval/typing.py +53 -19
- dataeval/utils/__init__.py +2 -2
- dataeval/utils/_array.py +18 -7
- dataeval/utils/data/__init__.py +5 -20
- dataeval/utils/data/_dataset.py +6 -4
- dataeval/utils/data/collate.py +2 -0
- dataeval/utils/datasets/__init__.py +17 -0
- dataeval/utils/{data/datasets → datasets}/_base.py +10 -7
- dataeval/utils/{data/datasets → datasets}/_cifar10.py +11 -11
- dataeval/utils/{data/datasets → datasets}/_milco.py +44 -16
- dataeval/utils/{data/datasets → datasets}/_mnist.py +11 -7
- dataeval/utils/{data/datasets → datasets}/_ships.py +10 -6
- dataeval/utils/{data/datasets → datasets}/_voc.py +43 -22
- dataeval/utils/torch/_internal.py +12 -35
- {dataeval-0.84.0.dist-info → dataeval-1.0.0.dist-info}/METADATA +2 -3
- dataeval-1.0.0.dist-info/RECORD +107 -0
- dataeval/detectors/drift/_torch.py +0 -222
- dataeval/utils/data/_embeddings.py +0 -186
- dataeval/utils/data/datasets/__init__.py +0 -17
- dataeval/utils/data/selections/__init__.py +0 -17
- dataeval/utils/data/selections/_classfilter.py +0 -59
- dataeval-0.84.0.dist-info/RECORD +0 -106
- /dataeval/{utils/data → data}/_targets.py +0 -0
- /dataeval/utils/{metadata.py → data/metadata.py} +0 -0
- /dataeval/utils/{data/datasets → datasets}/_fileio.py +0 -0
- /dataeval/utils/{data/datasets → datasets}/_mixin.py +0 -0
- /dataeval/utils/{data/datasets → datasets}/_types.py +0 -0
- {dataeval-0.84.0.dist-info → dataeval-1.0.0.dist-info}/LICENSE.txt +0 -0
- {dataeval-0.84.0.dist-info → dataeval-1.0.0.dist-info}/WHEEL +0 -0
@@ -0,0 +1,109 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
__all__ = []
|
4
|
+
|
5
|
+
from typing import Any, Generic, Iterable, Sequence, Sized, TypeVar, cast
|
6
|
+
|
7
|
+
import numpy as np
|
8
|
+
from numpy.typing import NDArray
|
9
|
+
|
10
|
+
from dataeval.data._selection import Select, Selection, SelectionStage, Subselection
|
11
|
+
from dataeval.typing import Array, ObjectDetectionDatum, ObjectDetectionTarget, SegmentationDatum, SegmentationTarget
|
12
|
+
from dataeval.utils._array import as_numpy
|
13
|
+
from dataeval.utils.data.metadata import flatten
|
14
|
+
|
15
|
+
|
16
|
+
class ClassFilter(Selection[Any]):
|
17
|
+
"""
|
18
|
+
Filter the dataset by class.
|
19
|
+
|
20
|
+
Parameters
|
21
|
+
----------
|
22
|
+
classes : Sequence[int]
|
23
|
+
The classes to filter by.
|
24
|
+
filter_detections : bool, default True
|
25
|
+
Whether to filter detections from targets for object detection and segmentation datasets.
|
26
|
+
"""
|
27
|
+
|
28
|
+
stage = SelectionStage.FILTER
|
29
|
+
|
30
|
+
def __init__(self, classes: Sequence[int], filter_detections: bool = True) -> None:
|
31
|
+
self.classes = classes
|
32
|
+
self.filter_detections = filter_detections
|
33
|
+
|
34
|
+
def __call__(self, dataset: Select[Any]) -> None:
|
35
|
+
if not self.classes:
|
36
|
+
return
|
37
|
+
|
38
|
+
selection = []
|
39
|
+
subselection = set()
|
40
|
+
for idx in dataset._selection:
|
41
|
+
target = dataset._dataset[idx][1]
|
42
|
+
if isinstance(target, Array):
|
43
|
+
# Get the label for the image
|
44
|
+
label = int(np.argmax(as_numpy(target)))
|
45
|
+
# Check to see if the label is in the classes to filter for
|
46
|
+
if label in self.classes:
|
47
|
+
# Include the image
|
48
|
+
selection.append(idx)
|
49
|
+
elif isinstance(target, (ObjectDetectionTarget, SegmentationTarget)):
|
50
|
+
# Get the set of labels from the target
|
51
|
+
labels = set(target.labels if isinstance(target.labels, Iterable) else [target.labels])
|
52
|
+
# Check to see if any labels are in the classes to filter for
|
53
|
+
if labels.intersection(self.classes):
|
54
|
+
# Include the image
|
55
|
+
selection.append(idx)
|
56
|
+
# If we are filtering out other labels and there are other labels, add a subselection filter
|
57
|
+
if self.filter_detections and labels.difference(self.classes):
|
58
|
+
subselection.add(idx)
|
59
|
+
else:
|
60
|
+
raise TypeError(f"ClassFilter does not support targets of type {type(target)}.")
|
61
|
+
|
62
|
+
dataset._selection = selection
|
63
|
+
dataset._subselections.append((ClassFilterSubSelection(self.classes), subselection))
|
64
|
+
|
65
|
+
|
66
|
+
_T = TypeVar("_T")
|
67
|
+
_TDatum = TypeVar("_TDatum", ObjectDetectionDatum, SegmentationDatum)
|
68
|
+
_TTarget = TypeVar("_TTarget", ObjectDetectionTarget, SegmentationTarget)
|
69
|
+
|
70
|
+
|
71
|
+
def _try_mask_object(obj: _T, mask: NDArray[np.bool_]) -> _T:
|
72
|
+
if isinstance(obj, Sized) and not isinstance(obj, (str, bytes, bytearray)) and len(obj) == len(mask):
|
73
|
+
if isinstance(obj, Array):
|
74
|
+
return obj[mask]
|
75
|
+
elif isinstance(obj, Sequence):
|
76
|
+
return cast(_T, [item for i, item in enumerate(obj) if mask[i]])
|
77
|
+
return obj
|
78
|
+
|
79
|
+
|
80
|
+
class ClassFilterTarget(Generic[_TTarget]):
|
81
|
+
def __init__(self, target: _TTarget, mask: NDArray[np.bool_]) -> None:
|
82
|
+
self.__dict__.update(target.__dict__)
|
83
|
+
self._length = len(target.labels) if isinstance(target.labels, Sized) else int(bool(target.labels))
|
84
|
+
self._mask = mask
|
85
|
+
self._target = target
|
86
|
+
|
87
|
+
def __getattribute__(self, name: str) -> Any:
|
88
|
+
if name in ("_length", "_mask", "_target") or name.startswith("__") and name.endswith("__"):
|
89
|
+
return super().__getattribute__(name)
|
90
|
+
|
91
|
+
attr = getattr(self._target, name)
|
92
|
+
return _try_mask_object(attr, self._mask)
|
93
|
+
|
94
|
+
|
95
|
+
class ClassFilterSubSelection(Subselection[Any]):
|
96
|
+
def __init__(self, classes: Sequence[int]) -> None:
|
97
|
+
self.classes = classes
|
98
|
+
|
99
|
+
def __call__(self, datum: _TDatum) -> _TDatum:
|
100
|
+
# build a mask for any arrays
|
101
|
+
image, target, metadata = datum
|
102
|
+
|
103
|
+
mask = np.isin(as_numpy(target.labels), self.classes)
|
104
|
+
flattened_metadata = flatten(metadata)[0]
|
105
|
+
filtered_metadata = {k: _try_mask_object(v, mask) for k, v in flattened_metadata.items()}
|
106
|
+
|
107
|
+
# return a masked datum
|
108
|
+
filtered_datum = image, ClassFilterTarget(target, mask), filtered_metadata
|
109
|
+
return cast(_TDatum, filtered_datum)
|
@@ -14,8 +14,8 @@ from sklearn.cluster import KMeans
|
|
14
14
|
from sklearn.metrics import pairwise_distances
|
15
15
|
|
16
16
|
from dataeval.config import EPSILON, DeviceLike, get_seed
|
17
|
-
from dataeval.
|
18
|
-
from dataeval.
|
17
|
+
from dataeval.data import Embeddings, Select
|
18
|
+
from dataeval.data._selection import Selection, SelectionStage
|
19
19
|
|
20
20
|
_logger = logging.getLogger(__name__)
|
21
21
|
|
@@ -272,7 +272,7 @@ class Prioritize(Selection[Any]):
|
|
272
272
|
return _KMeansComplexitySorter(samples, self._c)
|
273
273
|
|
274
274
|
def _to_normalized_ndarray(self, embeddings: Embeddings, selection: list[int] | None = None) -> NDArray[Any]:
|
275
|
-
emb: NDArray[Any] = embeddings.
|
275
|
+
emb: NDArray[Any] = embeddings.to_numpy(selection)
|
276
276
|
emb /= max(np.max(np.linalg.norm(emb, axis=1)), EPSILON)
|
277
277
|
return emb
|
278
278
|
|
@@ -8,9 +8,9 @@ import numpy as np
|
|
8
8
|
from numpy.random import BitGenerator, Generator, SeedSequence
|
9
9
|
from numpy.typing import NDArray
|
10
10
|
|
11
|
-
from dataeval.
|
11
|
+
from dataeval.data._selection import Select, Selection, SelectionStage
|
12
|
+
from dataeval.typing import Array
|
12
13
|
from dataeval.utils._array import as_numpy
|
13
|
-
from dataeval.utils.data._selection import Select, Selection, SelectionStage
|
14
14
|
|
15
15
|
|
16
16
|
class Shuffle(Selection[Any]):
|
@@ -30,7 +30,7 @@ class Shuffle(Selection[Any]):
|
|
30
30
|
seed: int | NDArray[Any] | SeedSequence | BitGenerator | Generator | None
|
31
31
|
stage = SelectionStage.ORDER
|
32
32
|
|
33
|
-
def __init__(self, seed: int |
|
33
|
+
def __init__(self, seed: int | Sequence[int] | Array | SeedSequence | BitGenerator | Generator | None = None):
|
34
34
|
self.seed = as_numpy(seed) if isinstance(seed, (Sequence, Array)) else seed
|
35
35
|
|
36
36
|
def __call__(self, dataset: Select[Any]) -> None:
|
@@ -9,14 +9,14 @@ __all__ = [
|
|
9
9
|
"DriftMMDOutput",
|
10
10
|
"DriftOutput",
|
11
11
|
"DriftUncertainty",
|
12
|
-
"
|
12
|
+
"UpdateStrategy",
|
13
13
|
"updates",
|
14
14
|
]
|
15
15
|
|
16
16
|
from dataeval.detectors.drift import updates
|
17
|
+
from dataeval.detectors.drift._base import UpdateStrategy
|
17
18
|
from dataeval.detectors.drift._cvm import DriftCVM
|
18
19
|
from dataeval.detectors.drift._ks import DriftKS
|
19
20
|
from dataeval.detectors.drift._mmd import DriftMMD
|
20
|
-
from dataeval.detectors.drift._torch import preprocess_drift
|
21
21
|
from dataeval.detectors.drift._uncertainty import DriftUncertainty
|
22
22
|
from dataeval.outputs._drift import DriftMMDOutput, DriftOutput
|
@@ -13,15 +13,16 @@ __all__ = []
|
|
13
13
|
import math
|
14
14
|
from abc import abstractmethod
|
15
15
|
from functools import wraps
|
16
|
-
from typing import
|
16
|
+
from typing import Callable, Literal, Protocol, TypeVar, runtime_checkable
|
17
17
|
|
18
18
|
import numpy as np
|
19
19
|
from numpy.typing import NDArray
|
20
20
|
|
21
|
+
from dataeval.data import Embeddings
|
21
22
|
from dataeval.outputs import DriftOutput
|
22
23
|
from dataeval.outputs._base import set_metadata
|
23
|
-
from dataeval.typing import Array
|
24
|
-
from dataeval.utils._array import as_numpy,
|
24
|
+
from dataeval.typing import Array
|
25
|
+
from dataeval.utils._array import as_numpy, flatten
|
25
26
|
|
26
27
|
R = TypeVar("R")
|
27
28
|
|
@@ -32,220 +33,88 @@ class UpdateStrategy(Protocol):
|
|
32
33
|
Protocol for reference dataset update strategy for drift detectors
|
33
34
|
"""
|
34
35
|
|
35
|
-
def __call__(self, x_ref: NDArray[
|
36
|
+
def __call__(self, x_ref: NDArray[np.float32], x_new: NDArray[np.float32], count: int) -> NDArray[np.float32]: ...
|
36
37
|
|
37
38
|
|
38
|
-
def
|
39
|
+
def update_strategy(fn: Callable[..., R]) -> Callable[..., R]:
|
39
40
|
"""Decorator to update x_ref with x using selected update methodology"""
|
40
41
|
|
41
42
|
@wraps(fn)
|
42
|
-
def _(self,
|
43
|
-
output = fn(self,
|
43
|
+
def _(self: BaseDrift, data: Embeddings | Array, *args, **kwargs) -> R:
|
44
|
+
output = fn(self, data, *args, **kwargs)
|
44
45
|
|
45
46
|
# update reference dataset
|
46
|
-
if self.
|
47
|
-
self._x_ref = self.
|
47
|
+
if self.update_strategy is not None:
|
48
|
+
self._x_ref = self.update_strategy(self.x_ref, self._encode(data), self.n)
|
49
|
+
self.n += len(data)
|
48
50
|
|
49
|
-
# used for reservoir sampling
|
50
|
-
self.n += len(x)
|
51
|
-
return output
|
52
|
-
|
53
|
-
return _
|
54
|
-
|
55
|
-
|
56
|
-
def preprocess_x(fn: Callable[..., R]) -> Callable[..., R]:
|
57
|
-
"""Decorator to run preprocess_fn on x before calling wrapped function"""
|
58
|
-
|
59
|
-
@wraps(fn)
|
60
|
-
def _(self, x, *args, **kwargs) -> R:
|
61
|
-
if self._x_refcount == 0:
|
62
|
-
self._x = self._preprocess(x)
|
63
|
-
self._x_refcount += 1
|
64
|
-
output = fn(self, self._x, *args, **kwargs)
|
65
|
-
self._x_refcount -= 1
|
66
|
-
if self._x_refcount == 0:
|
67
|
-
del self._x
|
68
51
|
return output
|
69
52
|
|
70
53
|
return _
|
71
54
|
|
72
55
|
|
73
56
|
class BaseDrift:
|
74
|
-
|
75
|
-
|
76
|
-
|
77
|
-
|
78
|
-
the reference data (`x_ref`), performing statistical correction (e.g., Bonferroni, FDR),
|
79
|
-
and updating the reference data if needed.
|
80
|
-
|
81
|
-
Parameters
|
82
|
-
----------
|
83
|
-
x_ref : ArrayLike
|
84
|
-
The reference dataset used for drift detection. This is the baseline data against
|
85
|
-
which new data points will be compared.
|
86
|
-
p_val : float, optional
|
87
|
-
The significance level for detecting drift, by default 0.05.
|
88
|
-
x_ref_preprocessed : bool, optional
|
89
|
-
Flag indicating whether the reference data has already been preprocessed, by default False.
|
90
|
-
update_x_ref : UpdateStrategy, optional
|
91
|
-
A strategy object specifying how the reference data should be updated when drift is detected,
|
92
|
-
by default None.
|
93
|
-
preprocess_fn : Callable[[ArrayLike], ArrayLike], optional
|
94
|
-
A function to preprocess the data before drift detection, by default None.
|
95
|
-
correction : {'bonferroni', 'fdr'}, optional
|
96
|
-
Statistical correction method applied to p-values, by default "bonferroni".
|
97
|
-
|
98
|
-
Attributes
|
99
|
-
----------
|
100
|
-
_x_ref : ArrayLike
|
101
|
-
The reference dataset that is either raw or preprocessed.
|
102
|
-
p_val : float
|
103
|
-
The significance level for drift detection.
|
104
|
-
update_x_ref : UpdateStrategy or None
|
105
|
-
The strategy for updating the reference data if applicable.
|
106
|
-
preprocess_fn : Callable or None
|
107
|
-
Function used for preprocessing input data before drift detection.
|
108
|
-
correction : str
|
109
|
-
Statistical correction method applied to p-values.
|
110
|
-
n : int
|
111
|
-
The number of samples in the reference dataset (`x_ref`).
|
112
|
-
x_ref_preprocessed : bool
|
113
|
-
A flag that indicates whether the reference dataset has been preprocessed.
|
114
|
-
_x_refcount : int
|
115
|
-
Counter for how many times the reference data has been accessed after preprocessing.
|
116
|
-
|
117
|
-
Methods
|
118
|
-
-------
|
119
|
-
x_ref:
|
120
|
-
Property that returns the reference dataset, and applies preprocessing if not already done.
|
121
|
-
_preprocess(x):
|
122
|
-
Preprocesses the given data using the specified `preprocess_fn` if provided.
|
123
|
-
"""
|
57
|
+
p_val: float
|
58
|
+
update_strategy: UpdateStrategy | None
|
59
|
+
correction: Literal["bonferroni", "fdr"]
|
60
|
+
n: int
|
124
61
|
|
125
62
|
def __init__(
|
126
63
|
self,
|
127
|
-
|
64
|
+
data: Embeddings | Array,
|
128
65
|
p_val: float = 0.05,
|
129
|
-
|
130
|
-
update_x_ref: UpdateStrategy | None = None,
|
131
|
-
preprocess_fn: Callable[..., ArrayLike] | None = None,
|
66
|
+
update_strategy: UpdateStrategy | None = None,
|
132
67
|
correction: Literal["bonferroni", "fdr"] = "bonferroni",
|
133
68
|
) -> None:
|
134
69
|
# Type checking
|
135
|
-
if
|
136
|
-
raise ValueError("`
|
137
|
-
if update_x_ref is not None and not isinstance(update_x_ref, UpdateStrategy):
|
138
|
-
raise ValueError("`update_x_ref` is not a valid ReferenceUpdate class.")
|
70
|
+
if update_strategy is not None and not isinstance(update_strategy, UpdateStrategy):
|
71
|
+
raise ValueError("`update_strategy` is not a valid UpdateStrategy class.")
|
139
72
|
if correction not in ["bonferroni", "fdr"]:
|
140
73
|
raise ValueError("`correction` must be `bonferroni` or `fdr`.")
|
141
74
|
|
142
|
-
self.
|
143
|
-
self.x_ref_preprocessed: bool = x_ref_preprocessed
|
144
|
-
|
145
|
-
# Other attributes
|
75
|
+
self._data = data
|
146
76
|
self.p_val = p_val
|
147
|
-
self.
|
148
|
-
self.preprocess_fn = preprocess_fn
|
77
|
+
self.update_strategy = update_strategy
|
149
78
|
self.correction = correction
|
150
|
-
self.n
|
79
|
+
self.n = len(data)
|
151
80
|
|
152
|
-
|
153
|
-
self._x_refcount = 0
|
81
|
+
self._x_ref: NDArray[np.float32] | None = None
|
154
82
|
|
155
83
|
@property
|
156
|
-
def x_ref(self) ->
|
84
|
+
def x_ref(self) -> NDArray[np.float32]:
|
157
85
|
"""
|
158
|
-
Retrieve the reference data
|
86
|
+
Retrieve the reference data of the drift detector.
|
159
87
|
|
160
88
|
Returns
|
161
89
|
-------
|
162
|
-
|
163
|
-
The reference
|
90
|
+
NDArray[np.float32]
|
91
|
+
The reference data as a 32-bit floating point numpy array.
|
164
92
|
"""
|
165
|
-
if
|
166
|
-
self.
|
167
|
-
if self.preprocess_fn is not None:
|
168
|
-
self._x_ref = self.preprocess_fn(self._x_ref)
|
169
|
-
|
93
|
+
if self._x_ref is None:
|
94
|
+
self._x_ref = self._encode(self._data)
|
170
95
|
return self._x_ref
|
171
96
|
|
172
|
-
def
|
173
|
-
|
174
|
-
|
175
|
-
|
176
|
-
|
177
|
-
|
178
|
-
|
179
|
-
|
180
|
-
|
181
|
-
Returns
|
182
|
-
-------
|
183
|
-
ArrayLike
|
184
|
-
The preprocessed input data.
|
185
|
-
"""
|
186
|
-
if self.preprocess_fn is not None:
|
187
|
-
x = self.preprocess_fn(x)
|
188
|
-
return x
|
97
|
+
def _encode(self, data: Embeddings | Array) -> NDArray[np.float32]:
|
98
|
+
array = (
|
99
|
+
data.to_numpy().astype(np.float32)
|
100
|
+
if isinstance(data, Embeddings)
|
101
|
+
else self._data.new(data).to_numpy().astype(np.float32)
|
102
|
+
if isinstance(self._data, Embeddings)
|
103
|
+
else as_numpy(data).astype(np.float32)
|
104
|
+
)
|
105
|
+
return flatten(array)
|
189
106
|
|
190
107
|
|
191
108
|
class BaseDriftUnivariate(BaseDrift):
|
192
|
-
"""
|
193
|
-
Base class for :term:`drift<Drift>` detection methods using univariate statistical tests.
|
194
|
-
|
195
|
-
This class inherits from `BaseDrift` and serves as a generic component for detecting
|
196
|
-
distribution drift in univariate features. If the number of features `n_features` is greater
|
197
|
-
than 1, a multivariate correction method (e.g., Bonferroni or FDR) is applied to control
|
198
|
-
the :term:`false positive rate<False Positive Rate (FP)>`, ensuring it does not exceed the specified
|
199
|
-
:term:`p-value<P-Value>`.
|
200
|
-
|
201
|
-
Parameters
|
202
|
-
----------
|
203
|
-
x_ref : ArrayLike
|
204
|
-
Reference data used as the baseline to compare against when detecting drift.
|
205
|
-
p_val : float, default 0.05
|
206
|
-
Significance level used for detecting drift.
|
207
|
-
x_ref_preprocessed : bool, default False
|
208
|
-
Indicates whether the reference data has been preprocessed.
|
209
|
-
update_x_ref : UpdateStrategy | None, default None
|
210
|
-
Strategy for updating the reference data when drift is detected.
|
211
|
-
preprocess_fn : Callable[ArrayLike] | None, default None
|
212
|
-
Function used to preprocess input data before detecting drift.
|
213
|
-
correction : 'bonferroni' | 'fdr', default 'bonferroni'
|
214
|
-
Multivariate correction method applied to p-values.
|
215
|
-
n_features : int | None, default None
|
216
|
-
Number of features used in the univariate drift tests. If not provided, it will
|
217
|
-
be inferred from the data.
|
218
|
-
|
219
|
-
Attributes
|
220
|
-
----------
|
221
|
-
p_val : float
|
222
|
-
The significance level for drift detection.
|
223
|
-
correction : str
|
224
|
-
The method for controlling the :term:`False Discovery Rate (FDR)` or applying a Bonferroni correction.
|
225
|
-
update_x_ref : UpdateStrategy | None
|
226
|
-
Strategy for updating the reference data if applicable.
|
227
|
-
preprocess_fn : Callable | None
|
228
|
-
Function used for preprocessing input data before drift detection.
|
229
|
-
"""
|
230
|
-
|
231
109
|
def __init__(
|
232
110
|
self,
|
233
|
-
|
111
|
+
data: Embeddings | Array,
|
234
112
|
p_val: float = 0.05,
|
235
|
-
|
236
|
-
update_x_ref: UpdateStrategy | None = None,
|
237
|
-
preprocess_fn: Callable[[ArrayLike], ArrayLike] | None = None,
|
113
|
+
update_strategy: UpdateStrategy | None = None,
|
238
114
|
correction: Literal["bonferroni", "fdr"] = "bonferroni",
|
239
115
|
n_features: int | None = None,
|
240
116
|
) -> None:
|
241
|
-
super().__init__(
|
242
|
-
x_ref,
|
243
|
-
p_val,
|
244
|
-
x_ref_preprocessed,
|
245
|
-
update_x_ref,
|
246
|
-
preprocess_fn,
|
247
|
-
correction,
|
248
|
-
)
|
117
|
+
super().__init__(data, p_val, update_strategy, correction)
|
249
118
|
|
250
119
|
self._n_features = n_features
|
251
120
|
|
@@ -255,8 +124,7 @@ class BaseDriftUnivariate(BaseDrift):
|
|
255
124
|
Get the number of features in the reference data.
|
256
125
|
|
257
126
|
If the number of features is not provided during initialization, it will be inferred
|
258
|
-
from the reference data (``x_ref``).
|
259
|
-
of features will be inferred after applying the preprocessing function.
|
127
|
+
from the reference data (``x_ref``).
|
260
128
|
|
261
129
|
Returns
|
262
130
|
-------
|
@@ -264,48 +132,36 @@ class BaseDriftUnivariate(BaseDrift):
|
|
264
132
|
Number of features in the reference data.
|
265
133
|
"""
|
266
134
|
# lazy process n_features as needed
|
267
|
-
if
|
268
|
-
|
269
|
-
x_ref = (
|
270
|
-
self.x_ref
|
271
|
-
if self.preprocess_fn is None or self.x_ref_preprocessed
|
272
|
-
else self.preprocess_fn(self._x_ref[0:1])
|
273
|
-
)
|
274
|
-
# infer features from preprocessed reference data
|
275
|
-
shape = x_ref.shape if isinstance(x_ref, Array) else as_numpy(x_ref).shape
|
276
|
-
self._n_features = int(math.prod(shape[1:])) # Multiplies all channel sizes after first
|
135
|
+
if self._n_features is None:
|
136
|
+
self._n_features = int(math.prod(self._data[0].shape))
|
277
137
|
|
278
138
|
return self._n_features
|
279
139
|
|
280
|
-
|
281
|
-
def score(self, x: ArrayLike) -> tuple[NDArray[np.float32], NDArray[np.float32]]:
|
140
|
+
def score(self, data: Embeddings | Array) -> tuple[NDArray[np.float32], NDArray[np.float32]]:
|
282
141
|
"""
|
283
142
|
Calculates p-values and test statistics per feature.
|
284
143
|
|
285
144
|
Parameters
|
286
145
|
----------
|
287
|
-
|
288
|
-
Batch of instances
|
146
|
+
data : Embeddings or Array
|
147
|
+
Batch of instances to score.
|
289
148
|
|
290
149
|
Returns
|
291
150
|
-------
|
292
151
|
tuple[NDArray, NDArray]
|
293
152
|
Feature level p-values and test statistics
|
294
153
|
"""
|
295
|
-
x_np =
|
296
|
-
x_np = x_np.reshape(x_np.shape[0], -1)
|
297
|
-
x_ref_np = as_numpy(self.x_ref)
|
298
|
-
x_ref_np = x_ref_np.reshape(x_ref_np.shape[0], -1)
|
154
|
+
x_np = self._encode(data)
|
299
155
|
p_val = np.zeros(self.n_features, dtype=np.float32)
|
300
156
|
dist = np.zeros_like(p_val)
|
301
157
|
for f in range(self.n_features):
|
302
|
-
dist[f], p_val[f] = self._score_fn(
|
158
|
+
dist[f], p_val[f] = self._score_fn(self.x_ref[:, f], x_np[:, f])
|
303
159
|
return p_val, dist
|
304
160
|
|
305
161
|
@abstractmethod
|
306
162
|
def _score_fn(self, x: NDArray[np.float32], y: NDArray[np.float32]) -> tuple[np.float32, np.float32]: ...
|
307
163
|
|
308
|
-
def _apply_correction(self, p_vals: NDArray) -> tuple[bool, float]:
|
164
|
+
def _apply_correction(self, p_vals: NDArray[np.float32]) -> tuple[bool, float]:
|
309
165
|
"""
|
310
166
|
Apply the specified correction method (Bonferroni or FDR) to the p-values.
|
311
167
|
|
@@ -343,20 +199,16 @@ class BaseDriftUnivariate(BaseDrift):
|
|
343
199
|
raise ValueError("`correction` needs to be either `bonferroni` or `fdr`.")
|
344
200
|
|
345
201
|
@set_metadata
|
346
|
-
@
|
347
|
-
|
348
|
-
def predict(
|
349
|
-
self,
|
350
|
-
x: ArrayLike,
|
351
|
-
) -> DriftOutput:
|
202
|
+
@update_strategy
|
203
|
+
def predict(self, data: Embeddings | Array) -> DriftOutput:
|
352
204
|
"""
|
353
205
|
Predict whether a batch of data has drifted from the reference data and update
|
354
206
|
reference data using specified update strategy.
|
355
207
|
|
356
208
|
Parameters
|
357
209
|
----------
|
358
|
-
|
359
|
-
Batch of instances.
|
210
|
+
data : Embeddings or Array
|
211
|
+
Batch of instances to predict drift on.
|
360
212
|
|
361
213
|
Returns
|
362
214
|
-------
|
@@ -365,7 +217,7 @@ class BaseDriftUnivariate(BaseDrift):
|
|
365
217
|
p-values, threshold after multivariate correction if needed and test :term:`statistics<Statistics>`.
|
366
218
|
"""
|
367
219
|
# compute drift scores
|
368
|
-
p_vals, dist = self.score(
|
220
|
+
p_vals, dist = self.score(data)
|
369
221
|
|
370
222
|
feature_drift = (p_vals < self.p_val).astype(np.bool_)
|
371
223
|
drift_pred, threshold = self._apply_correction(p_vals)
|
dataeval/detectors/drift/_cvm.py
CHANGED
@@ -10,14 +10,15 @@ from __future__ import annotations
|
|
10
10
|
|
11
11
|
__all__ = []
|
12
12
|
|
13
|
-
from typing import
|
13
|
+
from typing import Literal
|
14
14
|
|
15
15
|
import numpy as np
|
16
16
|
from numpy.typing import NDArray
|
17
17
|
from scipy.stats import cramervonmises_2samp
|
18
18
|
|
19
|
+
from dataeval.data._embeddings import Embeddings
|
19
20
|
from dataeval.detectors.drift._base import BaseDriftUnivariate, UpdateStrategy
|
20
|
-
from dataeval.typing import
|
21
|
+
from dataeval.typing import Array
|
21
22
|
|
22
23
|
|
23
24
|
class DriftCVM(BaseDriftUnivariate):
|
@@ -31,40 +32,32 @@ class DriftCVM(BaseDriftUnivariate):
|
|
31
32
|
|
32
33
|
Parameters
|
33
34
|
----------
|
34
|
-
|
35
|
+
data : Embeddings or Array
|
35
36
|
Data used as reference distribution.
|
36
|
-
p_val : float
|
37
|
+
p_val : float or None, default 0.05
|
37
38
|
:term:`p-value<P-Value>` used for significance of the statistical test for each feature.
|
38
39
|
If the FDR correction method is used, this corresponds to the acceptable
|
39
40
|
q-value.
|
40
|
-
|
41
|
-
Whether the given reference data ``x_ref`` has been preprocessed yet.
|
42
|
-
If ``True``, only the test data ``x`` will be preprocessed at prediction time.
|
43
|
-
If ``False``, the reference data will also be preprocessed.
|
44
|
-
update_x_ref : UpdateStrategy | None, default None
|
41
|
+
update_strategy : UpdateStrategy or None, default None
|
45
42
|
Reference data can optionally be updated using an UpdateStrategy class. Update
|
46
43
|
using the last n instances seen by the detector with LastSeenUpdateStrategy
|
47
44
|
or via reservoir sampling with ReservoirSamplingUpdateStrategy.
|
48
|
-
|
49
|
-
Function to preprocess the data before computing the data drift metrics.
|
50
|
-
Typically a :term:`dimensionality reduction<Dimensionality Reduction>` technique.
|
51
|
-
correction : "bonferroni" | "fdr", default "bonferroni"
|
45
|
+
correction : "bonferroni" or "fdr", default "bonferroni"
|
52
46
|
Correction type for multivariate data. Either 'bonferroni' or 'fdr' (False
|
53
47
|
Discovery Rate).
|
54
|
-
n_features : int
|
55
|
-
Number of features used in the
|
56
|
-
|
57
|
-
|
48
|
+
n_features : int or None, default None
|
49
|
+
Number of features used in the univariate drift tests. If not provided, it will
|
50
|
+
be inferred from the data.
|
51
|
+
|
58
52
|
|
59
53
|
Example
|
60
54
|
-------
|
61
|
-
>>> from
|
62
|
-
>>> from dataeval.detectors.drift import preprocess_drift
|
55
|
+
>>> from dataeval.data import Embeddings
|
63
56
|
|
64
|
-
Use
|
57
|
+
Use Embeddings to encode images before testing for drift
|
65
58
|
|
66
|
-
>>>
|
67
|
-
>>> drift = DriftCVM(
|
59
|
+
>>> train_emb = Embeddings(train_images, model=encoder, batch_size=64)
|
60
|
+
>>> drift = DriftCVM(train_emb)
|
68
61
|
|
69
62
|
Test incoming images for drift
|
70
63
|
|
@@ -74,20 +67,16 @@ class DriftCVM(BaseDriftUnivariate):
|
|
74
67
|
|
75
68
|
def __init__(
|
76
69
|
self,
|
77
|
-
|
70
|
+
data: Embeddings | Array,
|
78
71
|
p_val: float = 0.05,
|
79
|
-
|
80
|
-
update_x_ref: UpdateStrategy | None = None,
|
81
|
-
preprocess_fn: Callable[[ArrayLike], ArrayLike] | None = None,
|
72
|
+
update_strategy: UpdateStrategy | None = None,
|
82
73
|
correction: Literal["bonferroni", "fdr"] = "bonferroni",
|
83
74
|
n_features: int | None = None,
|
84
75
|
) -> None:
|
85
76
|
super().__init__(
|
86
|
-
|
77
|
+
data=data,
|
87
78
|
p_val=p_val,
|
88
|
-
|
89
|
-
update_x_ref=update_x_ref,
|
90
|
-
preprocess_fn=preprocess_fn,
|
79
|
+
update_strategy=update_strategy,
|
91
80
|
correction=correction,
|
92
81
|
n_features=n_features,
|
93
82
|
)
|