dataeval 0.84.0__tar.gz → 0.84.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (109) hide show
  1. {dataeval-0.84.0 → dataeval-0.84.1}/PKG-INFO +2 -3
  2. {dataeval-0.84.0 → dataeval-0.84.1}/README.md +1 -2
  3. {dataeval-0.84.0 → dataeval-0.84.1}/pyproject.toml +1 -1
  4. {dataeval-0.84.0 → dataeval-0.84.1}/src/dataeval/__init__.py +1 -1
  5. {dataeval-0.84.0 → dataeval-0.84.1}/src/dataeval/detectors/drift/__init__.py +2 -2
  6. dataeval-0.84.1/src/dataeval/detectors/drift/_base.py +226 -0
  7. {dataeval-0.84.0 → dataeval-0.84.1}/src/dataeval/detectors/drift/_cvm.py +19 -30
  8. {dataeval-0.84.0 → dataeval-0.84.1}/src/dataeval/detectors/drift/_ks.py +18 -30
  9. dataeval-0.84.0/src/dataeval/detectors/drift/_torch.py → dataeval-0.84.1/src/dataeval/detectors/drift/_mmd.py +167 -75
  10. {dataeval-0.84.0 → dataeval-0.84.1}/src/dataeval/detectors/drift/_uncertainty.py +52 -56
  11. {dataeval-0.84.0 → dataeval-0.84.1}/src/dataeval/detectors/drift/updates.py +13 -12
  12. {dataeval-0.84.0 → dataeval-0.84.1}/src/dataeval/detectors/linters/duplicates.py +5 -3
  13. {dataeval-0.84.0 → dataeval-0.84.1}/src/dataeval/detectors/linters/outliers.py +2 -2
  14. {dataeval-0.84.0 → dataeval-0.84.1}/src/dataeval/detectors/ood/ae.py +1 -1
  15. {dataeval-0.84.0 → dataeval-0.84.1}/src/dataeval/metrics/stats/_base.py +7 -7
  16. {dataeval-0.84.0 → dataeval-0.84.1}/src/dataeval/metrics/stats/_dimensionstats.py +2 -2
  17. {dataeval-0.84.0 → dataeval-0.84.1}/src/dataeval/metrics/stats/_hashstats.py +2 -2
  18. {dataeval-0.84.0 → dataeval-0.84.1}/src/dataeval/metrics/stats/_imagestats.py +4 -4
  19. {dataeval-0.84.0 → dataeval-0.84.1}/src/dataeval/metrics/stats/_pixelstats.py +2 -2
  20. {dataeval-0.84.0 → dataeval-0.84.1}/src/dataeval/metrics/stats/_visualstats.py +2 -2
  21. {dataeval-0.84.0 → dataeval-0.84.1}/src/dataeval/typing.py +22 -19
  22. {dataeval-0.84.0 → dataeval-0.84.1}/src/dataeval/utils/_array.py +18 -7
  23. {dataeval-0.84.0 → dataeval-0.84.1}/src/dataeval/utils/data/_dataset.py +6 -4
  24. {dataeval-0.84.0 → dataeval-0.84.1}/src/dataeval/utils/data/_embeddings.py +46 -7
  25. {dataeval-0.84.0 → dataeval-0.84.1}/src/dataeval/utils/data/_images.py +2 -2
  26. {dataeval-0.84.0 → dataeval-0.84.1}/src/dataeval/utils/data/_metadata.py +5 -4
  27. {dataeval-0.84.0 → dataeval-0.84.1}/src/dataeval/utils/data/datasets/_base.py +7 -4
  28. {dataeval-0.84.0 → dataeval-0.84.1}/src/dataeval/utils/data/datasets/_milco.py +42 -14
  29. {dataeval-0.84.0 → dataeval-0.84.1}/src/dataeval/utils/data/datasets/_mnist.py +9 -5
  30. {dataeval-0.84.0 → dataeval-0.84.1}/src/dataeval/utils/data/datasets/_ships.py +8 -4
  31. {dataeval-0.84.0 → dataeval-0.84.1}/src/dataeval/utils/data/datasets/_voc.py +40 -19
  32. {dataeval-0.84.0 → dataeval-0.84.1}/src/dataeval/utils/data/selections/__init__.py +2 -0
  33. dataeval-0.84.1/src/dataeval/utils/data/selections/_classbalance.py +38 -0
  34. dataeval-0.84.1/src/dataeval/utils/data/selections/_classfilter.py +44 -0
  35. {dataeval-0.84.0 → dataeval-0.84.1}/src/dataeval/utils/data/selections/_prioritize.py +1 -1
  36. {dataeval-0.84.0 → dataeval-0.84.1}/src/dataeval/utils/data/selections/_shuffle.py +2 -2
  37. {dataeval-0.84.0 → dataeval-0.84.1}/src/dataeval/utils/torch/_internal.py +12 -35
  38. dataeval-0.84.0/src/dataeval/detectors/drift/_base.py +0 -374
  39. dataeval-0.84.0/src/dataeval/detectors/drift/_mmd.py +0 -178
  40. dataeval-0.84.0/src/dataeval/utils/data/selections/_classfilter.py +0 -59
  41. {dataeval-0.84.0 → dataeval-0.84.1}/LICENSE.txt +0 -0
  42. {dataeval-0.84.0 → dataeval-0.84.1}/src/dataeval/_log.py +0 -0
  43. {dataeval-0.84.0 → dataeval-0.84.1}/src/dataeval/config.py +0 -0
  44. {dataeval-0.84.0 → dataeval-0.84.1}/src/dataeval/detectors/__init__.py +0 -0
  45. {dataeval-0.84.0 → dataeval-0.84.1}/src/dataeval/detectors/linters/__init__.py +0 -0
  46. {dataeval-0.84.0 → dataeval-0.84.1}/src/dataeval/detectors/ood/__init__.py +0 -0
  47. {dataeval-0.84.0 → dataeval-0.84.1}/src/dataeval/detectors/ood/base.py +0 -0
  48. {dataeval-0.84.0 → dataeval-0.84.1}/src/dataeval/detectors/ood/mixin.py +0 -0
  49. {dataeval-0.84.0 → dataeval-0.84.1}/src/dataeval/detectors/ood/vae.py +0 -0
  50. {dataeval-0.84.0 → dataeval-0.84.1}/src/dataeval/metadata/__init__.py +0 -0
  51. {dataeval-0.84.0 → dataeval-0.84.1}/src/dataeval/metadata/_distance.py +0 -0
  52. {dataeval-0.84.0 → dataeval-0.84.1}/src/dataeval/metadata/_ood.py +0 -0
  53. {dataeval-0.84.0 → dataeval-0.84.1}/src/dataeval/metadata/_utils.py +0 -0
  54. {dataeval-0.84.0 → dataeval-0.84.1}/src/dataeval/metrics/__init__.py +0 -0
  55. {dataeval-0.84.0 → dataeval-0.84.1}/src/dataeval/metrics/bias/__init__.py +0 -0
  56. {dataeval-0.84.0 → dataeval-0.84.1}/src/dataeval/metrics/bias/_balance.py +0 -0
  57. {dataeval-0.84.0 → dataeval-0.84.1}/src/dataeval/metrics/bias/_completeness.py +0 -0
  58. {dataeval-0.84.0 → dataeval-0.84.1}/src/dataeval/metrics/bias/_coverage.py +0 -0
  59. {dataeval-0.84.0 → dataeval-0.84.1}/src/dataeval/metrics/bias/_diversity.py +0 -0
  60. {dataeval-0.84.0 → dataeval-0.84.1}/src/dataeval/metrics/bias/_parity.py +0 -0
  61. {dataeval-0.84.0 → dataeval-0.84.1}/src/dataeval/metrics/estimators/__init__.py +0 -0
  62. {dataeval-0.84.0 → dataeval-0.84.1}/src/dataeval/metrics/estimators/_ber.py +0 -0
  63. {dataeval-0.84.0 → dataeval-0.84.1}/src/dataeval/metrics/estimators/_clusterer.py +0 -0
  64. {dataeval-0.84.0 → dataeval-0.84.1}/src/dataeval/metrics/estimators/_divergence.py +0 -0
  65. {dataeval-0.84.0 → dataeval-0.84.1}/src/dataeval/metrics/estimators/_uap.py +0 -0
  66. {dataeval-0.84.0 → dataeval-0.84.1}/src/dataeval/metrics/stats/__init__.py +0 -0
  67. {dataeval-0.84.0 → dataeval-0.84.1}/src/dataeval/metrics/stats/_boxratiostats.py +0 -0
  68. {dataeval-0.84.0 → dataeval-0.84.1}/src/dataeval/metrics/stats/_labelstats.py +0 -0
  69. {dataeval-0.84.0 → dataeval-0.84.1}/src/dataeval/outputs/__init__.py +0 -0
  70. {dataeval-0.84.0 → dataeval-0.84.1}/src/dataeval/outputs/_base.py +0 -0
  71. {dataeval-0.84.0 → dataeval-0.84.1}/src/dataeval/outputs/_bias.py +0 -0
  72. {dataeval-0.84.0 → dataeval-0.84.1}/src/dataeval/outputs/_drift.py +0 -0
  73. {dataeval-0.84.0 → dataeval-0.84.1}/src/dataeval/outputs/_estimators.py +0 -0
  74. {dataeval-0.84.0 → dataeval-0.84.1}/src/dataeval/outputs/_linters.py +0 -0
  75. {dataeval-0.84.0 → dataeval-0.84.1}/src/dataeval/outputs/_metadata.py +0 -0
  76. {dataeval-0.84.0 → dataeval-0.84.1}/src/dataeval/outputs/_ood.py +0 -0
  77. {dataeval-0.84.0 → dataeval-0.84.1}/src/dataeval/outputs/_stats.py +0 -0
  78. {dataeval-0.84.0 → dataeval-0.84.1}/src/dataeval/outputs/_utils.py +0 -0
  79. {dataeval-0.84.0 → dataeval-0.84.1}/src/dataeval/outputs/_workflows.py +0 -0
  80. {dataeval-0.84.0 → dataeval-0.84.1}/src/dataeval/py.typed +0 -0
  81. {dataeval-0.84.0 → dataeval-0.84.1}/src/dataeval/utils/__init__.py +0 -0
  82. {dataeval-0.84.0 → dataeval-0.84.1}/src/dataeval/utils/_bin.py +0 -0
  83. {dataeval-0.84.0 → dataeval-0.84.1}/src/dataeval/utils/_clusterer.py +0 -0
  84. {dataeval-0.84.0 → dataeval-0.84.1}/src/dataeval/utils/_fast_mst.py +0 -0
  85. {dataeval-0.84.0 → dataeval-0.84.1}/src/dataeval/utils/_image.py +0 -0
  86. {dataeval-0.84.0 → dataeval-0.84.1}/src/dataeval/utils/_method.py +0 -0
  87. {dataeval-0.84.0 → dataeval-0.84.1}/src/dataeval/utils/_mst.py +0 -0
  88. {dataeval-0.84.0 → dataeval-0.84.1}/src/dataeval/utils/_plot.py +0 -0
  89. {dataeval-0.84.0 → dataeval-0.84.1}/src/dataeval/utils/data/__init__.py +0 -0
  90. {dataeval-0.84.0 → dataeval-0.84.1}/src/dataeval/utils/data/_selection.py +0 -0
  91. {dataeval-0.84.0 → dataeval-0.84.1}/src/dataeval/utils/data/_split.py +0 -0
  92. {dataeval-0.84.0 → dataeval-0.84.1}/src/dataeval/utils/data/_targets.py +0 -0
  93. {dataeval-0.84.0 → dataeval-0.84.1}/src/dataeval/utils/data/collate.py +0 -0
  94. {dataeval-0.84.0 → dataeval-0.84.1}/src/dataeval/utils/data/datasets/__init__.py +0 -0
  95. {dataeval-0.84.0 → dataeval-0.84.1}/src/dataeval/utils/data/datasets/_cifar10.py +9 -9
  96. {dataeval-0.84.0 → dataeval-0.84.1}/src/dataeval/utils/data/datasets/_fileio.py +0 -0
  97. {dataeval-0.84.0 → dataeval-0.84.1}/src/dataeval/utils/data/datasets/_mixin.py +0 -0
  98. {dataeval-0.84.0 → dataeval-0.84.1}/src/dataeval/utils/data/datasets/_types.py +0 -0
  99. {dataeval-0.84.0 → dataeval-0.84.1}/src/dataeval/utils/data/selections/_indices.py +0 -0
  100. {dataeval-0.84.0 → dataeval-0.84.1}/src/dataeval/utils/data/selections/_limit.py +0 -0
  101. {dataeval-0.84.0 → dataeval-0.84.1}/src/dataeval/utils/data/selections/_reverse.py +0 -0
  102. {dataeval-0.84.0 → dataeval-0.84.1}/src/dataeval/utils/metadata.py +0 -0
  103. {dataeval-0.84.0 → dataeval-0.84.1}/src/dataeval/utils/torch/__init__.py +0 -0
  104. {dataeval-0.84.0 → dataeval-0.84.1}/src/dataeval/utils/torch/_blocks.py +0 -0
  105. {dataeval-0.84.0 → dataeval-0.84.1}/src/dataeval/utils/torch/_gmm.py +0 -0
  106. {dataeval-0.84.0 → dataeval-0.84.1}/src/dataeval/utils/torch/models.py +0 -0
  107. {dataeval-0.84.0 → dataeval-0.84.1}/src/dataeval/utils/torch/trainer.py +0 -0
  108. {dataeval-0.84.0 → dataeval-0.84.1}/src/dataeval/workflows/__init__.py +0 -0
  109. {dataeval-0.84.0 → dataeval-0.84.1}/src/dataeval/workflows/sufficiency.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: dataeval
3
- Version: 0.84.0
3
+ Version: 0.84.1
4
4
  Summary: DataEval provides a simple interface to characterize image data and its impact on model performance across classification and object-detection tasks
5
5
  Home-page: https://dataeval.ai/
6
6
  License: MIT
@@ -82,8 +82,7 @@ using MAITE-compliant datasets and models.
82
82
 
83
83
  **Python versions:** 3.9 - 3.12
84
84
 
85
- **Supported packages**: *NumPy*, *Pandas*, *Sci-kit learn*, *MAITE*, *NRTK*,
86
- *Gradient*
85
+ **Supported packages**: *NumPy*, *Pandas*, *Sci-kit learn*, *MAITE*, *NRTK*
87
86
 
88
87
  Choose your preferred method of installation below or follow our
89
88
  [installation guide](https://dataeval.readthedocs.io/en/v0.74.2/installation.html).
@@ -40,8 +40,7 @@ using MAITE-compliant datasets and models.
40
40
 
41
41
  **Python versions:** 3.9 - 3.12
42
42
 
43
- **Supported packages**: *NumPy*, *Pandas*, *Sci-kit learn*, *MAITE*, *NRTK*,
44
- *Gradient*
43
+ **Supported packages**: *NumPy*, *Pandas*, *Sci-kit learn*, *MAITE*, *NRTK*
45
44
 
46
45
  Choose your preferred method of installation below or follow our
47
46
  [installation guide](https://dataeval.readthedocs.io/en/v0.74.2/installation.html).
@@ -1,6 +1,6 @@
1
1
  [tool.poetry]
2
2
  name = "dataeval"
3
- version = "0.84.0" # dynamic
3
+ version = "0.84.1" # dynamic
4
4
  description = "DataEval provides a simple interface to characterize image data and its impact on model performance across classification and object-detection tasks"
5
5
  license = "MIT"
6
6
  readme = "README.md"
@@ -8,7 +8,7 @@ shifts that impact performance of deployed models.
8
8
  from __future__ import annotations
9
9
 
10
10
  __all__ = ["config", "detectors", "log", "metrics", "typing", "utils", "workflows"]
11
- __version__ = "0.84.0"
11
+ __version__ = "0.84.1"
12
12
 
13
13
  import logging
14
14
 
@@ -9,14 +9,14 @@ __all__ = [
9
9
  "DriftMMDOutput",
10
10
  "DriftOutput",
11
11
  "DriftUncertainty",
12
- "preprocess_drift",
12
+ "UpdateStrategy",
13
13
  "updates",
14
14
  ]
15
15
 
16
16
  from dataeval.detectors.drift import updates
17
+ from dataeval.detectors.drift._base import UpdateStrategy
17
18
  from dataeval.detectors.drift._cvm import DriftCVM
18
19
  from dataeval.detectors.drift._ks import DriftKS
19
20
  from dataeval.detectors.drift._mmd import DriftMMD
20
- from dataeval.detectors.drift._torch import preprocess_drift
21
21
  from dataeval.detectors.drift._uncertainty import DriftUncertainty
22
22
  from dataeval.outputs._drift import DriftMMDOutput, DriftOutput
@@ -0,0 +1,226 @@
1
+ """
2
+ Source code derived from Alibi-Detect 0.11.4
3
+ https://github.com/SeldonIO/alibi-detect/tree/v0.11.4
4
+
5
+ Original code Copyright (c) 2023 Seldon Technologies Ltd
6
+ Licensed under Apache Software License (Apache 2.0)
7
+ """
8
+
9
+ from __future__ import annotations
10
+
11
+ __all__ = []
12
+
13
+ import math
14
+ from abc import abstractmethod
15
+ from functools import wraps
16
+ from typing import Callable, Literal, Protocol, TypeVar, runtime_checkable
17
+
18
+ import numpy as np
19
+ from numpy.typing import NDArray
20
+
21
+ from dataeval.outputs import DriftOutput
22
+ from dataeval.outputs._base import set_metadata
23
+ from dataeval.typing import Array
24
+ from dataeval.utils._array import as_numpy, flatten
25
+ from dataeval.utils.data import Embeddings
26
+
27
+ R = TypeVar("R")
28
+
29
+
30
+ @runtime_checkable
31
+ class UpdateStrategy(Protocol):
32
+ """
33
+ Protocol for reference dataset update strategy for drift detectors
34
+ """
35
+
36
+ def __call__(self, x_ref: NDArray[np.float32], x_new: NDArray[np.float32], count: int) -> NDArray[np.float32]: ...
37
+
38
+
39
+ def update_strategy(fn: Callable[..., R]) -> Callable[..., R]:
40
+ """Decorator to update x_ref with x using selected update methodology"""
41
+
42
+ @wraps(fn)
43
+ def _(self: BaseDrift, data: Embeddings | Array, *args, **kwargs) -> R:
44
+ output = fn(self, data, *args, **kwargs)
45
+
46
+ # update reference dataset
47
+ if self.update_strategy is not None:
48
+ self._x_ref = self.update_strategy(self.x_ref, self._encode(data), self.n)
49
+ self.n += len(data)
50
+
51
+ return output
52
+
53
+ return _
54
+
55
+
56
+ class BaseDrift:
57
+ p_val: float
58
+ update_strategy: UpdateStrategy | None
59
+ correction: Literal["bonferroni", "fdr"]
60
+ n: int
61
+
62
+ def __init__(
63
+ self,
64
+ data: Embeddings | Array,
65
+ p_val: float = 0.05,
66
+ update_strategy: UpdateStrategy | None = None,
67
+ correction: Literal["bonferroni", "fdr"] = "bonferroni",
68
+ ) -> None:
69
+ # Type checking
70
+ if update_strategy is not None and not isinstance(update_strategy, UpdateStrategy):
71
+ raise ValueError("`update_strategy` is not a valid UpdateStrategy class.")
72
+ if correction not in ["bonferroni", "fdr"]:
73
+ raise ValueError("`correction` must be `bonferroni` or `fdr`.")
74
+
75
+ self._data = data
76
+ self.p_val = p_val
77
+ self.update_strategy = update_strategy
78
+ self.correction = correction
79
+ self.n = len(data)
80
+
81
+ self._x_ref: NDArray[np.float32] | None = None
82
+
83
+ @property
84
+ def x_ref(self) -> NDArray[np.float32]:
85
+ """
86
+ Retrieve the reference data of the drift detector.
87
+
88
+ Returns
89
+ -------
90
+ NDArray[np.float32]
91
+ The reference data as a 32-bit floating point numpy array.
92
+ """
93
+ if self._x_ref is None:
94
+ self._x_ref = self._encode(self._data)
95
+ return self._x_ref
96
+
97
+ def _encode(self, data: Embeddings | Array) -> NDArray[np.float32]:
98
+ array = (
99
+ data.to_numpy().astype(np.float32)
100
+ if isinstance(data, Embeddings)
101
+ else self._data.new(data).to_numpy().astype(np.float32)
102
+ if isinstance(self._data, Embeddings)
103
+ else as_numpy(data).astype(np.float32)
104
+ )
105
+ return flatten(array)
106
+
107
+
108
+ class BaseDriftUnivariate(BaseDrift):
109
+ def __init__(
110
+ self,
111
+ data: Embeddings | Array,
112
+ p_val: float = 0.05,
113
+ update_strategy: UpdateStrategy | None = None,
114
+ correction: Literal["bonferroni", "fdr"] = "bonferroni",
115
+ n_features: int | None = None,
116
+ ) -> None:
117
+ super().__init__(data, p_val, update_strategy, correction)
118
+
119
+ self._n_features = n_features
120
+
121
+ @property
122
+ def n_features(self) -> int:
123
+ """
124
+ Get the number of features in the reference data.
125
+
126
+ If the number of features is not provided during initialization, it will be inferred
127
+ from the reference data (``x_ref``).
128
+
129
+ Returns
130
+ -------
131
+ int
132
+ Number of features in the reference data.
133
+ """
134
+ # lazy process n_features as needed
135
+ if self._n_features is None:
136
+ self._n_features = int(math.prod(self._data[0].shape))
137
+
138
+ return self._n_features
139
+
140
+ def score(self, data: Embeddings | Array) -> tuple[NDArray[np.float32], NDArray[np.float32]]:
141
+ """
142
+ Calculates p-values and test statistics per feature.
143
+
144
+ Parameters
145
+ ----------
146
+ data : Embeddings or Array
147
+ Batch of instances to score.
148
+
149
+ Returns
150
+ -------
151
+ tuple[NDArray, NDArray]
152
+ Feature level p-values and test statistics
153
+ """
154
+ x_np = self._encode(data)
155
+ p_val = np.zeros(self.n_features, dtype=np.float32)
156
+ dist = np.zeros_like(p_val)
157
+ for f in range(self.n_features):
158
+ dist[f], p_val[f] = self._score_fn(self.x_ref[:, f], x_np[:, f])
159
+ return p_val, dist
160
+
161
+ @abstractmethod
162
+ def _score_fn(self, x: NDArray[np.float32], y: NDArray[np.float32]) -> tuple[np.float32, np.float32]: ...
163
+
164
+ def _apply_correction(self, p_vals: NDArray[np.float32]) -> tuple[bool, float]:
165
+ """
166
+ Apply the specified correction method (Bonferroni or FDR) to the p-values.
167
+
168
+ If the correction method is Bonferroni, the threshold for detecting :term:`drift<Drift>`
169
+ is divided by the number of features. For FDR, the correction is applied
170
+ using the Benjamini-Hochberg procedure.
171
+
172
+ Parameters
173
+ ----------
174
+ p_vals : NDArray
175
+ Array of p-values from the univariate tests for each feature.
176
+
177
+ Returns
178
+ -------
179
+ tuple[bool, float]
180
+ A tuple containing a boolean indicating if drift was detected and the
181
+ threshold after correction.
182
+ """
183
+ if self.correction == "bonferroni":
184
+ threshold = self.p_val / self.n_features
185
+ drift_pred = bool((p_vals < threshold).any())
186
+ return drift_pred, threshold
187
+ elif self.correction == "fdr":
188
+ n = p_vals.shape[0]
189
+ i = np.arange(n) + np.int_(1)
190
+ p_sorted = np.sort(p_vals)
191
+ q_threshold = self.p_val * i / n
192
+ below_threshold = p_sorted < q_threshold
193
+ try:
194
+ idx_threshold = int(np.where(below_threshold)[0].max())
195
+ except ValueError: # sorted p-values not below thresholds
196
+ return bool(below_threshold.any()), q_threshold.min()
197
+ return bool(below_threshold.any()), q_threshold[idx_threshold]
198
+ else:
199
+ raise ValueError("`correction` needs to be either `bonferroni` or `fdr`.")
200
+
201
+ @set_metadata
202
+ @update_strategy
203
+ def predict(self, data: Embeddings | Array) -> DriftOutput:
204
+ """
205
+ Predict whether a batch of data has drifted from the reference data and update
206
+ reference data using specified update strategy.
207
+
208
+ Parameters
209
+ ----------
210
+ data : Embeddings or Array
211
+ Batch of instances to predict drift on.
212
+
213
+ Returns
214
+ -------
215
+ DriftOutput
216
+ Dictionary containing the :term:`drift<Drift>` prediction and optionally the feature level
217
+ p-values, threshold after multivariate correction if needed and test :term:`statistics<Statistics>`.
218
+ """
219
+ # compute drift scores
220
+ p_vals, dist = self.score(data)
221
+
222
+ feature_drift = (p_vals < self.p_val).astype(np.bool_)
223
+ drift_pred, threshold = self._apply_correction(p_vals)
224
+ return DriftOutput(
225
+ drift_pred, threshold, float(np.mean(p_vals)), float(np.mean(dist)), feature_drift, self.p_val, p_vals, dist
226
+ )
@@ -10,14 +10,15 @@ from __future__ import annotations
10
10
 
11
11
  __all__ = []
12
12
 
13
- from typing import Callable, Literal
13
+ from typing import Literal
14
14
 
15
15
  import numpy as np
16
16
  from numpy.typing import NDArray
17
17
  from scipy.stats import cramervonmises_2samp
18
18
 
19
19
  from dataeval.detectors.drift._base import BaseDriftUnivariate, UpdateStrategy
20
- from dataeval.typing import ArrayLike
20
+ from dataeval.typing import Array
21
+ from dataeval.utils.data._embeddings import Embeddings
21
22
 
22
23
 
23
24
  class DriftCVM(BaseDriftUnivariate):
@@ -31,40 +32,32 @@ class DriftCVM(BaseDriftUnivariate):
31
32
 
32
33
  Parameters
33
34
  ----------
34
- x_ref : ArrayLike
35
+ data : Embeddings or Array
35
36
  Data used as reference distribution.
36
- p_val : float | None, default 0.05
37
+ p_val : float or None, default 0.05
37
38
  :term:`p-value<P-Value>` used for significance of the statistical test for each feature.
38
39
  If the FDR correction method is used, this corresponds to the acceptable
39
40
  q-value.
40
- x_ref_preprocessed : bool, default False
41
- Whether the given reference data ``x_ref`` has been preprocessed yet.
42
- If ``True``, only the test data ``x`` will be preprocessed at prediction time.
43
- If ``False``, the reference data will also be preprocessed.
44
- update_x_ref : UpdateStrategy | None, default None
41
+ update_strategy : UpdateStrategy or None, default None
45
42
  Reference data can optionally be updated using an UpdateStrategy class. Update
46
43
  using the last n instances seen by the detector with LastSeenUpdateStrategy
47
44
  or via reservoir sampling with ReservoirSamplingUpdateStrategy.
48
- preprocess_fn : Callable | None, default None
49
- Function to preprocess the data before computing the data drift metrics.
50
- Typically a :term:`dimensionality reduction<Dimensionality Reduction>` technique.
51
- correction : "bonferroni" | "fdr", default "bonferroni"
45
+ correction : "bonferroni" or "fdr", default "bonferroni"
52
46
  Correction type for multivariate data. Either 'bonferroni' or 'fdr' (False
53
47
  Discovery Rate).
54
- n_features : int | None, default None
55
- Number of features used in the statistical test. No need to pass it if no
56
- preprocessing takes place. In case of a preprocessing step, this can also
57
- be inferred automatically but could be more expensive to compute.
48
+ n_features : int or None, default None
49
+ Number of features used in the univariate drift tests. If not provided, it will
50
+ be inferred from the data.
51
+
58
52
 
59
53
  Example
60
54
  -------
61
- >>> from functools import partial
62
- >>> from dataeval.detectors.drift import preprocess_drift
55
+ >>> from dataeval.utils.data import Embeddings
63
56
 
64
- Use a preprocess function to encode images before testing for drift
57
+ Use Embeddings to encode images before testing for drift
65
58
 
66
- >>> preprocess_fn = partial(preprocess_drift, model=encoder, batch_size=64)
67
- >>> drift = DriftCVM(train_images, preprocess_fn=preprocess_fn)
59
+ >>> train_emb = Embeddings(train_images, model=encoder, batch_size=64)
60
+ >>> drift = DriftCVM(train_emb)
68
61
 
69
62
  Test incoming images for drift
70
63
 
@@ -74,20 +67,16 @@ class DriftCVM(BaseDriftUnivariate):
74
67
 
75
68
  def __init__(
76
69
  self,
77
- x_ref: ArrayLike,
70
+ data: Embeddings | Array,
78
71
  p_val: float = 0.05,
79
- x_ref_preprocessed: bool = False,
80
- update_x_ref: UpdateStrategy | None = None,
81
- preprocess_fn: Callable[[ArrayLike], ArrayLike] | None = None,
72
+ update_strategy: UpdateStrategy | None = None,
82
73
  correction: Literal["bonferroni", "fdr"] = "bonferroni",
83
74
  n_features: int | None = None,
84
75
  ) -> None:
85
76
  super().__init__(
86
- x_ref=x_ref,
77
+ data=data,
87
78
  p_val=p_val,
88
- x_ref_preprocessed=x_ref_preprocessed,
89
- update_x_ref=update_x_ref,
90
- preprocess_fn=preprocess_fn,
79
+ update_strategy=update_strategy,
91
80
  correction=correction,
92
81
  n_features=n_features,
93
82
  )
@@ -10,14 +10,15 @@ from __future__ import annotations
10
10
 
11
11
  __all__ = []
12
12
 
13
- from typing import Callable, Literal
13
+ from typing import Literal
14
14
 
15
15
  import numpy as np
16
16
  from numpy.typing import NDArray
17
17
  from scipy.stats import ks_2samp
18
18
 
19
19
  from dataeval.detectors.drift._base import BaseDriftUnivariate, UpdateStrategy
20
- from dataeval.typing import ArrayLike
20
+ from dataeval.typing import Array
21
+ from dataeval.utils.data._embeddings import Embeddings
21
22
 
22
23
 
23
24
  class DriftKS(BaseDriftUnivariate):
@@ -31,43 +32,34 @@ class DriftKS(BaseDriftUnivariate):
31
32
 
32
33
  Parameters
33
34
  ----------
34
- x_ref : ArrayLike
35
+ data : Embeddings or Array
35
36
  Data used as reference distribution.
36
- p_val : float | None, default 0.05
37
+ p_val : float or None, default 0.05
37
38
  :term:`p-value<P-Value>` used for significance of the statistical test for each feature.
38
39
  If the FDR correction method is used, this corresponds to the acceptable
39
40
  q-value.
40
- x_ref_preprocessed : bool, default False
41
- Whether the given reference data ``x_ref`` has been preprocessed yet.
42
- If ``True``, only the test data ``x`` will be preprocessed at prediction time.
43
- If ``False``, the reference data will also be preprocessed.
44
- update_x_ref : UpdateStrategy | None, default None
41
+ update_strategy : UpdateStrategy or None, default None
45
42
  Reference data can optionally be updated using an UpdateStrategy class. Update
46
43
  using the last n instances seen by the detector with LastSeenUpdateStrategy
47
44
  or via reservoir sampling with ReservoirSamplingUpdateStrategy.
48
- preprocess_fn : Callable | None, default None
49
- Function to preprocess the data before computing the data :term:`drift<Drift>` metrics.
50
- Typically a :term:`dimensionality reduction<Dimensionality Reduction>` technique.
51
- correction : "bonferroni" | "fdr", default "bonferroni"
45
+ correction : "bonferroni" or "fdr", default "bonferroni"
52
46
  Correction type for multivariate data. Either 'bonferroni' or 'fdr' (False
53
47
  Discovery Rate).
54
- alternative : "two-sided" | "less" | "greater", default "two-sided"
48
+ alternative : "two-sided", "less" or "greater", default "two-sided"
55
49
  Defines the alternative hypothesis. Options are 'two-sided', 'less' or
56
50
  'greater'.
57
51
  n_features : int | None, default None
58
- Number of features used in the statistical test. No need to pass it if no
59
- preprocessing takes place. In case of a preprocessing step, this can also
60
- be inferred automatically but could be more expensive to compute.
52
+ Number of features used in the univariate drift tests. If not provided, it will
53
+ be inferred from the data.
61
54
 
62
55
  Example
63
56
  -------
64
- >>> from functools import partial
65
- >>> from dataeval.detectors.drift import preprocess_drift
57
+ >>> from dataeval.utils.data import Embeddings
66
58
 
67
- Use a preprocess function to encode images before testing for drift
59
+ Use Embeddings to encode images before testing for drift
68
60
 
69
- >>> preprocess_fn = partial(preprocess_drift, model=encoder, batch_size=64)
70
- >>> drift = DriftKS(train_images, preprocess_fn=preprocess_fn)
61
+ >>> train_emb = Embeddings(train_images, model=encoder, batch_size=64)
62
+ >>> drift = DriftKS(train_emb)
71
63
 
72
64
  Test incoming images for drift
73
65
 
@@ -77,21 +69,17 @@ class DriftKS(BaseDriftUnivariate):
77
69
 
78
70
  def __init__(
79
71
  self,
80
- x_ref: ArrayLike,
72
+ data: Embeddings | Array,
81
73
  p_val: float = 0.05,
82
- x_ref_preprocessed: bool = False,
83
- update_x_ref: UpdateStrategy | None = None,
84
- preprocess_fn: Callable[[ArrayLike], ArrayLike] | None = None,
74
+ update_strategy: UpdateStrategy | None = None,
85
75
  correction: Literal["bonferroni", "fdr"] = "bonferroni",
86
76
  alternative: Literal["two-sided", "less", "greater"] = "two-sided",
87
77
  n_features: int | None = None,
88
78
  ) -> None:
89
79
  super().__init__(
90
- x_ref=x_ref,
80
+ data=data,
91
81
  p_val=p_val,
92
- x_ref_preprocessed=x_ref_preprocessed,
93
- update_x_ref=update_x_ref,
94
- preprocess_fn=preprocess_fn,
82
+ update_strategy=update_strategy,
95
83
  correction=correction,
96
84
  n_features=n_features,
97
85
  )