dataeval 0.65.0__tar.gz → 0.67.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {dataeval-0.65.0 → dataeval-0.67.0}/PKG-INFO +1 -1
- {dataeval-0.65.0 → dataeval-0.67.0}/pyproject.toml +2 -3
- dataeval-0.67.0/src/dataeval/__init__.py +22 -0
- {dataeval-0.65.0 → dataeval-0.67.0}/src/dataeval/_internal/detectors/clusterer.py +24 -22
- dataeval-0.67.0/src/dataeval/_internal/detectors/drift/base.py +465 -0
- {dataeval-0.65.0 → dataeval-0.67.0}/src/dataeval/_internal/detectors/drift/cvm.py +25 -23
- {dataeval-0.65.0 → dataeval-0.67.0}/src/dataeval/_internal/detectors/drift/ks.py +28 -25
- {dataeval-0.65.0 → dataeval-0.67.0}/src/dataeval/_internal/detectors/drift/mmd.py +30 -29
- {dataeval-0.65.0 → dataeval-0.67.0}/src/dataeval/_internal/detectors/drift/torch.py +66 -58
- {dataeval-0.65.0 → dataeval-0.67.0}/src/dataeval/_internal/detectors/drift/uncertainty.py +28 -28
- {dataeval-0.65.0 → dataeval-0.67.0}/src/dataeval/_internal/detectors/duplicates.py +28 -18
- {dataeval-0.65.0 → dataeval-0.67.0}/src/dataeval/_internal/detectors/ood/ae.py +15 -29
- {dataeval-0.65.0 → dataeval-0.67.0}/src/dataeval/_internal/detectors/ood/aegmm.py +33 -27
- {dataeval-0.65.0 → dataeval-0.67.0}/src/dataeval/_internal/detectors/ood/base.py +61 -43
- {dataeval-0.65.0 → dataeval-0.67.0}/src/dataeval/_internal/detectors/ood/llr.py +27 -24
- {dataeval-0.65.0 → dataeval-0.67.0}/src/dataeval/_internal/detectors/ood/vae.py +32 -31
- {dataeval-0.65.0 → dataeval-0.67.0}/src/dataeval/_internal/detectors/ood/vaegmm.py +34 -28
- dataeval-0.65.0/src/dataeval/_internal/detectors/linter.py → dataeval-0.67.0/src/dataeval/_internal/detectors/outliers.py +33 -27
- {dataeval-0.65.0 → dataeval-0.67.0}/src/dataeval/_internal/flags.py +5 -3
- {dataeval-0.65.0 → dataeval-0.67.0}/src/dataeval/_internal/interop.py +4 -2
- {dataeval-0.65.0 → dataeval-0.67.0}/src/dataeval/_internal/metrics/balance.py +33 -4
- {dataeval-0.65.0 → dataeval-0.67.0}/src/dataeval/_internal/metrics/ber.py +6 -4
- {dataeval-0.65.0 → dataeval-0.67.0}/src/dataeval/_internal/metrics/diversity.py +70 -27
- {dataeval-0.65.0 → dataeval-0.67.0}/src/dataeval/_internal/metrics/parity.py +114 -26
- dataeval-0.67.0/src/dataeval/_internal/metrics/stats.py +362 -0
- {dataeval-0.65.0 → dataeval-0.67.0}/src/dataeval/_internal/metrics/uap.py +28 -2
- {dataeval-0.65.0 → dataeval-0.67.0}/src/dataeval/_internal/metrics/utils.py +20 -18
- {dataeval-0.65.0 → dataeval-0.67.0}/src/dataeval/_internal/models/pytorch/autoencoder.py +127 -22
- {dataeval-0.65.0 → dataeval-0.67.0}/src/dataeval/_internal/models/tensorflow/autoencoder.py +33 -30
- {dataeval-0.65.0 → dataeval-0.67.0}/src/dataeval/_internal/models/tensorflow/gmm.py +4 -2
- {dataeval-0.65.0 → dataeval-0.67.0}/src/dataeval/_internal/models/tensorflow/losses.py +15 -11
- {dataeval-0.65.0 → dataeval-0.67.0}/src/dataeval/_internal/models/tensorflow/pixelcnn.py +19 -18
- {dataeval-0.65.0 → dataeval-0.67.0}/src/dataeval/_internal/models/tensorflow/trainer.py +8 -6
- {dataeval-0.65.0 → dataeval-0.67.0}/src/dataeval/_internal/models/tensorflow/utils.py +21 -19
- {dataeval-0.65.0 → dataeval-0.67.0}/src/dataeval/_internal/output.py +13 -10
- {dataeval-0.65.0 → dataeval-0.67.0}/src/dataeval/_internal/utils.py +5 -3
- {dataeval-0.65.0 → dataeval-0.67.0}/src/dataeval/_internal/workflows/sufficiency.py +42 -30
- dataeval-0.67.0/src/dataeval/detectors/__init__.py +10 -0
- dataeval-0.67.0/src/dataeval/detectors/drift/__init__.py +16 -0
- dataeval-0.67.0/src/dataeval/detectors/drift/kernels/__init__.py +6 -0
- dataeval-0.67.0/src/dataeval/detectors/drift/updates/__init__.py +3 -0
- dataeval-0.67.0/src/dataeval/detectors/linters/__init__.py +5 -0
- dataeval-0.67.0/src/dataeval/detectors/ood/__init__.py +11 -0
- dataeval-0.67.0/src/dataeval/metrics/__init__.py +3 -0
- dataeval-0.67.0/src/dataeval/metrics/bias/__init__.py +14 -0
- dataeval-0.67.0/src/dataeval/metrics/estimators/__init__.py +9 -0
- dataeval-0.67.0/src/dataeval/metrics/stats/__init__.py +6 -0
- dataeval-0.67.0/src/dataeval/tensorflow/__init__.py +3 -0
- dataeval-0.67.0/src/dataeval/tensorflow/loss/__init__.py +3 -0
- dataeval-0.67.0/src/dataeval/tensorflow/models/__init__.py +5 -0
- dataeval-0.67.0/src/dataeval/tensorflow/recon/__init__.py +3 -0
- dataeval-0.67.0/src/dataeval/torch/__init__.py +3 -0
- {dataeval-0.65.0/src/dataeval/models/torch → dataeval-0.67.0/src/dataeval/torch/models}/__init__.py +1 -2
- dataeval-0.67.0/src/dataeval/torch/trainer/__init__.py +3 -0
- dataeval-0.67.0/src/dataeval/utils/__init__.py +6 -0
- dataeval-0.67.0/src/dataeval/workflows/__init__.py +6 -0
- dataeval-0.65.0/src/dataeval/__init__.py +0 -18
- dataeval-0.65.0/src/dataeval/_internal/detectors/drift/base.py +0 -285
- dataeval-0.65.0/src/dataeval/_internal/metrics/stats.py +0 -224
- dataeval-0.65.0/src/dataeval/detectors/__init__.py +0 -29
- dataeval-0.65.0/src/dataeval/metrics/__init__.py +0 -27
- dataeval-0.65.0/src/dataeval/models/__init__.py +0 -15
- dataeval-0.65.0/src/dataeval/models/tensorflow/__init__.py +0 -6
- dataeval-0.65.0/src/dataeval/utils/__init__.py +0 -9
- dataeval-0.65.0/src/dataeval/workflows/__init__.py +0 -8
- {dataeval-0.65.0 → dataeval-0.67.0}/LICENSE.txt +0 -0
- {dataeval-0.65.0 → dataeval-0.67.0}/README.md +0 -0
- {dataeval-0.65.0 → dataeval-0.67.0}/src/dataeval/_internal/detectors/__init__.py +0 -0
- {dataeval-0.65.0 → dataeval-0.67.0}/src/dataeval/_internal/detectors/drift/__init__.py +0 -0
- {dataeval-0.65.0 → dataeval-0.67.0}/src/dataeval/_internal/detectors/ood/__init__.py +0 -0
- {dataeval-0.65.0 → dataeval-0.67.0}/src/dataeval/_internal/metrics/__init__.py +0 -0
- {dataeval-0.65.0 → dataeval-0.67.0}/src/dataeval/_internal/metrics/coverage.py +0 -0
- {dataeval-0.65.0 → dataeval-0.67.0}/src/dataeval/_internal/metrics/divergence.py +0 -0
- {dataeval-0.65.0 → dataeval-0.67.0}/src/dataeval/_internal/models/__init__.py +0 -0
- {dataeval-0.65.0 → dataeval-0.67.0}/src/dataeval/_internal/models/pytorch/__init__.py +0 -0
- {dataeval-0.65.0 → dataeval-0.67.0}/src/dataeval/_internal/models/pytorch/blocks.py +0 -0
- {dataeval-0.65.0 → dataeval-0.67.0}/src/dataeval/_internal/models/pytorch/utils.py +0 -0
- {dataeval-0.65.0 → dataeval-0.67.0}/src/dataeval/_internal/models/tensorflow/__init__.py +0 -0
- {dataeval-0.65.0 → dataeval-0.67.0}/src/dataeval/_internal/workflows/__init__.py +0 -0
- {dataeval-0.65.0 → dataeval-0.67.0}/src/dataeval/flags/__init__.py +0 -0
- {dataeval-0.65.0 → dataeval-0.67.0}/src/dataeval/py.typed +0 -0
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: dataeval
|
3
|
-
Version: 0.
|
3
|
+
Version: 0.67.0
|
4
4
|
Summary: DataEval provides a simple interface to characterize image data and its impact on model performance across classification and object-detection tasks
|
5
5
|
Home-page: https://dataeval.ai/
|
6
6
|
License: MIT
|
@@ -1,6 +1,6 @@
|
|
1
1
|
[tool.poetry]
|
2
2
|
name = "dataeval"
|
3
|
-
version = "0.
|
3
|
+
version = "0.67.0" # dynamic
|
4
4
|
description = "DataEval provides a simple interface to characterize image data and its impact on model performance across classification and object-detection tasks"
|
5
5
|
license = "MIT"
|
6
6
|
readme = "README.md"
|
@@ -94,8 +94,7 @@ jupyter-client = {version = "8.6.0", python = "~3.11"}
|
|
94
94
|
jupyter-cache = {version = "*", python = "~3.11"}
|
95
95
|
myst-nb = {version = "1.0.0", python = "~3.11"}
|
96
96
|
protobuf = {version = "4.25.3", python = "~3.11"}
|
97
|
-
sphinx-
|
98
|
-
sphinx-rtd-theme = {version = "1.3.0", python = "~3.11"}
|
97
|
+
pydata-sphinx-theme = "^0.15.4"
|
99
98
|
sphinx-design = {version = "*", python = "~3.11"}
|
100
99
|
sphinx-tabs = {version = "*", python = "~3.11"}
|
101
100
|
Sphinx = {version = "7.2.6", python = "~3.11"}
|
@@ -0,0 +1,22 @@
|
|
1
|
+
__version__ = "0.67.0"
|
2
|
+
|
3
|
+
from importlib.util import find_spec
|
4
|
+
|
5
|
+
_IS_TORCH_AVAILABLE = find_spec("torch") is not None
|
6
|
+
_IS_TENSORFLOW_AVAILABLE = find_spec("tensorflow") is not None and find_spec("tensorflow_probability") is not None
|
7
|
+
|
8
|
+
del find_spec
|
9
|
+
|
10
|
+
from . import detectors, flags, metrics # noqa: E402
|
11
|
+
|
12
|
+
__all__ = ["detectors", "flags", "metrics"]
|
13
|
+
|
14
|
+
if _IS_TORCH_AVAILABLE: # pragma: no cover
|
15
|
+
from . import torch, utils, workflows
|
16
|
+
|
17
|
+
__all__ += ["torch", "utils", "workflows"]
|
18
|
+
|
19
|
+
if _IS_TENSORFLOW_AVAILABLE: # pragma: no cover
|
20
|
+
from . import tensorflow
|
21
|
+
|
22
|
+
__all__ += ["tensorflow"]
|
@@ -1,5 +1,7 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
1
3
|
from dataclasses import dataclass
|
2
|
-
from typing import
|
4
|
+
from typing import Iterable, NamedTuple, cast
|
3
5
|
|
4
6
|
import numpy as np
|
5
7
|
from numpy.typing import ArrayLike, NDArray
|
@@ -26,10 +28,10 @@ class ClustererOutput(OutputMetadata):
|
|
26
28
|
Groups of indices which are not exact but closely related data points
|
27
29
|
"""
|
28
30
|
|
29
|
-
outliers:
|
30
|
-
potential_outliers:
|
31
|
-
duplicates:
|
32
|
-
potential_duplicates:
|
31
|
+
outliers: list[int]
|
32
|
+
potential_outliers: list[int]
|
33
|
+
duplicates: list[list[int]]
|
34
|
+
potential_duplicates: list[list[int]]
|
33
35
|
|
34
36
|
|
35
37
|
def extend_linkage(link_arr: NDArray) -> NDArray:
|
@@ -59,7 +61,7 @@ def extend_linkage(link_arr: NDArray) -> NDArray:
|
|
59
61
|
class Cluster:
|
60
62
|
__slots__ = "merged", "samples", "sample_dist", "is_copy", "count", "dist_avg", "dist_std", "out1", "out2"
|
61
63
|
|
62
|
-
def __init__(self, merged: int, samples: NDArray, sample_dist:
|
64
|
+
def __init__(self, merged: int, samples: NDArray, sample_dist: float | NDArray, is_copy: bool = False):
|
63
65
|
self.merged = merged
|
64
66
|
self.samples = np.array(samples, dtype=np.int32)
|
65
67
|
self.sample_dist = np.array([sample_dist] if np.isscalar(sample_dist) else sample_dist)
|
@@ -81,7 +83,7 @@ class Cluster:
|
|
81
83
|
self.out1 = dist > out1
|
82
84
|
self.out2 = dist > out2
|
83
85
|
|
84
|
-
def copy(self) ->
|
86
|
+
def copy(self) -> Cluster:
|
85
87
|
return Cluster(False, self.samples, self.sample_dist, True)
|
86
88
|
|
87
89
|
def __repr__(self) -> str:
|
@@ -94,7 +96,7 @@ class Cluster:
|
|
94
96
|
return f"{self.__class__.__name__}(**{repr(_params)})"
|
95
97
|
|
96
98
|
|
97
|
-
class Clusters(
|
99
|
+
class Clusters(dict[int, dict[int, Cluster]]):
|
98
100
|
def __init__(self, *args, **kwargs):
|
99
101
|
super().__init__(*args, **kwargs)
|
100
102
|
self.max_level: int = 1
|
@@ -116,10 +118,10 @@ class ClusterMergeEntry:
|
|
116
118
|
self.inner_cluster = inner_cluster
|
117
119
|
self.status = status
|
118
120
|
|
119
|
-
def __lt__(self, value:
|
121
|
+
def __lt__(self, value: ClusterMergeEntry) -> bool:
|
120
122
|
return self.level.__lt__(value.level)
|
121
123
|
|
122
|
-
def __gt__(self, value:
|
124
|
+
def __gt__(self, value: ClusterMergeEntry) -> bool:
|
123
125
|
return self.level.__gt__(value.level)
|
124
126
|
|
125
127
|
|
@@ -184,7 +186,7 @@ class Clusterer:
|
|
184
186
|
return self._clusters
|
185
187
|
|
186
188
|
@property
|
187
|
-
def last_good_merge_levels(self) ->
|
189
|
+
def last_good_merge_levels(self) -> dict[int, int]:
|
188
190
|
if self._last_good_merge_levels is None:
|
189
191
|
self._last_good_merge_levels = self._get_last_merge_levels()
|
190
192
|
return self._last_good_merge_levels
|
@@ -208,7 +210,7 @@ class Clusterer:
|
|
208
210
|
def _create_clusters(self) -> Clusters:
|
209
211
|
"""Generates clusters based on linkage matrix"""
|
210
212
|
next_cluster_id = 0
|
211
|
-
cluster_map:
|
213
|
+
cluster_map: dict[int, ClusterPosition] = {} # Dictionary to associate new cluster ids with actual clusters
|
212
214
|
clusters: Clusters = Clusters()
|
213
215
|
|
214
216
|
# Walking through the linkage array to generate clusters
|
@@ -236,7 +238,7 @@ class Clusterer:
|
|
236
238
|
# Update clusters to include previously skipped levels
|
237
239
|
clusters = self._fill_levels(clusters, left, right)
|
238
240
|
elif left or right:
|
239
|
-
child, other_id = cast(
|
241
|
+
child, other_id = cast(tuple[ClusterPosition, int], (left, right_id) if left else (right, left_id))
|
240
242
|
cc = clusters[child.level][child.cid]
|
241
243
|
samples = np.concatenate([cc.samples, [other_id]])
|
242
244
|
sample_dist = np.concatenate([cc.sample_dist, sample_dist])
|
@@ -285,7 +287,7 @@ class Clusterer:
|
|
285
287
|
|
286
288
|
return cluster_matrix
|
287
289
|
|
288
|
-
def _calc_merge_indices(self, merge_mean:
|
290
|
+
def _calc_merge_indices(self, merge_mean: list[NDArray], intra_max: list[float]) -> NDArray:
|
289
291
|
"""
|
290
292
|
Determine what clusters should be merged and return their indices
|
291
293
|
"""
|
@@ -308,7 +310,7 @@ class Clusterer:
|
|
308
310
|
mask2 = mask2_vals < one_std_check
|
309
311
|
return np.logical_or(desired_merge, mask2)
|
310
312
|
|
311
|
-
def _generate_merge_list(self, cluster_matrix: NDArray) ->
|
313
|
+
def _generate_merge_list(self, cluster_matrix: NDArray) -> list[ClusterMergeEntry]:
|
312
314
|
"""
|
313
315
|
Runs through the clusters dictionary determining when clusters merge,
|
314
316
|
and how close are those clusters when they merge.
|
@@ -325,7 +327,7 @@ class Clusterer:
|
|
325
327
|
"""
|
326
328
|
intra_max = []
|
327
329
|
merge_mean = []
|
328
|
-
merge_list:
|
330
|
+
merge_list: list[ClusterMergeEntry] = []
|
329
331
|
|
330
332
|
for level, cluster_set in self.clusters.items():
|
331
333
|
for outer_cluster, cluster in cluster_set.items():
|
@@ -363,7 +365,7 @@ class Clusterer:
|
|
363
365
|
|
364
366
|
return merge_list
|
365
367
|
|
366
|
-
def _get_last_merge_levels(self) ->
|
368
|
+
def _get_last_merge_levels(self) -> dict[int, int]:
|
367
369
|
"""
|
368
370
|
Creates a dictionary for important cluster ids mapped to their last good merge level
|
369
371
|
|
@@ -372,7 +374,7 @@ class Clusterer:
|
|
372
374
|
Dict[int, int]
|
373
375
|
A mapping of a cluster id to its last good merge level
|
374
376
|
"""
|
375
|
-
last_merge_levels:
|
377
|
+
last_merge_levels: dict[int, int] = {}
|
376
378
|
|
377
379
|
if self._max_clusters <= 1:
|
378
380
|
last_merge_levels = {0: int(self._num_samples * 0.1)}
|
@@ -395,7 +397,7 @@ class Clusterer:
|
|
395
397
|
|
396
398
|
return last_merge_levels
|
397
399
|
|
398
|
-
def find_outliers(self, last_merge_levels:
|
400
|
+
def find_outliers(self, last_merge_levels: dict[int, int]) -> tuple[list[int], list[int]]:
|
399
401
|
"""
|
400
402
|
Retrieves outliers based on when the sample was added to the cluster
|
401
403
|
and how far it was from the cluster when it was added
|
@@ -439,9 +441,9 @@ class Clusterer:
|
|
439
441
|
|
440
442
|
return sorted(outliers), sorted(possible_outliers)
|
441
443
|
|
442
|
-
def _sorted_union_find(self, index_groups: Iterable[Iterable[int]]) ->
|
444
|
+
def _sorted_union_find(self, index_groups: Iterable[Iterable[int]]) -> list[list[int]]:
|
443
445
|
"""Merges and sorts groups of indices that share any common index"""
|
444
|
-
groups:
|
446
|
+
groups: list[list[int]] = []
|
445
447
|
for indices in zip(*index_groups):
|
446
448
|
indices = set(indices)
|
447
449
|
temp = []
|
@@ -454,7 +456,7 @@ class Clusterer:
|
|
454
456
|
groups = temp
|
455
457
|
return sorted(groups)
|
456
458
|
|
457
|
-
def find_duplicates(self, last_merge_levels:
|
459
|
+
def find_duplicates(self, last_merge_levels: dict[int, int]) -> tuple[list[list[int]], list[list[int]]]:
|
458
460
|
"""
|
459
461
|
Finds duplicate and near duplicate data based on the last good merge levels when building the cluster
|
460
462
|
|
@@ -0,0 +1,465 @@
|
|
1
|
+
"""
|
2
|
+
Source code derived from Alibi-Detect 0.11.4
|
3
|
+
https://github.com/SeldonIO/alibi-detect/tree/v0.11.4
|
4
|
+
|
5
|
+
Original code Copyright (c) 2023 Seldon Technologies Ltd
|
6
|
+
Licensed under Apache Software License (Apache 2.0)
|
7
|
+
"""
|
8
|
+
|
9
|
+
from __future__ import annotations
|
10
|
+
|
11
|
+
from abc import ABC, abstractmethod
|
12
|
+
from dataclasses import dataclass
|
13
|
+
from functools import wraps
|
14
|
+
from typing import Callable, Literal
|
15
|
+
|
16
|
+
import numpy as np
|
17
|
+
from numpy.typing import ArrayLike, NDArray
|
18
|
+
|
19
|
+
from dataeval._internal.interop import to_numpy
|
20
|
+
from dataeval._internal.output import OutputMetadata, set_metadata
|
21
|
+
|
22
|
+
|
23
|
+
@dataclass(frozen=True)
|
24
|
+
class DriftBaseOutput(OutputMetadata):
|
25
|
+
"""
|
26
|
+
Output class for Drift
|
27
|
+
|
28
|
+
Attributes
|
29
|
+
----------
|
30
|
+
is_drift : bool
|
31
|
+
Drift prediction for the images
|
32
|
+
threshold : float
|
33
|
+
Threshold after multivariate correction if needed
|
34
|
+
"""
|
35
|
+
|
36
|
+
is_drift: bool
|
37
|
+
threshold: float
|
38
|
+
|
39
|
+
|
40
|
+
@dataclass(frozen=True)
|
41
|
+
class DriftOutput(DriftBaseOutput):
|
42
|
+
"""
|
43
|
+
Output class for DriftCVM and DriftKS
|
44
|
+
|
45
|
+
Attributes
|
46
|
+
----------
|
47
|
+
is_drift : bool
|
48
|
+
Drift prediction for the images
|
49
|
+
threshold : float
|
50
|
+
Threshold after multivariate correction if needed
|
51
|
+
feature_drift : NDArray
|
52
|
+
Feature-level array of images detected to have drifted
|
53
|
+
feature_threshold : float
|
54
|
+
Feature-level threshold to determine drift
|
55
|
+
p_vals : NDArray
|
56
|
+
Feature-level p-values
|
57
|
+
distances : NDArray
|
58
|
+
Feature-level distances
|
59
|
+
"""
|
60
|
+
|
61
|
+
# is_drift: bool
|
62
|
+
# threshold: float
|
63
|
+
feature_drift: NDArray[np.bool_]
|
64
|
+
feature_threshold: float
|
65
|
+
p_vals: NDArray[np.float32]
|
66
|
+
distances: NDArray[np.float32]
|
67
|
+
|
68
|
+
|
69
|
+
def update_x_ref(fn):
|
70
|
+
@wraps(fn)
|
71
|
+
def _(self, x, *args, **kwargs):
|
72
|
+
output = fn(self, x, *args, **kwargs)
|
73
|
+
|
74
|
+
# update reference dataset
|
75
|
+
if self.update_x_ref is not None:
|
76
|
+
self._x_ref = self.update_x_ref(self.x_ref, x, self.n)
|
77
|
+
|
78
|
+
# used for reservoir sampling
|
79
|
+
self.n += len(x)
|
80
|
+
return output
|
81
|
+
|
82
|
+
return _
|
83
|
+
|
84
|
+
|
85
|
+
def preprocess_x(fn):
|
86
|
+
@wraps(fn)
|
87
|
+
def _(self, x, *args, **kwargs):
|
88
|
+
if self._x_refcount == 0:
|
89
|
+
self._x = self._preprocess(x)
|
90
|
+
self._x_refcount += 1
|
91
|
+
output = fn(self, self._x, *args, **kwargs)
|
92
|
+
self._x_refcount -= 1
|
93
|
+
if self._x_refcount == 0:
|
94
|
+
del self._x
|
95
|
+
return output
|
96
|
+
|
97
|
+
return _
|
98
|
+
|
99
|
+
|
100
|
+
class UpdateStrategy(ABC):
|
101
|
+
"""
|
102
|
+
Updates reference dataset for drift detector
|
103
|
+
|
104
|
+
Parameters
|
105
|
+
----------
|
106
|
+
n : int
|
107
|
+
Update with last n instances seen by the detector.
|
108
|
+
"""
|
109
|
+
|
110
|
+
def __init__(self, n: int):
|
111
|
+
self.n = n
|
112
|
+
|
113
|
+
@abstractmethod
|
114
|
+
def __call__(self, x_ref: NDArray, x: NDArray, count: int) -> NDArray:
|
115
|
+
"""Abstract implementation of update strategy"""
|
116
|
+
|
117
|
+
|
118
|
+
class LastSeenUpdate(UpdateStrategy):
|
119
|
+
"""
|
120
|
+
Updates reference dataset for drift detector using last seen method.
|
121
|
+
|
122
|
+
Parameters
|
123
|
+
----------
|
124
|
+
n : int
|
125
|
+
Update with last n instances seen by the detector.
|
126
|
+
"""
|
127
|
+
|
128
|
+
def __call__(self, x_ref: NDArray, x: NDArray, count: int) -> NDArray:
|
129
|
+
x_updated = np.concatenate([x_ref, x], axis=0)
|
130
|
+
return x_updated[-self.n :]
|
131
|
+
|
132
|
+
|
133
|
+
class ReservoirSamplingUpdate(UpdateStrategy):
|
134
|
+
"""
|
135
|
+
Updates reference dataset for drift detector using reservoir sampling method.
|
136
|
+
|
137
|
+
Parameters
|
138
|
+
----------
|
139
|
+
n : int
|
140
|
+
Update with last n instances seen by the detector.
|
141
|
+
"""
|
142
|
+
|
143
|
+
def __call__(self, x_ref: NDArray, x: NDArray, count: int) -> NDArray:
|
144
|
+
if x.shape[0] + count <= self.n:
|
145
|
+
return np.concatenate([x_ref, x], axis=0)
|
146
|
+
|
147
|
+
n_ref = x_ref.shape[0]
|
148
|
+
output_size = min(self.n, n_ref + x.shape[0])
|
149
|
+
shape = (output_size,) + x.shape[1:]
|
150
|
+
x_reservoir = np.zeros(shape, dtype=x_ref.dtype)
|
151
|
+
x_reservoir[:n_ref] = x_ref
|
152
|
+
for item in x:
|
153
|
+
count += 1
|
154
|
+
if n_ref < self.n:
|
155
|
+
x_reservoir[n_ref, :] = item
|
156
|
+
n_ref += 1
|
157
|
+
else:
|
158
|
+
r = np.random.randint(0, count)
|
159
|
+
if r < self.n:
|
160
|
+
x_reservoir[r, :] = item
|
161
|
+
return x_reservoir
|
162
|
+
|
163
|
+
|
164
|
+
class BaseDrift:
|
165
|
+
"""
|
166
|
+
A generic drift detection component for preprocessing data and applying statistical correction.
|
167
|
+
|
168
|
+
This class handles common tasks related to drift detection, such as preprocessing
|
169
|
+
the reference data (`x_ref`), performing statistical correction (e.g., Bonferroni, FDR),
|
170
|
+
and updating the reference data if needed.
|
171
|
+
|
172
|
+
Parameters
|
173
|
+
----------
|
174
|
+
x_ref : ArrayLike
|
175
|
+
The reference dataset used for drift detection. This is the baseline data against
|
176
|
+
which new data points will be compared.
|
177
|
+
p_val : float, optional
|
178
|
+
The significance level for detecting drift, by default 0.05.
|
179
|
+
x_ref_preprocessed : bool, optional
|
180
|
+
Flag indicating whether the reference data has already been preprocessed, by default False.
|
181
|
+
update_x_ref : UpdateStrategy, optional
|
182
|
+
A strategy object specifying how the reference data should be updated when drift is detected,
|
183
|
+
by default None.
|
184
|
+
preprocess_fn : Callable[[ArrayLike], ArrayLike], optional
|
185
|
+
A function to preprocess the data before drift detection, by default None.
|
186
|
+
correction : {'bonferroni', 'fdr'}, optional
|
187
|
+
Statistical correction method applied to p-values, by default "bonferroni".
|
188
|
+
|
189
|
+
Attributes
|
190
|
+
----------
|
191
|
+
_x_ref : ArrayLike
|
192
|
+
The reference dataset that is either raw or preprocessed.
|
193
|
+
p_val : float
|
194
|
+
The significance level for drift detection.
|
195
|
+
update_x_ref : UpdateStrategy or None
|
196
|
+
The strategy for updating the reference data if applicable.
|
197
|
+
preprocess_fn : Callable or None
|
198
|
+
Function used for preprocessing input data before drift detection.
|
199
|
+
correction : str
|
200
|
+
Statistical correction method applied to p-values.
|
201
|
+
n : int
|
202
|
+
The number of samples in the reference dataset (`x_ref`).
|
203
|
+
x_ref_preprocessed : bool
|
204
|
+
A flag that indicates whether the reference dataset has been preprocessed.
|
205
|
+
_x_refcount : int
|
206
|
+
Counter for how many times the reference data has been accessed after preprocessing.
|
207
|
+
|
208
|
+
Methods
|
209
|
+
-------
|
210
|
+
x_ref:
|
211
|
+
Property that returns the reference dataset, and applies preprocessing if not already done.
|
212
|
+
_preprocess(x):
|
213
|
+
Preprocesses the given data using the specified `preprocess_fn` if provided.
|
214
|
+
"""
|
215
|
+
|
216
|
+
def __init__(
|
217
|
+
self,
|
218
|
+
x_ref: ArrayLike,
|
219
|
+
p_val: float = 0.05,
|
220
|
+
x_ref_preprocessed: bool = False,
|
221
|
+
update_x_ref: UpdateStrategy | None = None,
|
222
|
+
preprocess_fn: Callable[[ArrayLike], ArrayLike] | None = None,
|
223
|
+
correction: Literal["bonferroni", "fdr"] = "bonferroni",
|
224
|
+
) -> None:
|
225
|
+
# Type checking
|
226
|
+
if preprocess_fn is not None and not isinstance(preprocess_fn, Callable):
|
227
|
+
raise ValueError("`preprocess_fn` is not a valid Callable.")
|
228
|
+
if update_x_ref is not None and not isinstance(update_x_ref, UpdateStrategy):
|
229
|
+
raise ValueError("`update_x_ref` is not a valid ReferenceUpdate class.")
|
230
|
+
if correction not in ["bonferroni", "fdr"]:
|
231
|
+
raise ValueError("`correction` must be `bonferroni` or `fdr`.")
|
232
|
+
|
233
|
+
self._x_ref = x_ref
|
234
|
+
self.x_ref_preprocessed = x_ref_preprocessed
|
235
|
+
|
236
|
+
# Other attributes
|
237
|
+
self.p_val = p_val
|
238
|
+
self.update_x_ref = update_x_ref
|
239
|
+
self.preprocess_fn = preprocess_fn
|
240
|
+
self.correction = correction
|
241
|
+
self.n = len(self._x_ref) # type: ignore
|
242
|
+
|
243
|
+
# Ref counter for preprocessed x
|
244
|
+
self._x_refcount = 0
|
245
|
+
|
246
|
+
@property
|
247
|
+
def x_ref(self) -> NDArray:
|
248
|
+
"""
|
249
|
+
Retrieve the reference data, applying preprocessing if not already done.
|
250
|
+
|
251
|
+
Returns
|
252
|
+
-------
|
253
|
+
NDArray
|
254
|
+
The reference dataset (`x_ref`), preprocessed if needed.
|
255
|
+
"""
|
256
|
+
if not self.x_ref_preprocessed:
|
257
|
+
self.x_ref_preprocessed = True
|
258
|
+
if self.preprocess_fn is not None:
|
259
|
+
self._x_ref = self.preprocess_fn(self._x_ref)
|
260
|
+
|
261
|
+
self._x_ref = to_numpy(self._x_ref)
|
262
|
+
return self._x_ref
|
263
|
+
|
264
|
+
def _preprocess(self, x: ArrayLike) -> ArrayLike:
|
265
|
+
"""
|
266
|
+
Preprocess the given data before computing the drift scores.
|
267
|
+
|
268
|
+
Parameters
|
269
|
+
----------
|
270
|
+
x : ArrayLike
|
271
|
+
The input data to preprocess.
|
272
|
+
|
273
|
+
Returns
|
274
|
+
-------
|
275
|
+
ArrayLike
|
276
|
+
The preprocessed input data.
|
277
|
+
"""
|
278
|
+
if self.preprocess_fn is not None:
|
279
|
+
x = self.preprocess_fn(x)
|
280
|
+
return x
|
281
|
+
|
282
|
+
|
283
|
+
class BaseDriftUnivariate(BaseDrift):
|
284
|
+
"""
|
285
|
+
Base class for drift detection methods using univariate statistical tests.
|
286
|
+
|
287
|
+
This class inherits from `BaseDrift` and serves as a generic component for detecting
|
288
|
+
distribution drift in univariate features. If the number of features `n_features` is greater
|
289
|
+
than 1, a multivariate correction method (e.g., Bonferroni or FDR) is applied to control
|
290
|
+
the false positive rate, ensuring it does not exceed the specified p-value.
|
291
|
+
|
292
|
+
Parameters
|
293
|
+
----------
|
294
|
+
x_ref : ArrayLike
|
295
|
+
Reference data used as the baseline to compare against when detecting drift.
|
296
|
+
p_val : float, default 0.05
|
297
|
+
Significance level used for detecting drift.
|
298
|
+
x_ref_preprocessed : bool, default False
|
299
|
+
Indicates whether the reference data has been preprocessed.
|
300
|
+
update_x_ref : UpdateStrategy | None, default None
|
301
|
+
Strategy for updating the reference data when drift is detected.
|
302
|
+
preprocess_fn : Callable[ArrayLike] | None, default None
|
303
|
+
Function used to preprocess input data before detecting drift.
|
304
|
+
correction : 'bonferroni' | 'fdr', default 'bonferroni'
|
305
|
+
Multivariate correction method applied to p-values.
|
306
|
+
n_features : int | None, default None
|
307
|
+
Number of features used in the univariate drift tests. If not provided, it will
|
308
|
+
be inferred from the data.
|
309
|
+
|
310
|
+
Attributes
|
311
|
+
----------
|
312
|
+
_n_features : int | None
|
313
|
+
Number of features in the data. If not provided, it is lazily inferred from the
|
314
|
+
input data and any preprocessing function.
|
315
|
+
p_val : float
|
316
|
+
The significance level for drift detection.
|
317
|
+
correction : str
|
318
|
+
The method for controlling the false discovery rate or applying a Bonferroni correction.
|
319
|
+
update_x_ref : UpdateStrategy | None
|
320
|
+
Strategy for updating the reference data if applicable.
|
321
|
+
preprocess_fn : Callable | None
|
322
|
+
Function used for preprocessing input data before drift detection.
|
323
|
+
|
324
|
+
Methods
|
325
|
+
-------
|
326
|
+
n_features:
|
327
|
+
Property that returns the number of features, inferring it if necessary.
|
328
|
+
score(x):
|
329
|
+
Abstract method to compute univariate feature scores after preprocessing.
|
330
|
+
_apply_correction(p_vals):
|
331
|
+
Apply a statistical correction to p-values to account for multiple testing.
|
332
|
+
predict(x):
|
333
|
+
Predict whether drift has occurred on a batch of data, applying multivariate correction if needed.
|
334
|
+
"""
|
335
|
+
|
336
|
+
def __init__(
|
337
|
+
self,
|
338
|
+
x_ref: ArrayLike,
|
339
|
+
p_val: float = 0.05,
|
340
|
+
x_ref_preprocessed: bool = False,
|
341
|
+
update_x_ref: UpdateStrategy | None = None,
|
342
|
+
preprocess_fn: Callable[[ArrayLike], ArrayLike] | None = None,
|
343
|
+
correction: Literal["bonferroni", "fdr"] = "bonferroni",
|
344
|
+
n_features: int | None = None,
|
345
|
+
) -> None:
|
346
|
+
super().__init__(
|
347
|
+
x_ref,
|
348
|
+
p_val,
|
349
|
+
x_ref_preprocessed,
|
350
|
+
update_x_ref,
|
351
|
+
preprocess_fn,
|
352
|
+
correction,
|
353
|
+
)
|
354
|
+
|
355
|
+
self._n_features = n_features
|
356
|
+
|
357
|
+
@property
|
358
|
+
def n_features(self) -> int:
|
359
|
+
"""
|
360
|
+
Get the number of features in the reference data.
|
361
|
+
|
362
|
+
If the number of features is not provided during initialization, it will be inferred
|
363
|
+
from the reference data (``x_ref``). If a preprocessing function is provided, the number
|
364
|
+
of features will be inferred after applying the preprocessing function.
|
365
|
+
|
366
|
+
Returns
|
367
|
+
-------
|
368
|
+
int
|
369
|
+
Number of features in the reference data.
|
370
|
+
"""
|
371
|
+
# lazy process n_features as needed
|
372
|
+
if not isinstance(self._n_features, int):
|
373
|
+
# compute number of features for the univariate tests
|
374
|
+
if not isinstance(self.preprocess_fn, Callable) or self.x_ref_preprocessed:
|
375
|
+
# infer features from preprocessed reference data
|
376
|
+
self._n_features = self.x_ref.reshape(self.x_ref.shape[0], -1).shape[-1]
|
377
|
+
else:
|
378
|
+
# infer number of features after applying preprocessing step
|
379
|
+
x = to_numpy(self.preprocess_fn(self._x_ref[0:1])) # type: ignore
|
380
|
+
self._n_features = x.reshape(x.shape[0], -1).shape[-1]
|
381
|
+
|
382
|
+
return self._n_features
|
383
|
+
|
384
|
+
@preprocess_x
|
385
|
+
@abstractmethod
|
386
|
+
def score(self, x: ArrayLike) -> tuple[NDArray[np.float32], NDArray[np.float32]]:
|
387
|
+
"""
|
388
|
+
Abstract method to calculate feature scores after preprocessing.
|
389
|
+
|
390
|
+
Parameters
|
391
|
+
----------
|
392
|
+
x : ArrayLike
|
393
|
+
The batch of data to calculate univariate drift scores for each feature.
|
394
|
+
|
395
|
+
Returns
|
396
|
+
-------
|
397
|
+
tuple[NDArray, NDArray]
|
398
|
+
A tuple containing p-values and distance statistics for each feature.
|
399
|
+
"""
|
400
|
+
|
401
|
+
def _apply_correction(self, p_vals: NDArray) -> tuple[bool, float]:
|
402
|
+
"""
|
403
|
+
Apply the specified correction method (Bonferroni or FDR) to the p-values.
|
404
|
+
|
405
|
+
If the correction method is Bonferroni, the threshold for detecting drift
|
406
|
+
is divided by the number of features. For FDR, the correction is applied
|
407
|
+
using the Benjamini-Hochberg procedure.
|
408
|
+
|
409
|
+
Parameters
|
410
|
+
----------
|
411
|
+
p_vals : NDArray
|
412
|
+
Array of p-values from the univariate tests for each feature.
|
413
|
+
|
414
|
+
Returns
|
415
|
+
-------
|
416
|
+
tuple[bool, float]
|
417
|
+
A tuple containing a boolean indicating if drift was detected and the
|
418
|
+
threshold after correction.
|
419
|
+
"""
|
420
|
+
if self.correction == "bonferroni":
|
421
|
+
threshold = self.p_val / self.n_features
|
422
|
+
drift_pred = bool((p_vals < threshold).any())
|
423
|
+
return drift_pred, threshold
|
424
|
+
elif self.correction == "fdr":
|
425
|
+
n = p_vals.shape[0]
|
426
|
+
i = np.arange(n) + 1
|
427
|
+
p_sorted = np.sort(p_vals)
|
428
|
+
q_threshold = self.p_val * i / n
|
429
|
+
below_threshold = p_sorted < q_threshold
|
430
|
+
try:
|
431
|
+
idx_threshold = int(np.where(below_threshold)[0].max())
|
432
|
+
except ValueError: # sorted p-values not below thresholds
|
433
|
+
return bool(below_threshold.any()), q_threshold.min()
|
434
|
+
return bool(below_threshold.any()), q_threshold[idx_threshold]
|
435
|
+
else:
|
436
|
+
raise ValueError("`correction` needs to be either `bonferroni` or `fdr`.")
|
437
|
+
|
438
|
+
@set_metadata("dataeval.detectors")
|
439
|
+
@preprocess_x
|
440
|
+
@update_x_ref
|
441
|
+
def predict(
|
442
|
+
self,
|
443
|
+
x: ArrayLike,
|
444
|
+
) -> DriftOutput:
|
445
|
+
"""
|
446
|
+
Predict whether a batch of data has drifted from the reference data and update
|
447
|
+
reference data using specified update strategy.
|
448
|
+
|
449
|
+
Parameters
|
450
|
+
----------
|
451
|
+
x : ArrayLike
|
452
|
+
Batch of instances.
|
453
|
+
|
454
|
+
Returns
|
455
|
+
-------
|
456
|
+
DriftOutput
|
457
|
+
Dictionary containing the drift prediction and optionally the feature level
|
458
|
+
p-values, threshold after multivariate correction if needed and test statistics.
|
459
|
+
"""
|
460
|
+
# compute drift scores
|
461
|
+
p_vals, dist = self.score(x)
|
462
|
+
|
463
|
+
feature_drift = (p_vals < self.p_val).astype(np.bool_)
|
464
|
+
drift_pred, threshold = self._apply_correction(p_vals)
|
465
|
+
return DriftOutput(drift_pred, threshold, feature_drift, self.p_val, p_vals, dist)
|