dataeval 0.64.0__py3-none-any.whl → 0.66.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (65) hide show
  1. dataeval/__init__.py +13 -9
  2. dataeval/_internal/detectors/clusterer.py +63 -49
  3. dataeval/_internal/detectors/drift/base.py +248 -51
  4. dataeval/_internal/detectors/drift/cvm.py +28 -26
  5. dataeval/_internal/detectors/drift/ks.py +31 -28
  6. dataeval/_internal/detectors/drift/mmd.py +62 -42
  7. dataeval/_internal/detectors/drift/torch.py +69 -60
  8. dataeval/_internal/detectors/drift/uncertainty.py +32 -32
  9. dataeval/_internal/detectors/duplicates.py +67 -31
  10. dataeval/_internal/detectors/ood/ae.py +15 -29
  11. dataeval/_internal/detectors/ood/aegmm.py +33 -27
  12. dataeval/_internal/detectors/ood/base.py +86 -47
  13. dataeval/_internal/detectors/ood/llr.py +34 -31
  14. dataeval/_internal/detectors/ood/vae.py +32 -31
  15. dataeval/_internal/detectors/ood/vaegmm.py +34 -28
  16. dataeval/_internal/detectors/{linter.py → outliers.py} +60 -38
  17. dataeval/_internal/flags.py +44 -21
  18. dataeval/_internal/interop.py +5 -3
  19. dataeval/_internal/metrics/balance.py +42 -5
  20. dataeval/_internal/metrics/ber.py +11 -8
  21. dataeval/_internal/metrics/coverage.py +15 -8
  22. dataeval/_internal/metrics/divergence.py +41 -7
  23. dataeval/_internal/metrics/diversity.py +57 -19
  24. dataeval/_internal/metrics/parity.py +141 -66
  25. dataeval/_internal/metrics/stats.py +330 -313
  26. dataeval/_internal/metrics/uap.py +33 -4
  27. dataeval/_internal/metrics/utils.py +79 -40
  28. dataeval/_internal/models/pytorch/autoencoder.py +127 -22
  29. dataeval/_internal/models/tensorflow/autoencoder.py +33 -30
  30. dataeval/_internal/models/tensorflow/gmm.py +4 -2
  31. dataeval/_internal/models/tensorflow/losses.py +17 -13
  32. dataeval/_internal/models/tensorflow/pixelcnn.py +19 -18
  33. dataeval/_internal/models/tensorflow/trainer.py +10 -7
  34. dataeval/_internal/models/tensorflow/utils.py +23 -20
  35. dataeval/_internal/output.py +85 -0
  36. dataeval/_internal/utils.py +5 -3
  37. dataeval/_internal/workflows/sufficiency.py +122 -121
  38. dataeval/detectors/__init__.py +6 -25
  39. dataeval/detectors/drift/__init__.py +16 -0
  40. dataeval/detectors/drift/kernels/__init__.py +6 -0
  41. dataeval/detectors/drift/updates/__init__.py +3 -0
  42. dataeval/detectors/linters/__init__.py +5 -0
  43. dataeval/detectors/ood/__init__.py +11 -0
  44. dataeval/flags/__init__.py +2 -2
  45. dataeval/metrics/__init__.py +2 -26
  46. dataeval/metrics/bias/__init__.py +14 -0
  47. dataeval/metrics/estimators/__init__.py +9 -0
  48. dataeval/metrics/stats/__init__.py +6 -0
  49. dataeval/tensorflow/__init__.py +3 -0
  50. dataeval/tensorflow/loss/__init__.py +3 -0
  51. dataeval/tensorflow/models/__init__.py +5 -0
  52. dataeval/tensorflow/recon/__init__.py +3 -0
  53. dataeval/torch/__init__.py +3 -0
  54. dataeval/{models/torch → torch/models}/__init__.py +1 -2
  55. dataeval/torch/trainer/__init__.py +3 -0
  56. dataeval/utils/__init__.py +3 -6
  57. dataeval/workflows/__init__.py +2 -4
  58. {dataeval-0.64.0.dist-info → dataeval-0.66.0.dist-info}/METADATA +1 -1
  59. dataeval-0.66.0.dist-info/RECORD +72 -0
  60. dataeval/_internal/metrics/base.py +0 -10
  61. dataeval/models/__init__.py +0 -15
  62. dataeval/models/tensorflow/__init__.py +0 -6
  63. dataeval-0.64.0.dist-info/RECORD +0 -60
  64. {dataeval-0.64.0.dist-info → dataeval-0.66.0.dist-info}/LICENSE.txt +0 -0
  65. {dataeval-0.64.0.dist-info → dataeval-0.66.0.dist-info}/WHEEL +0 -0
@@ -6,17 +6,47 @@ Original code Copyright (c) 2023 Seldon Technologies Ltd
6
6
  Licensed under Apache Software License (Apache 2.0)
7
7
  """
8
8
 
9
- from typing import Callable, Dict, Optional, Tuple, Union
9
+ from __future__ import annotations
10
+
11
+ from dataclasses import dataclass
12
+ from typing import Callable
10
13
 
11
14
  import torch
12
15
  from numpy.typing import ArrayLike
13
16
 
14
17
  from dataeval._internal.interop import to_numpy
18
+ from dataeval._internal.output import set_metadata
15
19
 
16
- from .base import BaseDrift, UpdateStrategy, preprocess_x, update_x_ref
20
+ from .base import BaseDrift, DriftBaseOutput, UpdateStrategy, preprocess_x, update_x_ref
17
21
  from .torch import GaussianRBF, get_device, mmd2_from_kernel_matrix
18
22
 
19
23
 
24
+ @dataclass(frozen=True)
25
+ class DriftMMDOutput(DriftBaseOutput):
26
+ """
27
+ Output class for DriftMMD
28
+
29
+ Attributes
30
+ ----------
31
+ is_drift : bool
32
+ Drift prediction for the images
33
+ threshold : float
34
+ P-value used for significance of the permutation test
35
+ p_val : float
36
+ P-value obtained from the permutation test
37
+ distance : float
38
+ MMD^2 between the reference and test set
39
+ distance_threshold : float
40
+ MMD^2 threshold above which drift is flagged
41
+ """
42
+
43
+ # is_drift: bool
44
+ # threshold: float
45
+ p_val: float
46
+ distance: float
47
+ distance_threshold: float
48
+
49
+
20
50
  class DriftMMD(BaseDrift):
21
51
  """
22
52
  Maximum Mean Discrepancy (MMD) data drift detector using a permutation test.
@@ -25,28 +55,24 @@ class DriftMMD(BaseDrift):
25
55
  ----------
26
56
  x_ref : ArrayLike
27
57
  Data used as reference distribution.
28
- p_val : float, default 0.05
29
- p-value used for the significance of the permutation test.
58
+ p_val : float | None, default 0.05
59
+ p-value used for significance of the statistical test for each feature.
60
+ If the FDR correction method is used, this corresponds to the acceptable
61
+ q-value.
30
62
  x_ref_preprocessed : bool, default False
31
- Whether the given reference data `x_ref` has been preprocessed yet. If
32
- `x_ref_preprocessed=True`, only the test data `x` will be preprocessed
33
- at prediction time. If `x_ref_preprocessed=False`, the reference data
34
- will also be preprocessed.
35
- preprocess_at_init : bool, default True
36
- Whether to preprocess the reference data when the detector is instantiated.
37
- Otherwise, the reference data will be preprocessed at prediction time. Only
38
- applies if `x_ref_preprocessed=False`.
39
- update_x_ref : Optional[UpdateStrategy], default None
63
+ Whether the given reference data ``x_ref`` has been preprocessed yet.
64
+ If ``True``, only the test data ``x`` will be preprocessed at prediction time.
65
+ If ``False``, the reference data will also be preprocessed.
66
+ update_x_ref : UpdateStrategy | None, default None
40
67
  Reference data can optionally be updated using an UpdateStrategy class. Update
41
- using the last n instances seen by the detector with
42
- :py:class:`dataeval.detectors.LastSeenUpdateStrategy`
43
- or via reservoir sampling with
44
- :py:class:`dataeval.detectors.ReservoirSamplingUpdateStrategy`.
45
- preprocess_fn : Optional[Callable], default None
68
+ using the last n instances seen by the detector with LastSeenUpdateStrategy
69
+ or via reservoir sampling with ReservoirSamplingUpdateStrategy.
70
+ preprocess_fn : Callable | None, default None
46
71
  Function to preprocess the data before computing the data drift metrics.
47
- kernel : Callable, default :py:class:`dataeval.detectors.GaussianRBF`
72
+ Typically a dimensionality reduction technique.
73
+ kernel : Callable, default GaussianRBF
48
74
  Kernel used for the MMD computation, defaults to Gaussian RBF kernel.
49
- sigma : Optional[ArrayLike], default None
75
+ sigma : ArrayLike | None, default None
50
76
  Optionally set the GaussianRBF kernel bandwidth. Can also pass multiple
51
77
  bandwidth values as an array. The kernel evaluation is then averaged over
52
78
  those bandwidths.
@@ -54,7 +80,7 @@ class DriftMMD(BaseDrift):
54
80
  Whether to already configure the kernel bandwidth from the reference data.
55
81
  n_permutations : int, default 100
56
82
  Number of permutations used in the permutation test.
57
- device : Optional[str], default None
83
+ device : str | None, default None
58
84
  Device type used. The default None uses the GPU and falls back on CPU.
59
85
  Can be specified by passing either 'cuda', 'gpu' or 'cpu'.
60
86
  """
@@ -64,13 +90,13 @@ class DriftMMD(BaseDrift):
64
90
  x_ref: ArrayLike,
65
91
  p_val: float = 0.05,
66
92
  x_ref_preprocessed: bool = False,
67
- update_x_ref: Optional[UpdateStrategy] = None,
68
- preprocess_fn: Optional[Callable[[ArrayLike], ArrayLike]] = None,
93
+ update_x_ref: UpdateStrategy | None = None,
94
+ preprocess_fn: Callable[[ArrayLike], ArrayLike] | None = None,
69
95
  kernel: Callable = GaussianRBF,
70
- sigma: Optional[ArrayLike] = None,
96
+ sigma: ArrayLike | None = None,
71
97
  configure_kernel_from_x_ref: bool = True,
72
98
  n_permutations: int = 100,
73
- device: Optional[str] = None,
99
+ device: str | None = None,
74
100
  ) -> None:
75
101
  super().__init__(x_ref, p_val, x_ref_preprocessed, update_x_ref, preprocess_fn)
76
102
 
@@ -104,7 +130,7 @@ class DriftMMD(BaseDrift):
104
130
  return kernel_mat
105
131
 
106
132
  @preprocess_x
107
- def score(self, x: ArrayLike) -> Tuple[float, float, float]:
133
+ def score(self, x: ArrayLike) -> tuple[float, float, float]:
108
134
  """
109
135
  Compute the p-value resulting from a permutation test using the maximum mean
110
136
  discrepancy as a distance measure between the reference data and the data to
@@ -117,8 +143,9 @@ class DriftMMD(BaseDrift):
117
143
 
118
144
  Returns
119
145
  -------
120
- p-value obtained from the permutation test, the MMD^2 between the reference and
121
- test set, and the MMD^2 threshold above which drift is flagged.
146
+ tuple(float, float, float)
147
+ p-value obtained from the permutation test, MMD^2 between the reference and test set,
148
+ and MMD^2 threshold above which drift is flagged
122
149
  """
123
150
  x = to_numpy(x)
124
151
  x_ref = torch.from_numpy(self.x_ref).to(self.device)
@@ -129,19 +156,17 @@ class DriftMMD(BaseDrift):
129
156
  mmd2_permuted = torch.Tensor(
130
157
  [mmd2_from_kernel_matrix(kernel_mat, n, permute=True, zero_diag=False) for _ in range(self.n_permutations)]
131
158
  )
132
- mmd2, mmd2_permuted = mmd2.cpu(), mmd2_permuted.cpu()
159
+ mmd2, mmd2_permuted = mmd2.detach().cpu(), mmd2_permuted.detach().cpu()
133
160
  p_val = (mmd2 <= mmd2_permuted).float().mean()
134
161
  # compute distance threshold
135
162
  idx_threshold = int(self.p_val * len(mmd2_permuted))
136
163
  distance_threshold = torch.sort(mmd2_permuted, descending=True).values[idx_threshold]
137
164
  return p_val.numpy().item(), mmd2.numpy().item(), distance_threshold.numpy()
138
165
 
166
+ @set_metadata("dataeval.detectors")
139
167
  @preprocess_x
140
168
  @update_x_ref
141
- def predict(
142
- self,
143
- x: ArrayLike,
144
- ) -> Dict[str, Union[int, float]]:
169
+ def predict(self, x: ArrayLike) -> DriftMMDOutput:
145
170
  """
146
171
  Predict whether a batch of data has drifted from the reference data and then
147
172
  updates reference data using specified strategy.
@@ -153,17 +178,12 @@ class DriftMMD(BaseDrift):
153
178
 
154
179
  Returns
155
180
  -------
156
- Dictionary containing the drift prediction, p-value, threshold and MMD metric.
181
+ DriftMMDOutput
182
+ Output class containing the drift prediction, p-value, threshold and MMD metric.
157
183
  """
158
184
  # compute drift scores
159
185
  p_val, dist, distance_threshold = self.score(x)
160
- drift_pred = int(p_val < self.p_val)
186
+ drift_pred = bool(p_val < self.p_val)
161
187
 
162
188
  # populate drift dict
163
- return {
164
- "is_drift": drift_pred,
165
- "p_val": p_val,
166
- "threshold": self.p_val,
167
- "distance": dist,
168
- "distance_threshold": distance_threshold,
169
- }
189
+ return DriftMMDOutput(drift_pred, self.p_val, p_val, dist, distance_threshold)
@@ -6,23 +6,26 @@ Original code Copyright (c) 2023 Seldon Technologies Ltd
6
6
  Licensed under Apache Software License (Apache 2.0)
7
7
  """
8
8
 
9
+ from __future__ import annotations
10
+
9
11
  from functools import partial
10
- from typing import Callable, Optional, Type, Union
12
+ from typing import Callable
11
13
 
12
14
  import numpy as np
13
15
  import torch
14
16
  import torch.nn as nn
17
+ from numpy.typing import NDArray
15
18
 
16
19
 
17
- def get_device(device: Optional[Union[str, torch.device]] = None) -> torch.device:
20
+ def get_device(device: str | torch.device | None = None) -> torch.device:
18
21
  """
19
22
  Instantiates a PyTorch device object.
20
23
 
21
24
  Parameters
22
25
  ----------
23
- device
24
- Either `None`, a str ('gpu' or 'cpu') indicating the device to choose, or an
25
- already instantiated device object. If `None`, the GPU is selected if it is
26
+ device : str | torch.device | None, default None
27
+ Either ``None``, a str ('gpu' or 'cpu') indicating the device to choose, or an
28
+ already instantiated device object. If ``None``, the GPU is selected if it is
26
29
  detected, otherwise the CPU is used as a fallback.
27
30
 
28
31
  Returns
@@ -48,18 +51,19 @@ def mmd2_from_kernel_matrix(
48
51
 
49
52
  Parameters
50
53
  ----------
51
- kernel_mat
54
+ kernel_mat : torch.Tensor
52
55
  Kernel matrix between samples x and y.
53
- m
56
+ m : int
54
57
  Number of instances in y.
55
- permute
58
+ permute : bool, default False
56
59
  Whether to permute the row indices. Used for permutation tests.
57
- zero_diag
60
+ zero_diag : bool, default True
58
61
  Whether to zero out the diagonal of the kernel matrix.
59
62
 
60
63
  Returns
61
64
  -------
62
- MMD^2 between the samples from the kernel matrix.
65
+ torch.Tensor
66
+ MMD^2 between the samples from the kernel matrix.
63
67
  """
64
68
  n = kernel_mat.shape[0] - m
65
69
  if zero_diag:
@@ -74,35 +78,36 @@ def mmd2_from_kernel_matrix(
74
78
 
75
79
 
76
80
  def predict_batch(
77
- x: Union[np.ndarray, torch.Tensor],
78
- model: Union[Callable, nn.Module, nn.Sequential],
79
- device: Optional[torch.device] = None,
81
+ x: NDArray | torch.Tensor,
82
+ model: Callable | nn.Module | nn.Sequential,
83
+ device: torch.device | None = None,
80
84
  batch_size: int = int(1e10),
81
- preprocess_fn: Optional[Callable] = None,
82
- dtype: Union[Type[np.generic], torch.dtype] = np.float32,
83
- ) -> Union[np.ndarray, torch.Tensor, tuple]:
85
+ preprocess_fn: Callable | None = None,
86
+ dtype: type[np.generic] | torch.dtype = np.float32,
87
+ ) -> NDArray | torch.Tensor | tuple:
84
88
  """
85
89
  Make batch predictions on a model.
86
90
 
87
91
  Parameters
88
92
  ----------
89
- x
93
+ x : np.ndarray | torch.Tensor
90
94
  Batch of instances.
91
- model
95
+ model : Callable | nn.Module | nn.Sequential
92
96
  PyTorch model.
93
- device
97
+ device : torch.device | None, default None
94
98
  Device type used. The default None tries to use the GPU and falls back on CPU.
95
99
  Can be specified by passing either torch.device('cuda') or torch.device('cpu').
96
- batch_size
100
+ batch_size : int, default 1e10
97
101
  Batch size used during prediction.
98
- preprocess_fn
102
+ preprocess_fn : Callable | None, default None
99
103
  Optional preprocessing function for each batch.
100
- dtype
101
- Model output type, e.g. np.float32 or torch.float32.
104
+ dtype : np.dtype | torch.dtype, default np.float32
105
+ Model output type, either a numpy or torch dtype, e.g. np.float32 or torch.float32.
102
106
 
103
107
  Returns
104
108
  -------
105
- Numpy array, torch tensor or tuples of those with model outputs.
109
+ NDArray | torch.Tensor | tuple
110
+ Numpy array, torch tensor or tuples of those with model outputs.
106
111
  """
107
112
  device = get_device(device)
108
113
  if isinstance(x, np.ndarray):
@@ -138,47 +143,48 @@ def predict_batch(
138
143
  else:
139
144
  raise TypeError(
140
145
  f"Model output type {type(preds_tmp)} not supported. The model \
141
- output type needs to be one of list, tuple, np.ndarray or \
146
+ output type needs to be one of list, tuple, NDArray or \
142
147
  torch.Tensor."
143
148
  )
144
149
  concat = partial(np.concatenate, axis=0) if return_np else partial(torch.cat, dim=0)
145
- out: Union[tuple, np.ndarray, torch.Tensor] = (
150
+ out: tuple | np.ndarray | torch.Tensor = (
146
151
  tuple(concat(p) for p in preds) if isinstance(preds, tuple) else concat(preds) # type: ignore
147
152
  )
148
153
  return out
149
154
 
150
155
 
151
156
  def preprocess_drift(
152
- x: np.ndarray,
157
+ x: NDArray,
153
158
  model: nn.Module,
154
- device: Optional[torch.device] = None,
155
- preprocess_batch_fn: Optional[Callable] = None,
159
+ device: torch.device | None = None,
160
+ preprocess_batch_fn: Callable | None = None,
156
161
  batch_size: int = int(1e10),
157
- dtype: Union[Type[np.generic], torch.dtype] = np.float32,
158
- ) -> Union[np.ndarray, torch.Tensor, tuple]:
162
+ dtype: type[np.generic] | torch.dtype = np.float32,
163
+ ) -> NDArray | torch.Tensor | tuple:
159
164
  """
160
165
  Prediction function used for preprocessing step of drift detector.
161
166
 
162
167
  Parameters
163
168
  ----------
164
- x
169
+ x : NDArray
165
170
  Batch of instances.
166
- model
171
+ model : nn.Module
167
172
  Model used for preprocessing.
168
- device
173
+ device : torch.device | None, default None
169
174
  Device type used. The default None tries to use the GPU and falls back on CPU.
170
175
  Can be specified by passing either torch.device('cuda') or torch.device('cpu').
171
- preprocess_batch_fn
176
+ preprocess_batch_fn : Callable | None, default None
172
177
  Optional batch preprocessing function. For example to convert a list of objects
173
178
  to a batch which can be processed by the PyTorch model.
174
- batch_size
179
+ batch_size : int, default 1e10
175
180
  Batch size used during prediction.
176
- dtype
177
- Model output type, e.g. np.float32 or torch.float32.
181
+ dtype : np.dtype | torch.dtype, default np.float32
182
+ Model output type, either a numpy or torch dtype, e.g. np.float32 or torch.float32.
178
183
 
179
184
  Returns
180
185
  -------
181
- Numpy array or torch tensor with predictions.
186
+ NDArray | torch.Tensor | tuple
187
+ Numpy array, torch tensor or tuples of those with model outputs.
182
188
  """
183
189
  return predict_batch(
184
190
  x,
@@ -199,15 +205,17 @@ def squared_pairwise_distance(
199
205
 
200
206
  Parameters
201
207
  ----------
202
- x
208
+ x : torch.Tensor
203
209
  Batch of instances of shape [Nx, features].
204
- y
210
+ y : torch.Tensor
205
211
  Batch of instances of shape [Ny, features].
206
- a_min
212
+ a_min : float
207
213
  Lower bound to clip distance values.
214
+
208
215
  Returns
209
216
  -------
210
- Pairwise squared Euclidean distance [Nx, Ny].
217
+ torch.Tensor
218
+ Pairwise squared Euclidean distance [Nx, Ny].
211
219
  """
212
220
  x2 = x.pow(2).sum(dim=-1, keepdim=True)
213
221
  y2 = y.pow(2).sum(dim=-1, keepdim=True)
@@ -221,17 +229,18 @@ def sigma_median(x: torch.Tensor, y: torch.Tensor, dist: torch.Tensor) -> torch.
221
229
 
222
230
  Parameters
223
231
  ----------
224
- x
232
+ x : torch.Tensor
225
233
  Tensor of instances with dimension [Nx, features].
226
- y
234
+ y : torch.Tensor
227
235
  Tensor of instances with dimension [Ny, features].
228
- dist
236
+ dist : torch.Tensor
229
237
  Tensor with dimensions [Nx, Ny], containing the pairwise distances
230
238
  between `x` and `y`.
231
239
 
232
240
  Returns
233
241
  -------
234
- The computed bandwidth, `sigma`.
242
+ torch.Tensor
243
+ The computed bandwidth, `sigma`.
235
244
  """
236
245
  n = min(x.shape[0], y.shape[0])
237
246
  n = n if (x[:n] == y[:n]).all() and x.shape == y.shape else 0
@@ -242,28 +251,28 @@ def sigma_median(x: torch.Tensor, y: torch.Tensor, dist: torch.Tensor) -> torch.
242
251
 
243
252
  class GaussianRBF(nn.Module):
244
253
  """
245
- Gaussian RBF kernel: k(x,y) = exp(-(1/(2*sigma^2)||x-y||^2). A forward pass
246
- takes a batch of instances x [Nx, features] and y [Ny, features] and returns
247
- the kernel matrix [Nx, Ny].
254
+ Gaussian RBF kernel: k(x,y) = exp(-(1/(2*sigma^2)||x-y||^2).
255
+
256
+ A forward pass takes a batch of instances x [Nx, features] and
257
+ y [Ny, features] and returns the kernel matrix [Nx, Ny].
248
258
 
249
259
  Parameters
250
260
  ----------
251
- sigma : Optional[torch.Tensor], default None
261
+ sigma : torch.Tensor | None, default None
252
262
  Bandwidth used for the kernel. Needn't be specified if being inferred or
253
263
  trained. Can pass multiple values to eval kernel with and then average.
254
- init_sigma_fn : Optional[Callable], default None
255
- Function used to compute the bandwidth `sigma`. Used when `sigma` is to be
256
- inferred. The function's signature should take in the tensors `x`, `y` and
257
- `dist` and return `sigma`. If `None`, it is set to
258
- :func:`~dataeval._internal.detectors.drift.torch.sigma_median`.
264
+ init_sigma_fn : Callable | None, default None
265
+ Function used to compute the bandwidth ``sigma``. Used when ``sigma`` is to be
266
+ inferred. The function's signature should take in the tensors ``x``, ``y`` and
267
+ ``dist`` and return ``sigma``. If ``None``, it is set to ``sigma_median``.
259
268
  trainable : bool, default False
260
269
  Whether or not to track gradients w.r.t. `sigma` to allow it to be trained.
261
270
  """
262
271
 
263
272
  def __init__(
264
273
  self,
265
- sigma: Optional[torch.Tensor] = None,
266
- init_sigma_fn: Optional[Callable] = None,
274
+ sigma: torch.Tensor | None = None,
275
+ init_sigma_fn: Callable | None = None,
267
276
  trainable: bool = False,
268
277
  ) -> None:
269
278
  super().__init__()
@@ -289,8 +298,8 @@ class GaussianRBF(nn.Module):
289
298
 
290
299
  def forward(
291
300
  self,
292
- x: Union[np.ndarray, torch.Tensor],
293
- y: Union[np.ndarray, torch.Tensor],
301
+ x: np.ndarray | torch.Tensor,
302
+ y: np.ndarray | torch.Tensor,
294
303
  infer_sigma: bool = False,
295
304
  ) -> torch.Tensor:
296
305
  x, y = torch.as_tensor(x), torch.as_tensor(y)
@@ -6,41 +6,44 @@ Original code Copyright (c) 2023 Seldon Technologies Ltd
6
6
  Licensed under Apache Software License (Apache 2.0)
7
7
  """
8
8
 
9
+ from __future__ import annotations
10
+
9
11
  from functools import partial
10
- from typing import Callable, Dict, Literal, Optional, Union
12
+ from typing import Callable, Literal
11
13
 
12
14
  import numpy as np
13
- from numpy.typing import ArrayLike
15
+ from numpy.typing import ArrayLike, NDArray
14
16
  from scipy.special import softmax
15
17
  from scipy.stats import entropy
16
18
 
17
- from .base import UpdateStrategy
19
+ from .base import DriftOutput, UpdateStrategy
18
20
  from .ks import DriftKS
19
21
  from .torch import get_device, preprocess_drift
20
22
 
21
23
 
22
24
  def classifier_uncertainty(
23
- x: np.ndarray,
25
+ x: NDArray,
24
26
  model_fn: Callable,
25
27
  preds_type: Literal["probs", "logits"] = "probs",
26
- ) -> np.ndarray:
28
+ ) -> NDArray:
27
29
  """
28
30
  Evaluate model_fn on x and transform predictions to prediction uncertainties.
29
31
 
30
32
  Parameters
31
33
  ----------
32
- x
34
+ x : np.ndarray
33
35
  Batch of instances.
34
- model_fn
36
+ model_fn : Callable
35
37
  Function that evaluates a classification model on x in a single call (contains
36
38
  batching logic if necessary).
37
- preds_type
39
+ preds_type : "probs" | "logits", default "probs"
38
40
  Type of prediction output by the model. Options are 'probs' (in [0,1]) or
39
41
  'logits' (in [-inf,inf]).
40
42
 
41
43
  Returns
42
44
  -------
43
- A scalar indication of uncertainty of the model on each instance in x.
45
+ NDArray
46
+ A scalar indication of uncertainty of the model on each instance in x.
44
47
  """
45
48
 
46
49
  preds = model_fn(x)
@@ -61,42 +64,38 @@ def classifier_uncertainty(
61
64
  class DriftUncertainty:
62
65
  """
63
66
  Test for a change in the number of instances falling into regions on which the
64
- model is uncertain. Performs a K-S test on prediction entropies.
67
+ model is uncertain.
68
+
69
+ Performs a K-S test on prediction entropies.
65
70
 
66
71
  Parameters
67
72
  ----------
68
73
  x_ref : ArrayLike
69
- Data used as reference distribution. Should be disjoint from the data the
70
- model was trained on for accurate p-values.
74
+ Data used as reference distribution.
71
75
  model : Callable
72
76
  Classification model outputting class probabilities (or logits)
73
77
  p_val : float, default 0.05
74
78
  p-value used for the significance of the test.
75
79
  x_ref_preprocessed : bool, default False
76
- Whether the given reference data `x_ref` has been preprocessed yet. If
77
- `x_ref_preprocessed=True`, only the test data `x` will be preprocessed at
78
- prediction time. If `x_ref_preprocessed=False`, the reference data will
79
- also be preprocessed.
80
- update_x_ref : Optional[UpdateStrategy], default None
80
+ Whether the given reference data ``x_ref`` has been preprocessed yet.
81
+ If ``True``, only the test data ``x`` will be preprocessed at prediction time.
82
+ If ``False``, the reference data will also be preprocessed.
83
+ update_x_ref : UpdateStrategy | None, default None
81
84
  Reference data can optionally be updated using an UpdateStrategy class. Update
82
- using the last n instances seen by the detector with
83
- :py:class:`dataeval.detectors.LastSeenUpdateStrategy`
84
- or via reservoir sampling with
85
- :py:class:`dataeval.detectors.ReservoirSamplingUpdateStrategy`.
86
- preds_type : Literal["probs", "logits"], default "logits"
85
+ using the last n instances seen by the detector with LastSeenUpdateStrategy
86
+ or via reservoir sampling with ReservoirSamplingUpdateStrategy.
87
+ preds_type : "probs" | "logits", default "logits"
87
88
  Type of prediction output by the model. Options are 'probs' (in [0,1]) or
88
89
  'logits' (in [-inf,inf]).
89
90
  batch_size : int, default 32
90
91
  Batch size used to evaluate model. Only relevant when backend has been
91
92
  specified for batch prediction.
92
- preprocess_batch_fn : Optional[Callable], default None
93
+ preprocess_batch_fn : Callable | None, default None
93
94
  Optional batch preprocessing function. For example to convert a list of
94
95
  objects to a batch which can be processed by the model.
95
- device : Optional[str], default None
96
+ device : str | None, default None
96
97
  Device type used. The default None tries to use the GPU and falls back on
97
98
  CPU if needed. Can be specified by passing either 'cuda', 'gpu' or 'cpu'.
98
- input_shape : Optional[tuple], default None
99
- Shape of input data.
100
99
  """
101
100
 
102
101
  def __init__(
@@ -105,13 +104,13 @@ class DriftUncertainty:
105
104
  model: Callable,
106
105
  p_val: float = 0.05,
107
106
  x_ref_preprocessed: bool = False,
108
- update_x_ref: Optional[UpdateStrategy] = None,
107
+ update_x_ref: UpdateStrategy | None = None,
109
108
  preds_type: Literal["probs", "logits"] = "probs",
110
109
  batch_size: int = 32,
111
- preprocess_batch_fn: Optional[Callable] = None,
112
- device: Optional[str] = None,
110
+ preprocess_batch_fn: Callable | None = None,
111
+ device: str | None = None,
113
112
  ) -> None:
114
- def model_fn(x: np.ndarray) -> np.ndarray:
113
+ def model_fn(x: NDArray) -> NDArray:
115
114
  return preprocess_drift(
116
115
  x,
117
116
  model, # type: ignore
@@ -134,7 +133,7 @@ class DriftUncertainty:
134
133
  preprocess_fn=preprocess_fn, # type: ignore
135
134
  )
136
135
 
137
- def predict(self, x: ArrayLike) -> Dict[str, Union[int, float, np.ndarray]]:
136
+ def predict(self, x: ArrayLike) -> DriftOutput:
138
137
  """
139
138
  Predict whether a batch of data has drifted from the reference data.
140
139
 
@@ -145,6 +144,7 @@ class DriftUncertainty:
145
144
 
146
145
  Returns
147
146
  -------
148
- Dictionary containing the drift prediction, p-value, and threshold statistics.
147
+ DriftUnvariateOutput
148
+ Dictionary containing the drift prediction, p-value, and threshold statistics.
149
149
  """
150
150
  return self._detector.predict(x)