dataeval 0.73.1__py3-none-any.whl → 0.74.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dataeval/__init__.py +1 -1
- dataeval/detectors/drift/torch.py +1 -101
- dataeval/detectors/ood/__init__.py +10 -3
- dataeval/detectors/ood/ae.py +2 -1
- dataeval/detectors/ood/ae_torch.py +70 -0
- dataeval/detectors/ood/aegmm.py +4 -3
- dataeval/detectors/ood/base.py +58 -108
- dataeval/detectors/ood/base_tf.py +109 -0
- dataeval/detectors/ood/base_torch.py +109 -0
- dataeval/detectors/ood/llr.py +2 -2
- dataeval/detectors/ood/metadata_ks_compare.py +53 -14
- dataeval/detectors/ood/vae.py +3 -2
- dataeval/detectors/ood/vaegmm.py +5 -4
- dataeval/metrics/bias/__init__.py +3 -0
- dataeval/metrics/bias/balance.py +70 -67
- dataeval/metrics/bias/coverage.py +1 -1
- dataeval/metrics/bias/diversity.py +64 -133
- dataeval/metrics/bias/metadata_preprocessing.py +285 -0
- dataeval/metrics/bias/metadata_utils.py +229 -0
- dataeval/metrics/bias/parity.py +47 -157
- dataeval/utils/gmm.py +26 -0
- dataeval/utils/metadata.py +29 -9
- dataeval/utils/tensorflow/_internal/gmm.py +4 -24
- dataeval/utils/torch/gmm.py +98 -0
- dataeval/utils/torch/models.py +192 -0
- dataeval/utils/torch/trainer.py +84 -5
- dataeval/utils/torch/utils.py +107 -1
- {dataeval-0.73.1.dist-info → dataeval-0.74.0.dist-info}/METADATA +1 -2
- {dataeval-0.73.1.dist-info → dataeval-0.74.0.dist-info}/RECORD +31 -25
- dataeval/metrics/bias/metadata.py +0 -440
- {dataeval-0.73.1.dist-info → dataeval-0.74.0.dist-info}/LICENSE.txt +0 -0
- {dataeval-0.73.1.dist-info → dataeval-0.74.0.dist-info}/WHEEL +0 -0
dataeval/detectors/ood/llr.py
CHANGED
@@ -16,7 +16,7 @@ from typing import TYPE_CHECKING, Callable
|
|
16
16
|
import numpy as np
|
17
17
|
from numpy.typing import ArrayLike, NDArray
|
18
18
|
|
19
|
-
from dataeval.detectors.ood.base import
|
19
|
+
from dataeval.detectors.ood.base import OODBaseMixin, OODScoreOutput
|
20
20
|
from dataeval.interop import to_numpy
|
21
21
|
from dataeval.utils.lazy import lazyload
|
22
22
|
from dataeval.utils.tensorflow._internal.trainer import trainer
|
@@ -96,7 +96,7 @@ def _mutate_categorical(
|
|
96
96
|
return tf.cast(X, tf.float32) # type: ignore
|
97
97
|
|
98
98
|
|
99
|
-
class OOD_LLR(
|
99
|
+
class OOD_LLR(OODBaseMixin[tf_models.PixelCNN]):
|
100
100
|
"""
|
101
101
|
Likelihood Ratios based outlier detector.
|
102
102
|
|
@@ -2,17 +2,45 @@ from __future__ import annotations
|
|
2
2
|
|
3
3
|
import numbers
|
4
4
|
import warnings
|
5
|
-
from
|
5
|
+
from dataclasses import dataclass
|
6
|
+
from typing import Any, Mapping, NamedTuple
|
6
7
|
|
7
8
|
import numpy as np
|
8
9
|
from numpy.typing import NDArray
|
9
10
|
from scipy.stats import iqr, ks_2samp
|
10
11
|
from scipy.stats import wasserstein_distance as emd
|
11
12
|
|
13
|
+
from dataeval.output import OutputMetadata, set_metadata
|
12
14
|
|
15
|
+
|
16
|
+
class MetadataKSResult(NamedTuple):
|
17
|
+
statistic: float
|
18
|
+
statistic_location: float
|
19
|
+
shift_magnitude: float
|
20
|
+
pvalue: float
|
21
|
+
|
22
|
+
|
23
|
+
@dataclass(frozen=True)
|
24
|
+
class KSOutput(OutputMetadata):
|
25
|
+
"""
|
26
|
+
Output class for results of ks_2samp featurewise comparisons of new metadata to reference metadata.
|
27
|
+
|
28
|
+
Attributes
|
29
|
+
----------
|
30
|
+
mdc : dict[str, dict[str, float]]
|
31
|
+
dict keyed by metadata feature names. Each value contains four floats, which are the KS statistic itself, its
|
32
|
+
location within the range of the reference metadata, the shift of new metadata relative to reference, the
|
33
|
+
p-value from the KS two-sample test.
|
34
|
+
|
35
|
+
"""
|
36
|
+
|
37
|
+
mdc: dict[str, MetadataKSResult]
|
38
|
+
|
39
|
+
|
40
|
+
@set_metadata()
|
13
41
|
def meta_distribution_compare(
|
14
42
|
md0: Mapping[str, list[Any] | NDArray[Any]], md1: Mapping[str, list[Any] | NDArray[Any]]
|
15
|
-
) ->
|
43
|
+
) -> KSOutput:
|
16
44
|
"""Measures the featurewise distance between two metadata distributions, and computes a p-value to evaluate its
|
17
45
|
significance.
|
18
46
|
|
@@ -43,27 +71,29 @@ def meta_distribution_compare(
|
|
43
71
|
>>> import numpy
|
44
72
|
>>> md0 = {"time": [1.2, 3.4, 5.6], "altitude": [235, 6789, 101112]}
|
45
73
|
>>> md1 = {"time": [7.8, 9.10, 11.12], "altitude": [532, 9876, 211101]}
|
46
|
-
>>> md_out = meta_distribution_compare(md0, md1)
|
74
|
+
>>> md_out = meta_distribution_compare(md0, md1).mdc
|
47
75
|
>>> for k, v in md_out.items():
|
48
76
|
>>> print(k)
|
49
77
|
>>> for kv in v:
|
50
78
|
>>> print("\t", f"{kv}: {v[kv]:.3f}")
|
51
79
|
time
|
52
|
-
|
53
|
-
|
54
|
-
|
80
|
+
statistic: 1.000
|
81
|
+
statistic_location: 0.444
|
82
|
+
shift_magnitude: 2.700
|
83
|
+
pvalue: 0.000
|
55
84
|
altitude
|
56
|
-
|
57
|
-
|
58
|
-
|
85
|
+
statistic: 0.333
|
86
|
+
statistic_location: 0.478
|
87
|
+
shift_magnitude: 0.749
|
88
|
+
pvalue: 0.944
|
59
89
|
"""
|
60
90
|
|
61
91
|
if (metadata_keys := md0.keys()) != md1.keys():
|
62
92
|
raise ValueError(f"Both sets of metadata keys must be identical: {list(md0)}, {list(md1)}")
|
63
93
|
|
64
|
-
|
94
|
+
mdc = {} # output dict
|
65
95
|
for k in metadata_keys:
|
66
|
-
|
96
|
+
mdc.update({k: {}})
|
67
97
|
|
68
98
|
x0, x1 = list(md0[k]), list(md1[k])
|
69
99
|
|
@@ -81,7 +111,9 @@ def meta_distribution_compare(
|
|
81
111
|
|
82
112
|
xmin, xmax = min(allx), max(allx)
|
83
113
|
if xmin == xmax: # only one value in this feature, so fill in the obvious results for feature k
|
84
|
-
|
114
|
+
mdc[k] = MetadataKSResult(
|
115
|
+
**{"statistic": 0.0, "statistic_location": 0.0, "shift_magnitude": 0.0, "pvalue": 1.0}
|
116
|
+
)
|
85
117
|
continue
|
86
118
|
|
87
119
|
ks_result = ks_2samp(x0, x1, method="asymp")
|
@@ -94,6 +126,13 @@ def meta_distribution_compare(
|
|
94
126
|
|
95
127
|
drift = emd(x0, x1) / dX
|
96
128
|
|
97
|
-
|
129
|
+
mdc[k] = MetadataKSResult(
|
130
|
+
**{
|
131
|
+
"statistic": ks_result.statistic, # pyright: ignore
|
132
|
+
"statistic_location": loc,
|
133
|
+
"shift_magnitude": drift,
|
134
|
+
"pvalue": ks_result.pvalue, # pyright: ignore
|
135
|
+
}
|
136
|
+
)
|
98
137
|
|
99
|
-
return
|
138
|
+
return KSOutput(mdc)
|
dataeval/detectors/ood/vae.py
CHANGED
@@ -15,7 +15,8 @@ from typing import TYPE_CHECKING, Callable
|
|
15
15
|
import numpy as np
|
16
16
|
from numpy.typing import ArrayLike
|
17
17
|
|
18
|
-
from dataeval.detectors.ood.base import
|
18
|
+
from dataeval.detectors.ood.base import OODScoreOutput
|
19
|
+
from dataeval.detectors.ood.base_tf import OODBase
|
19
20
|
from dataeval.interop import to_numpy
|
20
21
|
from dataeval.utils.lazy import lazyload
|
21
22
|
from dataeval.utils.tensorflow._internal.loss import Elbo
|
@@ -67,7 +68,7 @@ class OOD_VAE(OODBase):
|
|
67
68
|
self,
|
68
69
|
x_ref: ArrayLike,
|
69
70
|
threshold_perc: float = 100.0,
|
70
|
-
loss_fn: Callable[..., tf.Tensor] = Elbo(0.05),
|
71
|
+
loss_fn: Callable[..., tf.Tensor] | None = Elbo(0.05),
|
71
72
|
optimizer: keras.optimizers.Optimizer | None = None,
|
72
73
|
epochs: int = 20,
|
73
74
|
batch_size: int = 64,
|
dataeval/detectors/ood/vaegmm.py
CHANGED
@@ -15,7 +15,8 @@ from typing import TYPE_CHECKING, Callable
|
|
15
15
|
import numpy as np
|
16
16
|
from numpy.typing import ArrayLike
|
17
17
|
|
18
|
-
from dataeval.detectors.ood.base import
|
18
|
+
from dataeval.detectors.ood.base import OODScoreOutput
|
19
|
+
from dataeval.detectors.ood.base_tf import OODBaseGMM
|
19
20
|
from dataeval.interop import to_numpy
|
20
21
|
from dataeval.utils.lazy import lazyload
|
21
22
|
from dataeval.utils.tensorflow._internal.gmm import gmm_energy
|
@@ -33,7 +34,7 @@ else:
|
|
33
34
|
tf_models = lazyload("dataeval.utils.tensorflow._internal.models")
|
34
35
|
|
35
36
|
|
36
|
-
class OOD_VAEGMM(
|
37
|
+
class OOD_VAEGMM(OODBaseGMM):
|
37
38
|
"""
|
38
39
|
VAE with Gaussian Mixture Model based outlier detector.
|
39
40
|
|
@@ -53,7 +54,7 @@ class OOD_VAEGMM(OODGMMBase):
|
|
53
54
|
self,
|
54
55
|
x_ref: ArrayLike,
|
55
56
|
threshold_perc: float = 100.0,
|
56
|
-
loss_fn: Callable[..., tf.Tensor] = LossGMM(elbo=Elbo(0.05)),
|
57
|
+
loss_fn: Callable[..., tf.Tensor] | None = LossGMM(elbo=Elbo(0.05)),
|
57
58
|
optimizer: keras.optimizers.Optimizer | None = None,
|
58
59
|
epochs: int = 20,
|
59
60
|
batch_size: int = 64,
|
@@ -69,7 +70,7 @@ class OOD_VAEGMM(OODGMMBase):
|
|
69
70
|
_, z, _ = predict_batch(X_samples, self.model, batch_size=batch_size)
|
70
71
|
|
71
72
|
# compute average energy for samples
|
72
|
-
energy, _ = gmm_energy(z, self.
|
73
|
+
energy, _ = gmm_energy(z, self._gmm_params, return_mean=False)
|
73
74
|
energy_samples = energy.numpy().reshape((-1, self.samples)) # type: ignore
|
74
75
|
iscore = np.mean(energy_samples, axis=-1)
|
75
76
|
return OODScoreOutput(iscore)
|
@@ -6,6 +6,7 @@ representation which may impact model performance.
|
|
6
6
|
from dataeval.metrics.bias.balance import BalanceOutput, balance
|
7
7
|
from dataeval.metrics.bias.coverage import CoverageOutput, coverage
|
8
8
|
from dataeval.metrics.bias.diversity import DiversityOutput, diversity
|
9
|
+
from dataeval.metrics.bias.metadata_preprocessing import MetadataOutput, metadata_preprocessing
|
9
10
|
from dataeval.metrics.bias.parity import ParityOutput, label_parity, parity
|
10
11
|
|
11
12
|
__all__ = [
|
@@ -14,8 +15,10 @@ __all__ = [
|
|
14
15
|
"diversity",
|
15
16
|
"label_parity",
|
16
17
|
"parity",
|
18
|
+
"metadata_preprocessing",
|
17
19
|
"BalanceOutput",
|
18
20
|
"CoverageOutput",
|
19
21
|
"DiversityOutput",
|
20
22
|
"ParityOutput",
|
23
|
+
"MetadataOutput",
|
21
24
|
]
|
dataeval/metrics/bias/balance.py
CHANGED
@@ -5,13 +5,15 @@ __all__ = ["BalanceOutput", "balance"]
|
|
5
5
|
import contextlib
|
6
6
|
import warnings
|
7
7
|
from dataclasses import dataclass
|
8
|
-
from typing import Any
|
8
|
+
from typing import Any
|
9
9
|
|
10
10
|
import numpy as np
|
11
|
-
|
11
|
+
import scipy as sp
|
12
|
+
from numpy.typing import NDArray
|
12
13
|
from sklearn.feature_selection import mutual_info_classif, mutual_info_regression
|
13
14
|
|
14
|
-
from dataeval.metrics.bias.
|
15
|
+
from dataeval.metrics.bias.metadata_preprocessing import MetadataOutput
|
16
|
+
from dataeval.metrics.bias.metadata_utils import get_counts, heatmap
|
15
17
|
from dataeval.output import OutputMetadata, set_metadata
|
16
18
|
|
17
19
|
with contextlib.suppress(ImportError):
|
@@ -31,17 +33,17 @@ class BalanceOutput(OutputMetadata):
|
|
31
33
|
Estimate of inter/intra-factor mutual information
|
32
34
|
classwise : NDArray[np.float64]
|
33
35
|
Estimate of mutual information between metadata factors and individual class labels
|
36
|
+
factor_names : list[str]
|
37
|
+
Names of each metadata factor
|
34
38
|
class_list : NDArray
|
35
39
|
Array of the class labels present in the dataset
|
36
|
-
metadata_names : list[str]
|
37
|
-
Names of each metadata factor
|
38
40
|
"""
|
39
41
|
|
40
42
|
balance: NDArray[np.float64]
|
41
43
|
factors: NDArray[np.float64]
|
42
44
|
classwise: NDArray[np.float64]
|
45
|
+
factor_names: list[str]
|
43
46
|
class_list: NDArray[Any]
|
44
|
-
metadata_names: list[str]
|
45
47
|
|
46
48
|
def plot(
|
47
49
|
self,
|
@@ -65,7 +67,7 @@ class BalanceOutput(OutputMetadata):
|
|
65
67
|
if row_labels is None:
|
66
68
|
row_labels = self.class_list
|
67
69
|
if col_labels is None:
|
68
|
-
col_labels =
|
70
|
+
col_labels = self.factor_names
|
69
71
|
|
70
72
|
fig = heatmap(
|
71
73
|
self.classwise,
|
@@ -83,7 +85,7 @@ class BalanceOutput(OutputMetadata):
|
|
83
85
|
# Finalize the data for the plot, last row is last factor x last factor so it gets dropped
|
84
86
|
heat_data = np.where(mask, np.nan, data)[:-1]
|
85
87
|
# Creating label array for heat map axes
|
86
|
-
heat_labels =
|
88
|
+
heat_labels = self.factor_names
|
87
89
|
|
88
90
|
if row_labels is None:
|
89
91
|
row_labels = heat_labels[:-1]
|
@@ -95,7 +97,7 @@ class BalanceOutput(OutputMetadata):
|
|
95
97
|
return fig
|
96
98
|
|
97
99
|
|
98
|
-
def
|
100
|
+
def _validate_num_neighbors(num_neighbors: int) -> int:
|
99
101
|
if not isinstance(num_neighbors, (int, float)):
|
100
102
|
raise TypeError(
|
101
103
|
f"Variable {num_neighbors} is not real-valued numeric type."
|
@@ -117,28 +119,16 @@ def validate_num_neighbors(num_neighbors: int) -> int:
|
|
117
119
|
|
118
120
|
@set_metadata("dataeval.metrics")
|
119
121
|
def balance(
|
120
|
-
|
121
|
-
metadata: Mapping[str, ArrayLike],
|
122
|
+
metadata: MetadataOutput,
|
122
123
|
num_neighbors: int = 5,
|
123
|
-
continuous_factor_bincounts: Mapping[str, int] | None = None,
|
124
124
|
) -> BalanceOutput:
|
125
125
|
"""
|
126
126
|
Mutual information (MI) between factors (class label, metadata, label/image properties)
|
127
127
|
|
128
128
|
Parameters
|
129
129
|
----------
|
130
|
-
|
131
|
-
|
132
|
-
metadata : Mapping[str, ArrayLike]
|
133
|
-
Dict of lists of metadata factors for each image
|
134
|
-
num_neighbors : int, default 5
|
135
|
-
Number of nearest neighbors to use for computing MI between discrete
|
136
|
-
and continuous variables.
|
137
|
-
continuous_factor_bincounts : Mapping[str, int] or None, default None
|
138
|
-
The factors in metadata that have continuous values and the array of bin counts to
|
139
|
-
discretize values into. All factors are treated as having discrete values unless they
|
140
|
-
are specified as keys in this dictionary. Each element of this array must occur as a key
|
141
|
-
in metadata.
|
130
|
+
metadata : MetadataOutput
|
131
|
+
Output after running `metadata_preprocessing`
|
142
132
|
|
143
133
|
Returns
|
144
134
|
-------
|
@@ -150,30 +140,33 @@ def balance(
|
|
150
140
|
----
|
151
141
|
We use `mutual_info_classif` from sklearn since class label is categorical.
|
152
142
|
`mutual_info_classif` outputs are consistent up to O(1e-4) and depend on a random
|
153
|
-
seed. MI is computed differently for categorical and continuous variables
|
154
|
-
we attempt to infer whether a variable is categorical by the fraction of unique
|
155
|
-
values in the dataset.
|
143
|
+
seed. MI is computed differently for categorical and continuous variables.
|
156
144
|
|
157
145
|
Example
|
158
146
|
-------
|
159
147
|
Return balance (mutual information) of factors with class_labels
|
160
148
|
|
161
|
-
>>> bal = balance(
|
149
|
+
>>> bal = balance(metadata)
|
162
150
|
>>> bal.balance
|
163
|
-
array([0.
|
151
|
+
array([0.9999982 , 0.2494567 , 0.02994455, 0.13363788, 0. ,
|
152
|
+
0. ])
|
164
153
|
|
165
154
|
Return intra/interfactor balance (mutual information)
|
166
155
|
|
167
156
|
>>> bal.factors
|
168
|
-
array([[0.
|
169
|
-
[0.
|
170
|
-
[0.09725766, 0.
|
157
|
+
array([[0.99999935, 0.31360499, 0.26925848, 0.85201924, 0.36653548],
|
158
|
+
[0.31360499, 0.99999856, 0.09725766, 0.15836905, 1.98031993],
|
159
|
+
[0.26925848, 0.09725766, 0.99999846, 0.03713108, 0.01544656],
|
160
|
+
[0.85201924, 0.15836905, 0.03713108, 0.47450653, 0.25509664],
|
161
|
+
[0.36653548, 1.98031993, 0.01544656, 0.25509664, 1.06260686]])
|
171
162
|
|
172
163
|
Return classwise balance (mutual information) of factors with individual class_labels
|
173
164
|
|
174
165
|
>>> bal.classwise
|
175
|
-
array([[0.
|
176
|
-
|
166
|
+
array([[0.9999982 , 0.2494567 , 0.02994455, 0.13363788, 0. ,
|
167
|
+
0. ],
|
168
|
+
[0.9999982 , 0.2494567 , 0.02994455, 0.13363788, 0. ,
|
169
|
+
0. ]])
|
177
170
|
|
178
171
|
|
179
172
|
See Also
|
@@ -182,68 +175,78 @@ def balance(
|
|
182
175
|
sklearn.feature_selection.mutual_info_regression
|
183
176
|
sklearn.metrics.mutual_info_score
|
184
177
|
"""
|
185
|
-
num_neighbors =
|
186
|
-
|
187
|
-
num_factors =
|
188
|
-
|
189
|
-
mi
|
178
|
+
num_neighbors = _validate_num_neighbors(num_neighbors)
|
179
|
+
|
180
|
+
num_factors = metadata.total_num_factors
|
181
|
+
is_discrete = [True] * (len(metadata.discrete_factor_names) + 1) + [False] * len(metadata.continuous_factor_names)
|
182
|
+
mi = np.full((num_factors, num_factors), np.nan, dtype=np.float32)
|
183
|
+
data = np.hstack((metadata.class_labels[:, np.newaxis], metadata.discrete_data))
|
184
|
+
discretized_data = data
|
185
|
+
if metadata.continuous_data is not None:
|
186
|
+
data = np.hstack((data, metadata.continuous_data))
|
187
|
+
discrete_idx = [metadata.discrete_factor_names.index(name) for name in metadata.continuous_factor_names]
|
188
|
+
discretized_data = np.hstack((discretized_data, metadata.discrete_data[:, discrete_idx]))
|
190
189
|
|
191
190
|
for idx in range(num_factors):
|
192
|
-
|
193
|
-
|
194
|
-
if continuous_factor_bincounts and names[idx] not in continuous_factor_bincounts:
|
195
|
-
mi[idx, :] = mutual_info_classif(
|
191
|
+
if idx >= len(metadata.discrete_factor_names) + 1:
|
192
|
+
mi[idx, :] = mutual_info_regression(
|
196
193
|
data,
|
197
|
-
|
198
|
-
discrete_features=
|
194
|
+
data[:, idx],
|
195
|
+
discrete_features=is_discrete, # type: ignore
|
199
196
|
n_neighbors=num_neighbors,
|
200
197
|
random_state=0,
|
201
198
|
)
|
202
199
|
else:
|
203
|
-
mi[idx, :] =
|
200
|
+
mi[idx, :] = mutual_info_classif(
|
204
201
|
data,
|
205
|
-
|
206
|
-
discrete_features=
|
202
|
+
data[:, idx],
|
203
|
+
discrete_features=is_discrete, # type: ignore
|
207
204
|
n_neighbors=num_neighbors,
|
208
205
|
random_state=0,
|
209
206
|
)
|
210
207
|
|
211
|
-
|
212
|
-
|
208
|
+
# Normalization via entropy
|
209
|
+
bin_cnts = get_counts(discretized_data)
|
210
|
+
ent_factor = sp.stats.entropy(bin_cnts, axis=0)
|
211
|
+
norm_factor = 0.5 * np.add.outer(ent_factor, ent_factor) + 1e-6
|
212
|
+
|
213
213
|
# in principle MI should be symmetric, but it is not in practice.
|
214
214
|
nmi = 0.5 * (mi + mi.T) / norm_factor
|
215
215
|
balance = nmi[0]
|
216
216
|
factors = nmi[1:, 1:]
|
217
217
|
|
218
|
-
# unique class labels
|
219
|
-
class_idx = names.index(CLASS_LABEL)
|
220
|
-
u_cls = np.unique(data[:, class_idx])
|
221
|
-
num_classes = len(u_cls)
|
222
|
-
|
223
218
|
# assume class is a factor
|
224
|
-
|
225
|
-
classwise_mi
|
219
|
+
num_classes = metadata.class_names.size
|
220
|
+
classwise_mi = np.full((num_classes, num_factors), np.nan, dtype=np.float32)
|
226
221
|
|
227
|
-
#
|
228
|
-
|
229
|
-
|
230
|
-
tgt_bin = np.stack([data[:, class_idx] == cls for cls in u_cls]).T.astype(np.intp)
|
231
|
-
names = [str(idx) for idx in range(num_classes)]
|
232
|
-
ent_tgt_bin = entropy(tgt_bin, names, continuous_factor_bincounts)
|
222
|
+
# classwise targets
|
223
|
+
classes = np.unique(metadata.class_labels)
|
224
|
+
tgt_bin = data[:, 0][:, None] == classes
|
233
225
|
|
234
226
|
# classification MI for discrete/categorical features
|
235
227
|
for idx in range(num_classes):
|
236
|
-
# tgt = class_data == cls
|
237
228
|
# units: nat
|
238
229
|
classwise_mi[idx, :] = mutual_info_classif(
|
239
230
|
data,
|
240
231
|
tgt_bin[:, idx],
|
241
|
-
discrete_features=
|
232
|
+
discrete_features=is_discrete, # type: ignore
|
242
233
|
n_neighbors=num_neighbors,
|
243
234
|
random_state=0,
|
244
235
|
)
|
245
236
|
|
246
|
-
|
237
|
+
# Classwise normalization via entropy
|
238
|
+
classwise_bin_cnts = get_counts(tgt_bin)
|
239
|
+
ent_tgt_bin = sp.stats.entropy(classwise_bin_cnts, axis=0)
|
240
|
+
norm_factor = 0.5 * np.add.outer(ent_tgt_bin, ent_factor) + 1e-6
|
247
241
|
classwise = classwise_mi / norm_factor
|
248
242
|
|
249
|
-
|
243
|
+
# Grabbing factor names for plotting function
|
244
|
+
factor_names = ["class"]
|
245
|
+
for name in metadata.discrete_factor_names:
|
246
|
+
if name in metadata.continuous_factor_names:
|
247
|
+
name = name + "-discrete"
|
248
|
+
factor_names.append(name)
|
249
|
+
for name in metadata.continuous_factor_names:
|
250
|
+
factor_names.append(name + "-continuous")
|
251
|
+
|
252
|
+
return BalanceOutput(balance, factors, classwise, factor_names, metadata.class_names)
|
@@ -12,7 +12,7 @@ from numpy.typing import ArrayLike, NDArray
|
|
12
12
|
from scipy.spatial.distance import pdist, squareform
|
13
13
|
|
14
14
|
from dataeval.interop import to_numpy
|
15
|
-
from dataeval.metrics.bias.
|
15
|
+
from dataeval.metrics.bias.metadata_utils import coverage_plot
|
16
16
|
from dataeval.output import OutputMetadata, set_metadata
|
17
17
|
from dataeval.utils.shared import flatten
|
18
18
|
|