dataeval 0.76.1__py3-none-any.whl → 0.81.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (96) hide show
  1. dataeval/__init__.py +3 -3
  2. dataeval/{output.py → _output.py} +14 -0
  3. dataeval/config.py +77 -0
  4. dataeval/detectors/__init__.py +1 -1
  5. dataeval/detectors/drift/__init__.py +6 -6
  6. dataeval/detectors/drift/{base.py → _base.py} +41 -30
  7. dataeval/detectors/drift/{cvm.py → _cvm.py} +21 -28
  8. dataeval/detectors/drift/{ks.py → _ks.py} +20 -26
  9. dataeval/detectors/drift/{mmd.py → _mmd.py} +33 -19
  10. dataeval/detectors/drift/{torch.py → _torch.py} +2 -1
  11. dataeval/detectors/drift/{uncertainty.py → _uncertainty.py} +23 -7
  12. dataeval/detectors/drift/updates.py +1 -1
  13. dataeval/detectors/linters/__init__.py +0 -3
  14. dataeval/detectors/linters/duplicates.py +17 -8
  15. dataeval/detectors/linters/outliers.py +23 -14
  16. dataeval/detectors/ood/ae.py +29 -8
  17. dataeval/detectors/ood/base.py +5 -4
  18. dataeval/detectors/ood/metadata_ks_compare.py +1 -1
  19. dataeval/detectors/ood/mixin.py +20 -5
  20. dataeval/detectors/ood/output.py +1 -1
  21. dataeval/detectors/ood/vae.py +73 -0
  22. dataeval/metadata/__init__.py +5 -0
  23. dataeval/metadata/_ood.py +238 -0
  24. dataeval/metrics/__init__.py +1 -1
  25. dataeval/metrics/bias/__init__.py +5 -4
  26. dataeval/metrics/bias/{balance.py → _balance.py} +67 -17
  27. dataeval/metrics/bias/{coverage.py → _coverage.py} +41 -35
  28. dataeval/metrics/bias/{diversity.py → _diversity.py} +17 -12
  29. dataeval/metrics/bias/{parity.py → _parity.py} +89 -61
  30. dataeval/metrics/estimators/__init__.py +14 -4
  31. dataeval/metrics/estimators/{ber.py → _ber.py} +42 -11
  32. dataeval/metrics/estimators/_clusterer.py +104 -0
  33. dataeval/metrics/estimators/{divergence.py → _divergence.py} +18 -13
  34. dataeval/metrics/estimators/{uap.py → _uap.py} +4 -4
  35. dataeval/metrics/stats/__init__.py +7 -7
  36. dataeval/metrics/stats/{base.py → _base.py} +52 -16
  37. dataeval/metrics/stats/{boxratiostats.py → _boxratiostats.py} +6 -9
  38. dataeval/metrics/stats/{datasetstats.py → _datasetstats.py} +10 -14
  39. dataeval/metrics/stats/{dimensionstats.py → _dimensionstats.py} +6 -5
  40. dataeval/metrics/stats/{hashstats.py → _hashstats.py} +6 -6
  41. dataeval/metrics/stats/{labelstats.py → _labelstats.py} +4 -4
  42. dataeval/metrics/stats/{pixelstats.py → _pixelstats.py} +5 -4
  43. dataeval/metrics/stats/{visualstats.py → _visualstats.py} +9 -8
  44. dataeval/typing.py +54 -0
  45. dataeval/utils/__init__.py +2 -2
  46. dataeval/utils/_array.py +169 -0
  47. dataeval/utils/_bin.py +199 -0
  48. dataeval/utils/_clusterer.py +144 -0
  49. dataeval/utils/_fast_mst.py +189 -0
  50. dataeval/utils/{image.py → _image.py} +6 -4
  51. dataeval/utils/_method.py +18 -0
  52. dataeval/utils/{shared.py → _mst.py} +3 -65
  53. dataeval/utils/{plot.py → _plot.py} +4 -4
  54. dataeval/utils/data/__init__.py +22 -0
  55. dataeval/utils/data/_embeddings.py +105 -0
  56. dataeval/utils/data/_images.py +65 -0
  57. dataeval/utils/data/_metadata.py +352 -0
  58. dataeval/utils/data/_selection.py +119 -0
  59. dataeval/utils/{dataset/split.py → data/_split.py} +13 -14
  60. dataeval/utils/data/_targets.py +73 -0
  61. dataeval/utils/data/_types.py +58 -0
  62. dataeval/utils/data/collate.py +103 -0
  63. dataeval/utils/data/datasets/__init__.py +17 -0
  64. dataeval/utils/data/datasets/_base.py +254 -0
  65. dataeval/utils/data/datasets/_cifar10.py +134 -0
  66. dataeval/utils/data/datasets/_fileio.py +168 -0
  67. dataeval/utils/data/datasets/_milco.py +153 -0
  68. dataeval/utils/data/datasets/_mixin.py +56 -0
  69. dataeval/utils/data/datasets/_mnist.py +183 -0
  70. dataeval/utils/data/datasets/_ships.py +123 -0
  71. dataeval/utils/data/datasets/_voc.py +352 -0
  72. dataeval/utils/data/selections/__init__.py +15 -0
  73. dataeval/utils/data/selections/_classfilter.py +60 -0
  74. dataeval/utils/data/selections/_indices.py +26 -0
  75. dataeval/utils/data/selections/_limit.py +26 -0
  76. dataeval/utils/data/selections/_reverse.py +18 -0
  77. dataeval/utils/data/selections/_shuffle.py +29 -0
  78. dataeval/utils/metadata.py +51 -376
  79. dataeval/utils/torch/{gmm.py → _gmm.py} +4 -2
  80. dataeval/utils/torch/{internal.py → _internal.py} +21 -51
  81. dataeval/utils/torch/models.py +43 -2
  82. dataeval/workflows/sufficiency.py +10 -9
  83. {dataeval-0.76.1.dist-info → dataeval-0.81.0.dist-info}/METADATA +4 -1
  84. dataeval-0.81.0.dist-info/RECORD +94 -0
  85. dataeval/detectors/linters/clusterer.py +0 -512
  86. dataeval/detectors/linters/merged_stats.py +0 -49
  87. dataeval/detectors/ood/metadata_least_likely.py +0 -119
  88. dataeval/interop.py +0 -69
  89. dataeval/utils/dataset/__init__.py +0 -7
  90. dataeval/utils/dataset/datasets.py +0 -412
  91. dataeval/utils/dataset/read.py +0 -63
  92. dataeval-0.76.1.dist-info/RECORD +0 -67
  93. /dataeval/{log.py → _log.py} +0 -0
  94. /dataeval/utils/torch/{blocks.py → _blocks.py} +0 -0
  95. {dataeval-0.76.1.dist-info → dataeval-0.81.0.dist-info}/LICENSE.txt +0 -0
  96. {dataeval-0.76.1.dist-info → dataeval-0.81.0.dist-info}/WHEEL +0 -0
dataeval/__init__.py CHANGED
@@ -7,12 +7,12 @@ shifts that impact performance of deployed models.
7
7
 
8
8
  from __future__ import annotations
9
9
 
10
- __all__ = ["detectors", "log", "metrics", "utils", "workflows"]
11
- __version__ = "0.76.1"
10
+ __all__ = ["config", "detectors", "log", "metrics", "typing", "utils", "workflows"]
11
+ __version__ = "0.81.0"
12
12
 
13
13
  import logging
14
14
 
15
- from dataeval import detectors, metrics, utils, workflows
15
+ from dataeval import config, detectors, metrics, typing, utils, workflows
16
16
 
17
17
  logging.getLogger(__name__).addHandler(logging.NullHandler())
18
18
 
@@ -32,9 +32,23 @@ class Output:
32
32
  return f"{self.__class__.__name__}: {str(self.dict())}"
33
33
 
34
34
  def dict(self) -> dict[str, Any]:
35
+ """
36
+ Output attributes as a dictionary.
37
+
38
+ Returns
39
+ -------
40
+ dict[str, Any]
41
+ """
35
42
  return {k: v for k, v in self.__dict__.items() if not k.startswith("_")}
36
43
 
37
44
  def meta(self) -> dict[str, Any]:
45
+ """
46
+ Execution metadata as a dictionary.
47
+
48
+ Returns
49
+ -------
50
+ dict[str, Any]
51
+ """
38
52
  return {k.removeprefix("_"): v for k, v in self.__dict__.items() if k.startswith("_")}
39
53
 
40
54
 
dataeval/config.py ADDED
@@ -0,0 +1,77 @@
1
+ """
2
+ Global configuration settings for DataEval.
3
+ """
4
+
5
+ from __future__ import annotations
6
+
7
+ __all__ = ["get_device", "set_device", "get_max_processes", "set_max_processes"]
8
+
9
+ import torch
10
+ from torch import device
11
+
12
+ _device: device | None = None
13
+ _processes: int | None = None
14
+
15
+
16
+ def set_device(device: str | device | int) -> None:
17
+ """
18
+ Sets the default device to use when executing against a PyTorch backend.
19
+
20
+ Parameters
21
+ ----------
22
+ device : str or int or `torch.device`
23
+ The default device to use. See `torch.device <https://pytorch.org/docs/stable/tensor_attributes.html#torch.device>`_
24
+ documentation for more information.
25
+ """
26
+ global _device
27
+ _device = torch.device(device)
28
+
29
+
30
+ def get_device(override: str | device | int | None = None) -> torch.device:
31
+ """
32
+ Returns the PyTorch device to use.
33
+
34
+ Parameters
35
+ ----------
36
+ override : str or int or `torch.device` or None, default None
37
+ The user specified override if provided, otherwise returns the default device.
38
+
39
+ Returns
40
+ -------
41
+ `torch.device`
42
+ """
43
+ if override is None:
44
+ global _device
45
+ return torch.get_default_device() if _device is None else _device
46
+ else:
47
+ return torch.device(override)
48
+
49
+
50
+ def set_max_processes(processes: int | None) -> None:
51
+ """
52
+ Sets the maximum number of worker processes to use when running tasks that support parallel processing.
53
+
54
+ Parameters
55
+ ----------
56
+ processes : int or None
57
+ The maximum number of worker processes to use, or None to use
58
+ `os.process_cpu_count <https://docs.python.org/3/library/os.html#os.process_cpu_count>`_
59
+ to determine the number of worker processes.
60
+ """
61
+ global _processes
62
+ _processes = processes
63
+
64
+
65
+ def get_max_processes() -> int | None:
66
+ """
67
+ Returns the maximum number of worker processes to use when running tasks that support parallel processing.
68
+
69
+ Returns
70
+ -------
71
+ int or None
72
+ The maximum number of worker processes to use, or None to use
73
+ `os.process_cpu_count <https://docs.python.org/3/library/os.html#os.process_cpu_count>`_
74
+ to determine the number of worker processes.
75
+ """
76
+ global _processes
77
+ return _processes
@@ -4,4 +4,4 @@ Detectors can determine if a dataset or individual images in a dataset are indic
4
4
 
5
5
  __all__ = ["drift", "linters", "ood"]
6
6
 
7
- from dataeval.detectors import drift, linters, ood
7
+ from . import drift, linters, ood
@@ -14,9 +14,9 @@ __all__ = [
14
14
  ]
15
15
 
16
16
  from dataeval.detectors.drift import updates
17
- from dataeval.detectors.drift.base import DriftOutput
18
- from dataeval.detectors.drift.cvm import DriftCVM
19
- from dataeval.detectors.drift.ks import DriftKS
20
- from dataeval.detectors.drift.mmd import DriftMMD, DriftMMDOutput
21
- from dataeval.detectors.drift.torch import preprocess_drift
22
- from dataeval.detectors.drift.uncertainty import DriftUncertainty
17
+ from dataeval.detectors.drift._base import DriftOutput
18
+ from dataeval.detectors.drift._cvm import DriftCVM
19
+ from dataeval.detectors.drift._ks import DriftKS
20
+ from dataeval.detectors.drift._mmd import DriftMMD, DriftMMDOutput
21
+ from dataeval.detectors.drift._torch import preprocess_drift
22
+ from dataeval.detectors.drift._uncertainty import DriftUncertainty
@@ -10,16 +10,18 @@ from __future__ import annotations
10
10
 
11
11
  __all__ = []
12
12
 
13
+ import math
13
14
  from abc import ABC, abstractmethod
14
15
  from dataclasses import dataclass
15
16
  from functools import wraps
16
17
  from typing import Any, Callable, Literal, TypeVar
17
18
 
18
19
  import numpy as np
19
- from numpy.typing import ArrayLike, NDArray
20
+ from numpy.typing import NDArray
20
21
 
21
- from dataeval.interop import as_numpy
22
- from dataeval.output import Output, set_metadata
22
+ from dataeval._output import Output, set_metadata
23
+ from dataeval.typing import Array, ArrayLike
24
+ from dataeval.utils._array import as_numpy, to_numpy
23
25
 
24
26
  R = TypeVar("R")
25
27
 
@@ -46,16 +48,9 @@ class UpdateStrategy(ABC):
46
48
  class DriftBaseOutput(Output):
47
49
  """
48
50
  Base output class for Drift Detector classes
49
-
50
- Attributes
51
- ----------
52
- is_drift : bool
53
- Drift prediction for the images
54
- threshold : float
55
- Threshold after multivariate correction if needed
56
51
  """
57
52
 
58
- is_drift: bool
53
+ drifted: bool
59
54
  threshold: float
60
55
  p_val: float
61
56
  distance: float
@@ -64,14 +59,18 @@ class DriftBaseOutput(Output):
64
59
  @dataclass(frozen=True)
65
60
  class DriftOutput(DriftBaseOutput):
66
61
  """
67
- Output class for :class:`DriftCVM`, :class:`DriftKS`, and :class:`DriftUncertainty` drift detectors.
62
+ Output class for :class:`.DriftCVM`, :class:`.DriftKS`, and :class:`.DriftUncertainty` drift detectors.
68
63
 
69
64
  Attributes
70
65
  ----------
71
- is_drift : bool
66
+ drifted : bool
72
67
  :term:`Drift` prediction for the images
73
68
  threshold : float
74
69
  Threshold after multivariate correction if needed
70
+ p_val : float
71
+ Instance-level p-value
72
+ distance : float
73
+ Instance-level distance
75
74
  feature_drift : NDArray
76
75
  Feature-level array of images detected to have drifted
77
76
  feature_threshold : float
@@ -82,7 +81,7 @@ class DriftOutput(DriftBaseOutput):
82
81
  Feature-level distances
83
82
  """
84
83
 
85
- # is_drift: bool
84
+ # drifted: bool
86
85
  # threshold: float
87
86
  # p_val: float
88
87
  # distance: float
@@ -196,7 +195,7 @@ class BaseDrift:
196
195
  if correction not in ["bonferroni", "fdr"]:
197
196
  raise ValueError("`correction` must be `bonferroni` or `fdr`.")
198
197
 
199
- self._x_ref = as_numpy(x_ref)
198
+ self._x_ref = x_ref
200
199
  self.x_ref_preprocessed: bool = x_ref_preprocessed
201
200
 
202
201
  # Other attributes
@@ -204,25 +203,25 @@ class BaseDrift:
204
203
  self.update_x_ref = update_x_ref
205
204
  self.preprocess_fn = preprocess_fn
206
205
  self.correction = correction
207
- self.n: int = len(self._x_ref)
206
+ self.n: int = len(x_ref)
208
207
 
209
208
  # Ref counter for preprocessed x
210
209
  self._x_refcount = 0
211
210
 
212
211
  @property
213
- def x_ref(self) -> NDArray[Any]:
212
+ def x_ref(self) -> ArrayLike:
214
213
  """
215
214
  Retrieve the reference data, applying preprocessing if not already done.
216
215
 
217
216
  Returns
218
217
  -------
219
- NDArray
218
+ ArrayLike
220
219
  The reference dataset (`x_ref`), preprocessed if needed.
221
220
  """
222
221
  if not self.x_ref_preprocessed:
223
222
  self.x_ref_preprocessed = True
224
223
  if self.preprocess_fn is not None:
225
- self._x_ref = as_numpy(self.preprocess_fn(self._x_ref))
224
+ self._x_ref = self.preprocess_fn(self._x_ref)
226
225
 
227
226
  return self._x_ref
228
227
 
@@ -323,32 +322,44 @@ class BaseDriftUnivariate(BaseDrift):
323
322
  # lazy process n_features as needed
324
323
  if not isinstance(self._n_features, int):
325
324
  # compute number of features for the univariate tests
326
- if not isinstance(self.preprocess_fn, Callable) or self.x_ref_preprocessed:
327
- # infer features from preprocessed reference data
328
- self._n_features = self.x_ref.reshape(self.x_ref.shape[0], -1).shape[-1]
329
- else:
330
- # infer number of features after applying preprocessing step
331
- x = as_numpy(self.preprocess_fn(self._x_ref[0:1])) # type: ignore
332
- self._n_features = x.reshape(x.shape[0], -1).shape[-1]
325
+ x_ref = (
326
+ self.x_ref
327
+ if self.preprocess_fn is None or self.x_ref_preprocessed
328
+ else self.preprocess_fn(self._x_ref[0:1])
329
+ )
330
+ # infer features from preprocessed reference data
331
+ shape = x_ref.shape if isinstance(x_ref, Array) else as_numpy(x_ref).shape
332
+ self._n_features = int(math.prod(shape[1:])) # Multiplies all channel sizes after first
333
333
 
334
334
  return self._n_features
335
335
 
336
336
  @preprocess_x
337
- @abstractmethod
338
337
  def score(self, x: ArrayLike) -> tuple[NDArray[np.float32], NDArray[np.float32]]:
339
338
  """
340
- Abstract method to calculate feature scores after preprocessing.
339
+ Calculates p-values and test statistics per feature.
341
340
 
342
341
  Parameters
343
342
  ----------
344
343
  x : ArrayLike
345
- The batch of data to calculate univariate :term:`drift<Drift>` scores for each feature.
344
+ Batch of instances
346
345
 
347
346
  Returns
348
347
  -------
349
348
  tuple[NDArray, NDArray]
350
- A tuple containing p-values and distance :term:`statistics<Statistics>` for each feature.
349
+ Feature level p-values and test statistics
351
350
  """
351
+ x_np = to_numpy(x)
352
+ x_np = x_np.reshape(x_np.shape[0], -1)
353
+ x_ref_np = as_numpy(self.x_ref)
354
+ x_ref_np = x_ref_np.reshape(x_ref_np.shape[0], -1)
355
+ p_val = np.zeros(self.n_features, dtype=np.float32)
356
+ dist = np.zeros_like(p_val)
357
+ for f in range(self.n_features):
358
+ dist[f], p_val[f] = self._score_fn(x_ref_np[:, f], x_np[:, f])
359
+ return p_val, dist
360
+
361
+ @abstractmethod
362
+ def _score_fn(self, x: NDArray[np.float32], y: NDArray[np.float32]) -> tuple[np.float32, np.float32]: ...
352
363
 
353
364
  def _apply_correction(self, p_vals: NDArray) -> tuple[bool, float]:
354
365
  """
@@ -13,11 +13,11 @@ __all__ = []
13
13
  from typing import Callable, Literal
14
14
 
15
15
  import numpy as np
16
- from numpy.typing import ArrayLike, NDArray
16
+ from numpy.typing import NDArray
17
17
  from scipy.stats import cramervonmises_2samp
18
18
 
19
- from dataeval.detectors.drift.base import BaseDriftUnivariate, UpdateStrategy, preprocess_x
20
- from dataeval.interop import to_numpy
19
+ from dataeval.detectors.drift._base import BaseDriftUnivariate, UpdateStrategy
20
+ from dataeval.typing import ArrayLike
21
21
 
22
22
 
23
23
  class DriftCVM(BaseDriftUnivariate):
@@ -55,6 +55,21 @@ class DriftCVM(BaseDriftUnivariate):
55
55
  Number of features used in the statistical test. No need to pass it if no
56
56
  preprocessing takes place. In case of a preprocessing step, this can also
57
57
  be inferred automatically but could be more expensive to compute.
58
+
59
+ Example
60
+ -------
61
+ >>> from functools import partial
62
+ >>> from dataeval.detectors.drift import preprocess_drift
63
+
64
+ Use a preprocess function to encode images before testing for drift
65
+
66
+ >>> preprocess_fn = partial(preprocess_drift, model=encoder, batch_size=64)
67
+ >>> drift = DriftCVM(train_images, preprocess_fn=preprocess_fn)
68
+
69
+ Test incoming images for drift
70
+
71
+ >>> drift.predict(test_images).drifted
72
+ True
58
73
  """
59
74
 
60
75
  def __init__(
@@ -77,28 +92,6 @@ class DriftCVM(BaseDriftUnivariate):
77
92
  n_features=n_features,
78
93
  )
79
94
 
80
- @preprocess_x
81
- def score(self, x: ArrayLike) -> tuple[NDArray[np.float32], NDArray[np.float32]]:
82
- """
83
- Performs the two-sample Cramér-von Mises test(s), computing the :term:`p-value<P-value>` and
84
- test statistic per feature.
85
-
86
- Parameters
87
- ----------
88
- x : ArrayLike
89
- Batch of instances.
90
-
91
- Returns
92
- -------
93
- tuple[NDArray, NDArray]
94
- Feature level p-values and CVM statistic
95
- """
96
- x_np = to_numpy(x)
97
- x_np = x_np.reshape(x_np.shape[0], -1)
98
- x_ref = self.x_ref.reshape(self.x_ref.shape[0], -1)
99
- p_val = np.zeros(self.n_features, dtype=np.float32)
100
- dist = np.zeros_like(p_val)
101
- for f in range(self.n_features):
102
- result = cramervonmises_2samp(x_ref[:, f], x_np[:, f], method="auto")
103
- p_val[f], dist[f] = result.pvalue, result.statistic
104
- return p_val, dist
95
+ def _score_fn(self, x: NDArray[np.float32], y: NDArray[np.float32]) -> tuple[np.float32, np.float32]:
96
+ result = cramervonmises_2samp(x, y, method="auto")
97
+ return np.float32(result.statistic), np.float32(result.pvalue)
@@ -13,11 +13,11 @@ __all__ = []
13
13
  from typing import Callable, Literal
14
14
 
15
15
  import numpy as np
16
- from numpy.typing import ArrayLike, NDArray
16
+ from numpy.typing import NDArray
17
17
  from scipy.stats import ks_2samp
18
18
 
19
- from dataeval.detectors.drift.base import BaseDriftUnivariate, UpdateStrategy, preprocess_x
20
- from dataeval.interop import to_numpy
19
+ from dataeval.detectors.drift._base import BaseDriftUnivariate, UpdateStrategy
20
+ from dataeval.typing import ArrayLike
21
21
 
22
22
 
23
23
  class DriftKS(BaseDriftUnivariate):
@@ -58,6 +58,21 @@ class DriftKS(BaseDriftUnivariate):
58
58
  Number of features used in the statistical test. No need to pass it if no
59
59
  preprocessing takes place. In case of a preprocessing step, this can also
60
60
  be inferred automatically but could be more expensive to compute.
61
+
62
+ Example
63
+ -------
64
+ >>> from functools import partial
65
+ >>> from dataeval.detectors.drift import preprocess_drift
66
+
67
+ Use a preprocess function to encode images before testing for drift
68
+
69
+ >>> preprocess_fn = partial(preprocess_drift, model=encoder, batch_size=64)
70
+ >>> drift = DriftKS(train_images, preprocess_fn=preprocess_fn)
71
+
72
+ Test incoming images for drift
73
+
74
+ >>> drift.predict(test_images).drifted
75
+ True
61
76
  """
62
77
 
63
78
  def __init__(
@@ -84,26 +99,5 @@ class DriftKS(BaseDriftUnivariate):
84
99
  # Other attributes
85
100
  self.alternative = alternative
86
101
 
87
- @preprocess_x
88
- def score(self, x: ArrayLike) -> tuple[NDArray[np.float32], NDArray[np.float32]]:
89
- """
90
- Compute KS scores and :term:Statistics` per feature.
91
-
92
- Parameters
93
- ----------
94
- x : ArrayLike
95
- Batch of instances.
96
-
97
- Returns
98
- -------
99
- tuple[NDArray, NDArray]
100
- Feature level :term:p-values and KS statistic
101
- """
102
- x = to_numpy(x)
103
- x = x.reshape(x.shape[0], -1)
104
- x_ref = self.x_ref.reshape(self.x_ref.shape[0], -1)
105
- p_val = np.zeros(self.n_features, dtype=np.float32)
106
- dist = np.zeros_like(p_val)
107
- for f in range(self.n_features):
108
- dist[f], p_val[f] = ks_2samp(x_ref[:, f], x[:, f], alternative=self.alternative, method="exact")
109
- return p_val, dist
102
+ def _score_fn(self, x: NDArray[np.float32], y: NDArray[np.float32]) -> tuple[np.float32, np.float32]:
103
+ return ks_2samp(x, y, alternative=self.alternative, method="exact")
@@ -14,23 +14,22 @@ from dataclasses import dataclass
14
14
  from typing import Callable
15
15
 
16
16
  import torch
17
- from numpy.typing import ArrayLike
18
17
 
19
- from dataeval.detectors.drift.base import BaseDrift, DriftBaseOutput, UpdateStrategy, preprocess_x, update_x_ref
20
- from dataeval.detectors.drift.torch import GaussianRBF, mmd2_from_kernel_matrix
21
- from dataeval.interop import as_numpy
22
- from dataeval.output import set_metadata
23
- from dataeval.utils.torch.internal import get_device
18
+ from dataeval._output import set_metadata
19
+ from dataeval.config import get_device
20
+ from dataeval.detectors.drift._base import BaseDrift, DriftBaseOutput, UpdateStrategy, preprocess_x, update_x_ref
21
+ from dataeval.detectors.drift._torch import GaussianRBF, mmd2_from_kernel_matrix
22
+ from dataeval.typing import ArrayLike
24
23
 
25
24
 
26
25
  @dataclass(frozen=True)
27
26
  class DriftMMDOutput(DriftBaseOutput):
28
27
  """
29
- Output class for :class:`DriftMMD` :term:`drift<Drift>` detector.
28
+ Output class for :class:`.DriftMMD` :term:`drift<Drift>` detector.
30
29
 
31
30
  Attributes
32
31
  ----------
33
- is_drift : bool
32
+ drifted : bool
34
33
  Drift prediction for the images
35
34
  threshold : float
36
35
  :term:`P-Value` used for significance of the permutation test
@@ -42,7 +41,7 @@ class DriftMMDOutput(DriftBaseOutput):
42
41
  MMD^2 threshold above which drift is flagged
43
42
  """
44
43
 
45
- # is_drift: bool
44
+ # drifted: bool
46
45
  # threshold: float
47
46
  # p_val: float
48
47
  # distance: float
@@ -84,6 +83,21 @@ class DriftMMD(BaseDrift):
84
83
  device : str | None, default None
85
84
  Device type used. The default None uses the GPU and falls back on CPU.
86
85
  Can be specified by passing either 'cuda', 'gpu' or 'cpu'.
86
+
87
+ Example
88
+ -------
89
+ >>> from functools import partial
90
+ >>> from dataeval.detectors.drift import preprocess_drift
91
+
92
+ Use a preprocess function to encode images before testing for drift
93
+
94
+ >>> preprocess_fn = partial(preprocess_drift, model=encoder, batch_size=64)
95
+ >>> drift = DriftMMD(train_images, preprocess_fn=preprocess_fn)
96
+
97
+ Test incoming images for drift
98
+
99
+ >>> drift.predict(test_images).drifted
100
+ True
87
101
  """
88
102
 
89
103
  def __init__(
@@ -110,12 +124,12 @@ class DriftMMD(BaseDrift):
110
124
  self.device: torch.device = get_device(device)
111
125
 
112
126
  # initialize kernel
113
- sigma_tensor = torch.from_numpy(as_numpy(sigma)).to(self.device) if sigma is not None else None
127
+ sigma_tensor = torch.as_tensor(sigma, device=self.device) if sigma is not None else None
114
128
  self._kernel = GaussianRBF(sigma_tensor).to(self.device)
115
129
 
116
130
  # compute kernel matrix for the reference data
117
131
  if self._infer_sigma or isinstance(sigma_tensor, torch.Tensor):
118
- x = torch.from_numpy(self.x_ref).to(self.device)
132
+ x = torch.as_tensor(self.x_ref, device=self.device)
119
133
  self._k_xx = self._kernel(x, x, infer_sigma=self._infer_sigma)
120
134
  self._infer_sigma = False
121
135
  else:
@@ -147,21 +161,21 @@ class DriftMMD(BaseDrift):
147
161
  p-value obtained from the permutation test, MMD^2 between the reference and test set,
148
162
  and MMD^2 threshold above which :term:`drift<Drift>` is flagged
149
163
  """
150
- x = as_numpy(x)
151
- x_ref = torch.from_numpy(self.x_ref).to(self.device)
152
- n = x.shape[0]
153
- kernel_mat = self._kernel_matrix(x_ref, torch.from_numpy(x).to(self.device))
164
+ x_ref = torch.as_tensor(self.x_ref, device=self.device)
165
+ x_test = torch.as_tensor(x, device=self.device)
166
+ n = x_test.shape[0]
167
+ kernel_mat = self._kernel_matrix(x_ref, x_test)
154
168
  kernel_mat = kernel_mat - torch.diag(kernel_mat.diag()) # zero diagonal
155
169
  mmd2 = mmd2_from_kernel_matrix(kernel_mat, n, permute=False, zero_diag=False)
156
- mmd2_permuted = torch.Tensor(
157
- [mmd2_from_kernel_matrix(kernel_mat, n, permute=True, zero_diag=False) for _ in range(self.n_permutations)]
170
+ mmd2_permuted = torch.tensor(
171
+ [mmd2_from_kernel_matrix(kernel_mat, n, permute=True, zero_diag=False)] * self.n_permutations,
172
+ device=self.device,
158
173
  )
159
- mmd2, mmd2_permuted = mmd2.detach().cpu(), mmd2_permuted.detach().cpu()
160
174
  p_val = (mmd2 <= mmd2_permuted).float().mean()
161
175
  # compute distance threshold
162
176
  idx_threshold = int(self.p_val * len(mmd2_permuted))
163
177
  distance_threshold = torch.sort(mmd2_permuted, descending=True).values[idx_threshold]
164
- return p_val.numpy().item(), mmd2.numpy().item(), distance_threshold.numpy().item()
178
+ return float(p_val.item()), float(mmd2.item()), float(distance_threshold.item())
165
179
 
166
180
  @set_metadata
167
181
  @preprocess_x
@@ -17,7 +17,8 @@ import torch
17
17
  import torch.nn as nn
18
18
  from numpy.typing import NDArray
19
19
 
20
- from dataeval.utils.torch.internal import get_device, predict_batch
20
+ from dataeval.config import get_device
21
+ from dataeval.utils.torch._internal import predict_batch
21
22
 
22
23
 
23
24
  def mmd2_from_kernel_matrix(
@@ -14,14 +14,15 @@ from functools import partial
14
14
  from typing import Callable, Literal
15
15
 
16
16
  import numpy as np
17
- from numpy.typing import ArrayLike, NDArray
17
+ from numpy.typing import NDArray
18
18
  from scipy.special import softmax
19
19
  from scipy.stats import entropy
20
20
 
21
- from dataeval.detectors.drift.base import DriftOutput, UpdateStrategy
22
- from dataeval.detectors.drift.ks import DriftKS
23
- from dataeval.detectors.drift.torch import preprocess_drift
24
- from dataeval.utils.torch.internal import get_device
21
+ from dataeval.config import get_device
22
+ from dataeval.detectors.drift._base import DriftOutput, UpdateStrategy
23
+ from dataeval.detectors.drift._ks import DriftKS
24
+ from dataeval.detectors.drift._torch import preprocess_drift
25
+ from dataeval.typing import ArrayLike
25
26
 
26
27
 
27
28
  def classifier_uncertainty(
@@ -87,7 +88,7 @@ class DriftUncertainty:
87
88
  Reference data can optionally be updated using an UpdateStrategy class. Update
88
89
  using the last n instances seen by the detector with LastSeenUpdateStrategy
89
90
  or via reservoir sampling with ReservoirSamplingUpdateStrategy.
90
- preds_type : "probs" | "logits", default "logits"
91
+ preds_type : "probs" | "logits", default "probs"
91
92
  Type of prediction output by the model. Options are 'probs' (in [0,1]) or
92
93
  'logits' (in [-inf,inf]).
93
94
  batch_size : int, default 32
@@ -98,7 +99,22 @@ class DriftUncertainty:
98
99
  objects to a batch which can be processed by the model.
99
100
  device : str | None, default None
100
101
  Device type used. The default None tries to use the GPU and falls back on
101
- CPU if needed. Can be specified by passing either 'cuda', 'gpu' or 'cpu'.
102
+ CPU if needed. Can be specified by passing either 'cuda' or 'cpu'.
103
+
104
+ Example
105
+ -------
106
+ >>> model = ClassificationModel()
107
+ >>> drift = DriftUncertainty(x_ref, model=model, batch_size=20)
108
+
109
+ Verify reference images have not drifted
110
+
111
+ >>> drift.predict(x_ref.copy()).drifted
112
+ False
113
+
114
+ Test incoming images for drift
115
+
116
+ >>> drift.predict(x_test).drifted
117
+ True
102
118
  """
103
119
 
104
120
  def __init__(
@@ -12,7 +12,7 @@ from typing import Any
12
12
  import numpy as np
13
13
  from numpy.typing import NDArray
14
14
 
15
- from dataeval.detectors.drift.base import UpdateStrategy
15
+ from dataeval.detectors.drift._base import UpdateStrategy
16
16
 
17
17
 
18
18
  class LastSeenUpdate(UpdateStrategy):
@@ -3,14 +3,11 @@ Linters help identify potential issues in training and test data and are an impo
3
3
  """
4
4
 
5
5
  __all__ = [
6
- "Clusterer",
7
- "ClustererOutput",
8
6
  "Duplicates",
9
7
  "DuplicatesOutput",
10
8
  "Outliers",
11
9
  "OutliersOutput",
12
10
  ]
13
11
 
14
- from dataeval.detectors.linters.clusterer import Clusterer, ClustererOutput
15
12
  from dataeval.detectors.linters.duplicates import Duplicates, DuplicatesOutput
16
13
  from dataeval.detectors.linters.outliers import Outliers, OutliersOutput