dataeval 0.65.0__py3-none-any.whl → 0.66.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (61) hide show
  1. dataeval/__init__.py +13 -9
  2. dataeval/_internal/detectors/clusterer.py +24 -22
  3. dataeval/_internal/detectors/drift/base.py +206 -26
  4. dataeval/_internal/detectors/drift/cvm.py +25 -23
  5. dataeval/_internal/detectors/drift/ks.py +28 -25
  6. dataeval/_internal/detectors/drift/mmd.py +30 -29
  7. dataeval/_internal/detectors/drift/torch.py +66 -58
  8. dataeval/_internal/detectors/drift/uncertainty.py +28 -28
  9. dataeval/_internal/detectors/duplicates.py +28 -18
  10. dataeval/_internal/detectors/ood/ae.py +15 -29
  11. dataeval/_internal/detectors/ood/aegmm.py +33 -27
  12. dataeval/_internal/detectors/ood/base.py +61 -43
  13. dataeval/_internal/detectors/ood/llr.py +27 -24
  14. dataeval/_internal/detectors/ood/vae.py +32 -31
  15. dataeval/_internal/detectors/ood/vaegmm.py +34 -28
  16. dataeval/_internal/detectors/{linter.py → outliers.py} +33 -27
  17. dataeval/_internal/flags.py +5 -3
  18. dataeval/_internal/interop.py +4 -2
  19. dataeval/_internal/metrics/balance.py +33 -4
  20. dataeval/_internal/metrics/ber.py +6 -4
  21. dataeval/_internal/metrics/diversity.py +45 -12
  22. dataeval/_internal/metrics/parity.py +114 -26
  23. dataeval/_internal/metrics/stats.py +154 -16
  24. dataeval/_internal/metrics/uap.py +28 -2
  25. dataeval/_internal/metrics/utils.py +20 -18
  26. dataeval/_internal/models/pytorch/autoencoder.py +127 -22
  27. dataeval/_internal/models/tensorflow/autoencoder.py +33 -30
  28. dataeval/_internal/models/tensorflow/gmm.py +4 -2
  29. dataeval/_internal/models/tensorflow/losses.py +15 -11
  30. dataeval/_internal/models/tensorflow/pixelcnn.py +19 -18
  31. dataeval/_internal/models/tensorflow/trainer.py +8 -6
  32. dataeval/_internal/models/tensorflow/utils.py +21 -19
  33. dataeval/_internal/output.py +13 -10
  34. dataeval/_internal/utils.py +5 -3
  35. dataeval/_internal/workflows/sufficiency.py +42 -30
  36. dataeval/detectors/__init__.py +6 -25
  37. dataeval/detectors/drift/__init__.py +16 -0
  38. dataeval/detectors/drift/kernels/__init__.py +6 -0
  39. dataeval/detectors/drift/updates/__init__.py +3 -0
  40. dataeval/detectors/linters/__init__.py +5 -0
  41. dataeval/detectors/ood/__init__.py +11 -0
  42. dataeval/metrics/__init__.py +2 -26
  43. dataeval/metrics/bias/__init__.py +14 -0
  44. dataeval/metrics/estimators/__init__.py +9 -0
  45. dataeval/metrics/stats/__init__.py +6 -0
  46. dataeval/tensorflow/__init__.py +3 -0
  47. dataeval/tensorflow/loss/__init__.py +3 -0
  48. dataeval/tensorflow/models/__init__.py +5 -0
  49. dataeval/tensorflow/recon/__init__.py +3 -0
  50. dataeval/torch/__init__.py +3 -0
  51. dataeval/{models/torch → torch/models}/__init__.py +1 -2
  52. dataeval/torch/trainer/__init__.py +3 -0
  53. dataeval/utils/__init__.py +3 -6
  54. dataeval/workflows/__init__.py +2 -4
  55. {dataeval-0.65.0.dist-info → dataeval-0.66.0.dist-info}/METADATA +1 -1
  56. dataeval-0.66.0.dist-info/RECORD +72 -0
  57. dataeval/models/__init__.py +0 -15
  58. dataeval/models/tensorflow/__init__.py +0 -6
  59. dataeval-0.65.0.dist-info/RECORD +0 -60
  60. {dataeval-0.65.0.dist-info → dataeval-0.66.0.dist-info}/LICENSE.txt +0 -0
  61. {dataeval-0.65.0.dist-info → dataeval-0.66.0.dist-info}/WHEEL +0 -0
@@ -6,7 +6,9 @@ Original code Copyright (c) 2023 Seldon Technologies Ltd
6
6
  Licensed under Apache Software License (Apache 2.0)
7
7
  """
8
8
 
9
- from typing import Callable, Literal, Optional, Tuple
9
+ from __future__ import annotations
10
+
11
+ from typing import Callable, Literal
10
12
 
11
13
  import numpy as np
12
14
  from numpy.typing import ArrayLike, NDArray
@@ -19,38 +21,38 @@ from .base import BaseDriftUnivariate, UpdateStrategy, preprocess_x
19
21
 
20
22
  class DriftKS(BaseDriftUnivariate):
21
23
  """
22
- Kolmogorov-Smirnov (K-S) data drift detector with Bonferroni or False Discovery
23
- Rate (FDR) correction for multivariate data.
24
+ Drift detector employing the Kolmogorov-Smirnov (KS) distribution test.
25
+
26
+ The KS test detects changes in the maximum distance between two data
27
+ distributions with Bonferroni or False Discovery Rate (FDR) correction
28
+ for multivariate data.
24
29
 
25
30
  Parameters
26
31
  ----------
27
- x_ref : NDArray
32
+ x_ref : ArrayLike
28
33
  Data used as reference distribution.
29
- p_val : float, default 0.05
34
+ p_val : float | None, default 0.05
30
35
  p-value used for significance of the statistical test for each feature.
31
36
  If the FDR correction method is used, this corresponds to the acceptable
32
37
  q-value.
33
38
  x_ref_preprocessed : bool, default False
34
- Whether the given reference data `x_ref` has been preprocessed yet. If
35
- `x_ref_preprocessed=True`, only the test data `x` will be preprocessed at
36
- prediction time. If `x_ref_preprocessed=False`, the reference data will also
37
- be preprocessed.
38
- update_x_ref : Optional[UpdateStrategy], default None
39
+ Whether the given reference data ``x_ref`` has been preprocessed yet.
40
+ If ``True``, only the test data ``x`` will be preprocessed at prediction time.
41
+ If ``False``, the reference data will also be preprocessed.
42
+ update_x_ref : UpdateStrategy | None, default None
39
43
  Reference data can optionally be updated using an UpdateStrategy class. Update
40
- using the last n instances seen by the detector with
41
- :py:class:`dataeval.detectors.LastSeenUpdateStrategy`
42
- or via reservoir sampling with
43
- :py:class:`dataeval.detectors.ReservoirSamplingUpdateStrategy`.
44
- preprocess_fn : Optional[Callable[[NDArray], NDArray]], default None
44
+ using the last n instances seen by the detector with LastSeenUpdateStrategy
45
+ or via reservoir sampling with ReservoirSamplingUpdateStrategy.
46
+ preprocess_fn : Callable | None, default None
45
47
  Function to preprocess the data before computing the data drift metrics.
46
48
  Typically a dimensionality reduction technique.
47
- correction : Literal["bonferroni", "fdr"], default "bonferroni"
49
+ correction : "bonferroni" | "fdr", default "bonferroni"
48
50
  Correction type for multivariate data. Either 'bonferroni' or 'fdr' (False
49
51
  Discovery Rate).
50
- alternative : Literal["two-sided", "less", "greater"], default "two-sided"
52
+ alternative : "two-sided" | "less" | "greater", default "two-sided"
51
53
  Defines the alternative hypothesis. Options are 'two-sided', 'less' or
52
54
  'greater'.
53
- n_features
55
+ n_features : int | None, default None
54
56
  Number of features used in the statistical test. No need to pass it if no
55
57
  preprocessing takes place. In case of a preprocessing step, this can also
56
58
  be inferred automatically but could be more expensive to compute.
@@ -61,11 +63,11 @@ class DriftKS(BaseDriftUnivariate):
61
63
  x_ref: ArrayLike,
62
64
  p_val: float = 0.05,
63
65
  x_ref_preprocessed: bool = False,
64
- update_x_ref: Optional[UpdateStrategy] = None,
65
- preprocess_fn: Optional[Callable[[ArrayLike], ArrayLike]] = None,
66
+ update_x_ref: UpdateStrategy | None = None,
67
+ preprocess_fn: Callable[[ArrayLike], ArrayLike] | None = None,
66
68
  correction: Literal["bonferroni", "fdr"] = "bonferroni",
67
69
  alternative: Literal["two-sided", "less", "greater"] = "two-sided",
68
- n_features: Optional[int] = None,
70
+ n_features: int | None = None,
69
71
  ) -> None:
70
72
  super().__init__(
71
73
  x_ref=x_ref,
@@ -81,18 +83,19 @@ class DriftKS(BaseDriftUnivariate):
81
83
  self.alternative = alternative
82
84
 
83
85
  @preprocess_x
84
- def score(self, x: ArrayLike) -> Tuple[NDArray[np.float32], NDArray[np.float32]]:
86
+ def score(self, x: ArrayLike) -> tuple[NDArray[np.float32], NDArray[np.float32]]:
85
87
  """
86
- Compute K-S scores and statistics per feature.
88
+ Compute KS scores and statistics per feature.
87
89
 
88
90
  Parameters
89
91
  ----------
90
- x
92
+ x : ArrayLike
91
93
  Batch of instances.
92
94
 
93
95
  Returns
94
96
  -------
95
- Feature level p-values and K-S statistics.
97
+ tuple[NDArray, NDArray]
98
+ Feature level p-values and KS statistic
96
99
  """
97
100
  x = to_numpy(x)
98
101
  x = x.reshape(x.shape[0], -1)
@@ -6,8 +6,10 @@ Original code Copyright (c) 2023 Seldon Technologies Ltd
6
6
  Licensed under Apache Software License (Apache 2.0)
7
7
  """
8
8
 
9
+ from __future__ import annotations
10
+
9
11
  from dataclasses import dataclass
10
- from typing import Callable, Optional, Tuple
12
+ from typing import Callable
11
13
 
12
14
  import torch
13
15
  from numpy.typing import ArrayLike
@@ -15,13 +17,15 @@ from numpy.typing import ArrayLike
15
17
  from dataeval._internal.interop import to_numpy
16
18
  from dataeval._internal.output import set_metadata
17
19
 
18
- from .base import BaseDrift, DriftOutput, UpdateStrategy, preprocess_x, update_x_ref
20
+ from .base import BaseDrift, DriftBaseOutput, UpdateStrategy, preprocess_x, update_x_ref
19
21
  from .torch import GaussianRBF, get_device, mmd2_from_kernel_matrix
20
22
 
21
23
 
22
24
  @dataclass(frozen=True)
23
- class DriftMMDOutput(DriftOutput):
25
+ class DriftMMDOutput(DriftBaseOutput):
24
26
  """
27
+ Output class for DriftMMD
28
+
25
29
  Attributes
26
30
  ----------
27
31
  is_drift : bool
@@ -51,28 +55,24 @@ class DriftMMD(BaseDrift):
51
55
  ----------
52
56
  x_ref : ArrayLike
53
57
  Data used as reference distribution.
54
- p_val : float, default 0.05
55
- p-value used for the significance of the permutation test.
58
+ p_val : float | None, default 0.05
59
+ p-value used for significance of the statistical test for each feature.
60
+ If the FDR correction method is used, this corresponds to the acceptable
61
+ q-value.
56
62
  x_ref_preprocessed : bool, default False
57
- Whether the given reference data `x_ref` has been preprocessed yet. If
58
- `x_ref_preprocessed=True`, only the test data `x` will be preprocessed
59
- at prediction time. If `x_ref_preprocessed=False`, the reference data
60
- will also be preprocessed.
61
- preprocess_at_init : bool, default True
62
- Whether to preprocess the reference data when the detector is instantiated.
63
- Otherwise, the reference data will be preprocessed at prediction time. Only
64
- applies if `x_ref_preprocessed=False`.
65
- update_x_ref : Optional[UpdateStrategy], default None
63
+ Whether the given reference data ``x_ref`` has been preprocessed yet.
64
+ If ``True``, only the test data ``x`` will be preprocessed at prediction time.
65
+ If ``False``, the reference data will also be preprocessed.
66
+ update_x_ref : UpdateStrategy | None, default None
66
67
  Reference data can optionally be updated using an UpdateStrategy class. Update
67
- using the last n instances seen by the detector with
68
- :py:class:`dataeval.detectors.LastSeenUpdateStrategy`
69
- or via reservoir sampling with
70
- :py:class:`dataeval.detectors.ReservoirSamplingUpdateStrategy`.
71
- preprocess_fn : Optional[Callable], default None
68
+ using the last n instances seen by the detector with LastSeenUpdateStrategy
69
+ or via reservoir sampling with ReservoirSamplingUpdateStrategy.
70
+ preprocess_fn : Callable | None, default None
72
71
  Function to preprocess the data before computing the data drift metrics.
73
- kernel : Callable, default :py:class:`dataeval.detectors.GaussianRBF`
72
+ Typically a dimensionality reduction technique.
73
+ kernel : Callable, default GaussianRBF
74
74
  Kernel used for the MMD computation, defaults to Gaussian RBF kernel.
75
- sigma : Optional[ArrayLike], default None
75
+ sigma : ArrayLike | None, default None
76
76
  Optionally set the GaussianRBF kernel bandwidth. Can also pass multiple
77
77
  bandwidth values as an array. The kernel evaluation is then averaged over
78
78
  those bandwidths.
@@ -80,7 +80,7 @@ class DriftMMD(BaseDrift):
80
80
  Whether to already configure the kernel bandwidth from the reference data.
81
81
  n_permutations : int, default 100
82
82
  Number of permutations used in the permutation test.
83
- device : Optional[str], default None
83
+ device : str | None, default None
84
84
  Device type used. The default None uses the GPU and falls back on CPU.
85
85
  Can be specified by passing either 'cuda', 'gpu' or 'cpu'.
86
86
  """
@@ -90,13 +90,13 @@ class DriftMMD(BaseDrift):
90
90
  x_ref: ArrayLike,
91
91
  p_val: float = 0.05,
92
92
  x_ref_preprocessed: bool = False,
93
- update_x_ref: Optional[UpdateStrategy] = None,
94
- preprocess_fn: Optional[Callable[[ArrayLike], ArrayLike]] = None,
93
+ update_x_ref: UpdateStrategy | None = None,
94
+ preprocess_fn: Callable[[ArrayLike], ArrayLike] | None = None,
95
95
  kernel: Callable = GaussianRBF,
96
- sigma: Optional[ArrayLike] = None,
96
+ sigma: ArrayLike | None = None,
97
97
  configure_kernel_from_x_ref: bool = True,
98
98
  n_permutations: int = 100,
99
- device: Optional[str] = None,
99
+ device: str | None = None,
100
100
  ) -> None:
101
101
  super().__init__(x_ref, p_val, x_ref_preprocessed, update_x_ref, preprocess_fn)
102
102
 
@@ -130,7 +130,7 @@ class DriftMMD(BaseDrift):
130
130
  return kernel_mat
131
131
 
132
132
  @preprocess_x
133
- def score(self, x: ArrayLike) -> Tuple[float, float, float]:
133
+ def score(self, x: ArrayLike) -> tuple[float, float, float]:
134
134
  """
135
135
  Compute the p-value resulting from a permutation test using the maximum mean
136
136
  discrepancy as a distance measure between the reference data and the data to
@@ -143,8 +143,9 @@ class DriftMMD(BaseDrift):
143
143
 
144
144
  Returns
145
145
  -------
146
- p-value obtained from the permutation test, the MMD^2 between the reference and
147
- test set, and the MMD^2 threshold above which drift is flagged.
146
+ tuple(float, float, float)
147
+ p-value obtained from the permutation test, MMD^2 between the reference and test set,
148
+ and MMD^2 threshold above which drift is flagged
148
149
  """
149
150
  x = to_numpy(x)
150
151
  x_ref = torch.from_numpy(self.x_ref).to(self.device)
@@ -6,8 +6,10 @@ Original code Copyright (c) 2023 Seldon Technologies Ltd
6
6
  Licensed under Apache Software License (Apache 2.0)
7
7
  """
8
8
 
9
+ from __future__ import annotations
10
+
9
11
  from functools import partial
10
- from typing import Callable, Optional, Type, Union
12
+ from typing import Callable
11
13
 
12
14
  import numpy as np
13
15
  import torch
@@ -15,15 +17,15 @@ import torch.nn as nn
15
17
  from numpy.typing import NDArray
16
18
 
17
19
 
18
- def get_device(device: Optional[Union[str, torch.device]] = None) -> torch.device:
20
+ def get_device(device: str | torch.device | None = None) -> torch.device:
19
21
  """
20
22
  Instantiates a PyTorch device object.
21
23
 
22
24
  Parameters
23
25
  ----------
24
- device
25
- Either `None`, a str ('gpu' or 'cpu') indicating the device to choose, or an
26
- already instantiated device object. If `None`, the GPU is selected if it is
26
+ device : str | torch.device | None, default None
27
+ Either ``None``, a str ('gpu' or 'cpu') indicating the device to choose, or an
28
+ already instantiated device object. If ``None``, the GPU is selected if it is
27
29
  detected, otherwise the CPU is used as a fallback.
28
30
 
29
31
  Returns
@@ -49,18 +51,19 @@ def mmd2_from_kernel_matrix(
49
51
 
50
52
  Parameters
51
53
  ----------
52
- kernel_mat
54
+ kernel_mat : torch.Tensor
53
55
  Kernel matrix between samples x and y.
54
- m
56
+ m : int
55
57
  Number of instances in y.
56
- permute
58
+ permute : bool, default False
57
59
  Whether to permute the row indices. Used for permutation tests.
58
- zero_diag
60
+ zero_diag : bool, default True
59
61
  Whether to zero out the diagonal of the kernel matrix.
60
62
 
61
63
  Returns
62
64
  -------
63
- MMD^2 between the samples from the kernel matrix.
65
+ torch.Tensor
66
+ MMD^2 between the samples from the kernel matrix.
64
67
  """
65
68
  n = kernel_mat.shape[0] - m
66
69
  if zero_diag:
@@ -75,35 +78,36 @@ def mmd2_from_kernel_matrix(
75
78
 
76
79
 
77
80
  def predict_batch(
78
- x: Union[NDArray, torch.Tensor],
79
- model: Union[Callable, nn.Module, nn.Sequential],
80
- device: Optional[torch.device] = None,
81
+ x: NDArray | torch.Tensor,
82
+ model: Callable | nn.Module | nn.Sequential,
83
+ device: torch.device | None = None,
81
84
  batch_size: int = int(1e10),
82
- preprocess_fn: Optional[Callable] = None,
83
- dtype: Union[Type[np.generic], torch.dtype] = np.float32,
84
- ) -> Union[NDArray, torch.Tensor, tuple]:
85
+ preprocess_fn: Callable | None = None,
86
+ dtype: type[np.generic] | torch.dtype = np.float32,
87
+ ) -> NDArray | torch.Tensor | tuple:
85
88
  """
86
89
  Make batch predictions on a model.
87
90
 
88
91
  Parameters
89
92
  ----------
90
- x
93
+ x : np.ndarray | torch.Tensor
91
94
  Batch of instances.
92
- model
95
+ model : Callable | nn.Module | nn.Sequential
93
96
  PyTorch model.
94
- device
97
+ device : torch.device | None, default None
95
98
  Device type used. The default None tries to use the GPU and falls back on CPU.
96
99
  Can be specified by passing either torch.device('cuda') or torch.device('cpu').
97
- batch_size
100
+ batch_size : int, default 1e10
98
101
  Batch size used during prediction.
99
- preprocess_fn
102
+ preprocess_fn : Callable | None, default None
100
103
  Optional preprocessing function for each batch.
101
- dtype
102
- Model output type, e.g. np.float32 or torch.float32.
104
+ dtype : np.dtype | torch.dtype, default np.float32
105
+ Model output type, either a numpy or torch dtype, e.g. np.float32 or torch.float32.
103
106
 
104
107
  Returns
105
108
  -------
106
- Numpy array, torch tensor or tuples of those with model outputs.
109
+ NDArray | torch.Tensor | tuple
110
+ Numpy array, torch tensor or tuples of those with model outputs.
107
111
  """
108
112
  device = get_device(device)
109
113
  if isinstance(x, np.ndarray):
@@ -143,7 +147,7 @@ def predict_batch(
143
147
  torch.Tensor."
144
148
  )
145
149
  concat = partial(np.concatenate, axis=0) if return_np else partial(torch.cat, dim=0)
146
- out: Union[tuple, np.ndarray, torch.Tensor] = (
150
+ out: tuple | np.ndarray | torch.Tensor = (
147
151
  tuple(concat(p) for p in preds) if isinstance(preds, tuple) else concat(preds) # type: ignore
148
152
  )
149
153
  return out
@@ -152,34 +156,35 @@ def predict_batch(
152
156
  def preprocess_drift(
153
157
  x: NDArray,
154
158
  model: nn.Module,
155
- device: Optional[torch.device] = None,
156
- preprocess_batch_fn: Optional[Callable] = None,
159
+ device: torch.device | None = None,
160
+ preprocess_batch_fn: Callable | None = None,
157
161
  batch_size: int = int(1e10),
158
- dtype: Union[Type[np.generic], torch.dtype] = np.float32,
159
- ) -> Union[NDArray, torch.Tensor, tuple]:
162
+ dtype: type[np.generic] | torch.dtype = np.float32,
163
+ ) -> NDArray | torch.Tensor | tuple:
160
164
  """
161
165
  Prediction function used for preprocessing step of drift detector.
162
166
 
163
167
  Parameters
164
168
  ----------
165
- x
169
+ x : NDArray
166
170
  Batch of instances.
167
- model
171
+ model : nn.Module
168
172
  Model used for preprocessing.
169
- device
173
+ device : torch.device | None, default None
170
174
  Device type used. The default None tries to use the GPU and falls back on CPU.
171
175
  Can be specified by passing either torch.device('cuda') or torch.device('cpu').
172
- preprocess_batch_fn
176
+ preprocess_batch_fn : Callable | None, default None
173
177
  Optional batch preprocessing function. For example to convert a list of objects
174
178
  to a batch which can be processed by the PyTorch model.
175
- batch_size
179
+ batch_size : int, default 1e10
176
180
  Batch size used during prediction.
177
- dtype
178
- Model output type, e.g. np.float32 or torch.float32.
181
+ dtype : np.dtype | torch.dtype, default np.float32
182
+ Model output type, either a numpy or torch dtype, e.g. np.float32 or torch.float32.
179
183
 
180
184
  Returns
181
185
  -------
182
- Numpy array or torch tensor with predictions.
186
+ NDArray | torch.Tensor | tuple
187
+ Numpy array, torch tensor or tuples of those with model outputs.
183
188
  """
184
189
  return predict_batch(
185
190
  x,
@@ -200,15 +205,17 @@ def squared_pairwise_distance(
200
205
 
201
206
  Parameters
202
207
  ----------
203
- x
208
+ x : torch.Tensor
204
209
  Batch of instances of shape [Nx, features].
205
- y
210
+ y : torch.Tensor
206
211
  Batch of instances of shape [Ny, features].
207
- a_min
212
+ a_min : float
208
213
  Lower bound to clip distance values.
214
+
209
215
  Returns
210
216
  -------
211
- Pairwise squared Euclidean distance [Nx, Ny].
217
+ torch.Tensor
218
+ Pairwise squared Euclidean distance [Nx, Ny].
212
219
  """
213
220
  x2 = x.pow(2).sum(dim=-1, keepdim=True)
214
221
  y2 = y.pow(2).sum(dim=-1, keepdim=True)
@@ -222,17 +229,18 @@ def sigma_median(x: torch.Tensor, y: torch.Tensor, dist: torch.Tensor) -> torch.
222
229
 
223
230
  Parameters
224
231
  ----------
225
- x
232
+ x : torch.Tensor
226
233
  Tensor of instances with dimension [Nx, features].
227
- y
234
+ y : torch.Tensor
228
235
  Tensor of instances with dimension [Ny, features].
229
- dist
236
+ dist : torch.Tensor
230
237
  Tensor with dimensions [Nx, Ny], containing the pairwise distances
231
238
  between `x` and `y`.
232
239
 
233
240
  Returns
234
241
  -------
235
- The computed bandwidth, `sigma`.
242
+ torch.Tensor
243
+ The computed bandwidth, `sigma`.
236
244
  """
237
245
  n = min(x.shape[0], y.shape[0])
238
246
  n = n if (x[:n] == y[:n]).all() and x.shape == y.shape else 0
@@ -243,28 +251,28 @@ def sigma_median(x: torch.Tensor, y: torch.Tensor, dist: torch.Tensor) -> torch.
243
251
 
244
252
  class GaussianRBF(nn.Module):
245
253
  """
246
- Gaussian RBF kernel: k(x,y) = exp(-(1/(2*sigma^2)||x-y||^2). A forward pass
247
- takes a batch of instances x [Nx, features] and y [Ny, features] and returns
248
- the kernel matrix [Nx, Ny].
254
+ Gaussian RBF kernel: k(x,y) = exp(-(1/(2*sigma^2)||x-y||^2).
255
+
256
+ A forward pass takes a batch of instances x [Nx, features] and
257
+ y [Ny, features] and returns the kernel matrix [Nx, Ny].
249
258
 
250
259
  Parameters
251
260
  ----------
252
- sigma : Optional[torch.Tensor], default None
261
+ sigma : torch.Tensor | None, default None
253
262
  Bandwidth used for the kernel. Needn't be specified if being inferred or
254
263
  trained. Can pass multiple values to eval kernel with and then average.
255
- init_sigma_fn : Optional[Callable], default None
256
- Function used to compute the bandwidth `sigma`. Used when `sigma` is to be
257
- inferred. The function's signature should take in the tensors `x`, `y` and
258
- `dist` and return `sigma`. If `None`, it is set to
259
- :func:`~dataeval._internal.detectors.drift.torch.sigma_median`.
264
+ init_sigma_fn : Callable | None, default None
265
+ Function used to compute the bandwidth ``sigma``. Used when ``sigma`` is to be
266
+ inferred. The function's signature should take in the tensors ``x``, ``y`` and
267
+ ``dist`` and return ``sigma``. If ``None``, it is set to ``sigma_median``.
260
268
  trainable : bool, default False
261
269
  Whether or not to track gradients w.r.t. `sigma` to allow it to be trained.
262
270
  """
263
271
 
264
272
  def __init__(
265
273
  self,
266
- sigma: Optional[torch.Tensor] = None,
267
- init_sigma_fn: Optional[Callable] = None,
274
+ sigma: torch.Tensor | None = None,
275
+ init_sigma_fn: Callable | None = None,
268
276
  trainable: bool = False,
269
277
  ) -> None:
270
278
  super().__init__()
@@ -290,8 +298,8 @@ class GaussianRBF(nn.Module):
290
298
 
291
299
  def forward(
292
300
  self,
293
- x: Union[np.ndarray, torch.Tensor],
294
- y: Union[np.ndarray, torch.Tensor],
301
+ x: np.ndarray | torch.Tensor,
302
+ y: np.ndarray | torch.Tensor,
295
303
  infer_sigma: bool = False,
296
304
  ) -> torch.Tensor:
297
305
  x, y = torch.as_tensor(x), torch.as_tensor(y)
@@ -6,15 +6,17 @@ Original code Copyright (c) 2023 Seldon Technologies Ltd
6
6
  Licensed under Apache Software License (Apache 2.0)
7
7
  """
8
8
 
9
+ from __future__ import annotations
10
+
9
11
  from functools import partial
10
- from typing import Callable, Literal, Optional
12
+ from typing import Callable, Literal
11
13
 
12
14
  import numpy as np
13
15
  from numpy.typing import ArrayLike, NDArray
14
16
  from scipy.special import softmax
15
17
  from scipy.stats import entropy
16
18
 
17
- from .base import DriftUnivariateOutput, UpdateStrategy
19
+ from .base import DriftOutput, UpdateStrategy
18
20
  from .ks import DriftKS
19
21
  from .torch import get_device, preprocess_drift
20
22
 
@@ -29,18 +31,19 @@ def classifier_uncertainty(
29
31
 
30
32
  Parameters
31
33
  ----------
32
- x
34
+ x : np.ndarray
33
35
  Batch of instances.
34
- model_fn
36
+ model_fn : Callable
35
37
  Function that evaluates a classification model on x in a single call (contains
36
38
  batching logic if necessary).
37
- preds_type
39
+ preds_type : "probs" | "logits", default "probs"
38
40
  Type of prediction output by the model. Options are 'probs' (in [0,1]) or
39
41
  'logits' (in [-inf,inf]).
40
42
 
41
43
  Returns
42
44
  -------
43
- A scalar indication of uncertainty of the model on each instance in x.
45
+ NDArray
46
+ A scalar indication of uncertainty of the model on each instance in x.
44
47
  """
45
48
 
46
49
  preds = model_fn(x)
@@ -61,42 +64,38 @@ def classifier_uncertainty(
61
64
  class DriftUncertainty:
62
65
  """
63
66
  Test for a change in the number of instances falling into regions on which the
64
- model is uncertain. Performs a K-S test on prediction entropies.
67
+ model is uncertain.
68
+
69
+ Performs a K-S test on prediction entropies.
65
70
 
66
71
  Parameters
67
72
  ----------
68
73
  x_ref : ArrayLike
69
- Data used as reference distribution. Should be disjoint from the data the
70
- model was trained on for accurate p-values.
74
+ Data used as reference distribution.
71
75
  model : Callable
72
76
  Classification model outputting class probabilities (or logits)
73
77
  p_val : float, default 0.05
74
78
  p-value used for the significance of the test.
75
79
  x_ref_preprocessed : bool, default False
76
- Whether the given reference data `x_ref` has been preprocessed yet. If
77
- `x_ref_preprocessed=True`, only the test data `x` will be preprocessed at
78
- prediction time. If `x_ref_preprocessed=False`, the reference data will
79
- also be preprocessed.
80
- update_x_ref : Optional[UpdateStrategy], default None
80
+ Whether the given reference data ``x_ref`` has been preprocessed yet.
81
+ If ``True``, only the test data ``x`` will be preprocessed at prediction time.
82
+ If ``False``, the reference data will also be preprocessed.
83
+ update_x_ref : UpdateStrategy | None, default None
81
84
  Reference data can optionally be updated using an UpdateStrategy class. Update
82
- using the last n instances seen by the detector with
83
- :py:class:`dataeval.detectors.LastSeenUpdateStrategy`
84
- or via reservoir sampling with
85
- :py:class:`dataeval.detectors.ReservoirSamplingUpdateStrategy`.
86
- preds_type : Literal["probs", "logits"], default "logits"
85
+ using the last n instances seen by the detector with LastSeenUpdateStrategy
86
+ or via reservoir sampling with ReservoirSamplingUpdateStrategy.
87
+ preds_type : "probs" | "logits", default "logits"
87
88
  Type of prediction output by the model. Options are 'probs' (in [0,1]) or
88
89
  'logits' (in [-inf,inf]).
89
90
  batch_size : int, default 32
90
91
  Batch size used to evaluate model. Only relevant when backend has been
91
92
  specified for batch prediction.
92
- preprocess_batch_fn : Optional[Callable], default None
93
+ preprocess_batch_fn : Callable | None, default None
93
94
  Optional batch preprocessing function. For example to convert a list of
94
95
  objects to a batch which can be processed by the model.
95
- device : Optional[str], default None
96
+ device : str | None, default None
96
97
  Device type used. The default None tries to use the GPU and falls back on
97
98
  CPU if needed. Can be specified by passing either 'cuda', 'gpu' or 'cpu'.
98
- input_shape : Optional[tuple], default None
99
- Shape of input data.
100
99
  """
101
100
 
102
101
  def __init__(
@@ -105,11 +104,11 @@ class DriftUncertainty:
105
104
  model: Callable,
106
105
  p_val: float = 0.05,
107
106
  x_ref_preprocessed: bool = False,
108
- update_x_ref: Optional[UpdateStrategy] = None,
107
+ update_x_ref: UpdateStrategy | None = None,
109
108
  preds_type: Literal["probs", "logits"] = "probs",
110
109
  batch_size: int = 32,
111
- preprocess_batch_fn: Optional[Callable] = None,
112
- device: Optional[str] = None,
110
+ preprocess_batch_fn: Callable | None = None,
111
+ device: str | None = None,
113
112
  ) -> None:
114
113
  def model_fn(x: NDArray) -> NDArray:
115
114
  return preprocess_drift(
@@ -134,7 +133,7 @@ class DriftUncertainty:
134
133
  preprocess_fn=preprocess_fn, # type: ignore
135
134
  )
136
135
 
137
- def predict(self, x: ArrayLike) -> DriftUnivariateOutput:
136
+ def predict(self, x: ArrayLike) -> DriftOutput:
138
137
  """
139
138
  Predict whether a batch of data has drifted from the reference data.
140
139
 
@@ -145,6 +144,7 @@ class DriftUncertainty:
145
144
 
146
145
  Returns
147
146
  -------
148
- Dictionary containing the drift prediction, p-value, and threshold statistics.
147
+ DriftUnvariateOutput
148
+ Dictionary containing the drift prediction, p-value, and threshold statistics.
149
149
  """
150
150
  return self._detector.predict(x)