dataeval 0.82.0__py3-none-any.whl → 0.82.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dataeval/__init__.py +1 -1
- dataeval/config.py +68 -11
- dataeval/detectors/drift/_mmd.py +9 -9
- dataeval/detectors/drift/_torch.py +7 -7
- dataeval/detectors/drift/_uncertainty.py +4 -4
- dataeval/detectors/linters/duplicates.py +3 -3
- dataeval/detectors/linters/outliers.py +3 -3
- dataeval/detectors/ood/ae.py +5 -4
- dataeval/detectors/ood/base.py +2 -2
- dataeval/detectors/ood/metadata_ood_mi.py +4 -6
- dataeval/detectors/ood/mixin.py +1 -1
- dataeval/detectors/ood/vae.py +2 -1
- dataeval/metadata/_distance.py +11 -44
- dataeval/metadata/_ood.py +9 -7
- dataeval/metrics/bias/_balance.py +7 -3
- dataeval/metrics/bias/_diversity.py +3 -0
- dataeval/metrics/bias/_parity.py +2 -0
- dataeval/metrics/stats/_base.py +3 -3
- dataeval/metrics/stats/_boxratiostats.py +1 -1
- dataeval/metrics/stats/_imagestats.py +4 -4
- dataeval/outputs/__init__.py +4 -0
- dataeval/outputs/_base.py +50 -21
- dataeval/outputs/_bias.py +1 -1
- dataeval/outputs/_linters.py +4 -2
- dataeval/outputs/_metadata.py +54 -0
- dataeval/outputs/_stats.py +12 -6
- dataeval/utils/data/_embeddings.py +8 -9
- dataeval/utils/data/_metadata.py +16 -7
- dataeval/utils/data/_selection.py +4 -8
- dataeval/utils/data/_split.py +3 -2
- dataeval/utils/data/selections/_classfilter.py +5 -3
- dataeval/utils/torch/_internal.py +5 -5
- dataeval/utils/torch/trainer.py +8 -8
- {dataeval-0.82.0.dist-info → dataeval-0.82.1.dist-info}/METADATA +1 -1
- {dataeval-0.82.0.dist-info → dataeval-0.82.1.dist-info}/RECORD +37 -36
- {dataeval-0.82.0.dist-info → dataeval-0.82.1.dist-info}/LICENSE.txt +0 -0
- {dataeval-0.82.0.dist-info → dataeval-0.82.1.dist-info}/WHEEL +0 -0
dataeval/__init__.py
CHANGED
dataeval/config.py
CHANGED
@@ -4,36 +4,61 @@ Global configuration settings for DataEval.
|
|
4
4
|
|
5
5
|
from __future__ import annotations
|
6
6
|
|
7
|
-
__all__ = ["get_device", "set_device", "get_max_processes", "set_max_processes"]
|
7
|
+
__all__ = ["get_device", "set_device", "get_max_processes", "set_max_processes", "DeviceLike"]
|
8
8
|
|
9
|
+
import sys
|
10
|
+
from typing import Union
|
11
|
+
|
12
|
+
if sys.version_info >= (3, 10):
|
13
|
+
from typing import TypeAlias
|
14
|
+
else:
|
15
|
+
from typing_extensions import TypeAlias
|
16
|
+
|
17
|
+
import numpy as np
|
9
18
|
import torch
|
10
|
-
from torch import device
|
11
19
|
|
12
|
-
_device: device | None = None
|
20
|
+
_device: torch.device | None = None
|
13
21
|
_processes: int | None = None
|
22
|
+
_seed: int | None = None
|
23
|
+
|
24
|
+
DeviceLike: TypeAlias = Union[int, str, tuple[str, int], torch.device]
|
25
|
+
"""
|
26
|
+
Type alias for types that are acceptable for specifying a torch.device.
|
27
|
+
|
28
|
+
See Also
|
29
|
+
--------
|
30
|
+
`torch.device <https://pytorch.org/docs/stable/tensor_attributes.html#torch.device>`_
|
31
|
+
"""
|
32
|
+
|
14
33
|
|
34
|
+
def _todevice(device: DeviceLike) -> torch.device:
|
35
|
+
return torch.device(*device) if isinstance(device, tuple) else torch.device(device)
|
15
36
|
|
16
|
-
|
37
|
+
|
38
|
+
def set_device(device: DeviceLike) -> None:
|
17
39
|
"""
|
18
40
|
Sets the default device to use when executing against a PyTorch backend.
|
19
41
|
|
20
42
|
Parameters
|
21
43
|
----------
|
22
|
-
device :
|
23
|
-
The default device to use. See
|
24
|
-
|
44
|
+
device : DeviceLike
|
45
|
+
The default device to use. See documentation for more information.
|
46
|
+
|
47
|
+
See Also
|
48
|
+
--------
|
49
|
+
`torch.device <https://pytorch.org/docs/stable/tensor_attributes.html#torch.device>`_
|
25
50
|
"""
|
26
51
|
global _device
|
27
|
-
_device =
|
52
|
+
_device = _todevice(device)
|
28
53
|
|
29
54
|
|
30
|
-
def get_device(override:
|
55
|
+
def get_device(override: DeviceLike | None = None) -> torch.device:
|
31
56
|
"""
|
32
57
|
Returns the PyTorch device to use.
|
33
58
|
|
34
59
|
Parameters
|
35
60
|
----------
|
36
|
-
override :
|
61
|
+
override : DeviceLike or None, default None
|
37
62
|
The user specified override if provided, otherwise returns the default device.
|
38
63
|
|
39
64
|
Returns
|
@@ -44,7 +69,7 @@ def get_device(override: str | device | int | None = None) -> torch.device:
|
|
44
69
|
global _device
|
45
70
|
return torch.get_default_device() if _device is None else _device
|
46
71
|
else:
|
47
|
-
return
|
72
|
+
return _todevice(override)
|
48
73
|
|
49
74
|
|
50
75
|
def set_max_processes(processes: int | None) -> None:
|
@@ -75,3 +100,35 @@ def get_max_processes() -> int | None:
|
|
75
100
|
"""
|
76
101
|
global _processes
|
77
102
|
return _processes
|
103
|
+
|
104
|
+
|
105
|
+
def set_seed(seed: int | None, all_generators: bool = False) -> None:
|
106
|
+
"""
|
107
|
+
Sets the seed for use by classes that allow for a random state or seed.
|
108
|
+
|
109
|
+
Parameters
|
110
|
+
----------
|
111
|
+
seed : int or None
|
112
|
+
The seed to use.
|
113
|
+
all_generators : bool, default False
|
114
|
+
Whether to set the seed for all generators, including NumPy and PyTorch.
|
115
|
+
"""
|
116
|
+
global _seed
|
117
|
+
_seed = seed
|
118
|
+
|
119
|
+
if all_generators:
|
120
|
+
np.random.seed(seed)
|
121
|
+
torch.manual_seed(seed)
|
122
|
+
|
123
|
+
|
124
|
+
def get_seed() -> int | None:
|
125
|
+
"""
|
126
|
+
Returns the seed for random state or seed.
|
127
|
+
|
128
|
+
Returns
|
129
|
+
-------
|
130
|
+
int or None
|
131
|
+
The seed to use.
|
132
|
+
"""
|
133
|
+
global _seed
|
134
|
+
return _seed
|
dataeval/detectors/drift/_mmd.py
CHANGED
@@ -14,7 +14,7 @@ from typing import Callable
|
|
14
14
|
|
15
15
|
import torch
|
16
16
|
|
17
|
-
from dataeval.config import get_device
|
17
|
+
from dataeval.config import DeviceLike, get_device
|
18
18
|
from dataeval.detectors.drift._base import BaseDrift, UpdateStrategy, preprocess_x, update_x_ref
|
19
19
|
from dataeval.detectors.drift._torch import GaussianRBF, mmd2_from_kernel_matrix
|
20
20
|
from dataeval.outputs import DriftMMDOutput
|
@@ -31,7 +31,7 @@ class DriftMMD(BaseDrift):
|
|
31
31
|
----------
|
32
32
|
x_ref : ArrayLike
|
33
33
|
Data used as reference distribution.
|
34
|
-
p_val : float
|
34
|
+
p_val : float or None, default 0.05
|
35
35
|
:term:`P-value` used for significance of the statistical test for each feature.
|
36
36
|
If the FDR correction method is used, this corresponds to the acceptable
|
37
37
|
q-value.
|
@@ -39,14 +39,14 @@ class DriftMMD(BaseDrift):
|
|
39
39
|
Whether the given reference data ``x_ref`` has been preprocessed yet.
|
40
40
|
If ``True``, only the test data ``x`` will be preprocessed at prediction time.
|
41
41
|
If ``False``, the reference data will also be preprocessed.
|
42
|
-
update_x_ref : UpdateStrategy
|
42
|
+
update_x_ref : UpdateStrategy or None, default None
|
43
43
|
Reference data can optionally be updated using an UpdateStrategy class. Update
|
44
44
|
using the last n instances seen by the detector with LastSeenUpdateStrategy
|
45
45
|
or via reservoir sampling with ReservoirSamplingUpdateStrategy.
|
46
|
-
preprocess_fn : Callable
|
46
|
+
preprocess_fn : Callable or None, default None
|
47
47
|
Function to preprocess the data before computing the data drift metrics.
|
48
48
|
Typically a :term:`dimensionality reduction<Dimensionality Reduction>` technique.
|
49
|
-
sigma : ArrayLike
|
49
|
+
sigma : ArrayLike or None, default None
|
50
50
|
Optionally set the internal GaussianRBF kernel bandwidth. Can also pass multiple
|
51
51
|
bandwidth values as an array. The kernel evaluation is then averaged over
|
52
52
|
those bandwidths.
|
@@ -54,9 +54,9 @@ class DriftMMD(BaseDrift):
|
|
54
54
|
Whether to already configure the kernel bandwidth from the reference data.
|
55
55
|
n_permutations : int, default 100
|
56
56
|
Number of permutations used in the permutation test.
|
57
|
-
device :
|
58
|
-
|
59
|
-
|
57
|
+
device : DeviceLike or None, default None
|
58
|
+
The hardware device to use if specified, otherwise uses the DataEval
|
59
|
+
default or torch default.
|
60
60
|
|
61
61
|
Example
|
62
62
|
-------
|
@@ -84,7 +84,7 @@ class DriftMMD(BaseDrift):
|
|
84
84
|
sigma: ArrayLike | None = None,
|
85
85
|
configure_kernel_from_x_ref: bool = True,
|
86
86
|
n_permutations: int = 100,
|
87
|
-
device:
|
87
|
+
device: DeviceLike | None = None,
|
88
88
|
) -> None:
|
89
89
|
super().__init__(x_ref, p_val, x_ref_preprocessed, update_x_ref, preprocess_fn)
|
90
90
|
|
@@ -17,7 +17,7 @@ import torch
|
|
17
17
|
import torch.nn as nn
|
18
18
|
from numpy.typing import NDArray
|
19
19
|
|
20
|
-
from dataeval.config import get_device
|
20
|
+
from dataeval.config import DeviceLike, get_device
|
21
21
|
from dataeval.utils.torch._internal import predict_batch
|
22
22
|
|
23
23
|
|
@@ -59,7 +59,7 @@ def mmd2_from_kernel_matrix(
|
|
59
59
|
def preprocess_drift(
|
60
60
|
x: NDArray[Any],
|
61
61
|
model: nn.Module,
|
62
|
-
device:
|
62
|
+
device: DeviceLike | None = None,
|
63
63
|
preprocess_batch_fn: Callable | None = None,
|
64
64
|
batch_size: int = int(1e10),
|
65
65
|
dtype: type[np.generic] | torch.dtype = np.float32,
|
@@ -73,15 +73,15 @@ def preprocess_drift(
|
|
73
73
|
Batch of instances.
|
74
74
|
model : nn.Module
|
75
75
|
Model used for preprocessing.
|
76
|
-
device :
|
77
|
-
|
78
|
-
|
79
|
-
preprocess_batch_fn : Callable
|
76
|
+
device : DeviceLike or None, default None
|
77
|
+
The hardware device to use if specified, otherwise uses the DataEval
|
78
|
+
default or torch default.
|
79
|
+
preprocess_batch_fn : Callable or None, default None
|
80
80
|
Optional batch preprocessing function. For example to convert a list of objects
|
81
81
|
to a batch which can be processed by the PyTorch model.
|
82
82
|
batch_size : int, default 1e10
|
83
83
|
Batch size used during prediction.
|
84
|
-
dtype : np.dtype
|
84
|
+
dtype : np.dtype or torch.dtype, default np.float32
|
85
85
|
Model output type, either a :term:`NumPy` or torch dtype, e.g. np.float32 or torch.float32.
|
86
86
|
|
87
87
|
Returns
|
@@ -85,20 +85,20 @@ class DriftUncertainty:
|
|
85
85
|
Whether the given reference data ``x_ref`` has been preprocessed yet.
|
86
86
|
If ``True``, only the test data ``x`` will be preprocessed at prediction time.
|
87
87
|
If ``False``, the reference data will also be preprocessed.
|
88
|
-
update_x_ref : UpdateStrategy
|
88
|
+
update_x_ref : UpdateStrategy or None, default None
|
89
89
|
Reference data can optionally be updated using an UpdateStrategy class. Update
|
90
90
|
using the last n instances seen by the detector with LastSeenUpdateStrategy
|
91
91
|
or via reservoir sampling with ReservoirSamplingUpdateStrategy.
|
92
|
-
preds_type : "probs"
|
92
|
+
preds_type : "probs" or "logits", default "probs"
|
93
93
|
Type of prediction output by the model. Options are 'probs' (in [0,1]) or
|
94
94
|
'logits' (in [-inf,inf]).
|
95
95
|
batch_size : int, default 32
|
96
96
|
Batch size used to evaluate model. Only relevant when backend has been
|
97
97
|
specified for batch prediction.
|
98
|
-
preprocess_batch_fn : Callable
|
98
|
+
preprocess_batch_fn : Callable or None, default None
|
99
99
|
Optional batch preprocessing function. For example to convert a list of
|
100
100
|
objects to a batch which can be processed by the model.
|
101
|
-
device :
|
101
|
+
device : DeviceLike or None, default None
|
102
102
|
Device type used. The default None tries to use the GPU and falls back on
|
103
103
|
CPU if needed. Can be specified by passing either 'cuda' or 'cpu'.
|
104
104
|
|
@@ -88,13 +88,13 @@ class Duplicates:
|
|
88
88
|
"""
|
89
89
|
|
90
90
|
if isinstance(hashes, HashStatsOutput):
|
91
|
-
return DuplicatesOutput(**self._get_duplicates(hashes.
|
91
|
+
return DuplicatesOutput(**self._get_duplicates(hashes.data()))
|
92
92
|
|
93
93
|
if not isinstance(hashes, Sequence):
|
94
94
|
raise TypeError("Invalid stats output type; only use output from hashstats.")
|
95
95
|
|
96
96
|
combined, dataset_steps = combine_stats(hashes)
|
97
|
-
duplicates = self._get_duplicates(combined.
|
97
|
+
duplicates = self._get_duplicates(combined.data())
|
98
98
|
|
99
99
|
# split up results from combined dataset into individual dataset buckets
|
100
100
|
for dup_type, dup_list in duplicates.items():
|
@@ -136,5 +136,5 @@ class Duplicates:
|
|
136
136
|
""" # noqa: E501
|
137
137
|
images = Images(data) if isinstance(data, Dataset) else data
|
138
138
|
self.stats = hashstats(images)
|
139
|
-
duplicates = self._get_duplicates(self.stats.
|
139
|
+
duplicates = self._get_duplicates(self.stats.data())
|
140
140
|
return DuplicatesOutput(**duplicates)
|
@@ -169,7 +169,7 @@ class Outliers:
|
|
169
169
|
{}
|
170
170
|
""" # noqa: E501
|
171
171
|
if isinstance(stats, (ImageStatsOutput, DimensionStatsOutput, PixelStatsOutput, VisualStatsOutput)):
|
172
|
-
return OutliersOutput(self._get_outliers(stats.
|
172
|
+
return OutliersOutput(self._get_outliers(stats.data()))
|
173
173
|
|
174
174
|
if not isinstance(stats, Sequence):
|
175
175
|
raise TypeError(
|
@@ -189,7 +189,7 @@ class Outliers:
|
|
189
189
|
output_list: list[dict[int, dict[str, float]]] = [{} for _ in stats]
|
190
190
|
for _, indices in stats_map.items():
|
191
191
|
substats, dataset_steps = combine_stats([stats[i] for i in indices])
|
192
|
-
outliers = self._get_outliers(substats.
|
192
|
+
outliers = self._get_outliers(substats.data())
|
193
193
|
for idx, issue in outliers.items():
|
194
194
|
k, v = get_dataset_step_from_idx(idx, dataset_steps)
|
195
195
|
output_list[indices[k]][v] = issue
|
@@ -225,5 +225,5 @@ class Outliers:
|
|
225
225
|
"""
|
226
226
|
images = Images(data) if isinstance(data, Dataset) else data
|
227
227
|
self.stats = imagestats(images)
|
228
|
-
outliers = self._get_outliers(self.stats.
|
228
|
+
outliers = self._get_outliers(self.stats.data())
|
229
229
|
return OutliersOutput(outliers)
|
dataeval/detectors/ood/ae.py
CHANGED
@@ -18,6 +18,7 @@ import numpy as np
|
|
18
18
|
import torch
|
19
19
|
from numpy.typing import NDArray
|
20
20
|
|
21
|
+
from dataeval.config import DeviceLike
|
21
22
|
from dataeval.detectors.ood.base import OODBase
|
22
23
|
from dataeval.outputs import OODScoreOutput
|
23
24
|
from dataeval.typing import ArrayLike
|
@@ -33,9 +34,9 @@ class OOD_AE(OODBase):
|
|
33
34
|
model : torch.nn.Module
|
34
35
|
An autoencoder model to use for encoding and reconstruction of images
|
35
36
|
for detection of out-of-distribution samples.
|
36
|
-
device :
|
37
|
-
The device to use
|
38
|
-
|
37
|
+
device : DeviceLike or None, default None
|
38
|
+
The hardware device to use if specified, otherwise uses the DataEval
|
39
|
+
default or torch default.
|
39
40
|
|
40
41
|
Example
|
41
42
|
-------
|
@@ -57,7 +58,7 @@ class OOD_AE(OODBase):
|
|
57
58
|
array([ True, True, False, True, True, True, True, True])
|
58
59
|
"""
|
59
60
|
|
60
|
-
def __init__(self, model: torch.nn.Module, device:
|
61
|
+
def __init__(self, model: torch.nn.Module, device: DeviceLike | None = None) -> None:
|
61
62
|
super().__init__(model, device)
|
62
63
|
|
63
64
|
def fit(
|
dataeval/detectors/ood/base.py
CHANGED
@@ -14,7 +14,7 @@ from typing import Callable, cast
|
|
14
14
|
|
15
15
|
import torch
|
16
16
|
|
17
|
-
from dataeval.config import get_device
|
17
|
+
from dataeval.config import DeviceLike, get_device
|
18
18
|
from dataeval.detectors.ood.mixin import OODBaseMixin, OODFitMixin, OODGMMMixin
|
19
19
|
from dataeval.typing import ArrayLike
|
20
20
|
from dataeval.utils._array import to_numpy
|
@@ -23,7 +23,7 @@ from dataeval.utils.torch._internal import trainer
|
|
23
23
|
|
24
24
|
|
25
25
|
class OODBase(OODBaseMixin[torch.nn.Module], OODFitMixin[Callable[..., torch.nn.Module], torch.optim.Optimizer]):
|
26
|
-
def __init__(self, model: torch.nn.Module, device:
|
26
|
+
def __init__(self, model: torch.nn.Module, device: DeviceLike | None = None) -> None:
|
27
27
|
self.device: torch.device = get_device(device)
|
28
28
|
super().__init__(model)
|
29
29
|
|
@@ -10,6 +10,8 @@ import numpy as np
|
|
10
10
|
from numpy.typing import NDArray
|
11
11
|
from sklearn.feature_selection import mutual_info_classif
|
12
12
|
|
13
|
+
from dataeval.config import get_seed
|
14
|
+
|
13
15
|
# NATS2BITS is the reciprocal of natural log of 2. If you have an information/entropy-type quantity measured in nats,
|
14
16
|
# which is what many library functions return, multiply it by NATS2BITS to get it in bits.
|
15
17
|
NATS2BITS = 1.442695
|
@@ -19,7 +21,6 @@ def get_metadata_ood_mi(
|
|
19
21
|
metadata: dict[str, list[Any] | NDArray[Any]],
|
20
22
|
is_ood: NDArray[np.bool_],
|
21
23
|
discrete_features: str | bool | NDArray[np.bool_] = False,
|
22
|
-
random_state: int | None = None,
|
23
24
|
) -> dict[str, float]:
|
24
25
|
"""Computes mutual information between a set of metadata features and an out-of-distribution flag.
|
25
26
|
|
@@ -39,9 +40,6 @@ def get_metadata_ood_mi(
|
|
39
40
|
A boolean array, with one value per example, that indicates which examples are OOD.
|
40
41
|
discrete_features : str | bool | NDArray[np.bool_]
|
41
42
|
Either a boolean array or a single boolean value, indicate which features take on discrete values.
|
42
|
-
random_state : int, optional - default None
|
43
|
-
Determines random number generation for small noise added to continuous variables. Set to a value for
|
44
|
-
reproducible results.
|
45
43
|
|
46
44
|
Returns
|
47
45
|
-------
|
@@ -55,7 +53,7 @@ def get_metadata_ood_mi(
|
|
55
53
|
|
56
54
|
>>> metadata = {"time": np.linspace(0, 10, 100), "altitude": np.linspace(0, 16, 100) ** 2}
|
57
55
|
>>> is_ood = metadata["altitude"] > 100
|
58
|
-
>>> get_metadata_ood_mi(metadata, is_ood, discrete_features=False
|
56
|
+
>>> get_metadata_ood_mi(metadata, is_ood, discrete_features=False)
|
59
57
|
{'time': 0.9359596758173668, 'altitude': 0.9407686591507002}
|
60
58
|
"""
|
61
59
|
numerical_keys = [k for k, v in metadata.items() if all(isinstance(vi, numbers.Number) for vi in v)]
|
@@ -84,7 +82,7 @@ def get_metadata_ood_mi(
|
|
84
82
|
Xscl,
|
85
83
|
is_ood,
|
86
84
|
discrete_features=discrete_features, # type: ignore
|
87
|
-
random_state=
|
85
|
+
random_state=get_seed(),
|
88
86
|
)
|
89
87
|
* NATS2BITS
|
90
88
|
)
|
dataeval/detectors/ood/mixin.py
CHANGED
@@ -157,4 +157,4 @@ class OODBaseMixin(Generic[TModel], ABC):
|
|
157
157
|
# compute outlier scores
|
158
158
|
score = self.score(X, batch_size=batch_size)
|
159
159
|
ood_pred = score.get(ood_type) > self._threshold_score(ood_type)
|
160
|
-
return OODOutput(is_ood=ood_pred, **score.
|
160
|
+
return OODOutput(is_ood=ood_pred, **score.data())
|
dataeval/detectors/ood/vae.py
CHANGED
@@ -17,6 +17,7 @@ from typing import Callable
|
|
17
17
|
import numpy as np
|
18
18
|
import torch
|
19
19
|
|
20
|
+
from dataeval.config import DeviceLike
|
20
21
|
from dataeval.detectors.ood.base import OODBase
|
21
22
|
from dataeval.outputs import OODScoreOutput
|
22
23
|
from dataeval.typing import ArrayLike
|
@@ -34,7 +35,7 @@ class OOD_VAE(OODBase):
|
|
34
35
|
An Autoencoder model.
|
35
36
|
"""
|
36
37
|
|
37
|
-
def __init__(self, model: torch.nn.Module, device:
|
38
|
+
def __init__(self, model: torch.nn.Module, device: DeviceLike | None = None) -> None:
|
38
39
|
super().__init__(model, device)
|
39
40
|
|
40
41
|
def fit(
|
dataeval/metadata/_distance.py
CHANGED
@@ -10,7 +10,8 @@ from scipy.stats import iqr, ks_2samp
|
|
10
10
|
from scipy.stats import wasserstein_distance as emd
|
11
11
|
|
12
12
|
from dataeval.metadata._utils import _compare_keys, _validate_factors_and_data
|
13
|
-
from dataeval.outputs
|
13
|
+
from dataeval.outputs import MetadataDistanceOutput, MetadataDistanceValues
|
14
|
+
from dataeval.outputs._base import set_metadata
|
14
15
|
from dataeval.typing import ArrayLike
|
15
16
|
from dataeval.utils.data import Metadata
|
16
17
|
|
@@ -23,41 +24,6 @@ class KSType(NamedTuple):
|
|
23
24
|
pvalue: float
|
24
25
|
|
25
26
|
|
26
|
-
class MetadataKSResult(NamedTuple):
|
27
|
-
"""
|
28
|
-
Attributes
|
29
|
-
----------
|
30
|
-
statistic : float
|
31
|
-
the KS statistic
|
32
|
-
location : float
|
33
|
-
The value at which the KS statistic has its maximum, measured in IQR-normalized units relative
|
34
|
-
to the median of the reference distribution.
|
35
|
-
dist : float
|
36
|
-
The Earth Mover's Distance normalized by the interquartile range (IQR) of the reference
|
37
|
-
pvalue : float
|
38
|
-
The p-value from the KS two-sample test
|
39
|
-
"""
|
40
|
-
|
41
|
-
statistic: float
|
42
|
-
location: float
|
43
|
-
dist: float
|
44
|
-
pvalue: float
|
45
|
-
|
46
|
-
|
47
|
-
class KSOutput(MappingOutput[str, MetadataKSResult]):
|
48
|
-
"""
|
49
|
-
Output class for results of ks_2samp featurewise comparisons of new metadata to reference metadata.
|
50
|
-
|
51
|
-
|
52
|
-
Attributes
|
53
|
-
----------
|
54
|
-
key: str
|
55
|
-
Metadata feature names
|
56
|
-
value: :class:`MetadataKSResult`
|
57
|
-
Output per feature name containing the statistic, statistic location, distance, and pvalue.
|
58
|
-
"""
|
59
|
-
|
60
|
-
|
61
27
|
def _calculate_drift(x1: ArrayLike, x2: ArrayLike) -> float:
|
62
28
|
"""Calculates the shift magnitude between x1 and x2 scaled by x1"""
|
63
29
|
|
@@ -74,7 +40,8 @@ def _calculate_drift(x1: ArrayLike, x2: ArrayLike) -> float:
|
|
74
40
|
return distance if xmin == xmax else distance / (xmax - xmin)
|
75
41
|
|
76
42
|
|
77
|
-
|
43
|
+
@set_metadata
|
44
|
+
def metadata_distance(metadata1: Metadata, metadata2: Metadata) -> MetadataDistanceOutput:
|
78
45
|
"""
|
79
46
|
Measures the feature-wise distance between two continuous metadata distributions and
|
80
47
|
computes a p-value to evaluate its significance.
|
@@ -90,8 +57,8 @@ def metadata_distance(metadata1: Metadata, metadata2: Metadata) -> KSOutput:
|
|
90
57
|
|
91
58
|
Returns
|
92
59
|
-------
|
93
|
-
|
94
|
-
A
|
60
|
+
MetadataDistanceOutput
|
61
|
+
A mapping with keys corresponding to metadata feature names, and values that are KstestResult objects, as
|
95
62
|
defined by scipy.stats.ks_2samp.
|
96
63
|
|
97
64
|
See Also
|
@@ -110,7 +77,7 @@ def metadata_distance(metadata1: Metadata, metadata2: Metadata) -> KSOutput:
|
|
110
77
|
>>> list(output)
|
111
78
|
['time', 'altitude']
|
112
79
|
>>> output["time"]
|
113
|
-
|
80
|
+
MetadataDistanceValues(statistic=1.0, location=0.44354838709677413, dist=2.7, pvalue=0.0)
|
114
81
|
"""
|
115
82
|
|
116
83
|
_compare_keys(metadata1.continuous_factor_names, metadata2.continuous_factor_names)
|
@@ -134,7 +101,7 @@ def metadata_distance(metadata1: Metadata, metadata2: Metadata) -> KSOutput:
|
|
134
101
|
)
|
135
102
|
|
136
103
|
# Set default for statistic, location, and magnitude to zero and pvalue to one
|
137
|
-
results: dict[str,
|
104
|
+
results: dict[str, MetadataDistanceValues] = {}
|
138
105
|
|
139
106
|
# Per factor
|
140
107
|
for i, fname in enumerate(fnames):
|
@@ -147,7 +114,7 @@ def metadata_distance(metadata1: Metadata, metadata2: Metadata) -> KSOutput:
|
|
147
114
|
|
148
115
|
# Default case
|
149
116
|
if xmin == xmax:
|
150
|
-
results[fname] =
|
117
|
+
results[fname] = MetadataDistanceValues(statistic=0.0, location=0.0, dist=0.0, pvalue=1.0)
|
151
118
|
continue
|
152
119
|
|
153
120
|
ks_result = cast(KSType, ks_2samp(fdata1, fdata2, method="asymp"))
|
@@ -157,11 +124,11 @@ def metadata_distance(metadata1: Metadata, metadata2: Metadata) -> KSOutput:
|
|
157
124
|
|
158
125
|
drift = _calculate_drift(fdata1, fdata2)
|
159
126
|
|
160
|
-
results[fname] =
|
127
|
+
results[fname] = MetadataDistanceValues(
|
161
128
|
statistic=ks_result.statistic,
|
162
129
|
location=loc,
|
163
130
|
dist=drift,
|
164
131
|
pvalue=ks_result.pvalue,
|
165
132
|
)
|
166
133
|
|
167
|
-
return
|
134
|
+
return MetadataDistanceOutput(results)
|
dataeval/metadata/_ood.py
CHANGED
@@ -8,7 +8,8 @@ import numpy as np
|
|
8
8
|
from numpy.typing import NDArray
|
9
9
|
|
10
10
|
from dataeval.metadata._utils import _compare_keys, _validate_factors_and_data
|
11
|
-
from dataeval.outputs import OODOutput
|
11
|
+
from dataeval.outputs import MostDeviatedFactorsOutput, OODOutput
|
12
|
+
from dataeval.outputs._base import set_metadata
|
12
13
|
from dataeval.utils.data import Metadata
|
13
14
|
|
14
15
|
|
@@ -119,11 +120,12 @@ def _calc_median_deviations(reference: NDArray, test: NDArray) -> NDArray:
|
|
119
120
|
return np.abs(np.where(test_dev >= 0, test_dev / pscale, test_dev / nscale)) # (S_t, F)
|
120
121
|
|
121
122
|
|
123
|
+
@set_metadata
|
122
124
|
def most_deviated_factors(
|
123
125
|
metadata_1: Metadata,
|
124
126
|
metadata_2: Metadata,
|
125
127
|
ood: OODOutput,
|
126
|
-
) ->
|
128
|
+
) -> MostDeviatedFactorsOutput:
|
127
129
|
"""
|
128
130
|
Determines greatest deviation in metadata features per out of distribution sample in metadata_2.
|
129
131
|
|
@@ -159,20 +161,20 @@ def most_deviated_factors(
|
|
159
161
|
|
160
162
|
>>> is_ood = OODOutput(np.array([True, True, True]), np.array([]), np.array([]))
|
161
163
|
>>> most_deviated_factors(metadata1, metadata2, is_ood)
|
162
|
-
[('time', 2.0), ('time', 2.592), ('time', 3.51)]
|
164
|
+
MostDeviatedFactorsOutput([('time', 2.0), ('time', 2.592), ('time', 3.51)])
|
163
165
|
|
164
166
|
If there are no out-of-distribution samples, a list is returned
|
165
167
|
|
166
168
|
>>> is_ood = OODOutput(np.array([False, False, False]), np.array([]), np.array([]))
|
167
169
|
>>> most_deviated_factors(metadata1, metadata2, is_ood)
|
168
|
-
[]
|
170
|
+
MostDeviatedFactorsOutput([])
|
169
171
|
"""
|
170
172
|
|
171
173
|
ood_mask: NDArray[np.bool] = ood.is_ood
|
172
174
|
|
173
175
|
# No metadata correlated with out of distribution data
|
174
176
|
if not any(ood_mask):
|
175
|
-
return []
|
177
|
+
return MostDeviatedFactorsOutput([])
|
176
178
|
|
177
179
|
# Combines reference and test factor names and data if exists and match exactly
|
178
180
|
# shape -> (samples, factors)
|
@@ -190,7 +192,7 @@ def most_deviated_factors(
|
|
190
192
|
f"At least 3 reference metadata samples are needed, got {len(metadata_ref)}",
|
191
193
|
UserWarning,
|
192
194
|
)
|
193
|
-
return []
|
195
|
+
return MostDeviatedFactorsOutput([])
|
194
196
|
|
195
197
|
if len(metadata_tst) != len(ood_mask):
|
196
198
|
raise ValueError(
|
@@ -214,4 +216,4 @@ def most_deviated_factors(
|
|
214
216
|
|
215
217
|
# List of tuples matching the factor name with its deviation
|
216
218
|
|
217
|
-
return [(factor, dev) for factor, dev in zip(most_ood_factors, deviation)]
|
219
|
+
return MostDeviatedFactorsOutput([(factor, dev) for factor, dev in zip(most_ood_factors, deviation)])
|
@@ -8,6 +8,7 @@ import numpy as np
|
|
8
8
|
import scipy as sp
|
9
9
|
from sklearn.feature_selection import mutual_info_classif, mutual_info_regression
|
10
10
|
|
11
|
+
from dataeval.config import get_seed
|
11
12
|
from dataeval.outputs import BalanceOutput
|
12
13
|
from dataeval.outputs._base import set_metadata
|
13
14
|
from dataeval.utils._bin import get_counts
|
@@ -91,6 +92,9 @@ def balance(
|
|
91
92
|
sklearn.feature_selection.mutual_info_regression
|
92
93
|
sklearn.metrics.mutual_info_score
|
93
94
|
"""
|
95
|
+
if not metadata.discrete_factor_names and not metadata.continuous_factor_names:
|
96
|
+
raise ValueError("No factors found in provided metadata.")
|
97
|
+
|
94
98
|
num_neighbors = _validate_num_neighbors(num_neighbors)
|
95
99
|
|
96
100
|
num_factors = metadata.total_num_factors
|
@@ -110,7 +114,7 @@ def balance(
|
|
110
114
|
data[:, idx],
|
111
115
|
discrete_features=is_discrete, # type: ignore
|
112
116
|
n_neighbors=num_neighbors,
|
113
|
-
random_state=
|
117
|
+
random_state=get_seed(),
|
114
118
|
)
|
115
119
|
else:
|
116
120
|
mi[idx, :] = mutual_info_classif(
|
@@ -118,7 +122,7 @@ def balance(
|
|
118
122
|
data[:, idx],
|
119
123
|
discrete_features=is_discrete, # type: ignore
|
120
124
|
n_neighbors=num_neighbors,
|
121
|
-
random_state=
|
125
|
+
random_state=get_seed(),
|
122
126
|
)
|
123
127
|
|
124
128
|
# Normalization via entropy
|
@@ -147,7 +151,7 @@ def balance(
|
|
147
151
|
tgt_bin[:, idx],
|
148
152
|
discrete_features=is_discrete, # type: ignore
|
149
153
|
n_neighbors=num_neighbors,
|
150
|
-
random_state=
|
154
|
+
random_state=get_seed(),
|
151
155
|
)
|
152
156
|
|
153
157
|
# Classwise normalization via entropy
|
@@ -158,6 +158,9 @@ def diversity(
|
|
158
158
|
--------
|
159
159
|
scipy.stats.entropy
|
160
160
|
"""
|
161
|
+
if not metadata.discrete_factor_names and not metadata.continuous_factor_names:
|
162
|
+
raise ValueError("No factors found in provided metadata.")
|
163
|
+
|
161
164
|
diversity_fn = get_method(_DIVERSITY_FN_MAP, method)
|
162
165
|
discretized_data = np.hstack((metadata.class_labels[:, np.newaxis], metadata.discrete_data))
|
163
166
|
cnts = get_counts(discretized_data)
|
dataeval/metrics/bias/_parity.py
CHANGED
@@ -241,6 +241,8 @@ def parity(metadata: Metadata) -> ParityOutput:
|
|
241
241
|
>>> parity(metadata)
|
242
242
|
ParityOutput(score=array([7.357, 5.467, 0.515]), p_value=array([0.289, 0.243, 0.773]), factor_names=['age', 'income', 'gender'], insufficient_data={'age': {3: {'artist': 4}, 4: {'artist': 4, 'teacher': 3}}, 'income': {1: {'artist': 3}}})
|
243
243
|
""" # noqa: E501
|
244
|
+
if not metadata.discrete_factor_names and not metadata.continuous_factor_names:
|
245
|
+
raise ValueError("No factors found in provided metadata.")
|
244
246
|
|
245
247
|
chi_scores = np.zeros(metadata.discrete_data.shape[1])
|
246
248
|
p_values = np.zeros_like(chi_scores)
|