dataeval 0.76.1__py3-none-any.whl → 0.82.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (113) hide show
  1. dataeval/__init__.py +3 -3
  2. dataeval/config.py +77 -0
  3. dataeval/detectors/__init__.py +1 -1
  4. dataeval/detectors/drift/__init__.py +6 -6
  5. dataeval/detectors/drift/{base.py → _base.py} +40 -85
  6. dataeval/detectors/drift/{cvm.py → _cvm.py} +21 -28
  7. dataeval/detectors/drift/{ks.py → _ks.py} +20 -26
  8. dataeval/detectors/drift/{mmd.py → _mmd.py} +31 -43
  9. dataeval/detectors/drift/{torch.py → _torch.py} +2 -1
  10. dataeval/detectors/drift/{uncertainty.py → _uncertainty.py} +24 -7
  11. dataeval/detectors/drift/updates.py +20 -3
  12. dataeval/detectors/linters/__init__.py +3 -5
  13. dataeval/detectors/linters/duplicates.py +13 -36
  14. dataeval/detectors/linters/outliers.py +23 -148
  15. dataeval/detectors/ood/__init__.py +1 -1
  16. dataeval/detectors/ood/ae.py +30 -9
  17. dataeval/detectors/ood/base.py +5 -4
  18. dataeval/detectors/ood/mixin.py +21 -7
  19. dataeval/detectors/ood/vae.py +73 -0
  20. dataeval/metadata/__init__.py +6 -0
  21. dataeval/metadata/_distance.py +167 -0
  22. dataeval/metadata/_ood.py +217 -0
  23. dataeval/metadata/_utils.py +44 -0
  24. dataeval/metrics/__init__.py +1 -1
  25. dataeval/metrics/bias/__init__.py +6 -4
  26. dataeval/metrics/bias/{balance.py → _balance.py} +15 -101
  27. dataeval/metrics/bias/_coverage.py +98 -0
  28. dataeval/metrics/bias/{diversity.py → _diversity.py} +18 -111
  29. dataeval/metrics/bias/{parity.py → _parity.py} +39 -77
  30. dataeval/metrics/estimators/__init__.py +15 -4
  31. dataeval/metrics/estimators/{ber.py → _ber.py} +42 -29
  32. dataeval/metrics/estimators/_clusterer.py +44 -0
  33. dataeval/metrics/estimators/{divergence.py → _divergence.py} +18 -30
  34. dataeval/metrics/estimators/{uap.py → _uap.py} +4 -18
  35. dataeval/metrics/stats/__init__.py +16 -13
  36. dataeval/metrics/stats/{base.py → _base.py} +82 -133
  37. dataeval/metrics/stats/{boxratiostats.py → _boxratiostats.py} +15 -18
  38. dataeval/metrics/stats/_dimensionstats.py +75 -0
  39. dataeval/metrics/stats/{hashstats.py → _hashstats.py} +21 -37
  40. dataeval/metrics/stats/_imagestats.py +94 -0
  41. dataeval/metrics/stats/_labelstats.py +131 -0
  42. dataeval/metrics/stats/{pixelstats.py → _pixelstats.py} +19 -50
  43. dataeval/metrics/stats/{visualstats.py → _visualstats.py} +23 -54
  44. dataeval/outputs/__init__.py +53 -0
  45. dataeval/{output.py → outputs/_base.py} +55 -25
  46. dataeval/outputs/_bias.py +381 -0
  47. dataeval/outputs/_drift.py +83 -0
  48. dataeval/outputs/_estimators.py +114 -0
  49. dataeval/outputs/_linters.py +184 -0
  50. dataeval/{detectors/ood/output.py → outputs/_ood.py} +22 -22
  51. dataeval/outputs/_stats.py +387 -0
  52. dataeval/outputs/_utils.py +44 -0
  53. dataeval/outputs/_workflows.py +364 -0
  54. dataeval/typing.py +234 -0
  55. dataeval/utils/__init__.py +2 -2
  56. dataeval/utils/_array.py +169 -0
  57. dataeval/utils/_bin.py +199 -0
  58. dataeval/utils/_clusterer.py +144 -0
  59. dataeval/utils/_fast_mst.py +189 -0
  60. dataeval/utils/{image.py → _image.py} +6 -4
  61. dataeval/utils/_method.py +14 -0
  62. dataeval/utils/{shared.py → _mst.py} +3 -65
  63. dataeval/utils/{plot.py → _plot.py} +6 -6
  64. dataeval/utils/data/__init__.py +26 -0
  65. dataeval/utils/data/_dataset.py +217 -0
  66. dataeval/utils/data/_embeddings.py +104 -0
  67. dataeval/utils/data/_images.py +68 -0
  68. dataeval/utils/data/_metadata.py +360 -0
  69. dataeval/utils/data/_selection.py +126 -0
  70. dataeval/utils/{dataset/split.py → data/_split.py} +12 -38
  71. dataeval/utils/data/_targets.py +85 -0
  72. dataeval/utils/data/collate.py +103 -0
  73. dataeval/utils/data/datasets/__init__.py +17 -0
  74. dataeval/utils/data/datasets/_base.py +254 -0
  75. dataeval/utils/data/datasets/_cifar10.py +134 -0
  76. dataeval/utils/data/datasets/_fileio.py +168 -0
  77. dataeval/utils/data/datasets/_milco.py +153 -0
  78. dataeval/utils/data/datasets/_mixin.py +56 -0
  79. dataeval/utils/data/datasets/_mnist.py +183 -0
  80. dataeval/utils/data/datasets/_ships.py +123 -0
  81. dataeval/utils/data/datasets/_types.py +52 -0
  82. dataeval/utils/data/datasets/_voc.py +352 -0
  83. dataeval/utils/data/selections/__init__.py +15 -0
  84. dataeval/utils/data/selections/_classfilter.py +57 -0
  85. dataeval/utils/data/selections/_indices.py +26 -0
  86. dataeval/utils/data/selections/_limit.py +26 -0
  87. dataeval/utils/data/selections/_reverse.py +18 -0
  88. dataeval/utils/data/selections/_shuffle.py +29 -0
  89. dataeval/utils/metadata.py +51 -376
  90. dataeval/utils/torch/{gmm.py → _gmm.py} +4 -2
  91. dataeval/utils/torch/{internal.py → _internal.py} +21 -51
  92. dataeval/utils/torch/models.py +43 -2
  93. dataeval/workflows/__init__.py +2 -1
  94. dataeval/workflows/sufficiency.py +11 -346
  95. {dataeval-0.76.1.dist-info → dataeval-0.82.0.dist-info}/METADATA +5 -2
  96. dataeval-0.82.0.dist-info/RECORD +104 -0
  97. dataeval/detectors/linters/clusterer.py +0 -512
  98. dataeval/detectors/linters/merged_stats.py +0 -49
  99. dataeval/detectors/ood/metadata_ks_compare.py +0 -129
  100. dataeval/detectors/ood/metadata_least_likely.py +0 -119
  101. dataeval/interop.py +0 -69
  102. dataeval/metrics/bias/coverage.py +0 -194
  103. dataeval/metrics/stats/datasetstats.py +0 -202
  104. dataeval/metrics/stats/dimensionstats.py +0 -115
  105. dataeval/metrics/stats/labelstats.py +0 -210
  106. dataeval/utils/dataset/__init__.py +0 -7
  107. dataeval/utils/dataset/datasets.py +0 -412
  108. dataeval/utils/dataset/read.py +0 -63
  109. dataeval-0.76.1.dist-info/RECORD +0 -67
  110. /dataeval/{log.py → _log.py} +0 -0
  111. /dataeval/utils/torch/{blocks.py → _blocks.py} +0 -0
  112. {dataeval-0.76.1.dist-info → dataeval-0.82.0.dist-info}/LICENSE.txt +0 -0
  113. {dataeval-0.76.1.dist-info → dataeval-0.82.0.dist-info}/WHEEL +0 -0
dataeval/__init__.py CHANGED
@@ -7,12 +7,12 @@ shifts that impact performance of deployed models.
7
7
 
8
8
  from __future__ import annotations
9
9
 
10
- __all__ = ["detectors", "log", "metrics", "utils", "workflows"]
11
- __version__ = "0.76.1"
10
+ __all__ = ["config", "detectors", "log", "metrics", "typing", "utils", "workflows"]
11
+ __version__ = "0.82.0"
12
12
 
13
13
  import logging
14
14
 
15
- from dataeval import detectors, metrics, utils, workflows
15
+ from dataeval import config, detectors, metrics, typing, utils, workflows
16
16
 
17
17
  logging.getLogger(__name__).addHandler(logging.NullHandler())
18
18
 
dataeval/config.py ADDED
@@ -0,0 +1,77 @@
1
+ """
2
+ Global configuration settings for DataEval.
3
+ """
4
+
5
+ from __future__ import annotations
6
+
7
+ __all__ = ["get_device", "set_device", "get_max_processes", "set_max_processes"]
8
+
9
+ import torch
10
+ from torch import device
11
+
12
+ _device: device | None = None
13
+ _processes: int | None = None
14
+
15
+
16
+ def set_device(device: str | device | int) -> None:
17
+ """
18
+ Sets the default device to use when executing against a PyTorch backend.
19
+
20
+ Parameters
21
+ ----------
22
+ device : str or int or `torch.device`
23
+ The default device to use. See `torch.device <https://pytorch.org/docs/stable/tensor_attributes.html#torch.device>`_
24
+ documentation for more information.
25
+ """
26
+ global _device
27
+ _device = torch.device(device)
28
+
29
+
30
+ def get_device(override: str | device | int | None = None) -> torch.device:
31
+ """
32
+ Returns the PyTorch device to use.
33
+
34
+ Parameters
35
+ ----------
36
+ override : str or int or `torch.device` or None, default None
37
+ The user specified override if provided, otherwise returns the default device.
38
+
39
+ Returns
40
+ -------
41
+ `torch.device`
42
+ """
43
+ if override is None:
44
+ global _device
45
+ return torch.get_default_device() if _device is None else _device
46
+ else:
47
+ return torch.device(override)
48
+
49
+
50
+ def set_max_processes(processes: int | None) -> None:
51
+ """
52
+ Sets the maximum number of worker processes to use when running tasks that support parallel processing.
53
+
54
+ Parameters
55
+ ----------
56
+ processes : int or None
57
+ The maximum number of worker processes to use, or None to use
58
+ `os.process_cpu_count <https://docs.python.org/3/library/os.html#os.process_cpu_count>`_
59
+ to determine the number of worker processes.
60
+ """
61
+ global _processes
62
+ _processes = processes
63
+
64
+
65
+ def get_max_processes() -> int | None:
66
+ """
67
+ Returns the maximum number of worker processes to use when running tasks that support parallel processing.
68
+
69
+ Returns
70
+ -------
71
+ int or None
72
+ The maximum number of worker processes to use, or None to use
73
+ `os.process_cpu_count <https://docs.python.org/3/library/os.html#os.process_cpu_count>`_
74
+ to determine the number of worker processes.
75
+ """
76
+ global _processes
77
+ return _processes
@@ -4,4 +4,4 @@ Detectors can determine if a dataset or individual images in a dataset are indic
4
4
 
5
5
  __all__ = ["drift", "linters", "ood"]
6
6
 
7
- from dataeval.detectors import drift, linters, ood
7
+ from . import drift, linters, ood
@@ -14,9 +14,9 @@ __all__ = [
14
14
  ]
15
15
 
16
16
  from dataeval.detectors.drift import updates
17
- from dataeval.detectors.drift.base import DriftOutput
18
- from dataeval.detectors.drift.cvm import DriftCVM
19
- from dataeval.detectors.drift.ks import DriftKS
20
- from dataeval.detectors.drift.mmd import DriftMMD, DriftMMDOutput
21
- from dataeval.detectors.drift.torch import preprocess_drift
22
- from dataeval.detectors.drift.uncertainty import DriftUncertainty
17
+ from dataeval.detectors.drift._cvm import DriftCVM
18
+ from dataeval.detectors.drift._ks import DriftKS
19
+ from dataeval.detectors.drift._mmd import DriftMMD
20
+ from dataeval.detectors.drift._torch import preprocess_drift
21
+ from dataeval.detectors.drift._uncertainty import DriftUncertainty
22
+ from dataeval.outputs._drift import DriftMMDOutput, DriftOutput
@@ -10,86 +10,29 @@ from __future__ import annotations
10
10
 
11
11
  __all__ = []
12
12
 
13
- from abc import ABC, abstractmethod
14
- from dataclasses import dataclass
13
+ import math
14
+ from abc import abstractmethod
15
15
  from functools import wraps
16
- from typing import Any, Callable, Literal, TypeVar
16
+ from typing import Any, Callable, Literal, Protocol, TypeVar, runtime_checkable
17
17
 
18
18
  import numpy as np
19
- from numpy.typing import ArrayLike, NDArray
19
+ from numpy.typing import NDArray
20
20
 
21
- from dataeval.interop import as_numpy
22
- from dataeval.output import Output, set_metadata
21
+ from dataeval.outputs import DriftOutput
22
+ from dataeval.outputs._base import set_metadata
23
+ from dataeval.typing import Array, ArrayLike
24
+ from dataeval.utils._array import as_numpy, to_numpy
23
25
 
24
26
  R = TypeVar("R")
25
27
 
26
28
 
27
- class UpdateStrategy(ABC):
29
+ @runtime_checkable
30
+ class UpdateStrategy(Protocol):
28
31
  """
29
- Updates reference dataset for drift detector
30
-
31
- Parameters
32
- ----------
33
- n : int
34
- Update with last n instances seen by the detector.
35
- """
36
-
37
- def __init__(self, n: int) -> None:
38
- self.n = n
39
-
40
- @abstractmethod
41
- def __call__(self, x_ref: NDArray[Any], x: NDArray[Any], count: int) -> NDArray[Any]:
42
- """Abstract implementation of update strategy"""
43
-
44
-
45
- @dataclass(frozen=True)
46
- class DriftBaseOutput(Output):
47
- """
48
- Base output class for Drift Detector classes
49
-
50
- Attributes
51
- ----------
52
- is_drift : bool
53
- Drift prediction for the images
54
- threshold : float
55
- Threshold after multivariate correction if needed
32
+ Protocol for reference dataset update strategy for drift detectors
56
33
  """
57
34
 
58
- is_drift: bool
59
- threshold: float
60
- p_val: float
61
- distance: float
62
-
63
-
64
- @dataclass(frozen=True)
65
- class DriftOutput(DriftBaseOutput):
66
- """
67
- Output class for :class:`DriftCVM`, :class:`DriftKS`, and :class:`DriftUncertainty` drift detectors.
68
-
69
- Attributes
70
- ----------
71
- is_drift : bool
72
- :term:`Drift` prediction for the images
73
- threshold : float
74
- Threshold after multivariate correction if needed
75
- feature_drift : NDArray
76
- Feature-level array of images detected to have drifted
77
- feature_threshold : float
78
- Feature-level threshold to determine drift
79
- p_vals : NDArray
80
- Feature-level p-values
81
- distances : NDArray
82
- Feature-level distances
83
- """
84
-
85
- # is_drift: bool
86
- # threshold: float
87
- # p_val: float
88
- # distance: float
89
- feature_drift: NDArray[np.bool_]
90
- feature_threshold: float
91
- p_vals: NDArray[np.float32]
92
- distances: NDArray[np.float32]
35
+ def __call__(self, x_ref: NDArray[Any], x: NDArray[Any], count: int) -> NDArray[Any]: ...
93
36
 
94
37
 
95
38
  def update_x_ref(fn: Callable[..., R]) -> Callable[..., R]:
@@ -196,7 +139,7 @@ class BaseDrift:
196
139
  if correction not in ["bonferroni", "fdr"]:
197
140
  raise ValueError("`correction` must be `bonferroni` or `fdr`.")
198
141
 
199
- self._x_ref = as_numpy(x_ref)
142
+ self._x_ref = x_ref
200
143
  self.x_ref_preprocessed: bool = x_ref_preprocessed
201
144
 
202
145
  # Other attributes
@@ -204,25 +147,25 @@ class BaseDrift:
204
147
  self.update_x_ref = update_x_ref
205
148
  self.preprocess_fn = preprocess_fn
206
149
  self.correction = correction
207
- self.n: int = len(self._x_ref)
150
+ self.n: int = len(x_ref)
208
151
 
209
152
  # Ref counter for preprocessed x
210
153
  self._x_refcount = 0
211
154
 
212
155
  @property
213
- def x_ref(self) -> NDArray[Any]:
156
+ def x_ref(self) -> ArrayLike:
214
157
  """
215
158
  Retrieve the reference data, applying preprocessing if not already done.
216
159
 
217
160
  Returns
218
161
  -------
219
- NDArray
162
+ ArrayLike
220
163
  The reference dataset (`x_ref`), preprocessed if needed.
221
164
  """
222
165
  if not self.x_ref_preprocessed:
223
166
  self.x_ref_preprocessed = True
224
167
  if self.preprocess_fn is not None:
225
- self._x_ref = as_numpy(self.preprocess_fn(self._x_ref))
168
+ self._x_ref = self.preprocess_fn(self._x_ref)
226
169
 
227
170
  return self._x_ref
228
171
 
@@ -323,32 +266,44 @@ class BaseDriftUnivariate(BaseDrift):
323
266
  # lazy process n_features as needed
324
267
  if not isinstance(self._n_features, int):
325
268
  # compute number of features for the univariate tests
326
- if not isinstance(self.preprocess_fn, Callable) or self.x_ref_preprocessed:
327
- # infer features from preprocessed reference data
328
- self._n_features = self.x_ref.reshape(self.x_ref.shape[0], -1).shape[-1]
329
- else:
330
- # infer number of features after applying preprocessing step
331
- x = as_numpy(self.preprocess_fn(self._x_ref[0:1])) # type: ignore
332
- self._n_features = x.reshape(x.shape[0], -1).shape[-1]
269
+ x_ref = (
270
+ self.x_ref
271
+ if self.preprocess_fn is None or self.x_ref_preprocessed
272
+ else self.preprocess_fn(self._x_ref[0:1])
273
+ )
274
+ # infer features from preprocessed reference data
275
+ shape = x_ref.shape if isinstance(x_ref, Array) else as_numpy(x_ref).shape
276
+ self._n_features = int(math.prod(shape[1:])) # Multiplies all channel sizes after first
333
277
 
334
278
  return self._n_features
335
279
 
336
280
  @preprocess_x
337
- @abstractmethod
338
281
  def score(self, x: ArrayLike) -> tuple[NDArray[np.float32], NDArray[np.float32]]:
339
282
  """
340
- Abstract method to calculate feature scores after preprocessing.
283
+ Calculates p-values and test statistics per feature.
341
284
 
342
285
  Parameters
343
286
  ----------
344
287
  x : ArrayLike
345
- The batch of data to calculate univariate :term:`drift<Drift>` scores for each feature.
288
+ Batch of instances
346
289
 
347
290
  Returns
348
291
  -------
349
292
  tuple[NDArray, NDArray]
350
- A tuple containing p-values and distance :term:`statistics<Statistics>` for each feature.
293
+ Feature level p-values and test statistics
351
294
  """
295
+ x_np = to_numpy(x)
296
+ x_np = x_np.reshape(x_np.shape[0], -1)
297
+ x_ref_np = as_numpy(self.x_ref)
298
+ x_ref_np = x_ref_np.reshape(x_ref_np.shape[0], -1)
299
+ p_val = np.zeros(self.n_features, dtype=np.float32)
300
+ dist = np.zeros_like(p_val)
301
+ for f in range(self.n_features):
302
+ dist[f], p_val[f] = self._score_fn(x_ref_np[:, f], x_np[:, f])
303
+ return p_val, dist
304
+
305
+ @abstractmethod
306
+ def _score_fn(self, x: NDArray[np.float32], y: NDArray[np.float32]) -> tuple[np.float32, np.float32]: ...
352
307
 
353
308
  def _apply_correction(self, p_vals: NDArray) -> tuple[bool, float]:
354
309
  """
@@ -13,11 +13,11 @@ __all__ = []
13
13
  from typing import Callable, Literal
14
14
 
15
15
  import numpy as np
16
- from numpy.typing import ArrayLike, NDArray
16
+ from numpy.typing import NDArray
17
17
  from scipy.stats import cramervonmises_2samp
18
18
 
19
- from dataeval.detectors.drift.base import BaseDriftUnivariate, UpdateStrategy, preprocess_x
20
- from dataeval.interop import to_numpy
19
+ from dataeval.detectors.drift._base import BaseDriftUnivariate, UpdateStrategy
20
+ from dataeval.typing import ArrayLike
21
21
 
22
22
 
23
23
  class DriftCVM(BaseDriftUnivariate):
@@ -55,6 +55,21 @@ class DriftCVM(BaseDriftUnivariate):
55
55
  Number of features used in the statistical test. No need to pass it if no
56
56
  preprocessing takes place. In case of a preprocessing step, this can also
57
57
  be inferred automatically but could be more expensive to compute.
58
+
59
+ Example
60
+ -------
61
+ >>> from functools import partial
62
+ >>> from dataeval.detectors.drift import preprocess_drift
63
+
64
+ Use a preprocess function to encode images before testing for drift
65
+
66
+ >>> preprocess_fn = partial(preprocess_drift, model=encoder, batch_size=64)
67
+ >>> drift = DriftCVM(train_images, preprocess_fn=preprocess_fn)
68
+
69
+ Test incoming images for drift
70
+
71
+ >>> drift.predict(test_images).drifted
72
+ True
58
73
  """
59
74
 
60
75
  def __init__(
@@ -77,28 +92,6 @@ class DriftCVM(BaseDriftUnivariate):
77
92
  n_features=n_features,
78
93
  )
79
94
 
80
- @preprocess_x
81
- def score(self, x: ArrayLike) -> tuple[NDArray[np.float32], NDArray[np.float32]]:
82
- """
83
- Performs the two-sample Cramér-von Mises test(s), computing the :term:`p-value<P-value>` and
84
- test statistic per feature.
85
-
86
- Parameters
87
- ----------
88
- x : ArrayLike
89
- Batch of instances.
90
-
91
- Returns
92
- -------
93
- tuple[NDArray, NDArray]
94
- Feature level p-values and CVM statistic
95
- """
96
- x_np = to_numpy(x)
97
- x_np = x_np.reshape(x_np.shape[0], -1)
98
- x_ref = self.x_ref.reshape(self.x_ref.shape[0], -1)
99
- p_val = np.zeros(self.n_features, dtype=np.float32)
100
- dist = np.zeros_like(p_val)
101
- for f in range(self.n_features):
102
- result = cramervonmises_2samp(x_ref[:, f], x_np[:, f], method="auto")
103
- p_val[f], dist[f] = result.pvalue, result.statistic
104
- return p_val, dist
95
+ def _score_fn(self, x: NDArray[np.float32], y: NDArray[np.float32]) -> tuple[np.float32, np.float32]:
96
+ result = cramervonmises_2samp(x, y, method="auto")
97
+ return np.float32(result.statistic), np.float32(result.pvalue)
@@ -13,11 +13,11 @@ __all__ = []
13
13
  from typing import Callable, Literal
14
14
 
15
15
  import numpy as np
16
- from numpy.typing import ArrayLike, NDArray
16
+ from numpy.typing import NDArray
17
17
  from scipy.stats import ks_2samp
18
18
 
19
- from dataeval.detectors.drift.base import BaseDriftUnivariate, UpdateStrategy, preprocess_x
20
- from dataeval.interop import to_numpy
19
+ from dataeval.detectors.drift._base import BaseDriftUnivariate, UpdateStrategy
20
+ from dataeval.typing import ArrayLike
21
21
 
22
22
 
23
23
  class DriftKS(BaseDriftUnivariate):
@@ -58,6 +58,21 @@ class DriftKS(BaseDriftUnivariate):
58
58
  Number of features used in the statistical test. No need to pass it if no
59
59
  preprocessing takes place. In case of a preprocessing step, this can also
60
60
  be inferred automatically but could be more expensive to compute.
61
+
62
+ Example
63
+ -------
64
+ >>> from functools import partial
65
+ >>> from dataeval.detectors.drift import preprocess_drift
66
+
67
+ Use a preprocess function to encode images before testing for drift
68
+
69
+ >>> preprocess_fn = partial(preprocess_drift, model=encoder, batch_size=64)
70
+ >>> drift = DriftKS(train_images, preprocess_fn=preprocess_fn)
71
+
72
+ Test incoming images for drift
73
+
74
+ >>> drift.predict(test_images).drifted
75
+ True
61
76
  """
62
77
 
63
78
  def __init__(
@@ -84,26 +99,5 @@ class DriftKS(BaseDriftUnivariate):
84
99
  # Other attributes
85
100
  self.alternative = alternative
86
101
 
87
- @preprocess_x
88
- def score(self, x: ArrayLike) -> tuple[NDArray[np.float32], NDArray[np.float32]]:
89
- """
90
- Compute KS scores and :term:Statistics` per feature.
91
-
92
- Parameters
93
- ----------
94
- x : ArrayLike
95
- Batch of instances.
96
-
97
- Returns
98
- -------
99
- tuple[NDArray, NDArray]
100
- Feature level :term:p-values and KS statistic
101
- """
102
- x = to_numpy(x)
103
- x = x.reshape(x.shape[0], -1)
104
- x_ref = self.x_ref.reshape(self.x_ref.shape[0], -1)
105
- p_val = np.zeros(self.n_features, dtype=np.float32)
106
- dist = np.zeros_like(p_val)
107
- for f in range(self.n_features):
108
- dist[f], p_val[f] = ks_2samp(x_ref[:, f], x[:, f], alternative=self.alternative, method="exact")
109
- return p_val, dist
102
+ def _score_fn(self, x: NDArray[np.float32], y: NDArray[np.float32]) -> tuple[np.float32, np.float32]:
103
+ return ks_2samp(x, y, alternative=self.alternative, method="exact")
@@ -10,43 +10,16 @@ from __future__ import annotations
10
10
 
11
11
  __all__ = []
12
12
 
13
- from dataclasses import dataclass
14
13
  from typing import Callable
15
14
 
16
15
  import torch
17
- from numpy.typing import ArrayLike
18
16
 
19
- from dataeval.detectors.drift.base import BaseDrift, DriftBaseOutput, UpdateStrategy, preprocess_x, update_x_ref
20
- from dataeval.detectors.drift.torch import GaussianRBF, mmd2_from_kernel_matrix
21
- from dataeval.interop import as_numpy
22
- from dataeval.output import set_metadata
23
- from dataeval.utils.torch.internal import get_device
24
-
25
-
26
- @dataclass(frozen=True)
27
- class DriftMMDOutput(DriftBaseOutput):
28
- """
29
- Output class for :class:`DriftMMD` :term:`drift<Drift>` detector.
30
-
31
- Attributes
32
- ----------
33
- is_drift : bool
34
- Drift prediction for the images
35
- threshold : float
36
- :term:`P-Value` used for significance of the permutation test
37
- p_val : float
38
- P-value obtained from the permutation test
39
- distance : float
40
- MMD^2 between the reference and test set
41
- distance_threshold : float
42
- MMD^2 threshold above which drift is flagged
43
- """
44
-
45
- # is_drift: bool
46
- # threshold: float
47
- # p_val: float
48
- # distance: float
49
- distance_threshold: float
17
+ from dataeval.config import get_device
18
+ from dataeval.detectors.drift._base import BaseDrift, UpdateStrategy, preprocess_x, update_x_ref
19
+ from dataeval.detectors.drift._torch import GaussianRBF, mmd2_from_kernel_matrix
20
+ from dataeval.outputs import DriftMMDOutput
21
+ from dataeval.outputs._base import set_metadata
22
+ from dataeval.typing import ArrayLike
50
23
 
51
24
 
52
25
  class DriftMMD(BaseDrift):
@@ -84,6 +57,21 @@ class DriftMMD(BaseDrift):
84
57
  device : str | None, default None
85
58
  Device type used. The default None uses the GPU and falls back on CPU.
86
59
  Can be specified by passing either 'cuda', 'gpu' or 'cpu'.
60
+
61
+ Example
62
+ -------
63
+ >>> from functools import partial
64
+ >>> from dataeval.detectors.drift import preprocess_drift
65
+
66
+ Use a preprocess function to encode images before testing for drift
67
+
68
+ >>> preprocess_fn = partial(preprocess_drift, model=encoder, batch_size=64)
69
+ >>> drift = DriftMMD(train_images, preprocess_fn=preprocess_fn)
70
+
71
+ Test incoming images for drift
72
+
73
+ >>> drift.predict(test_images).drifted
74
+ True
87
75
  """
88
76
 
89
77
  def __init__(
@@ -110,12 +98,12 @@ class DriftMMD(BaseDrift):
110
98
  self.device: torch.device = get_device(device)
111
99
 
112
100
  # initialize kernel
113
- sigma_tensor = torch.from_numpy(as_numpy(sigma)).to(self.device) if sigma is not None else None
101
+ sigma_tensor = torch.as_tensor(sigma, device=self.device) if sigma is not None else None
114
102
  self._kernel = GaussianRBF(sigma_tensor).to(self.device)
115
103
 
116
104
  # compute kernel matrix for the reference data
117
105
  if self._infer_sigma or isinstance(sigma_tensor, torch.Tensor):
118
- x = torch.from_numpy(self.x_ref).to(self.device)
106
+ x = torch.as_tensor(self.x_ref, device=self.device)
119
107
  self._k_xx = self._kernel(x, x, infer_sigma=self._infer_sigma)
120
108
  self._infer_sigma = False
121
109
  else:
@@ -147,21 +135,21 @@ class DriftMMD(BaseDrift):
147
135
  p-value obtained from the permutation test, MMD^2 between the reference and test set,
148
136
  and MMD^2 threshold above which :term:`drift<Drift>` is flagged
149
137
  """
150
- x = as_numpy(x)
151
- x_ref = torch.from_numpy(self.x_ref).to(self.device)
152
- n = x.shape[0]
153
- kernel_mat = self._kernel_matrix(x_ref, torch.from_numpy(x).to(self.device))
138
+ x_ref = torch.as_tensor(self.x_ref, device=self.device)
139
+ x_test = torch.as_tensor(x, device=self.device)
140
+ n = x_test.shape[0]
141
+ kernel_mat = self._kernel_matrix(x_ref, x_test)
154
142
  kernel_mat = kernel_mat - torch.diag(kernel_mat.diag()) # zero diagonal
155
143
  mmd2 = mmd2_from_kernel_matrix(kernel_mat, n, permute=False, zero_diag=False)
156
- mmd2_permuted = torch.Tensor(
157
- [mmd2_from_kernel_matrix(kernel_mat, n, permute=True, zero_diag=False) for _ in range(self.n_permutations)]
144
+ mmd2_permuted = torch.tensor(
145
+ [mmd2_from_kernel_matrix(kernel_mat, n, permute=True, zero_diag=False)] * self.n_permutations,
146
+ device=self.device,
158
147
  )
159
- mmd2, mmd2_permuted = mmd2.detach().cpu(), mmd2_permuted.detach().cpu()
160
148
  p_val = (mmd2 <= mmd2_permuted).float().mean()
161
149
  # compute distance threshold
162
150
  idx_threshold = int(self.p_val * len(mmd2_permuted))
163
151
  distance_threshold = torch.sort(mmd2_permuted, descending=True).values[idx_threshold]
164
- return p_val.numpy().item(), mmd2.numpy().item(), distance_threshold.numpy().item()
152
+ return float(p_val.item()), float(mmd2.item()), float(distance_threshold.item())
165
153
 
166
154
  @set_metadata
167
155
  @preprocess_x
@@ -17,7 +17,8 @@ import torch
17
17
  import torch.nn as nn
18
18
  from numpy.typing import NDArray
19
19
 
20
- from dataeval.utils.torch.internal import get_device, predict_batch
20
+ from dataeval.config import get_device
21
+ from dataeval.utils.torch._internal import predict_batch
21
22
 
22
23
 
23
24
  def mmd2_from_kernel_matrix(
@@ -14,14 +14,16 @@ from functools import partial
14
14
  from typing import Callable, Literal
15
15
 
16
16
  import numpy as np
17
- from numpy.typing import ArrayLike, NDArray
17
+ from numpy.typing import NDArray
18
18
  from scipy.special import softmax
19
19
  from scipy.stats import entropy
20
20
 
21
- from dataeval.detectors.drift.base import DriftOutput, UpdateStrategy
22
- from dataeval.detectors.drift.ks import DriftKS
23
- from dataeval.detectors.drift.torch import preprocess_drift
24
- from dataeval.utils.torch.internal import get_device
21
+ from dataeval.config import get_device
22
+ from dataeval.detectors.drift._base import UpdateStrategy
23
+ from dataeval.detectors.drift._ks import DriftKS
24
+ from dataeval.detectors.drift._torch import preprocess_drift
25
+ from dataeval.outputs import DriftOutput
26
+ from dataeval.typing import ArrayLike
25
27
 
26
28
 
27
29
  def classifier_uncertainty(
@@ -87,7 +89,7 @@ class DriftUncertainty:
87
89
  Reference data can optionally be updated using an UpdateStrategy class. Update
88
90
  using the last n instances seen by the detector with LastSeenUpdateStrategy
89
91
  or via reservoir sampling with ReservoirSamplingUpdateStrategy.
90
- preds_type : "probs" | "logits", default "logits"
92
+ preds_type : "probs" | "logits", default "probs"
91
93
  Type of prediction output by the model. Options are 'probs' (in [0,1]) or
92
94
  'logits' (in [-inf,inf]).
93
95
  batch_size : int, default 32
@@ -98,7 +100,22 @@ class DriftUncertainty:
98
100
  objects to a batch which can be processed by the model.
99
101
  device : str | None, default None
100
102
  Device type used. The default None tries to use the GPU and falls back on
101
- CPU if needed. Can be specified by passing either 'cuda', 'gpu' or 'cpu'.
103
+ CPU if needed. Can be specified by passing either 'cuda' or 'cpu'.
104
+
105
+ Example
106
+ -------
107
+ >>> model = ClassificationModel()
108
+ >>> drift = DriftUncertainty(x_ref, model=model, batch_size=20)
109
+
110
+ Verify reference images have not drifted
111
+
112
+ >>> drift.predict(x_ref.copy()).drifted
113
+ False
114
+
115
+ Test incoming images for drift
116
+
117
+ >>> drift.predict(x_test).drifted
118
+ True
102
119
  """
103
120
 
104
121
  def __init__(
@@ -7,15 +7,32 @@ from __future__ import annotations
7
7
 
8
8
  __all__ = ["LastSeenUpdate", "ReservoirSamplingUpdate"]
9
9
 
10
+ from abc import ABC, abstractmethod
10
11
  from typing import Any
11
12
 
12
13
  import numpy as np
13
14
  from numpy.typing import NDArray
14
15
 
15
- from dataeval.detectors.drift.base import UpdateStrategy
16
+
17
+ class BaseUpdateStrategy(ABC):
18
+ """
19
+ Updates reference dataset for drift detector
20
+
21
+ Parameters
22
+ ----------
23
+ n : int
24
+ Update with last n instances seen by the detector.
25
+ """
26
+
27
+ def __init__(self, n: int) -> None:
28
+ self.n = n
29
+
30
+ @abstractmethod
31
+ def __call__(self, x_ref: NDArray[Any], x: NDArray[Any], count: int) -> NDArray[Any]:
32
+ """Abstract implementation of update strategy"""
16
33
 
17
34
 
18
- class LastSeenUpdate(UpdateStrategy):
35
+ class LastSeenUpdate(BaseUpdateStrategy):
19
36
  """
20
37
  Updates reference dataset for :term:`drift<Drift>` detector using last seen method.
21
38
 
@@ -30,7 +47,7 @@ class LastSeenUpdate(UpdateStrategy):
30
47
  return x_updated[-self.n :]
31
48
 
32
49
 
33
- class ReservoirSamplingUpdate(UpdateStrategy):
50
+ class ReservoirSamplingUpdate(BaseUpdateStrategy):
34
51
  """
35
52
  Updates reference dataset for :term:`drift<Drift>` detector using reservoir sampling method.
36
53