dataeval 0.82.0__py3-none-any.whl → 0.83.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dataeval/__init__.py +7 -2
- dataeval/config.py +78 -11
- dataeval/detectors/drift/_mmd.py +9 -9
- dataeval/detectors/drift/_torch.py +7 -7
- dataeval/detectors/drift/_uncertainty.py +4 -4
- dataeval/detectors/linters/duplicates.py +3 -3
- dataeval/detectors/linters/outliers.py +3 -3
- dataeval/detectors/ood/ae.py +5 -4
- dataeval/detectors/ood/base.py +2 -2
- dataeval/detectors/ood/mixin.py +1 -1
- dataeval/detectors/ood/vae.py +2 -1
- dataeval/metadata/__init__.py +2 -2
- dataeval/metadata/_distance.py +11 -44
- dataeval/metadata/_ood.py +152 -33
- dataeval/metrics/bias/_balance.py +9 -5
- dataeval/metrics/bias/_diversity.py +3 -0
- dataeval/metrics/bias/_parity.py +2 -0
- dataeval/metrics/estimators/_ber.py +2 -1
- dataeval/metrics/stats/_base.py +20 -21
- dataeval/metrics/stats/_boxratiostats.py +1 -1
- dataeval/metrics/stats/_dimensionstats.py +2 -2
- dataeval/metrics/stats/_hashstats.py +2 -2
- dataeval/metrics/stats/_imagestats.py +8 -8
- dataeval/metrics/stats/_pixelstats.py +2 -2
- dataeval/metrics/stats/_visualstats.py +2 -2
- dataeval/outputs/__init__.py +5 -0
- dataeval/outputs/_base.py +50 -21
- dataeval/outputs/_bias.py +1 -1
- dataeval/outputs/_linters.py +4 -2
- dataeval/outputs/_metadata.py +61 -0
- dataeval/outputs/_stats.py +12 -6
- dataeval/typing.py +40 -9
- dataeval/utils/_mst.py +1 -2
- dataeval/utils/data/_embeddings.py +23 -19
- dataeval/utils/data/_metadata.py +16 -7
- dataeval/utils/data/_selection.py +22 -15
- dataeval/utils/data/_split.py +3 -2
- dataeval/utils/data/datasets/_base.py +4 -2
- dataeval/utils/data/datasets/_cifar10.py +17 -9
- dataeval/utils/data/datasets/_milco.py +18 -12
- dataeval/utils/data/datasets/_mnist.py +24 -8
- dataeval/utils/data/datasets/_ships.py +18 -8
- dataeval/utils/data/datasets/_types.py +1 -5
- dataeval/utils/data/datasets/_voc.py +47 -24
- dataeval/utils/data/selections/__init__.py +2 -0
- dataeval/utils/data/selections/_classfilter.py +5 -3
- dataeval/utils/data/selections/_prioritize.py +296 -0
- dataeval/utils/data/selections/_shuffle.py +13 -4
- dataeval/utils/torch/_gmm.py +3 -2
- dataeval/utils/torch/_internal.py +5 -5
- dataeval/utils/torch/trainer.py +8 -8
- {dataeval-0.82.0.dist-info → dataeval-0.83.0.dist-info}/METADATA +4 -4
- dataeval-0.83.0.dist-info/RECORD +105 -0
- dataeval/detectors/ood/metadata_ood_mi.py +0 -93
- dataeval-0.82.0.dist-info/RECORD +0 -104
- {dataeval-0.82.0.dist-info → dataeval-0.83.0.dist-info}/LICENSE.txt +0 -0
- {dataeval-0.82.0.dist-info → dataeval-0.83.0.dist-info}/WHEEL +0 -0
dataeval/__init__.py
CHANGED
@@ -8,7 +8,7 @@ shifts that impact performance of deployed models.
|
|
8
8
|
from __future__ import annotations
|
9
9
|
|
10
10
|
__all__ = ["config", "detectors", "log", "metrics", "typing", "utils", "workflows"]
|
11
|
-
__version__ = "0.
|
11
|
+
__version__ = "0.83.0"
|
12
12
|
|
13
13
|
import logging
|
14
14
|
|
@@ -34,7 +34,12 @@ def log(level: int = logging.DEBUG, handler: logging.Handler | None = None) -> N
|
|
34
34
|
logger = logging.getLogger(__name__)
|
35
35
|
if handler is None:
|
36
36
|
handler = logging.StreamHandler() if handler is None else handler
|
37
|
-
handler.setFormatter(
|
37
|
+
handler.setFormatter(
|
38
|
+
logging.Formatter(
|
39
|
+
"%(asctime)s %(levelname)-8s %(name)s.%(filename)s:%(lineno)s - %(funcName)10s() | %(message)s"
|
40
|
+
)
|
41
|
+
)
|
38
42
|
logger.addHandler(handler)
|
39
43
|
logger.setLevel(level)
|
44
|
+
logging.DEBUG
|
40
45
|
logger.debug(f"Added logging handler {handler} to logger: {__name__}")
|
dataeval/config.py
CHANGED
@@ -4,36 +4,71 @@ Global configuration settings for DataEval.
|
|
4
4
|
|
5
5
|
from __future__ import annotations
|
6
6
|
|
7
|
-
__all__ = ["get_device", "set_device", "get_max_processes", "set_max_processes"]
|
7
|
+
__all__ = ["get_device", "set_device", "get_max_processes", "set_max_processes", "DeviceLike"]
|
8
8
|
|
9
|
+
import sys
|
10
|
+
from typing import Union
|
11
|
+
|
12
|
+
if sys.version_info >= (3, 10):
|
13
|
+
from typing import TypeAlias
|
14
|
+
else:
|
15
|
+
from typing_extensions import TypeAlias
|
16
|
+
|
17
|
+
import numpy as np
|
9
18
|
import torch
|
10
|
-
from torch import device
|
11
19
|
|
12
|
-
|
20
|
+
### GLOBALS ###
|
21
|
+
|
22
|
+
_device: torch.device | None = None
|
13
23
|
_processes: int | None = None
|
24
|
+
_seed: int | None = None
|
25
|
+
|
26
|
+
### CONSTS ###
|
14
27
|
|
28
|
+
EPSILON = 1e-10
|
15
29
|
|
16
|
-
|
30
|
+
### TYPES ###
|
31
|
+
|
32
|
+
DeviceLike: TypeAlias = Union[int, str, tuple[str, int], torch.device]
|
33
|
+
"""
|
34
|
+
Type alias for types that are acceptable for specifying a torch.device.
|
35
|
+
|
36
|
+
See Also
|
37
|
+
--------
|
38
|
+
`torch.device <https://pytorch.org/docs/stable/tensor_attributes.html#torch.device>`_
|
39
|
+
"""
|
40
|
+
|
41
|
+
### FUNCS ###
|
42
|
+
|
43
|
+
|
44
|
+
def _todevice(device: DeviceLike) -> torch.device:
|
45
|
+
return torch.device(*device) if isinstance(device, tuple) else torch.device(device)
|
46
|
+
|
47
|
+
|
48
|
+
def set_device(device: DeviceLike) -> None:
|
17
49
|
"""
|
18
50
|
Sets the default device to use when executing against a PyTorch backend.
|
19
51
|
|
20
52
|
Parameters
|
21
53
|
----------
|
22
|
-
device :
|
23
|
-
The default device to use. See
|
24
|
-
|
54
|
+
device : DeviceLike
|
55
|
+
The default device to use. See documentation for more information.
|
56
|
+
|
57
|
+
See Also
|
58
|
+
--------
|
59
|
+
`torch.device <https://pytorch.org/docs/stable/tensor_attributes.html#torch.device>`_
|
25
60
|
"""
|
26
61
|
global _device
|
27
|
-
_device =
|
62
|
+
_device = _todevice(device)
|
28
63
|
|
29
64
|
|
30
|
-
def get_device(override:
|
65
|
+
def get_device(override: DeviceLike | None = None) -> torch.device:
|
31
66
|
"""
|
32
67
|
Returns the PyTorch device to use.
|
33
68
|
|
34
69
|
Parameters
|
35
70
|
----------
|
36
|
-
override :
|
71
|
+
override : DeviceLike or None, default None
|
37
72
|
The user specified override if provided, otherwise returns the default device.
|
38
73
|
|
39
74
|
Returns
|
@@ -44,7 +79,7 @@ def get_device(override: str | device | int | None = None) -> torch.device:
|
|
44
79
|
global _device
|
45
80
|
return torch.get_default_device() if _device is None else _device
|
46
81
|
else:
|
47
|
-
return
|
82
|
+
return _todevice(override)
|
48
83
|
|
49
84
|
|
50
85
|
def set_max_processes(processes: int | None) -> None:
|
@@ -75,3 +110,35 @@ def get_max_processes() -> int | None:
|
|
75
110
|
"""
|
76
111
|
global _processes
|
77
112
|
return _processes
|
113
|
+
|
114
|
+
|
115
|
+
def set_seed(seed: int | None, all_generators: bool = False) -> None:
|
116
|
+
"""
|
117
|
+
Sets the seed for use by classes that allow for a random state or seed.
|
118
|
+
|
119
|
+
Parameters
|
120
|
+
----------
|
121
|
+
seed : int or None
|
122
|
+
The seed to use.
|
123
|
+
all_generators : bool, default False
|
124
|
+
Whether to set the seed for all generators, including NumPy and PyTorch.
|
125
|
+
"""
|
126
|
+
global _seed
|
127
|
+
_seed = seed
|
128
|
+
|
129
|
+
if all_generators:
|
130
|
+
np.random.seed(seed)
|
131
|
+
torch.manual_seed(seed)
|
132
|
+
|
133
|
+
|
134
|
+
def get_seed() -> int | None:
|
135
|
+
"""
|
136
|
+
Returns the seed for random state or seed.
|
137
|
+
|
138
|
+
Returns
|
139
|
+
-------
|
140
|
+
int or None
|
141
|
+
The seed to use.
|
142
|
+
"""
|
143
|
+
global _seed
|
144
|
+
return _seed
|
dataeval/detectors/drift/_mmd.py
CHANGED
@@ -14,7 +14,7 @@ from typing import Callable
|
|
14
14
|
|
15
15
|
import torch
|
16
16
|
|
17
|
-
from dataeval.config import get_device
|
17
|
+
from dataeval.config import DeviceLike, get_device
|
18
18
|
from dataeval.detectors.drift._base import BaseDrift, UpdateStrategy, preprocess_x, update_x_ref
|
19
19
|
from dataeval.detectors.drift._torch import GaussianRBF, mmd2_from_kernel_matrix
|
20
20
|
from dataeval.outputs import DriftMMDOutput
|
@@ -31,7 +31,7 @@ class DriftMMD(BaseDrift):
|
|
31
31
|
----------
|
32
32
|
x_ref : ArrayLike
|
33
33
|
Data used as reference distribution.
|
34
|
-
p_val : float
|
34
|
+
p_val : float or None, default 0.05
|
35
35
|
:term:`P-value` used for significance of the statistical test for each feature.
|
36
36
|
If the FDR correction method is used, this corresponds to the acceptable
|
37
37
|
q-value.
|
@@ -39,14 +39,14 @@ class DriftMMD(BaseDrift):
|
|
39
39
|
Whether the given reference data ``x_ref`` has been preprocessed yet.
|
40
40
|
If ``True``, only the test data ``x`` will be preprocessed at prediction time.
|
41
41
|
If ``False``, the reference data will also be preprocessed.
|
42
|
-
update_x_ref : UpdateStrategy
|
42
|
+
update_x_ref : UpdateStrategy or None, default None
|
43
43
|
Reference data can optionally be updated using an UpdateStrategy class. Update
|
44
44
|
using the last n instances seen by the detector with LastSeenUpdateStrategy
|
45
45
|
or via reservoir sampling with ReservoirSamplingUpdateStrategy.
|
46
|
-
preprocess_fn : Callable
|
46
|
+
preprocess_fn : Callable or None, default None
|
47
47
|
Function to preprocess the data before computing the data drift metrics.
|
48
48
|
Typically a :term:`dimensionality reduction<Dimensionality Reduction>` technique.
|
49
|
-
sigma : ArrayLike
|
49
|
+
sigma : ArrayLike or None, default None
|
50
50
|
Optionally set the internal GaussianRBF kernel bandwidth. Can also pass multiple
|
51
51
|
bandwidth values as an array. The kernel evaluation is then averaged over
|
52
52
|
those bandwidths.
|
@@ -54,9 +54,9 @@ class DriftMMD(BaseDrift):
|
|
54
54
|
Whether to already configure the kernel bandwidth from the reference data.
|
55
55
|
n_permutations : int, default 100
|
56
56
|
Number of permutations used in the permutation test.
|
57
|
-
device :
|
58
|
-
|
59
|
-
|
57
|
+
device : DeviceLike or None, default None
|
58
|
+
The hardware device to use if specified, otherwise uses the DataEval
|
59
|
+
default or torch default.
|
60
60
|
|
61
61
|
Example
|
62
62
|
-------
|
@@ -84,7 +84,7 @@ class DriftMMD(BaseDrift):
|
|
84
84
|
sigma: ArrayLike | None = None,
|
85
85
|
configure_kernel_from_x_ref: bool = True,
|
86
86
|
n_permutations: int = 100,
|
87
|
-
device:
|
87
|
+
device: DeviceLike | None = None,
|
88
88
|
) -> None:
|
89
89
|
super().__init__(x_ref, p_val, x_ref_preprocessed, update_x_ref, preprocess_fn)
|
90
90
|
|
@@ -17,7 +17,7 @@ import torch
|
|
17
17
|
import torch.nn as nn
|
18
18
|
from numpy.typing import NDArray
|
19
19
|
|
20
|
-
from dataeval.config import get_device
|
20
|
+
from dataeval.config import DeviceLike, get_device
|
21
21
|
from dataeval.utils.torch._internal import predict_batch
|
22
22
|
|
23
23
|
|
@@ -59,7 +59,7 @@ def mmd2_from_kernel_matrix(
|
|
59
59
|
def preprocess_drift(
|
60
60
|
x: NDArray[Any],
|
61
61
|
model: nn.Module,
|
62
|
-
device:
|
62
|
+
device: DeviceLike | None = None,
|
63
63
|
preprocess_batch_fn: Callable | None = None,
|
64
64
|
batch_size: int = int(1e10),
|
65
65
|
dtype: type[np.generic] | torch.dtype = np.float32,
|
@@ -73,15 +73,15 @@ def preprocess_drift(
|
|
73
73
|
Batch of instances.
|
74
74
|
model : nn.Module
|
75
75
|
Model used for preprocessing.
|
76
|
-
device :
|
77
|
-
|
78
|
-
|
79
|
-
preprocess_batch_fn : Callable
|
76
|
+
device : DeviceLike or None, default None
|
77
|
+
The hardware device to use if specified, otherwise uses the DataEval
|
78
|
+
default or torch default.
|
79
|
+
preprocess_batch_fn : Callable or None, default None
|
80
80
|
Optional batch preprocessing function. For example to convert a list of objects
|
81
81
|
to a batch which can be processed by the PyTorch model.
|
82
82
|
batch_size : int, default 1e10
|
83
83
|
Batch size used during prediction.
|
84
|
-
dtype : np.dtype
|
84
|
+
dtype : np.dtype or torch.dtype, default np.float32
|
85
85
|
Model output type, either a :term:`NumPy` or torch dtype, e.g. np.float32 or torch.float32.
|
86
86
|
|
87
87
|
Returns
|
@@ -85,20 +85,20 @@ class DriftUncertainty:
|
|
85
85
|
Whether the given reference data ``x_ref`` has been preprocessed yet.
|
86
86
|
If ``True``, only the test data ``x`` will be preprocessed at prediction time.
|
87
87
|
If ``False``, the reference data will also be preprocessed.
|
88
|
-
update_x_ref : UpdateStrategy
|
88
|
+
update_x_ref : UpdateStrategy or None, default None
|
89
89
|
Reference data can optionally be updated using an UpdateStrategy class. Update
|
90
90
|
using the last n instances seen by the detector with LastSeenUpdateStrategy
|
91
91
|
or via reservoir sampling with ReservoirSamplingUpdateStrategy.
|
92
|
-
preds_type : "probs"
|
92
|
+
preds_type : "probs" or "logits", default "probs"
|
93
93
|
Type of prediction output by the model. Options are 'probs' (in [0,1]) or
|
94
94
|
'logits' (in [-inf,inf]).
|
95
95
|
batch_size : int, default 32
|
96
96
|
Batch size used to evaluate model. Only relevant when backend has been
|
97
97
|
specified for batch prediction.
|
98
|
-
preprocess_batch_fn : Callable
|
98
|
+
preprocess_batch_fn : Callable or None, default None
|
99
99
|
Optional batch preprocessing function. For example to convert a list of
|
100
100
|
objects to a batch which can be processed by the model.
|
101
|
-
device :
|
101
|
+
device : DeviceLike or None, default None
|
102
102
|
Device type used. The default None tries to use the GPU and falls back on
|
103
103
|
CPU if needed. Can be specified by passing either 'cuda' or 'cpu'.
|
104
104
|
|
@@ -88,13 +88,13 @@ class Duplicates:
|
|
88
88
|
"""
|
89
89
|
|
90
90
|
if isinstance(hashes, HashStatsOutput):
|
91
|
-
return DuplicatesOutput(**self._get_duplicates(hashes.
|
91
|
+
return DuplicatesOutput(**self._get_duplicates(hashes.data()))
|
92
92
|
|
93
93
|
if not isinstance(hashes, Sequence):
|
94
94
|
raise TypeError("Invalid stats output type; only use output from hashstats.")
|
95
95
|
|
96
96
|
combined, dataset_steps = combine_stats(hashes)
|
97
|
-
duplicates = self._get_duplicates(combined.
|
97
|
+
duplicates = self._get_duplicates(combined.data())
|
98
98
|
|
99
99
|
# split up results from combined dataset into individual dataset buckets
|
100
100
|
for dup_type, dup_list in duplicates.items():
|
@@ -136,5 +136,5 @@ class Duplicates:
|
|
136
136
|
""" # noqa: E501
|
137
137
|
images = Images(data) if isinstance(data, Dataset) else data
|
138
138
|
self.stats = hashstats(images)
|
139
|
-
duplicates = self._get_duplicates(self.stats.
|
139
|
+
duplicates = self._get_duplicates(self.stats.data())
|
140
140
|
return DuplicatesOutput(**duplicates)
|
@@ -169,7 +169,7 @@ class Outliers:
|
|
169
169
|
{}
|
170
170
|
""" # noqa: E501
|
171
171
|
if isinstance(stats, (ImageStatsOutput, DimensionStatsOutput, PixelStatsOutput, VisualStatsOutput)):
|
172
|
-
return OutliersOutput(self._get_outliers(stats.
|
172
|
+
return OutliersOutput(self._get_outliers(stats.data()))
|
173
173
|
|
174
174
|
if not isinstance(stats, Sequence):
|
175
175
|
raise TypeError(
|
@@ -189,7 +189,7 @@ class Outliers:
|
|
189
189
|
output_list: list[dict[int, dict[str, float]]] = [{} for _ in stats]
|
190
190
|
for _, indices in stats_map.items():
|
191
191
|
substats, dataset_steps = combine_stats([stats[i] for i in indices])
|
192
|
-
outliers = self._get_outliers(substats.
|
192
|
+
outliers = self._get_outliers(substats.data())
|
193
193
|
for idx, issue in outliers.items():
|
194
194
|
k, v = get_dataset_step_from_idx(idx, dataset_steps)
|
195
195
|
output_list[indices[k]][v] = issue
|
@@ -225,5 +225,5 @@ class Outliers:
|
|
225
225
|
"""
|
226
226
|
images = Images(data) if isinstance(data, Dataset) else data
|
227
227
|
self.stats = imagestats(images)
|
228
|
-
outliers = self._get_outliers(self.stats.
|
228
|
+
outliers = self._get_outliers(self.stats.data())
|
229
229
|
return OutliersOutput(outliers)
|
dataeval/detectors/ood/ae.py
CHANGED
@@ -18,6 +18,7 @@ import numpy as np
|
|
18
18
|
import torch
|
19
19
|
from numpy.typing import NDArray
|
20
20
|
|
21
|
+
from dataeval.config import DeviceLike
|
21
22
|
from dataeval.detectors.ood.base import OODBase
|
22
23
|
from dataeval.outputs import OODScoreOutput
|
23
24
|
from dataeval.typing import ArrayLike
|
@@ -33,9 +34,9 @@ class OOD_AE(OODBase):
|
|
33
34
|
model : torch.nn.Module
|
34
35
|
An autoencoder model to use for encoding and reconstruction of images
|
35
36
|
for detection of out-of-distribution samples.
|
36
|
-
device :
|
37
|
-
The device to use
|
38
|
-
|
37
|
+
device : DeviceLike or None, default None
|
38
|
+
The hardware device to use if specified, otherwise uses the DataEval
|
39
|
+
default or torch default.
|
39
40
|
|
40
41
|
Example
|
41
42
|
-------
|
@@ -57,7 +58,7 @@ class OOD_AE(OODBase):
|
|
57
58
|
array([ True, True, False, True, True, True, True, True])
|
58
59
|
"""
|
59
60
|
|
60
|
-
def __init__(self, model: torch.nn.Module, device:
|
61
|
+
def __init__(self, model: torch.nn.Module, device: DeviceLike | None = None) -> None:
|
61
62
|
super().__init__(model, device)
|
62
63
|
|
63
64
|
def fit(
|
dataeval/detectors/ood/base.py
CHANGED
@@ -14,7 +14,7 @@ from typing import Callable, cast
|
|
14
14
|
|
15
15
|
import torch
|
16
16
|
|
17
|
-
from dataeval.config import get_device
|
17
|
+
from dataeval.config import DeviceLike, get_device
|
18
18
|
from dataeval.detectors.ood.mixin import OODBaseMixin, OODFitMixin, OODGMMMixin
|
19
19
|
from dataeval.typing import ArrayLike
|
20
20
|
from dataeval.utils._array import to_numpy
|
@@ -23,7 +23,7 @@ from dataeval.utils.torch._internal import trainer
|
|
23
23
|
|
24
24
|
|
25
25
|
class OODBase(OODBaseMixin[torch.nn.Module], OODFitMixin[Callable[..., torch.nn.Module], torch.optim.Optimizer]):
|
26
|
-
def __init__(self, model: torch.nn.Module, device:
|
26
|
+
def __init__(self, model: torch.nn.Module, device: DeviceLike | None = None) -> None:
|
27
27
|
self.device: torch.device = get_device(device)
|
28
28
|
super().__init__(model)
|
29
29
|
|
dataeval/detectors/ood/mixin.py
CHANGED
@@ -157,4 +157,4 @@ class OODBaseMixin(Generic[TModel], ABC):
|
|
157
157
|
# compute outlier scores
|
158
158
|
score = self.score(X, batch_size=batch_size)
|
159
159
|
ood_pred = score.get(ood_type) > self._threshold_score(ood_type)
|
160
|
-
return OODOutput(is_ood=ood_pred, **score.
|
160
|
+
return OODOutput(is_ood=ood_pred, **score.data())
|
dataeval/detectors/ood/vae.py
CHANGED
@@ -17,6 +17,7 @@ from typing import Callable
|
|
17
17
|
import numpy as np
|
18
18
|
import torch
|
19
19
|
|
20
|
+
from dataeval.config import DeviceLike
|
20
21
|
from dataeval.detectors.ood.base import OODBase
|
21
22
|
from dataeval.outputs import OODScoreOutput
|
22
23
|
from dataeval.typing import ArrayLike
|
@@ -34,7 +35,7 @@ class OOD_VAE(OODBase):
|
|
34
35
|
An Autoencoder model.
|
35
36
|
"""
|
36
37
|
|
37
|
-
def __init__(self, model: torch.nn.Module, device:
|
38
|
+
def __init__(self, model: torch.nn.Module, device: DeviceLike | None = None) -> None:
|
38
39
|
super().__init__(model, device)
|
39
40
|
|
40
41
|
def fit(
|
dataeval/metadata/__init__.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1
1
|
"""Explanatory functions using metadata and additional features such as ood or drift"""
|
2
2
|
|
3
|
-
__all__ = ["
|
3
|
+
__all__ = ["find_ood_predictors", "metadata_distance", "find_most_deviated_factors"]
|
4
4
|
|
5
5
|
from dataeval.metadata._distance import metadata_distance
|
6
|
-
from dataeval.metadata._ood import
|
6
|
+
from dataeval.metadata._ood import find_most_deviated_factors, find_ood_predictors
|
dataeval/metadata/_distance.py
CHANGED
@@ -10,7 +10,8 @@ from scipy.stats import iqr, ks_2samp
|
|
10
10
|
from scipy.stats import wasserstein_distance as emd
|
11
11
|
|
12
12
|
from dataeval.metadata._utils import _compare_keys, _validate_factors_and_data
|
13
|
-
from dataeval.outputs
|
13
|
+
from dataeval.outputs import MetadataDistanceOutput, MetadataDistanceValues
|
14
|
+
from dataeval.outputs._base import set_metadata
|
14
15
|
from dataeval.typing import ArrayLike
|
15
16
|
from dataeval.utils.data import Metadata
|
16
17
|
|
@@ -23,41 +24,6 @@ class KSType(NamedTuple):
|
|
23
24
|
pvalue: float
|
24
25
|
|
25
26
|
|
26
|
-
class MetadataKSResult(NamedTuple):
|
27
|
-
"""
|
28
|
-
Attributes
|
29
|
-
----------
|
30
|
-
statistic : float
|
31
|
-
the KS statistic
|
32
|
-
location : float
|
33
|
-
The value at which the KS statistic has its maximum, measured in IQR-normalized units relative
|
34
|
-
to the median of the reference distribution.
|
35
|
-
dist : float
|
36
|
-
The Earth Mover's Distance normalized by the interquartile range (IQR) of the reference
|
37
|
-
pvalue : float
|
38
|
-
The p-value from the KS two-sample test
|
39
|
-
"""
|
40
|
-
|
41
|
-
statistic: float
|
42
|
-
location: float
|
43
|
-
dist: float
|
44
|
-
pvalue: float
|
45
|
-
|
46
|
-
|
47
|
-
class KSOutput(MappingOutput[str, MetadataKSResult]):
|
48
|
-
"""
|
49
|
-
Output class for results of ks_2samp featurewise comparisons of new metadata to reference metadata.
|
50
|
-
|
51
|
-
|
52
|
-
Attributes
|
53
|
-
----------
|
54
|
-
key: str
|
55
|
-
Metadata feature names
|
56
|
-
value: :class:`MetadataKSResult`
|
57
|
-
Output per feature name containing the statistic, statistic location, distance, and pvalue.
|
58
|
-
"""
|
59
|
-
|
60
|
-
|
61
27
|
def _calculate_drift(x1: ArrayLike, x2: ArrayLike) -> float:
|
62
28
|
"""Calculates the shift magnitude between x1 and x2 scaled by x1"""
|
63
29
|
|
@@ -74,7 +40,8 @@ def _calculate_drift(x1: ArrayLike, x2: ArrayLike) -> float:
|
|
74
40
|
return distance if xmin == xmax else distance / (xmax - xmin)
|
75
41
|
|
76
42
|
|
77
|
-
|
43
|
+
@set_metadata
|
44
|
+
def metadata_distance(metadata1: Metadata, metadata2: Metadata) -> MetadataDistanceOutput:
|
78
45
|
"""
|
79
46
|
Measures the feature-wise distance between two continuous metadata distributions and
|
80
47
|
computes a p-value to evaluate its significance.
|
@@ -90,8 +57,8 @@ def metadata_distance(metadata1: Metadata, metadata2: Metadata) -> KSOutput:
|
|
90
57
|
|
91
58
|
Returns
|
92
59
|
-------
|
93
|
-
|
94
|
-
A
|
60
|
+
MetadataDistanceOutput
|
61
|
+
A mapping with keys corresponding to metadata feature names, and values that are KstestResult objects, as
|
95
62
|
defined by scipy.stats.ks_2samp.
|
96
63
|
|
97
64
|
See Also
|
@@ -110,7 +77,7 @@ def metadata_distance(metadata1: Metadata, metadata2: Metadata) -> KSOutput:
|
|
110
77
|
>>> list(output)
|
111
78
|
['time', 'altitude']
|
112
79
|
>>> output["time"]
|
113
|
-
|
80
|
+
MetadataDistanceValues(statistic=1.0, location=0.44354838709677413, dist=2.7, pvalue=0.0)
|
114
81
|
"""
|
115
82
|
|
116
83
|
_compare_keys(metadata1.continuous_factor_names, metadata2.continuous_factor_names)
|
@@ -134,7 +101,7 @@ def metadata_distance(metadata1: Metadata, metadata2: Metadata) -> KSOutput:
|
|
134
101
|
)
|
135
102
|
|
136
103
|
# Set default for statistic, location, and magnitude to zero and pvalue to one
|
137
|
-
results: dict[str,
|
104
|
+
results: dict[str, MetadataDistanceValues] = {}
|
138
105
|
|
139
106
|
# Per factor
|
140
107
|
for i, fname in enumerate(fnames):
|
@@ -147,7 +114,7 @@ def metadata_distance(metadata1: Metadata, metadata2: Metadata) -> KSOutput:
|
|
147
114
|
|
148
115
|
# Default case
|
149
116
|
if xmin == xmax:
|
150
|
-
results[fname] =
|
117
|
+
results[fname] = MetadataDistanceValues(statistic=0.0, location=0.0, dist=0.0, pvalue=1.0)
|
151
118
|
continue
|
152
119
|
|
153
120
|
ks_result = cast(KSType, ks_2samp(fdata1, fdata2, method="asymp"))
|
@@ -157,11 +124,11 @@ def metadata_distance(metadata1: Metadata, metadata2: Metadata) -> KSOutput:
|
|
157
124
|
|
158
125
|
drift = _calculate_drift(fdata1, fdata2)
|
159
126
|
|
160
|
-
results[fname] =
|
127
|
+
results[fname] = MetadataDistanceValues(
|
161
128
|
statistic=ks_result.statistic,
|
162
129
|
location=loc,
|
163
130
|
dist=drift,
|
164
131
|
pvalue=ks_result.pvalue,
|
165
132
|
)
|
166
133
|
|
167
|
-
return
|
134
|
+
return MetadataDistanceOutput(results)
|