dataeval 0.61.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (55) hide show
  1. dataeval/__init__.py +18 -0
  2. dataeval/_internal/detectors/__init__.py +0 -0
  3. dataeval/_internal/detectors/clusterer.py +469 -0
  4. dataeval/_internal/detectors/drift/__init__.py +0 -0
  5. dataeval/_internal/detectors/drift/base.py +265 -0
  6. dataeval/_internal/detectors/drift/cvm.py +97 -0
  7. dataeval/_internal/detectors/drift/ks.py +100 -0
  8. dataeval/_internal/detectors/drift/mmd.py +166 -0
  9. dataeval/_internal/detectors/drift/torch.py +310 -0
  10. dataeval/_internal/detectors/drift/uncertainty.py +149 -0
  11. dataeval/_internal/detectors/duplicates.py +49 -0
  12. dataeval/_internal/detectors/linter.py +78 -0
  13. dataeval/_internal/detectors/ood/__init__.py +0 -0
  14. dataeval/_internal/detectors/ood/ae.py +77 -0
  15. dataeval/_internal/detectors/ood/aegmm.py +69 -0
  16. dataeval/_internal/detectors/ood/base.py +199 -0
  17. dataeval/_internal/detectors/ood/llr.py +284 -0
  18. dataeval/_internal/detectors/ood/vae.py +86 -0
  19. dataeval/_internal/detectors/ood/vaegmm.py +79 -0
  20. dataeval/_internal/flags.py +47 -0
  21. dataeval/_internal/metrics/__init__.py +0 -0
  22. dataeval/_internal/metrics/base.py +92 -0
  23. dataeval/_internal/metrics/ber.py +124 -0
  24. dataeval/_internal/metrics/coverage.py +80 -0
  25. dataeval/_internal/metrics/divergence.py +94 -0
  26. dataeval/_internal/metrics/hash.py +79 -0
  27. dataeval/_internal/metrics/parity.py +180 -0
  28. dataeval/_internal/metrics/stats.py +332 -0
  29. dataeval/_internal/metrics/uap.py +45 -0
  30. dataeval/_internal/metrics/utils.py +158 -0
  31. dataeval/_internal/models/__init__.py +0 -0
  32. dataeval/_internal/models/pytorch/__init__.py +0 -0
  33. dataeval/_internal/models/pytorch/autoencoder.py +202 -0
  34. dataeval/_internal/models/pytorch/blocks.py +46 -0
  35. dataeval/_internal/models/pytorch/utils.py +67 -0
  36. dataeval/_internal/models/tensorflow/__init__.py +0 -0
  37. dataeval/_internal/models/tensorflow/autoencoder.py +317 -0
  38. dataeval/_internal/models/tensorflow/gmm.py +115 -0
  39. dataeval/_internal/models/tensorflow/losses.py +107 -0
  40. dataeval/_internal/models/tensorflow/pixelcnn.py +1106 -0
  41. dataeval/_internal/models/tensorflow/trainer.py +102 -0
  42. dataeval/_internal/models/tensorflow/utils.py +254 -0
  43. dataeval/_internal/workflows/sufficiency.py +555 -0
  44. dataeval/detectors/__init__.py +29 -0
  45. dataeval/flags/__init__.py +3 -0
  46. dataeval/metrics/__init__.py +7 -0
  47. dataeval/models/__init__.py +15 -0
  48. dataeval/models/tensorflow/__init__.py +6 -0
  49. dataeval/models/torch/__init__.py +8 -0
  50. dataeval/py.typed +0 -0
  51. dataeval/workflows/__init__.py +8 -0
  52. dataeval-0.61.0.dist-info/LICENSE.txt +21 -0
  53. dataeval-0.61.0.dist-info/METADATA +114 -0
  54. dataeval-0.61.0.dist-info/RECORD +55 -0
  55. dataeval-0.61.0.dist-info/WHEEL +4 -0
@@ -0,0 +1,265 @@
1
+ """
2
+ Source code derived from Alibi-Detect 0.11.4
3
+ https://github.com/SeldonIO/alibi-detect/tree/v0.11.4
4
+
5
+ Original code Copyright (c) 2023 Seldon Technologies Ltd
6
+ Licensed under Apache Software License (Apache 2.0)
7
+ """
8
+
9
+ from abc import ABC, abstractmethod
10
+ from functools import wraps
11
+ from random import random
12
+ from typing import Callable, Dict, Literal, Optional, Tuple, Union
13
+
14
+ import numpy as np
15
+
16
+
17
+ def update_x_ref(fn):
18
+ @wraps(fn)
19
+ def _(self, x, *args, **kwargs):
20
+ output = fn(self, x, *args, **kwargs)
21
+
22
+ # update reference dataset
23
+ if self.update_x_ref is not None:
24
+ self._x_ref = self.update_x_ref(self.x_ref, x, self.n)
25
+
26
+ # used for reservoir sampling
27
+ self.n += len(x)
28
+ return output
29
+
30
+ return _
31
+
32
+
33
+ def preprocess_x(fn):
34
+ @wraps(fn)
35
+ def _(self, x, *args, **kwargs):
36
+ if self._x_refcount == 0:
37
+ self._x = self._preprocess(x)
38
+ self._x_refcount += 1
39
+ output = fn(self, self._x, *args, **kwargs)
40
+ self._x_refcount -= 1
41
+ if self._x_refcount == 0:
42
+ del self._x
43
+ return output
44
+
45
+ return _
46
+
47
+
48
+ class UpdateStrategy(ABC):
49
+ def __init__(self, n: int):
50
+ self.n = n
51
+
52
+ @abstractmethod
53
+ def __call__(self, x_ref: np.ndarray, x: np.ndarray, count: int) -> np.ndarray:
54
+ """Abstract implementation of update strategy"""
55
+
56
+
57
+ class LastSeenUpdate(UpdateStrategy):
58
+ """
59
+ Updates reference dataset for drift detector using last seen method.
60
+
61
+ Parameters
62
+ ----------
63
+ n : int
64
+ Update with last n instances seen by the detector.
65
+ """
66
+
67
+ def __call__(self, x_ref: np.ndarray, x: np.ndarray, count: int) -> np.ndarray:
68
+ x_updated = np.concatenate([x_ref, x], axis=0)
69
+ return x_updated[-self.n :]
70
+
71
+
72
+ class ReservoirSamplingUpdate(UpdateStrategy):
73
+ """
74
+ Updates reference dataset for drift detector using reservoir sampling method.
75
+
76
+ Parameters
77
+ ----------
78
+ n : int
79
+ Update with reservoir sampling of size n.
80
+ """
81
+
82
+ def __call__(self, x_ref: np.ndarray, x: np.ndarray, count: int) -> np.ndarray:
83
+ if x.shape[0] + count <= self.n:
84
+ return np.concatenate([x_ref, x], axis=0)
85
+
86
+ n_ref = x_ref.shape[0]
87
+ output_size = min(self.n, n_ref + x.shape[0])
88
+ shape = (output_size,) + x.shape[1:]
89
+ x_reservoir = np.zeros(shape, dtype=x_ref.dtype)
90
+ x_reservoir[:n_ref] = x_ref
91
+ for item in x:
92
+ count += 1
93
+ if n_ref < self.n:
94
+ x_reservoir[n_ref, :] = item
95
+ n_ref += 1
96
+ else:
97
+ r = int(random() * count)
98
+ if r < self.n:
99
+ x_reservoir[r, :] = item
100
+ return x_reservoir
101
+
102
+
103
+ class BaseDrift:
104
+ """Generic drift detector component handling preprocessing of data and correction"""
105
+
106
+ def __init__(
107
+ self,
108
+ x_ref: np.ndarray,
109
+ p_val: float = 0.05,
110
+ x_ref_preprocessed: bool = False,
111
+ update_x_ref: Optional[UpdateStrategy] = None,
112
+ preprocess_fn: Optional[Callable[[np.ndarray], np.ndarray]] = None,
113
+ correction: Literal["bonferroni", "fdr"] = "bonferroni",
114
+ ) -> None:
115
+ # Type checking
116
+ if preprocess_fn is not None and not isinstance(preprocess_fn, Callable):
117
+ raise ValueError("`preprocess_fn` is not a valid Callable.")
118
+ if update_x_ref is not None and not isinstance(update_x_ref, UpdateStrategy):
119
+ raise ValueError("`update_x_ref` is not a valid ReferenceUpdate class.")
120
+ if correction not in ["bonferroni", "fdr"]:
121
+ raise ValueError("`correction` must be `bonferroni` or `fdr`.")
122
+
123
+ self._x_ref = x_ref
124
+ self.x_ref_preprocessed = x_ref_preprocessed
125
+
126
+ # Other attributes
127
+ self.p_val = p_val
128
+ self.update_x_ref = update_x_ref
129
+ self.preprocess_fn = preprocess_fn
130
+ self.correction = correction
131
+ self.n = len(x_ref)
132
+
133
+ # Ref counter for preprocessed x
134
+ self._x_refcount = 0
135
+
136
+ @property
137
+ def x_ref(self) -> np.ndarray:
138
+ if not self.x_ref_preprocessed:
139
+ self.x_ref_preprocessed = True
140
+ if self.preprocess_fn is not None:
141
+ self._x_ref = self.preprocess_fn(self._x_ref)
142
+
143
+ return self._x_ref
144
+
145
+ def _preprocess(self, x: np.ndarray) -> np.ndarray:
146
+ """Data preprocessing before computing the drift scores."""
147
+ if self.preprocess_fn is not None:
148
+ x = self.preprocess_fn(x)
149
+ return x
150
+
151
+
152
+ class BaseUnivariateDrift(BaseDrift):
153
+ """
154
+ Generic drift detector component which serves as a base class for methods using
155
+ univariate tests. If n_features > 1, a multivariate correction is applied such
156
+ that the false positive rate is upper bounded by the specified p-value, with
157
+ equality in the case of independent features.
158
+ """
159
+
160
+ def __init__(
161
+ self,
162
+ x_ref: np.ndarray,
163
+ p_val: float = 0.05,
164
+ x_ref_preprocessed: bool = False,
165
+ update_x_ref: Optional[UpdateStrategy] = None,
166
+ preprocess_fn: Optional[Callable[[np.ndarray], np.ndarray]] = None,
167
+ correction: Literal["bonferroni", "fdr"] = "bonferroni",
168
+ n_features: Optional[int] = None,
169
+ ) -> None:
170
+ super().__init__(
171
+ x_ref,
172
+ p_val,
173
+ x_ref_preprocessed,
174
+ update_x_ref,
175
+ preprocess_fn,
176
+ correction,
177
+ )
178
+
179
+ self._n_features = n_features
180
+
181
+ @property
182
+ def n_features(self) -> int:
183
+ # lazy process n_features as needed
184
+ if not isinstance(self._n_features, int):
185
+ # compute number of features for the univariate tests
186
+ if not isinstance(self.preprocess_fn, Callable) or self.x_ref_preprocessed:
187
+ # infer features from preprocessed reference data
188
+ self._n_features = self.x_ref.reshape(self.x_ref.shape[0], -1).shape[-1]
189
+ else:
190
+ # infer number of features after applying preprocessing step
191
+ x = self.preprocess_fn(self.x_ref[0:1])
192
+ self._n_features = x.reshape(x.shape[0], -1).shape[-1]
193
+
194
+ return self._n_features
195
+
196
+ @preprocess_x
197
+ @abstractmethod
198
+ def score(self, x: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
199
+ """Abstract method to calculate feature score after preprocessing"""
200
+
201
+ def _apply_correction(self, p_vals: np.ndarray) -> Tuple[int, float]:
202
+ if self.correction == "bonferroni":
203
+ threshold = self.p_val / self.n_features
204
+ drift_pred = int((p_vals < threshold).any())
205
+ return drift_pred, threshold
206
+ elif self.correction == "fdr":
207
+ n = p_vals.shape[0]
208
+ i = np.arange(n) + 1
209
+ p_sorted = np.sort(p_vals)
210
+ q_threshold = self.p_val * i / n
211
+ below_threshold = p_sorted < q_threshold
212
+ try:
213
+ idx_threshold = int(np.where(below_threshold)[0].max())
214
+ except ValueError: # sorted p-values not below thresholds
215
+ return int(below_threshold.any()), q_threshold.min()
216
+ return int(below_threshold.any()), q_threshold[idx_threshold]
217
+ else:
218
+ raise ValueError("`correction` needs to be either `bonferroni` or `fdr`.")
219
+
220
+ @preprocess_x
221
+ @update_x_ref
222
+ def predict(
223
+ self,
224
+ x: np.ndarray,
225
+ drift_type: Literal["batch", "feature"] = "batch",
226
+ ) -> Dict[str, Union[int, float, np.ndarray]]:
227
+ """
228
+ Predict whether a batch of data has drifted from the reference data and update
229
+ reference data using specified update strategy.
230
+
231
+ Parameters
232
+ ----------
233
+ x : np.ndarray
234
+ Batch of instances.
235
+ drift_type : Literal["batch", "feature"], default "batch"
236
+ Predict drift at the 'feature' or 'batch' level. For 'batch', the test
237
+ statistics for each feature are aggregated using the Bonferroni or False
238
+ Discovery Rate correction (if n_features>1).
239
+
240
+ Returns
241
+ -------
242
+ Dictionary containing the drift prediction and optionally the feature level
243
+ p-values, threshold after multivariate correction if needed and test
244
+ statistics.
245
+ """
246
+ # compute drift scores
247
+ p_vals, dist = self.score(x)
248
+
249
+ # TODO: return both feature-level and batch-level drift predictions by default
250
+ # values below p-value threshold are drift
251
+ if drift_type == "feature":
252
+ drift_pred = (p_vals < self.p_val).astype(int)
253
+ threshold = self.p_val
254
+ elif drift_type == "batch":
255
+ drift_pred, threshold = self._apply_correction(p_vals)
256
+ else:
257
+ raise ValueError("`drift_type` needs to be either `feature` or `batch`.")
258
+
259
+ # populate drift dict
260
+ return {
261
+ "is_drift": drift_pred,
262
+ "p_val": p_vals,
263
+ "threshold": threshold,
264
+ "distance": dist,
265
+ }
@@ -0,0 +1,97 @@
1
+ """
2
+ Source code derived from Alibi-Detect 0.11.4
3
+ https://github.com/SeldonIO/alibi-detect/tree/v0.11.4
4
+
5
+ Original code Copyright (c) 2023 Seldon Technologies Ltd
6
+ Licensed under Apache Software License (Apache 2.0)
7
+ """
8
+
9
+ from typing import Callable, Literal, Optional, Tuple
10
+
11
+ import numpy as np
12
+ from scipy.stats import cramervonmises_2samp
13
+
14
+ from .base import BaseUnivariateDrift, UpdateStrategy, preprocess_x
15
+
16
+
17
+ class DriftCVM(BaseUnivariateDrift):
18
+ """
19
+ Cramér-von Mises (CVM) data drift detector, which tests for any change in the
20
+ distribution of continuous univariate data. For multivariate data, a separate
21
+ CVM test is applied to each feature, and the obtained p-values are aggregated
22
+ via the Bonferroni or False Discovery Rate (FDR) corrections.
23
+
24
+ Parameters
25
+ ----------
26
+ x_ref : np.ndarray
27
+ Data used as reference distribution.
28
+ p_val : float, default 0.05
29
+ p-value used for significance of the statistical test for each feature.
30
+ If the FDR correction method is used, this corresponds to the acceptable
31
+ q-value.
32
+ x_ref_preprocessed : bool, default False
33
+ Whether the given reference data `x_ref` has been preprocessed yet. If
34
+ `x_ref_preprocessed=True`, only the test data `x` will be preprocessed at
35
+ prediction time. If `x_ref_preprocessed=False`, the reference data will also
36
+ be preprocessed.
37
+ update_x_ref : Optional[UpdateStrategy], default None
38
+ Reference data can optionally be updated using an UpdateStrategy class. Update
39
+ using the last n instances seen by the detector with
40
+ :py:class:`dataeval.detectors.LastSeenUpdateStrategy`
41
+ or via reservoir sampling with
42
+ :py:class:`dataeval.detectors.ReservoirSamplingUpdateStrategy`.
43
+ preprocess_fn : Optional[Callable[[np.ndarray], np.ndarray]], default None
44
+ Function to preprocess the data before computing the data drift metrics.
45
+ Typically a dimensionality reduction technique.
46
+ correction : Literal["bonferroni", "fdr"], default "bonferroni"
47
+ Correction type for multivariate data. Either 'bonferroni' or 'fdr' (False
48
+ Discovery Rate).
49
+ n_features
50
+ Number of features used in the statistical test. No need to pass it if no
51
+ preprocessing takes place. In case of a preprocessing step, this can also
52
+ be inferred automatically but could be more expensive to compute.
53
+ """
54
+
55
+ def __init__(
56
+ self,
57
+ x_ref: np.ndarray,
58
+ p_val: float = 0.05,
59
+ x_ref_preprocessed: bool = False,
60
+ update_x_ref: Optional[UpdateStrategy] = None,
61
+ preprocess_fn: Optional[Callable[[np.ndarray], np.ndarray]] = None,
62
+ correction: Literal["bonferroni", "fdr"] = "bonferroni",
63
+ n_features: Optional[int] = None,
64
+ ) -> None:
65
+ super().__init__(
66
+ x_ref=x_ref,
67
+ p_val=p_val,
68
+ x_ref_preprocessed=x_ref_preprocessed,
69
+ update_x_ref=update_x_ref,
70
+ preprocess_fn=preprocess_fn,
71
+ correction=correction,
72
+ n_features=n_features,
73
+ )
74
+
75
+ @preprocess_x
76
+ def score(self, x: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
77
+ """
78
+ Performs the two-sample Cramér-von Mises test(s), computing the p-value and
79
+ test statistic per feature.
80
+
81
+ Parameters
82
+ ----------
83
+ x
84
+ Batch of instances.
85
+
86
+ Returns
87
+ -------
88
+ Feature level p-values and CVM statistics.
89
+ """
90
+ x = x.reshape(x.shape[0], -1)
91
+ x_ref = self.x_ref.reshape(self.x_ref.shape[0], -1)
92
+ p_val = np.zeros(self.n_features, dtype=np.float32)
93
+ dist = np.zeros_like(p_val)
94
+ for f in range(self.n_features):
95
+ result = cramervonmises_2samp(x_ref[:, f], x[:, f], method="auto")
96
+ p_val[f], dist[f] = result.pvalue, result.statistic
97
+ return p_val, dist
@@ -0,0 +1,100 @@
1
+ """
2
+ Source code derived from Alibi-Detect 0.11.4
3
+ https://github.com/SeldonIO/alibi-detect/tree/v0.11.4
4
+
5
+ Original code Copyright (c) 2023 Seldon Technologies Ltd
6
+ Licensed under Apache Software License (Apache 2.0)
7
+ """
8
+
9
+ from typing import Callable, Literal, Optional, Tuple
10
+
11
+ import numpy as np
12
+ from scipy.stats import ks_2samp
13
+
14
+ from .base import BaseUnivariateDrift, UpdateStrategy, preprocess_x
15
+
16
+
17
+ class DriftKS(BaseUnivariateDrift):
18
+ """
19
+ Kolmogorov-Smirnov (K-S) data drift detector with Bonferroni or False Discovery
20
+ Rate (FDR) correction for multivariate data.
21
+
22
+ Parameters
23
+ ----------
24
+ x_ref : np.ndarray
25
+ Data used as reference distribution.
26
+ p_val : float, default 0.05
27
+ p-value used for significance of the statistical test for each feature.
28
+ If the FDR correction method is used, this corresponds to the acceptable
29
+ q-value.
30
+ x_ref_preprocessed : bool, default False
31
+ Whether the given reference data `x_ref` has been preprocessed yet. If
32
+ `x_ref_preprocessed=True`, only the test data `x` will be preprocessed at
33
+ prediction time. If `x_ref_preprocessed=False`, the reference data will also
34
+ be preprocessed.
35
+ update_x_ref : Optional[UpdateStrategy], default None
36
+ Reference data can optionally be updated using an UpdateStrategy class. Update
37
+ using the last n instances seen by the detector with
38
+ :py:class:`dataeval.detectors.LastSeenUpdateStrategy`
39
+ or via reservoir sampling with
40
+ :py:class:`dataeval.detectors.ReservoirSamplingUpdateStrategy`.
41
+ preprocess_fn : Optional[Callable[[np.ndarray], np.ndarray]], default None
42
+ Function to preprocess the data before computing the data drift metrics.
43
+ Typically a dimensionality reduction technique.
44
+ correction : Literal["bonferroni", "fdr"], default "bonferroni"
45
+ Correction type for multivariate data. Either 'bonferroni' or 'fdr' (False
46
+ Discovery Rate).
47
+ alternative : Literal["two-sided", "less", "greater"], default "two-sided"
48
+ Defines the alternative hypothesis. Options are 'two-sided', 'less' or
49
+ 'greater'.
50
+ n_features
51
+ Number of features used in the statistical test. No need to pass it if no
52
+ preprocessing takes place. In case of a preprocessing step, this can also
53
+ be inferred automatically but could be more expensive to compute.
54
+ """
55
+
56
+ def __init__(
57
+ self,
58
+ x_ref: np.ndarray,
59
+ p_val: float = 0.05,
60
+ x_ref_preprocessed: bool = False,
61
+ update_x_ref: Optional[UpdateStrategy] = None,
62
+ preprocess_fn: Optional[Callable[[np.ndarray], np.ndarray]] = None,
63
+ correction: Literal["bonferroni", "fdr"] = "bonferroni",
64
+ alternative: Literal["two-sided", "less", "greater"] = "two-sided",
65
+ n_features: Optional[int] = None,
66
+ ) -> None:
67
+ super().__init__(
68
+ x_ref=x_ref,
69
+ p_val=p_val,
70
+ x_ref_preprocessed=x_ref_preprocessed,
71
+ update_x_ref=update_x_ref,
72
+ preprocess_fn=preprocess_fn,
73
+ correction=correction,
74
+ n_features=n_features,
75
+ )
76
+
77
+ # Other attributes
78
+ self.alternative = alternative
79
+
80
+ @preprocess_x
81
+ def score(self, x: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
82
+ """
83
+ Compute K-S scores and statistics per feature.
84
+
85
+ Parameters
86
+ ----------
87
+ x
88
+ Batch of instances.
89
+
90
+ Returns
91
+ -------
92
+ Feature level p-values and K-S statistics.
93
+ """
94
+ x = x.reshape(x.shape[0], -1)
95
+ x_ref = self.x_ref.reshape(self.x_ref.shape[0], -1)
96
+ p_val = np.zeros(self.n_features, dtype=np.float32)
97
+ dist = np.zeros_like(p_val)
98
+ for f in range(self.n_features):
99
+ dist[f], p_val[f] = ks_2samp(x_ref[:, f], x[:, f], alternative=self.alternative, method="exact")
100
+ return p_val, dist
@@ -0,0 +1,166 @@
1
+ """
2
+ Source code derived from Alibi-Detect 0.11.4
3
+ https://github.com/SeldonIO/alibi-detect/tree/v0.11.4
4
+
5
+ Original code Copyright (c) 2023 Seldon Technologies Ltd
6
+ Licensed under Apache Software License (Apache 2.0)
7
+ """
8
+
9
+ from typing import Callable, Dict, Optional, Tuple, Union
10
+
11
+ import numpy as np
12
+ import torch
13
+
14
+ from .base import BaseDrift, UpdateStrategy, preprocess_x, update_x_ref
15
+ from .torch import GaussianRBF, get_device, mmd2_from_kernel_matrix
16
+
17
+
18
+ class DriftMMD(BaseDrift):
19
+ """
20
+ Maximum Mean Discrepancy (MMD) data drift detector using a permutation test.
21
+
22
+ Parameters
23
+ ----------
24
+ x_ref : np.ndarray
25
+ Data used as reference distribution.
26
+ p_val : float, default 0.05
27
+ p-value used for the significance of the permutation test.
28
+ x_ref_preprocessed : bool, default False
29
+ Whether the given reference data `x_ref` has been preprocessed yet. If
30
+ `x_ref_preprocessed=True`, only the test data `x` will be preprocessed
31
+ at prediction time. If `x_ref_preprocessed=False`, the reference data
32
+ will also be preprocessed.
33
+ preprocess_at_init : bool, default True
34
+ Whether to preprocess the reference data when the detector is instantiated.
35
+ Otherwise, the reference data will be preprocessed at prediction time. Only
36
+ applies if `x_ref_preprocessed=False`.
37
+ update_x_ref : Optional[UpdateStrategy], default None
38
+ Reference data can optionally be updated using an UpdateStrategy class. Update
39
+ using the last n instances seen by the detector with
40
+ :py:class:`dataeval.detectors.LastSeenUpdateStrategy`
41
+ or via reservoir sampling with
42
+ :py:class:`dataeval.detectors.ReservoirSamplingUpdateStrategy`.
43
+ preprocess_fn : Optional[Callable], default None
44
+ Function to preprocess the data before computing the data drift metrics.
45
+ kernel : Callable, default :py:class:`dataeval.detectors.GaussianRBF`
46
+ Kernel used for the MMD computation, defaults to Gaussian RBF kernel.
47
+ sigma : Optional[np.ndarray], default None
48
+ Optionally set the GaussianRBF kernel bandwidth. Can also pass multiple
49
+ bandwidth values as an array. The kernel evaluation is then averaged over
50
+ those bandwidths.
51
+ configure_kernel_from_x_ref : bool, default True
52
+ Whether to already configure the kernel bandwidth from the reference data.
53
+ n_permutations : int, default 100
54
+ Number of permutations used in the permutation test.
55
+ device : Optional[str], default None
56
+ Device type used. The default None uses the GPU and falls back on CPU.
57
+ Can be specified by passing either 'cuda', 'gpu' or 'cpu'.
58
+ """
59
+
60
+ def __init__(
61
+ self,
62
+ x_ref: np.ndarray,
63
+ p_val: float = 0.05,
64
+ x_ref_preprocessed: bool = False,
65
+ update_x_ref: Optional[UpdateStrategy] = None,
66
+ preprocess_fn: Optional[Callable[[np.ndarray], np.ndarray]] = None,
67
+ kernel: Callable = GaussianRBF,
68
+ sigma: Optional[np.ndarray] = None,
69
+ configure_kernel_from_x_ref: bool = True,
70
+ n_permutations: int = 100,
71
+ device: Optional[str] = None,
72
+ ) -> None:
73
+ super().__init__(x_ref, p_val, x_ref_preprocessed, update_x_ref, preprocess_fn)
74
+
75
+ self.infer_sigma = configure_kernel_from_x_ref
76
+ if configure_kernel_from_x_ref and isinstance(sigma, np.ndarray):
77
+ self.infer_sigma = False
78
+
79
+ self.n_permutations = n_permutations # nb of iterations through permutation test
80
+
81
+ # set device
82
+ self.device = get_device(device)
83
+
84
+ # initialize kernel
85
+ sigma_tensor = torch.from_numpy(sigma).to(self.device) if isinstance(sigma, np.ndarray) else None
86
+ self.kernel = kernel(sigma_tensor).to(self.device) if kernel == GaussianRBF else kernel
87
+
88
+ # compute kernel matrix for the reference data
89
+ if self.infer_sigma or isinstance(sigma_tensor, torch.Tensor):
90
+ x = torch.from_numpy(self.x_ref).to(self.device)
91
+ self.k_xx = self.kernel(x, x, infer_sigma=self.infer_sigma)
92
+ self.infer_sigma = False
93
+ else:
94
+ self.k_xx, self.infer_sigma = None, True
95
+
96
+ def _kernel_matrix(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
97
+ """Compute and return full kernel matrix between arrays x and y."""
98
+ k_xy = self.kernel(x, y, self.infer_sigma)
99
+ k_xx = self.k_xx if self.k_xx is not None and self.update_x_ref is None else self.kernel(x, x)
100
+ k_yy = self.kernel(y, y)
101
+ kernel_mat = torch.cat([torch.cat([k_xx, k_xy], 1), torch.cat([k_xy.T, k_yy], 1)], 0)
102
+ return kernel_mat
103
+
104
+ @preprocess_x
105
+ def score(self, x: np.ndarray) -> Tuple[float, float, float]:
106
+ """
107
+ Compute the p-value resulting from a permutation test using the maximum mean
108
+ discrepancy as a distance measure between the reference data and the data to
109
+ be tested.
110
+
111
+ Parameters
112
+ ----------
113
+ x
114
+ Batch of instances.
115
+
116
+ Returns
117
+ -------
118
+ p-value obtained from the permutation test, the MMD^2 between the reference and
119
+ test set, and the MMD^2 threshold above which drift is flagged.
120
+ """
121
+ x_ref = torch.from_numpy(self.x_ref).to(self.device)
122
+ n = x.shape[0]
123
+ kernel_mat = self._kernel_matrix(x_ref, torch.from_numpy(x).to(self.device))
124
+ kernel_mat = kernel_mat - torch.diag(kernel_mat.diag()) # zero diagonal
125
+ mmd2 = mmd2_from_kernel_matrix(kernel_mat, n, permute=False, zero_diag=False)
126
+ mmd2_permuted = torch.Tensor(
127
+ [mmd2_from_kernel_matrix(kernel_mat, n, permute=True, zero_diag=False) for _ in range(self.n_permutations)]
128
+ )
129
+ mmd2, mmd2_permuted = mmd2.cpu(), mmd2_permuted.cpu()
130
+ p_val = (mmd2 <= mmd2_permuted).float().mean()
131
+ # compute distance threshold
132
+ idx_threshold = int(self.p_val * len(mmd2_permuted))
133
+ distance_threshold = torch.sort(mmd2_permuted, descending=True).values[idx_threshold]
134
+ return p_val.numpy().item(), mmd2.numpy().item(), distance_threshold.numpy()
135
+
136
+ @preprocess_x
137
+ @update_x_ref
138
+ def predict(
139
+ self,
140
+ x: np.ndarray,
141
+ ) -> Dict[str, Union[int, float]]:
142
+ """
143
+ Predict whether a batch of data has drifted from the reference data and then
144
+ updates reference data using specified strategy.
145
+
146
+ Parameters
147
+ ----------
148
+ x
149
+ Batch of instances.
150
+
151
+ Returns
152
+ -------
153
+ Dictionary containing the drift prediction, p-value, threshold and MMD metric.
154
+ """
155
+ # compute drift scores
156
+ p_val, dist, distance_threshold = self.score(x)
157
+ drift_pred = int(p_val < self.p_val)
158
+
159
+ # populate drift dict
160
+ return {
161
+ "is_drift": drift_pred,
162
+ "p_val": p_val,
163
+ "threshold": self.p_val,
164
+ "distance": dist,
165
+ "distance_threshold": distance_threshold,
166
+ }