dataeval 0.76.1__py3-none-any.whl → 0.82.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dataeval/__init__.py +3 -3
- dataeval/config.py +77 -0
- dataeval/detectors/__init__.py +1 -1
- dataeval/detectors/drift/__init__.py +6 -6
- dataeval/detectors/drift/{base.py → _base.py} +40 -85
- dataeval/detectors/drift/{cvm.py → _cvm.py} +21 -28
- dataeval/detectors/drift/{ks.py → _ks.py} +20 -26
- dataeval/detectors/drift/{mmd.py → _mmd.py} +31 -43
- dataeval/detectors/drift/{torch.py → _torch.py} +2 -1
- dataeval/detectors/drift/{uncertainty.py → _uncertainty.py} +24 -7
- dataeval/detectors/drift/updates.py +20 -3
- dataeval/detectors/linters/__init__.py +3 -5
- dataeval/detectors/linters/duplicates.py +13 -36
- dataeval/detectors/linters/outliers.py +23 -148
- dataeval/detectors/ood/__init__.py +1 -1
- dataeval/detectors/ood/ae.py +30 -9
- dataeval/detectors/ood/base.py +5 -4
- dataeval/detectors/ood/mixin.py +21 -7
- dataeval/detectors/ood/vae.py +73 -0
- dataeval/metadata/__init__.py +6 -0
- dataeval/metadata/_distance.py +167 -0
- dataeval/metadata/_ood.py +217 -0
- dataeval/metadata/_utils.py +44 -0
- dataeval/metrics/__init__.py +1 -1
- dataeval/metrics/bias/__init__.py +6 -4
- dataeval/metrics/bias/{balance.py → _balance.py} +15 -101
- dataeval/metrics/bias/_coverage.py +98 -0
- dataeval/metrics/bias/{diversity.py → _diversity.py} +18 -111
- dataeval/metrics/bias/{parity.py → _parity.py} +39 -77
- dataeval/metrics/estimators/__init__.py +15 -4
- dataeval/metrics/estimators/{ber.py → _ber.py} +42 -29
- dataeval/metrics/estimators/_clusterer.py +44 -0
- dataeval/metrics/estimators/{divergence.py → _divergence.py} +18 -30
- dataeval/metrics/estimators/{uap.py → _uap.py} +4 -18
- dataeval/metrics/stats/__init__.py +16 -13
- dataeval/metrics/stats/{base.py → _base.py} +82 -133
- dataeval/metrics/stats/{boxratiostats.py → _boxratiostats.py} +15 -18
- dataeval/metrics/stats/_dimensionstats.py +75 -0
- dataeval/metrics/stats/{hashstats.py → _hashstats.py} +21 -37
- dataeval/metrics/stats/_imagestats.py +94 -0
- dataeval/metrics/stats/_labelstats.py +131 -0
- dataeval/metrics/stats/{pixelstats.py → _pixelstats.py} +19 -50
- dataeval/metrics/stats/{visualstats.py → _visualstats.py} +23 -54
- dataeval/outputs/__init__.py +53 -0
- dataeval/{output.py → outputs/_base.py} +55 -25
- dataeval/outputs/_bias.py +381 -0
- dataeval/outputs/_drift.py +83 -0
- dataeval/outputs/_estimators.py +114 -0
- dataeval/outputs/_linters.py +184 -0
- dataeval/{detectors/ood/output.py → outputs/_ood.py} +22 -22
- dataeval/outputs/_stats.py +387 -0
- dataeval/outputs/_utils.py +44 -0
- dataeval/outputs/_workflows.py +364 -0
- dataeval/typing.py +234 -0
- dataeval/utils/__init__.py +2 -2
- dataeval/utils/_array.py +169 -0
- dataeval/utils/_bin.py +199 -0
- dataeval/utils/_clusterer.py +144 -0
- dataeval/utils/_fast_mst.py +189 -0
- dataeval/utils/{image.py → _image.py} +6 -4
- dataeval/utils/_method.py +14 -0
- dataeval/utils/{shared.py → _mst.py} +3 -65
- dataeval/utils/{plot.py → _plot.py} +6 -6
- dataeval/utils/data/__init__.py +26 -0
- dataeval/utils/data/_dataset.py +217 -0
- dataeval/utils/data/_embeddings.py +104 -0
- dataeval/utils/data/_images.py +68 -0
- dataeval/utils/data/_metadata.py +360 -0
- dataeval/utils/data/_selection.py +126 -0
- dataeval/utils/{dataset/split.py → data/_split.py} +12 -38
- dataeval/utils/data/_targets.py +85 -0
- dataeval/utils/data/collate.py +103 -0
- dataeval/utils/data/datasets/__init__.py +17 -0
- dataeval/utils/data/datasets/_base.py +254 -0
- dataeval/utils/data/datasets/_cifar10.py +134 -0
- dataeval/utils/data/datasets/_fileio.py +168 -0
- dataeval/utils/data/datasets/_milco.py +153 -0
- dataeval/utils/data/datasets/_mixin.py +56 -0
- dataeval/utils/data/datasets/_mnist.py +183 -0
- dataeval/utils/data/datasets/_ships.py +123 -0
- dataeval/utils/data/datasets/_types.py +52 -0
- dataeval/utils/data/datasets/_voc.py +352 -0
- dataeval/utils/data/selections/__init__.py +15 -0
- dataeval/utils/data/selections/_classfilter.py +57 -0
- dataeval/utils/data/selections/_indices.py +26 -0
- dataeval/utils/data/selections/_limit.py +26 -0
- dataeval/utils/data/selections/_reverse.py +18 -0
- dataeval/utils/data/selections/_shuffle.py +29 -0
- dataeval/utils/metadata.py +51 -376
- dataeval/utils/torch/{gmm.py → _gmm.py} +4 -2
- dataeval/utils/torch/{internal.py → _internal.py} +21 -51
- dataeval/utils/torch/models.py +43 -2
- dataeval/workflows/__init__.py +2 -1
- dataeval/workflows/sufficiency.py +11 -346
- {dataeval-0.76.1.dist-info → dataeval-0.82.0.dist-info}/METADATA +5 -2
- dataeval-0.82.0.dist-info/RECORD +104 -0
- dataeval/detectors/linters/clusterer.py +0 -512
- dataeval/detectors/linters/merged_stats.py +0 -49
- dataeval/detectors/ood/metadata_ks_compare.py +0 -129
- dataeval/detectors/ood/metadata_least_likely.py +0 -119
- dataeval/interop.py +0 -69
- dataeval/metrics/bias/coverage.py +0 -194
- dataeval/metrics/stats/datasetstats.py +0 -202
- dataeval/metrics/stats/dimensionstats.py +0 -115
- dataeval/metrics/stats/labelstats.py +0 -210
- dataeval/utils/dataset/__init__.py +0 -7
- dataeval/utils/dataset/datasets.py +0 -412
- dataeval/utils/dataset/read.py +0 -63
- dataeval-0.76.1.dist-info/RECORD +0 -67
- /dataeval/{log.py → _log.py} +0 -0
- /dataeval/utils/torch/{blocks.py → _blocks.py} +0 -0
- {dataeval-0.76.1.dist-info → dataeval-0.82.0.dist-info}/LICENSE.txt +0 -0
- {dataeval-0.76.1.dist-info → dataeval-0.82.0.dist-info}/WHEEL +0 -0
@@ -0,0 +1,167 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
__all__ = []
|
4
|
+
|
5
|
+
import warnings
|
6
|
+
from typing import NamedTuple, cast
|
7
|
+
|
8
|
+
import numpy as np
|
9
|
+
from scipy.stats import iqr, ks_2samp
|
10
|
+
from scipy.stats import wasserstein_distance as emd
|
11
|
+
|
12
|
+
from dataeval.metadata._utils import _compare_keys, _validate_factors_and_data
|
13
|
+
from dataeval.outputs._base import MappingOutput
|
14
|
+
from dataeval.typing import ArrayLike
|
15
|
+
from dataeval.utils.data import Metadata
|
16
|
+
|
17
|
+
|
18
|
+
class KSType(NamedTuple):
|
19
|
+
"""Used to typehint scipy's internal hidden ks_2samp output"""
|
20
|
+
|
21
|
+
statistic: float
|
22
|
+
statistic_location: float
|
23
|
+
pvalue: float
|
24
|
+
|
25
|
+
|
26
|
+
class MetadataKSResult(NamedTuple):
|
27
|
+
"""
|
28
|
+
Attributes
|
29
|
+
----------
|
30
|
+
statistic : float
|
31
|
+
the KS statistic
|
32
|
+
location : float
|
33
|
+
The value at which the KS statistic has its maximum, measured in IQR-normalized units relative
|
34
|
+
to the median of the reference distribution.
|
35
|
+
dist : float
|
36
|
+
The Earth Mover's Distance normalized by the interquartile range (IQR) of the reference
|
37
|
+
pvalue : float
|
38
|
+
The p-value from the KS two-sample test
|
39
|
+
"""
|
40
|
+
|
41
|
+
statistic: float
|
42
|
+
location: float
|
43
|
+
dist: float
|
44
|
+
pvalue: float
|
45
|
+
|
46
|
+
|
47
|
+
class KSOutput(MappingOutput[str, MetadataKSResult]):
|
48
|
+
"""
|
49
|
+
Output class for results of ks_2samp featurewise comparisons of new metadata to reference metadata.
|
50
|
+
|
51
|
+
|
52
|
+
Attributes
|
53
|
+
----------
|
54
|
+
key: str
|
55
|
+
Metadata feature names
|
56
|
+
value: :class:`MetadataKSResult`
|
57
|
+
Output per feature name containing the statistic, statistic location, distance, and pvalue.
|
58
|
+
"""
|
59
|
+
|
60
|
+
|
61
|
+
def _calculate_drift(x1: ArrayLike, x2: ArrayLike) -> float:
|
62
|
+
"""Calculates the shift magnitude between x1 and x2 scaled by x1"""
|
63
|
+
|
64
|
+
distance = emd(x1, x2)
|
65
|
+
|
66
|
+
X = iqr(x1)
|
67
|
+
|
68
|
+
# Preferred scaling of x1
|
69
|
+
if X:
|
70
|
+
return distance / X
|
71
|
+
|
72
|
+
# Return if single-valued, else scale
|
73
|
+
xmin, xmax = np.min(x1), np.max(x1)
|
74
|
+
return distance if xmin == xmax else distance / (xmax - xmin)
|
75
|
+
|
76
|
+
|
77
|
+
def metadata_distance(metadata1: Metadata, metadata2: Metadata) -> KSOutput:
|
78
|
+
"""
|
79
|
+
Measures the feature-wise distance between two continuous metadata distributions and
|
80
|
+
computes a p-value to evaluate its significance.
|
81
|
+
|
82
|
+
Uses the Earth Mover's Distance and the Kolmogorov-Smirnov two-sample test, featurewise.
|
83
|
+
|
84
|
+
Parameters
|
85
|
+
----------
|
86
|
+
metadata1 : Metadata
|
87
|
+
Class containing continuous factor names and values to be used as reference
|
88
|
+
metadata2 : Metadata
|
89
|
+
Class containing continuous factor names and values to be compare with the reference
|
90
|
+
|
91
|
+
Returns
|
92
|
+
-------
|
93
|
+
dict[str, KstestResult]
|
94
|
+
A dictionary with keys corresponding to metadata feature names, and values that are KstestResult objects, as
|
95
|
+
defined by scipy.stats.ks_2samp.
|
96
|
+
|
97
|
+
See Also
|
98
|
+
--------
|
99
|
+
Earth mover's distance
|
100
|
+
|
101
|
+
Kolmogorov-Smirnov two-sample test
|
102
|
+
|
103
|
+
Note
|
104
|
+
----
|
105
|
+
This function only applies to the continuous data
|
106
|
+
|
107
|
+
Examples
|
108
|
+
--------
|
109
|
+
>>> output = metadata_distance(metadata1, metadata2)
|
110
|
+
>>> list(output)
|
111
|
+
['time', 'altitude']
|
112
|
+
>>> output["time"]
|
113
|
+
MetadataKSResult(statistic=1.0, location=0.44354838709677413, dist=2.7, pvalue=0.0)
|
114
|
+
"""
|
115
|
+
|
116
|
+
_compare_keys(metadata1.continuous_factor_names, metadata2.continuous_factor_names)
|
117
|
+
fnames = metadata1.continuous_factor_names
|
118
|
+
|
119
|
+
cont1 = np.atleast_2d(metadata1.continuous_data) # (S, F)
|
120
|
+
cont2 = np.atleast_2d(metadata2.continuous_data) # (S, F)
|
121
|
+
|
122
|
+
_validate_factors_and_data(fnames, cont1)
|
123
|
+
_validate_factors_and_data(fnames, cont2)
|
124
|
+
|
125
|
+
N = len(cont1)
|
126
|
+
M = len(cont2)
|
127
|
+
|
128
|
+
# This is a simplified version of sqrt(N*M / N+M) < 4
|
129
|
+
if (N - 16) * (M - 16) < 256:
|
130
|
+
warnings.warn(
|
131
|
+
f"Sample sizes of {N}, {M} will yield unreliable p-values from the KS test. "
|
132
|
+
f"Recommended 32 samples per factor or at least 16 if one set has many more.",
|
133
|
+
UserWarning,
|
134
|
+
)
|
135
|
+
|
136
|
+
# Set default for statistic, location, and magnitude to zero and pvalue to one
|
137
|
+
results: dict[str, MetadataKSResult] = {}
|
138
|
+
|
139
|
+
# Per factor
|
140
|
+
for i, fname in enumerate(fnames):
|
141
|
+
fdata1 = cont1[:, i] # (S, 1)
|
142
|
+
fdata2 = cont2[:, i] # (S, 1)
|
143
|
+
|
144
|
+
# Min and max over both distributions
|
145
|
+
xmin = min(np.min(fdata1), np.min(fdata2))
|
146
|
+
xmax = max(np.max(fdata1), np.max(fdata2))
|
147
|
+
|
148
|
+
# Default case
|
149
|
+
if xmin == xmax:
|
150
|
+
results[fname] = MetadataKSResult(statistic=0.0, location=0.0, dist=0.0, pvalue=1.0)
|
151
|
+
continue
|
152
|
+
|
153
|
+
ks_result = cast(KSType, ks_2samp(fdata1, fdata2, method="asymp"))
|
154
|
+
|
155
|
+
# Normalized location
|
156
|
+
loc = float((ks_result.statistic_location - xmin) / (xmax - xmin))
|
157
|
+
|
158
|
+
drift = _calculate_drift(fdata1, fdata2)
|
159
|
+
|
160
|
+
results[fname] = MetadataKSResult(
|
161
|
+
statistic=ks_result.statistic,
|
162
|
+
location=loc,
|
163
|
+
dist=drift,
|
164
|
+
pvalue=ks_result.pvalue,
|
165
|
+
)
|
166
|
+
|
167
|
+
return KSOutput(results)
|
@@ -0,0 +1,217 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
__all__ = []
|
4
|
+
|
5
|
+
import warnings
|
6
|
+
|
7
|
+
import numpy as np
|
8
|
+
from numpy.typing import NDArray
|
9
|
+
|
10
|
+
from dataeval.metadata._utils import _compare_keys, _validate_factors_and_data
|
11
|
+
from dataeval.outputs import OODOutput
|
12
|
+
from dataeval.utils.data import Metadata
|
13
|
+
|
14
|
+
|
15
|
+
def _combine_metadata(metadata_1: Metadata, metadata_2: Metadata) -> tuple[list[str], list[NDArray], list[NDArray]]:
|
16
|
+
"""
|
17
|
+
Combines the factor names and data arrays of metadata_1 and metadata_2 when the names
|
18
|
+
match exactly and data has the same number of columns (factors).
|
19
|
+
|
20
|
+
Parameters
|
21
|
+
----------
|
22
|
+
metadata_1 : Metadata
|
23
|
+
The set of factor names used as reference to determine the correct factor names and length of data
|
24
|
+
metadata_2 : Metadata
|
25
|
+
The compared set of factor names and data that must match metadata_1
|
26
|
+
|
27
|
+
Returns
|
28
|
+
-------
|
29
|
+
list[str]
|
30
|
+
The combined discrete and continuous factor names in that order.
|
31
|
+
list[NDArray]
|
32
|
+
Combined discrete and continuous data of metadata_1
|
33
|
+
list[NDArray]
|
34
|
+
Combined discrete and continuous data of metadata_2
|
35
|
+
|
36
|
+
Raises
|
37
|
+
------
|
38
|
+
ValueError
|
39
|
+
If keys do not match in metadata_1 and metadata_2
|
40
|
+
ValueError
|
41
|
+
If the length of keys do not match the length of the data
|
42
|
+
"""
|
43
|
+
factor_names: list[str] = []
|
44
|
+
m1_data: list[NDArray] = []
|
45
|
+
m2_data: list[NDArray] = []
|
46
|
+
|
47
|
+
# Both metadata must have the same number of factors (cols), but not necessarily samples (row)
|
48
|
+
if metadata_1.total_num_factors != metadata_2.total_num_factors:
|
49
|
+
raise ValueError(
|
50
|
+
f"Number of factors differs between metadata_1 ({metadata_1.total_num_factors}) "
|
51
|
+
f"and metadata_2 ({metadata_2.total_num_factors})"
|
52
|
+
)
|
53
|
+
|
54
|
+
# Validate and attach discrete data
|
55
|
+
if metadata_1.discrete_factor_names:
|
56
|
+
_compare_keys(metadata_1.discrete_factor_names, metadata_2.discrete_factor_names)
|
57
|
+
_validate_factors_and_data(metadata_1.discrete_factor_names, metadata_1.discrete_data)
|
58
|
+
|
59
|
+
factor_names.extend(metadata_1.discrete_factor_names)
|
60
|
+
m1_data.append(metadata_1.discrete_data)
|
61
|
+
m2_data.append(metadata_2.discrete_data)
|
62
|
+
|
63
|
+
# Validate and attach continuous data
|
64
|
+
if metadata_1.continuous_factor_names:
|
65
|
+
_compare_keys(metadata_1.continuous_factor_names, metadata_2.continuous_factor_names)
|
66
|
+
_validate_factors_and_data(metadata_1.continuous_factor_names, metadata_1.continuous_data)
|
67
|
+
|
68
|
+
factor_names.extend(metadata_1.continuous_factor_names)
|
69
|
+
m1_data.append(metadata_1.continuous_data)
|
70
|
+
m2_data.append(metadata_2.continuous_data)
|
71
|
+
|
72
|
+
# Turns list of discrete and continuous into one array
|
73
|
+
return factor_names, m1_data, m2_data
|
74
|
+
|
75
|
+
|
76
|
+
def _calc_median_deviations(reference: NDArray, test: NDArray) -> NDArray:
|
77
|
+
"""
|
78
|
+
Calculates deviations of the test data from the median of the reference data
|
79
|
+
|
80
|
+
Parameters
|
81
|
+
----------
|
82
|
+
reference : NDArray
|
83
|
+
Reference values of shape (samples, factors)
|
84
|
+
test : NDArray
|
85
|
+
Incoming values where each sample's factors will be compared to the median of
|
86
|
+
the reference set corresponding factors
|
87
|
+
|
88
|
+
Returns
|
89
|
+
-------
|
90
|
+
NDArray
|
91
|
+
Scaled positive and negative deviations of the test data from the reference.
|
92
|
+
|
93
|
+
Note
|
94
|
+
----
|
95
|
+
All return values are in the range [0, pos_inf]
|
96
|
+
"""
|
97
|
+
|
98
|
+
# Take median over samples (rows)
|
99
|
+
ref_median = np.median(reference, axis=0) # (F, )
|
100
|
+
|
101
|
+
# Shift reference and test distributions by reference
|
102
|
+
ref_dev = reference - ref_median # (S, F) - F
|
103
|
+
test_dev = test - ref_median # (S_t, F) - F
|
104
|
+
|
105
|
+
# Separate positive and negative distributions
|
106
|
+
# Fills with nans to keep shape in both 1-D and N-D matrices
|
107
|
+
pdev = np.where(ref_dev > 0, ref_dev, np.nan) # (S, F)
|
108
|
+
ndev = np.where(ref_dev < 0, ref_dev, np.nan) # (S, F)
|
109
|
+
|
110
|
+
# Calculate middle of positive and negative distributions per feature
|
111
|
+
pscale = np.nanmedian(pdev, axis=0) # (F, )
|
112
|
+
nscale = np.abs(np.nanmedian(ndev, axis=0)) # (F, )
|
113
|
+
|
114
|
+
# Replace 0's for division. Negatives should not happen
|
115
|
+
pscale = np.where(pscale > 0, pscale, 1.0) # (F, )
|
116
|
+
nscale = np.where(nscale > 0, nscale, 1.0) # (F, )
|
117
|
+
|
118
|
+
# Scales positive values by positive scale and negative values by negative
|
119
|
+
return np.abs(np.where(test_dev >= 0, test_dev / pscale, test_dev / nscale)) # (S_t, F)
|
120
|
+
|
121
|
+
|
122
|
+
def most_deviated_factors(
|
123
|
+
metadata_1: Metadata,
|
124
|
+
metadata_2: Metadata,
|
125
|
+
ood: OODOutput,
|
126
|
+
) -> list[tuple[str, float]]:
|
127
|
+
"""
|
128
|
+
Determines greatest deviation in metadata features per out of distribution sample in metadata_2.
|
129
|
+
|
130
|
+
Parameters
|
131
|
+
----------
|
132
|
+
metadata_1 : Metadata
|
133
|
+
A reference set of Metadata containing factor names and samples
|
134
|
+
with discrete and/or continuous values per factor
|
135
|
+
metadata_2 : Metadata
|
136
|
+
The set of Metadata that is tested against the reference metadata.
|
137
|
+
This set must have the same number of features but does not require the same number of samples.
|
138
|
+
ood : OODOutput
|
139
|
+
A class output by the DataEval's OOD functions that contains which examples are OOD.
|
140
|
+
|
141
|
+
Returns
|
142
|
+
-------
|
143
|
+
list[tuple[str, float]]
|
144
|
+
An array of the factor name and deviation of the highest metadata deviation for each OOD example in metadata_2.
|
145
|
+
|
146
|
+
Notes
|
147
|
+
-----
|
148
|
+
1. Both :class:`.Metadata` inputs must have discrete and continuous data in the shape (samples, factors)
|
149
|
+
and have equivalent factor names and lengths
|
150
|
+
2. The flag at index `i` in :attr:`.OODOutput.is_ood` must correspond
|
151
|
+
directly to sample `i` of `metadata_2` being out-of-distribution from `metadata_1`
|
152
|
+
|
153
|
+
Examples
|
154
|
+
--------
|
155
|
+
|
156
|
+
>>> from dataeval.detectors.ood import OODOutput
|
157
|
+
|
158
|
+
All samples are out-of-distribution
|
159
|
+
|
160
|
+
>>> is_ood = OODOutput(np.array([True, True, True]), np.array([]), np.array([]))
|
161
|
+
>>> most_deviated_factors(metadata1, metadata2, is_ood)
|
162
|
+
[('time', 2.0), ('time', 2.592), ('time', 3.51)]
|
163
|
+
|
164
|
+
If there are no out-of-distribution samples, a list is returned
|
165
|
+
|
166
|
+
>>> is_ood = OODOutput(np.array([False, False, False]), np.array([]), np.array([]))
|
167
|
+
>>> most_deviated_factors(metadata1, metadata2, is_ood)
|
168
|
+
[]
|
169
|
+
"""
|
170
|
+
|
171
|
+
ood_mask: NDArray[np.bool] = ood.is_ood
|
172
|
+
|
173
|
+
# No metadata correlated with out of distribution data
|
174
|
+
if not any(ood_mask):
|
175
|
+
return []
|
176
|
+
|
177
|
+
# Combines reference and test factor names and data if exists and match exactly
|
178
|
+
# shape -> (samples, factors)
|
179
|
+
factor_names, md_1, md_2 = _combine_metadata(
|
180
|
+
metadata_1=metadata_1,
|
181
|
+
metadata_2=metadata_2,
|
182
|
+
)
|
183
|
+
|
184
|
+
# Stack discrete and continuous factors as separate factors. Must have equal sample counts
|
185
|
+
metadata_ref = np.hstack(md_1) if md_1 else np.array([])
|
186
|
+
metadata_tst = np.hstack(md_2) if md_2 else np.array([])
|
187
|
+
|
188
|
+
if len(metadata_ref) < 3:
|
189
|
+
warnings.warn(
|
190
|
+
f"At least 3 reference metadata samples are needed, got {len(metadata_ref)}",
|
191
|
+
UserWarning,
|
192
|
+
)
|
193
|
+
return []
|
194
|
+
|
195
|
+
if len(metadata_tst) != len(ood_mask):
|
196
|
+
raise ValueError(
|
197
|
+
f"ood and test metadata must have the same length, "
|
198
|
+
f"got {len(ood_mask)} and {len(metadata_tst)} respectively."
|
199
|
+
)
|
200
|
+
|
201
|
+
# Calculates deviations of all samples in m2_data
|
202
|
+
# from the median values of the corresponding index in m1_data
|
203
|
+
# Guaranteed for inputs to not be empty
|
204
|
+
deviations = _calc_median_deviations(metadata_ref, metadata_tst)
|
205
|
+
|
206
|
+
# Get most impactful factor deviation of each sample for ood samples only
|
207
|
+
deviation = np.max(deviations, axis=1)[ood_mask].astype(np.float16)
|
208
|
+
|
209
|
+
# Get indices of most impactful factors for ood samples only
|
210
|
+
max_factors = np.argmax(deviations, axis=1)[ood_mask]
|
211
|
+
|
212
|
+
# Get names of most impactful factors TODO: Find better way than np.dtype(<U4)
|
213
|
+
most_ood_factors = np.array(factor_names)[max_factors].tolist()
|
214
|
+
|
215
|
+
# List of tuples matching the factor name with its deviation
|
216
|
+
|
217
|
+
return [(factor, dev) for factor, dev in zip(most_ood_factors, deviation)]
|
@@ -0,0 +1,44 @@
|
|
1
|
+
__all__ = []
|
2
|
+
|
3
|
+
from numpy.typing import NDArray
|
4
|
+
|
5
|
+
|
6
|
+
def _compare_keys(keys1: list[str], keys2: list[str]) -> None:
|
7
|
+
"""
|
8
|
+
Raises error when two lists are not equivalent including ordering
|
9
|
+
|
10
|
+
Parameters
|
11
|
+
----------
|
12
|
+
keys1 : list of strings
|
13
|
+
List of strings to compare
|
14
|
+
keys2 : list of strings
|
15
|
+
List of strings to compare
|
16
|
+
|
17
|
+
Raises
|
18
|
+
------
|
19
|
+
ValueError
|
20
|
+
If lists do not have the same values, value counts, or ordering
|
21
|
+
"""
|
22
|
+
|
23
|
+
if keys1 != keys2:
|
24
|
+
raise ValueError(f"Metadata keys must be identical, got {keys1} and {keys2}")
|
25
|
+
|
26
|
+
|
27
|
+
def _validate_factors_and_data(factors: list[str], data: NDArray) -> None:
|
28
|
+
"""
|
29
|
+
Raises error when the number of factors and number of rows do not match
|
30
|
+
|
31
|
+
Parameters
|
32
|
+
----------
|
33
|
+
factors : list of strings
|
34
|
+
List of factor names of size N
|
35
|
+
data : NDArray
|
36
|
+
Array of values with shape (M, N)
|
37
|
+
|
38
|
+
Raises
|
39
|
+
------
|
40
|
+
ValueError
|
41
|
+
If the length of factors does not equal the length of the transposed data
|
42
|
+
"""
|
43
|
+
if len(factors) != len(data.T):
|
44
|
+
raise ValueError(f"Factors and data have mismatched lengths. Got {len(factors)} and {len(data.T)}")
|
dataeval/metrics/__init__.py
CHANGED
@@ -7,6 +7,7 @@ __all__ = [
|
|
7
7
|
"BalanceOutput",
|
8
8
|
"CoverageOutput",
|
9
9
|
"DiversityOutput",
|
10
|
+
"LabelParityOutput",
|
10
11
|
"ParityOutput",
|
11
12
|
"balance",
|
12
13
|
"coverage",
|
@@ -15,7 +16,8 @@ __all__ = [
|
|
15
16
|
"parity",
|
16
17
|
]
|
17
18
|
|
18
|
-
from dataeval.metrics.bias.
|
19
|
-
from dataeval.metrics.bias.
|
20
|
-
from dataeval.metrics.bias.
|
21
|
-
from dataeval.metrics.bias.
|
19
|
+
from dataeval.metrics.bias._balance import balance
|
20
|
+
from dataeval.metrics.bias._coverage import coverage
|
21
|
+
from dataeval.metrics.bias._diversity import diversity
|
22
|
+
from dataeval.metrics.bias._parity import label_parity, parity
|
23
|
+
from dataeval.outputs._bias import BalanceOutput, CoverageOutput, DiversityOutput, LabelParityOutput, ParityOutput
|
@@ -2,99 +2,16 @@ from __future__ import annotations
|
|
2
2
|
|
3
3
|
__all__ = []
|
4
4
|
|
5
|
-
import contextlib
|
6
5
|
import warnings
|
7
|
-
from dataclasses import dataclass
|
8
|
-
from typing import Any
|
9
6
|
|
10
7
|
import numpy as np
|
11
8
|
import scipy as sp
|
12
|
-
from numpy.typing import NDArray
|
13
9
|
from sklearn.feature_selection import mutual_info_classif, mutual_info_regression
|
14
10
|
|
15
|
-
from dataeval.
|
16
|
-
from dataeval.
|
17
|
-
from dataeval.utils.
|
18
|
-
|
19
|
-
with contextlib.suppress(ImportError):
|
20
|
-
from matplotlib.figure import Figure
|
21
|
-
|
22
|
-
|
23
|
-
@dataclass(frozen=True)
|
24
|
-
class BalanceOutput(Output):
|
25
|
-
"""
|
26
|
-
Output class for :func:`balance` :term:`bias<Bias>` metric.
|
27
|
-
|
28
|
-
Attributes
|
29
|
-
----------
|
30
|
-
balance : NDArray[np.float64]
|
31
|
-
Estimate of mutual information between metadata factors and class label
|
32
|
-
factors : NDArray[np.float64]
|
33
|
-
Estimate of inter/intra-factor mutual information
|
34
|
-
classwise : NDArray[np.float64]
|
35
|
-
Estimate of mutual information between metadata factors and individual class labels
|
36
|
-
factor_names : list[str]
|
37
|
-
Names of each metadata factor
|
38
|
-
class_list : NDArray
|
39
|
-
Array of the class labels present in the dataset
|
40
|
-
"""
|
41
|
-
|
42
|
-
balance: NDArray[np.float64]
|
43
|
-
factors: NDArray[np.float64]
|
44
|
-
classwise: NDArray[np.float64]
|
45
|
-
factor_names: list[str]
|
46
|
-
class_list: NDArray[Any]
|
47
|
-
|
48
|
-
def plot(
|
49
|
-
self,
|
50
|
-
row_labels: list[Any] | NDArray[Any] | None = None,
|
51
|
-
col_labels: list[Any] | NDArray[Any] | None = None,
|
52
|
-
plot_classwise: bool = False,
|
53
|
-
) -> Figure:
|
54
|
-
"""
|
55
|
-
Plot a heatmap of balance information
|
56
|
-
|
57
|
-
Parameters
|
58
|
-
----------
|
59
|
-
row_labels : ArrayLike or None, default None
|
60
|
-
List/Array containing the labels for rows in the histogram
|
61
|
-
col_labels : ArrayLike or None, default None
|
62
|
-
List/Array containing the labels for columns in the histogram
|
63
|
-
plot_classwise : bool, default False
|
64
|
-
Whether to plot per-class balance instead of global balance
|
65
|
-
"""
|
66
|
-
if plot_classwise:
|
67
|
-
if row_labels is None:
|
68
|
-
row_labels = self.class_list
|
69
|
-
if col_labels is None:
|
70
|
-
col_labels = self.factor_names
|
71
|
-
|
72
|
-
fig = heatmap(
|
73
|
-
self.classwise,
|
74
|
-
row_labels,
|
75
|
-
col_labels,
|
76
|
-
xlabel="Factors",
|
77
|
-
ylabel="Class",
|
78
|
-
cbarlabel="Normalized Mutual Information",
|
79
|
-
)
|
80
|
-
else:
|
81
|
-
# Combine balance and factors results
|
82
|
-
data = np.concatenate([self.balance[np.newaxis, 1:], self.factors], axis=0)
|
83
|
-
# Create a mask for the upper triangle of the symmetrical array, ignoring the diagonal
|
84
|
-
mask = np.triu(data + 1, k=0) < 1
|
85
|
-
# Finalize the data for the plot, last row is last factor x last factor so it gets dropped
|
86
|
-
heat_data = np.where(mask, np.nan, data)[:-1]
|
87
|
-
# Creating label array for heat map axes
|
88
|
-
heat_labels = self.factor_names
|
89
|
-
|
90
|
-
if row_labels is None:
|
91
|
-
row_labels = heat_labels[:-1]
|
92
|
-
if col_labels is None:
|
93
|
-
col_labels = heat_labels[1:]
|
94
|
-
|
95
|
-
fig = heatmap(heat_data, row_labels, col_labels, cbarlabel="Normalized Mutual Information")
|
96
|
-
|
97
|
-
return fig
|
11
|
+
from dataeval.outputs import BalanceOutput
|
12
|
+
from dataeval.outputs._base import set_metadata
|
13
|
+
from dataeval.utils._bin import get_counts
|
14
|
+
from dataeval.utils.data import Metadata
|
98
15
|
|
99
16
|
|
100
17
|
def _validate_num_neighbors(num_neighbors: int) -> int:
|
@@ -128,7 +45,7 @@ def balance(
|
|
128
45
|
Parameters
|
129
46
|
----------
|
130
47
|
metadata : Metadata
|
131
|
-
Preprocessed metadata
|
48
|
+
Preprocessed metadata
|
132
49
|
num_neighbors : int, default 5
|
133
50
|
Number of points to consider as neighbors
|
134
51
|
|
@@ -150,25 +67,22 @@ def balance(
|
|
150
67
|
|
151
68
|
>>> bal = balance(metadata)
|
152
69
|
>>> bal.balance
|
153
|
-
array([0.
|
154
|
-
0. ])
|
70
|
+
array([1. , 0.249, 0.03 , 0.134, 0. , 0. ])
|
155
71
|
|
156
72
|
Return intra/interfactor balance (mutual information)
|
157
73
|
|
158
74
|
>>> bal.factors
|
159
|
-
array([[
|
160
|
-
[0.
|
161
|
-
[0.
|
162
|
-
[0.
|
163
|
-
[0.
|
75
|
+
array([[1. , 0.314, 0.269, 0.852, 0.367],
|
76
|
+
[0.314, 1. , 0.097, 0.158, 1.98 ],
|
77
|
+
[0.269, 0.097, 1. , 0.037, 0.015],
|
78
|
+
[0.852, 0.158, 0.037, 0.475, 0.255],
|
79
|
+
[0.367, 1.98 , 0.015, 0.255, 1.063]])
|
164
80
|
|
165
81
|
Return classwise balance (mutual information) of factors with individual class_labels
|
166
82
|
|
167
83
|
>>> bal.classwise
|
168
|
-
array([[0.
|
169
|
-
|
170
|
-
[0.9999982 , 0.2494567 , 0.02994455, 0.13363788, 0. ,
|
171
|
-
0. ]])
|
84
|
+
array([[1. , 0.249, 0.03 , 0.134, 0. , 0. ],
|
85
|
+
[1. , 0.249, 0.03 , 0.134, 0. , 0. ]])
|
172
86
|
|
173
87
|
|
174
88
|
See Also
|
@@ -184,7 +98,7 @@ def balance(
|
|
184
98
|
mi = np.full((num_factors, num_factors), np.nan, dtype=np.float32)
|
185
99
|
data = np.hstack((metadata.class_labels[:, np.newaxis], metadata.discrete_data))
|
186
100
|
discretized_data = data
|
187
|
-
if metadata.continuous_data
|
101
|
+
if len(metadata.continuous_data):
|
188
102
|
data = np.hstack((data, metadata.continuous_data))
|
189
103
|
discrete_idx = [metadata.discrete_factor_names.index(name) for name in metadata.continuous_factor_names]
|
190
104
|
discretized_data = np.hstack((discretized_data, metadata.discrete_data[:, discrete_idx]))
|
@@ -218,7 +132,7 @@ def balance(
|
|
218
132
|
factors = nmi[1:, 1:]
|
219
133
|
|
220
134
|
# assume class is a factor
|
221
|
-
num_classes = metadata.class_names
|
135
|
+
num_classes = len(metadata.class_names)
|
222
136
|
classwise_mi = np.full((num_classes, num_factors), np.nan, dtype=np.float32)
|
223
137
|
|
224
138
|
# classwise targets
|