dataeval 0.86.0__py3-none-any.whl → 0.86.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dataeval/__init__.py +1 -1
- dataeval/_log.py +1 -1
- dataeval/config.py +21 -4
- dataeval/data/_embeddings.py +2 -2
- dataeval/data/_images.py +2 -3
- dataeval/data/_metadata.py +188 -178
- dataeval/data/_selection.py +1 -2
- dataeval/data/_split.py +4 -5
- dataeval/data/_targets.py +17 -13
- dataeval/data/selections/_classfilter.py +2 -5
- dataeval/data/selections/_prioritize.py +6 -9
- dataeval/data/selections/_shuffle.py +3 -1
- dataeval/detectors/drift/_base.py +4 -5
- dataeval/detectors/drift/_mmd.py +3 -6
- dataeval/detectors/drift/_nml/_base.py +4 -2
- dataeval/detectors/drift/_nml/_chunk.py +11 -19
- dataeval/detectors/drift/_nml/_domainclassifier.py +8 -19
- dataeval/detectors/drift/_nml/_result.py +8 -9
- dataeval/detectors/drift/_nml/_thresholds.py +66 -77
- dataeval/detectors/linters/outliers.py +7 -7
- dataeval/metadata/_distance.py +10 -7
- dataeval/metadata/_ood.py +11 -103
- dataeval/metrics/bias/_balance.py +23 -33
- dataeval/metrics/bias/_diversity.py +16 -14
- dataeval/metrics/bias/_parity.py +18 -18
- dataeval/metrics/estimators/_divergence.py +2 -4
- dataeval/metrics/stats/_base.py +103 -42
- dataeval/metrics/stats/_boxratiostats.py +21 -19
- dataeval/metrics/stats/_dimensionstats.py +14 -10
- dataeval/metrics/stats/_hashstats.py +1 -1
- dataeval/metrics/stats/_pixelstats.py +6 -6
- dataeval/metrics/stats/_visualstats.py +3 -3
- dataeval/outputs/_base.py +22 -7
- dataeval/outputs/_bias.py +24 -70
- dataeval/outputs/_drift.py +1 -9
- dataeval/outputs/_linters.py +11 -11
- dataeval/outputs/_stats.py +82 -23
- dataeval/outputs/_workflows.py +2 -2
- dataeval/utils/_array.py +6 -9
- dataeval/utils/_bin.py +1 -2
- dataeval/utils/_clusterer.py +7 -4
- dataeval/utils/_fast_mst.py +27 -13
- dataeval/utils/_image.py +65 -11
- dataeval/utils/_mst.py +1 -3
- dataeval/utils/_plot.py +15 -10
- dataeval/utils/data/_dataset.py +54 -28
- dataeval/utils/data/metadata.py +104 -82
- dataeval/utils/datasets/__init__.py +2 -0
- dataeval/utils/datasets/_antiuav.py +189 -0
- dataeval/utils/datasets/_base.py +11 -8
- dataeval/utils/datasets/_cifar10.py +104 -45
- dataeval/utils/datasets/_fileio.py +21 -47
- dataeval/utils/datasets/_milco.py +22 -12
- dataeval/utils/datasets/_mixin.py +2 -4
- dataeval/utils/datasets/_mnist.py +3 -4
- dataeval/utils/datasets/_ships.py +14 -7
- dataeval/utils/datasets/_voc.py +229 -42
- dataeval/utils/torch/models.py +5 -10
- dataeval/utils/torch/trainer.py +3 -3
- dataeval/workflows/sufficiency.py +2 -2
- {dataeval-0.86.0.dist-info → dataeval-0.86.2.dist-info}/METADATA +2 -1
- dataeval-0.86.2.dist-info/RECORD +114 -0
- dataeval/detectors/ood/vae.py +0 -74
- dataeval-0.86.0.dist-info/RECORD +0 -114
- {dataeval-0.86.0.dist-info → dataeval-0.86.2.dist-info}/LICENSE.txt +0 -0
- {dataeval-0.86.0.dist-info → dataeval-0.86.2.dist-info}/WHEEL +0 -0
@@ -13,31 +13,31 @@ from dataeval.metrics.stats._imagestats import imagestats
|
|
13
13
|
from dataeval.outputs import DimensionStatsOutput, ImageStatsOutput, OutliersOutput, PixelStatsOutput, VisualStatsOutput
|
14
14
|
from dataeval.outputs._base import set_metadata
|
15
15
|
from dataeval.outputs._linters import IndexIssueMap, OutlierStatsOutput
|
16
|
-
from dataeval.outputs._stats import
|
16
|
+
from dataeval.outputs._stats import BASE_ATTRS
|
17
17
|
from dataeval.typing import ArrayLike, Dataset
|
18
18
|
|
19
19
|
|
20
20
|
def _get_outlier_mask(
|
21
21
|
values: NDArray, method: Literal["zscore", "modzscore", "iqr"], threshold: float | None
|
22
22
|
) -> NDArray:
|
23
|
+
values = values.astype(np.float64)
|
23
24
|
if method == "zscore":
|
24
25
|
threshold = threshold if threshold else 3.0
|
25
26
|
std = np.std(values)
|
26
27
|
abs_diff = np.abs(values - np.mean(values))
|
27
28
|
return std != 0 and (abs_diff / std) > threshold
|
28
|
-
|
29
|
+
if method == "modzscore":
|
29
30
|
threshold = threshold if threshold else 3.5
|
30
31
|
abs_diff = np.abs(values - np.median(values))
|
31
32
|
med_abs_diff = np.median(abs_diff) if np.median(abs_diff) != 0 else np.mean(abs_diff)
|
32
33
|
mod_z_score = 0.6745 * abs_diff / med_abs_diff
|
33
34
|
return mod_z_score > threshold
|
34
|
-
|
35
|
+
if method == "iqr":
|
35
36
|
threshold = threshold if threshold else 1.5
|
36
37
|
qrt = np.percentile(values, q=(25, 75), method="midpoint")
|
37
38
|
iqr = (qrt[1] - qrt[0]) * threshold
|
38
39
|
return (values < (qrt[0] - iqr)) | (values > (qrt[1] + iqr))
|
39
|
-
|
40
|
-
raise ValueError("Outlier method must be 'zscore' 'modzscore' or 'iqr'.")
|
40
|
+
raise ValueError("Outlier method must be 'zscore' 'modzscore' or 'iqr'.")
|
41
41
|
|
42
42
|
|
43
43
|
class Outliers:
|
@@ -103,7 +103,7 @@ class Outliers:
|
|
103
103
|
use_visual: bool = True,
|
104
104
|
outlier_method: Literal["zscore", "modzscore", "iqr"] = "modzscore",
|
105
105
|
outlier_threshold: float | None = None,
|
106
|
-
):
|
106
|
+
) -> None:
|
107
107
|
self.stats: ImageStatsOutput
|
108
108
|
self.use_dimension = use_dimension
|
109
109
|
self.use_pixel = use_pixel
|
@@ -114,7 +114,7 @@ class Outliers:
|
|
114
114
|
def _get_outliers(self, stats: dict) -> dict[int, dict[str, float]]:
|
115
115
|
flagged_images: dict[int, dict[str, float]] = {}
|
116
116
|
for stat, values in stats.items():
|
117
|
-
if stat in
|
117
|
+
if stat in BASE_ATTRS:
|
118
118
|
continue
|
119
119
|
if values.ndim == 1:
|
120
120
|
mask = _get_outlier_mask(values.astype(np.float64), self.outlier_method, self.outlier_threshold)
|
dataeval/metadata/_distance.py
CHANGED
@@ -80,14 +80,17 @@ def metadata_distance(metadata1: Metadata, metadata2: Metadata) -> MetadataDista
|
|
80
80
|
MetadataDistanceValues(statistic=1.0, location=0.44354838709677413, dist=2.7, pvalue=0.0)
|
81
81
|
"""
|
82
82
|
|
83
|
-
_compare_keys(metadata1.
|
84
|
-
|
83
|
+
_compare_keys(metadata1.factor_names, metadata2.factor_names)
|
84
|
+
cont_fnames = metadata1.get_factors_by_type("continuous")
|
85
85
|
|
86
|
-
|
87
|
-
|
86
|
+
if not cont_fnames:
|
87
|
+
return MetadataDistanceOutput({})
|
88
88
|
|
89
|
-
|
90
|
-
|
89
|
+
cont1 = np.atleast_2d(metadata1.dataframe[cont_fnames].to_numpy()) # (S, F)
|
90
|
+
cont2 = np.atleast_2d(metadata2.dataframe[cont_fnames].to_numpy()) # (S, F)
|
91
|
+
|
92
|
+
_validate_factors_and_data(cont_fnames, cont1)
|
93
|
+
_validate_factors_and_data(cont_fnames, cont2)
|
91
94
|
|
92
95
|
N = len(cont1)
|
93
96
|
M = len(cont2)
|
@@ -104,7 +107,7 @@ def metadata_distance(metadata1: Metadata, metadata2: Metadata) -> MetadataDista
|
|
104
107
|
results: dict[str, MetadataDistanceValues] = {}
|
105
108
|
|
106
109
|
# Per factor
|
107
|
-
for i, fname in enumerate(
|
110
|
+
for i, fname in enumerate(cont_fnames):
|
108
111
|
fdata1 = cont1[:, i] # (S, 1)
|
109
112
|
fdata2 = cont2[:, i] # (S, 1)
|
110
113
|
|
dataeval/metadata/_ood.py
CHANGED
@@ -15,95 +15,6 @@ from dataeval.outputs import MostDeviatedFactorsOutput, OODOutput, OODPredictorO
|
|
15
15
|
from dataeval.outputs._base import set_metadata
|
16
16
|
|
17
17
|
|
18
|
-
def _combine_discrete_continuous(metadata: Metadata) -> tuple[list[str], NDArray[np.float64]]:
|
19
|
-
"""Combines the discrete and continuous data of a :class:`Metadata` object
|
20
|
-
|
21
|
-
Returns
|
22
|
-
-------
|
23
|
-
Tuple[list[str], NDArray]
|
24
|
-
The combined list of factors names and the combined discrete and continuous data
|
25
|
-
|
26
|
-
Note
|
27
|
-
----
|
28
|
-
Discrete and continuous data must have the same number of samples
|
29
|
-
"""
|
30
|
-
names = []
|
31
|
-
data = []
|
32
|
-
|
33
|
-
if metadata.discrete_factor_names and metadata.discrete_data.size != 0:
|
34
|
-
names.extend(metadata.discrete_factor_names)
|
35
|
-
data.append(metadata.discrete_data)
|
36
|
-
|
37
|
-
if metadata.continuous_factor_names and metadata.continuous_data.size != 0:
|
38
|
-
names.extend(metadata.continuous_factor_names)
|
39
|
-
data.append(metadata.continuous_data)
|
40
|
-
|
41
|
-
return names, np.hstack(data, dtype=np.float64) if data else np.array([], dtype=np.float64)
|
42
|
-
|
43
|
-
|
44
|
-
def _combine_metadata(
|
45
|
-
metadata_1: Metadata, metadata_2: Metadata
|
46
|
-
) -> tuple[list[str], list[NDArray[np.float64 | np.int64]], list[NDArray[np.int64 | np.float64]]]:
|
47
|
-
"""
|
48
|
-
Combines the factor names and data arrays of metadata_1 and metadata_2 when the names
|
49
|
-
match exactly and data has the same number of columns (factors).
|
50
|
-
|
51
|
-
Parameters
|
52
|
-
----------
|
53
|
-
metadata_1 : Metadata
|
54
|
-
The set of factor names used as reference to determine the correct factor names and length of data
|
55
|
-
metadata_2 : Metadata
|
56
|
-
The compared set of factor names and data that must match metadata_1
|
57
|
-
|
58
|
-
Returns
|
59
|
-
-------
|
60
|
-
list[str]
|
61
|
-
The combined discrete and continuous factor names in that order.
|
62
|
-
list[NDArray]
|
63
|
-
Combined discrete and continuous data of metadata_1
|
64
|
-
list[NDArray]
|
65
|
-
Combined discrete and continuous data of metadata_2
|
66
|
-
|
67
|
-
Raises
|
68
|
-
------
|
69
|
-
ValueError
|
70
|
-
If keys do not match in metadata_1 and metadata_2
|
71
|
-
ValueError
|
72
|
-
If the length of keys do not match the length of the data
|
73
|
-
"""
|
74
|
-
factor_names: list[str] = []
|
75
|
-
m1_data: list[NDArray[np.int64 | np.float64]] = []
|
76
|
-
m2_data: list[NDArray[np.int64 | np.float64]] = []
|
77
|
-
|
78
|
-
# Both metadata must have the same number of factors (cols), but not necessarily samples (row)
|
79
|
-
if metadata_1.total_num_factors != metadata_2.total_num_factors:
|
80
|
-
raise ValueError(
|
81
|
-
f"Number of factors differs between metadata_1 ({metadata_1.total_num_factors}) "
|
82
|
-
f"and metadata_2 ({metadata_2.total_num_factors})"
|
83
|
-
)
|
84
|
-
|
85
|
-
# Validate and attach discrete data
|
86
|
-
if metadata_1.discrete_factor_names:
|
87
|
-
_compare_keys(metadata_1.discrete_factor_names, metadata_2.discrete_factor_names)
|
88
|
-
_validate_factors_and_data(metadata_1.discrete_factor_names, metadata_1.discrete_data)
|
89
|
-
|
90
|
-
factor_names.extend(metadata_1.discrete_factor_names)
|
91
|
-
m1_data.append(metadata_1.discrete_data)
|
92
|
-
m2_data.append(metadata_2.discrete_data)
|
93
|
-
|
94
|
-
# Validate and attach continuous data
|
95
|
-
if metadata_1.continuous_factor_names:
|
96
|
-
_compare_keys(metadata_1.continuous_factor_names, metadata_2.continuous_factor_names)
|
97
|
-
_validate_factors_and_data(metadata_1.continuous_factor_names, metadata_1.continuous_data)
|
98
|
-
|
99
|
-
factor_names.extend(metadata_1.continuous_factor_names)
|
100
|
-
m1_data.append(metadata_1.continuous_data)
|
101
|
-
m2_data.append(metadata_2.continuous_data)
|
102
|
-
|
103
|
-
# Turns list of discrete and continuous into one array
|
104
|
-
return factor_names, m1_data, m2_data
|
105
|
-
|
106
|
-
|
107
18
|
def _calc_median_deviations(reference: NDArray, test: NDArray) -> NDArray:
|
108
19
|
"""
|
109
20
|
Calculates deviations of the test data from the median of the reference data
|
@@ -207,16 +118,13 @@ def find_most_deviated_factors(
|
|
207
118
|
if not any(ood_mask):
|
208
119
|
return MostDeviatedFactorsOutput([])
|
209
120
|
|
210
|
-
|
211
|
-
|
212
|
-
|
213
|
-
metadata_1=metadata_ref,
|
214
|
-
metadata_2=metadata_tst,
|
215
|
-
)
|
121
|
+
factor_names = metadata_ref.factor_names
|
122
|
+
ref_data = metadata_ref.factor_data
|
123
|
+
tst_data = metadata_tst.factor_data
|
216
124
|
|
217
|
-
|
218
|
-
|
219
|
-
|
125
|
+
_compare_keys(factor_names, metadata_tst.factor_names)
|
126
|
+
_validate_factors_and_data(factor_names, ref_data)
|
127
|
+
_validate_factors_and_data(factor_names, tst_data)
|
220
128
|
|
221
129
|
if len(ref_data) < 3:
|
222
130
|
warnings.warn(
|
@@ -256,6 +164,7 @@ which is what many library functions return, multiply it by _NATS2BITS to get it
|
|
256
164
|
"""
|
257
165
|
|
258
166
|
|
167
|
+
@set_metadata
|
259
168
|
def find_ood_predictors(
|
260
169
|
metadata: Metadata,
|
261
170
|
ood: OODOutput,
|
@@ -305,8 +214,8 @@ def find_ood_predictors(
|
|
305
214
|
|
306
215
|
ood_mask: NDArray[np.bool_] = ood.is_ood
|
307
216
|
|
308
|
-
|
309
|
-
|
217
|
+
factors = metadata.factor_names
|
218
|
+
data = metadata.factor_data
|
310
219
|
|
311
220
|
# No metadata correlated with out of distribution data, return 0.0 for all factors
|
312
221
|
if not any(ood_mask):
|
@@ -320,14 +229,13 @@ def find_ood_predictors(
|
|
320
229
|
# Calculate mean, std of each factor over all samples
|
321
230
|
scaled_data = (data - np.mean(data, axis=0)) / np.std(data, axis=0, ddof=1) # (S, F)
|
322
231
|
|
323
|
-
discrete_features =
|
324
|
-
discrete_features[:discrete_features_count] = True
|
232
|
+
discrete_features = [info.factor_type != "continuous" for info in metadata.factor_info.values()]
|
325
233
|
|
326
234
|
mutual_info_values = (
|
327
235
|
mutual_info_classif(
|
328
236
|
X=scaled_data,
|
329
237
|
y=ood_mask,
|
330
|
-
discrete_features=discrete_features, # type: ignore
|
238
|
+
discrete_features=discrete_features, # type: ignore - sklearn function not typed
|
331
239
|
random_state=get_seed(),
|
332
240
|
)
|
333
241
|
* _NATS2BITS
|
@@ -68,22 +68,20 @@ def balance(
|
|
68
68
|
|
69
69
|
>>> bal = balance(metadata)
|
70
70
|
>>> bal.balance
|
71
|
-
array([1. , 0.
|
71
|
+
array([1. , 0.134, 0. , 0. ])
|
72
72
|
|
73
73
|
Return intra/interfactor balance (mutual information)
|
74
74
|
|
75
75
|
>>> bal.factors
|
76
|
-
array([[1. , 0.
|
77
|
-
[0.
|
78
|
-
[0.
|
79
|
-
[0.852, 0.158, 0.037, 0.475, 0.255],
|
80
|
-
[0.367, 1.98 , 0.015, 0.255, 1.063]])
|
76
|
+
array([[1. , 0.017, 0.015],
|
77
|
+
[0.017, 0.445, 0.245],
|
78
|
+
[0.015, 0.245, 1.063]])
|
81
79
|
|
82
80
|
Return classwise balance (mutual information) of factors with individual class_labels
|
83
81
|
|
84
82
|
>>> bal.classwise
|
85
|
-
array([[1. , 0.
|
86
|
-
[1. , 0.
|
83
|
+
array([[1. , 0.134, 0. , 0. ],
|
84
|
+
[1. , 0.134, 0. , 0. ]])
|
87
85
|
|
88
86
|
|
89
87
|
See Also
|
@@ -92,41 +90,39 @@ def balance(
|
|
92
90
|
sklearn.feature_selection.mutual_info_regression
|
93
91
|
sklearn.metrics.mutual_info_score
|
94
92
|
"""
|
95
|
-
if not metadata.
|
93
|
+
if not metadata.factor_names:
|
96
94
|
raise ValueError("No factors found in provided metadata.")
|
97
95
|
|
98
96
|
num_neighbors = _validate_num_neighbors(num_neighbors)
|
99
97
|
|
100
|
-
|
101
|
-
|
98
|
+
data = metadata.discretized_data
|
99
|
+
factor_types = {"class_label": "categorical"} | {k: v.factor_type for k, v in metadata.factor_info.items()}
|
100
|
+
is_discrete = [factor_type != "continuous" for factor_type in factor_types.values()]
|
101
|
+
num_factors = len(factor_types)
|
102
|
+
|
102
103
|
mi = np.full((num_factors, num_factors), np.nan, dtype=np.float32)
|
103
|
-
data = np.hstack((metadata.class_labels[:, np.newaxis],
|
104
|
-
|
105
|
-
|
106
|
-
|
107
|
-
|
108
|
-
discretized_data = np.hstack((discretized_data, metadata.discrete_data[:, discrete_idx]))
|
109
|
-
|
110
|
-
for idx in range(num_factors):
|
111
|
-
if idx >= len(metadata.discrete_factor_names) + 1:
|
112
|
-
mi[idx, :] = mutual_info_regression(
|
104
|
+
data = np.hstack((metadata.class_labels[:, np.newaxis], data))
|
105
|
+
|
106
|
+
for idx, factor_type in enumerate(factor_types.values()):
|
107
|
+
if factor_type != "continuous":
|
108
|
+
mi[idx, :] = mutual_info_classif(
|
113
109
|
data,
|
114
110
|
data[:, idx],
|
115
|
-
discrete_features=is_discrete, # type: ignore
|
111
|
+
discrete_features=is_discrete, # type: ignore - sklearn function not typed
|
116
112
|
n_neighbors=num_neighbors,
|
117
113
|
random_state=get_seed(),
|
118
114
|
)
|
119
115
|
else:
|
120
|
-
mi[idx, :] =
|
116
|
+
mi[idx, :] = mutual_info_regression(
|
121
117
|
data,
|
122
118
|
data[:, idx],
|
123
|
-
discrete_features=is_discrete, # type: ignore
|
119
|
+
discrete_features=is_discrete, # type: ignore - sklearn function not typed
|
124
120
|
n_neighbors=num_neighbors,
|
125
121
|
random_state=get_seed(),
|
126
122
|
)
|
127
123
|
|
128
124
|
# Normalization via entropy
|
129
|
-
bin_cnts = get_counts(
|
125
|
+
bin_cnts = get_counts(data)
|
130
126
|
ent_factor = sp.stats.entropy(bin_cnts, axis=0)
|
131
127
|
norm_factor = 0.5 * np.add.outer(ent_factor, ent_factor) + EPSILON
|
132
128
|
|
@@ -149,7 +145,7 @@ def balance(
|
|
149
145
|
classwise_mi[idx, :] = mutual_info_classif(
|
150
146
|
data,
|
151
147
|
tgt_bin[:, idx],
|
152
|
-
discrete_features=is_discrete, # type: ignore
|
148
|
+
discrete_features=is_discrete, # type: ignore - sklearn function not typed
|
153
149
|
n_neighbors=num_neighbors,
|
154
150
|
random_state=get_seed(),
|
155
151
|
)
|
@@ -161,12 +157,6 @@ def balance(
|
|
161
157
|
classwise = classwise_mi / norm_factor
|
162
158
|
|
163
159
|
# Grabbing factor names for plotting function
|
164
|
-
factor_names = ["
|
165
|
-
for name in metadata.discrete_factor_names:
|
166
|
-
if name in metadata.continuous_factor_names:
|
167
|
-
name = name + "-discrete"
|
168
|
-
factor_names.append(name)
|
169
|
-
for name in metadata.continuous_factor_names:
|
170
|
-
factor_names.append(name + "-continuous")
|
160
|
+
factor_names = ["class_label"] + metadata.factor_names
|
171
161
|
|
172
162
|
return BalanceOutput(balance, factors, classwise, factor_names, metadata.class_names)
|
@@ -138,43 +138,45 @@ def diversity(
|
|
138
138
|
|
139
139
|
>>> div_simp = diversity(metadata, method="simpson")
|
140
140
|
>>> div_simp.diversity_index
|
141
|
-
array([0.6 , 0.809, 1.
|
141
|
+
array([0.6 , 0.8 , 0.809, 1. ])
|
142
142
|
|
143
143
|
>>> div_simp.classwise
|
144
|
-
array([[0.
|
145
|
-
[0.
|
144
|
+
array([[0.8 , 0.5 , 0.8 ],
|
145
|
+
[0.528, 0.63 , 0.976]])
|
146
146
|
|
147
147
|
Compute Shannon diversity index of metadata and class labels
|
148
148
|
|
149
149
|
>>> div_shan = diversity(metadata, method="shannon")
|
150
150
|
>>> div_shan.diversity_index
|
151
|
-
array([0.811, 0.943, 1.
|
151
|
+
array([0.811, 0.918, 0.943, 1. ])
|
152
152
|
|
153
153
|
>>> div_shan.classwise
|
154
|
-
array([[0.
|
155
|
-
[0.
|
154
|
+
array([[0.918, 0.683, 0.918],
|
155
|
+
[0.764, 0.814, 0.991]])
|
156
156
|
|
157
157
|
See Also
|
158
158
|
--------
|
159
159
|
scipy.stats.entropy
|
160
160
|
"""
|
161
|
-
if not metadata.
|
161
|
+
if not metadata.factor_names:
|
162
162
|
raise ValueError("No factors found in provided metadata.")
|
163
163
|
|
164
164
|
diversity_fn = get_method(_DIVERSITY_FN_MAP, method)
|
165
|
-
discretized_data =
|
166
|
-
|
165
|
+
discretized_data = metadata.discretized_data
|
166
|
+
factor_names = metadata.factor_names
|
167
|
+
class_lbl = metadata.class_labels
|
168
|
+
|
169
|
+
class_labels_with_discretized_data = np.hstack((class_lbl[:, np.newaxis], discretized_data))
|
170
|
+
cnts = get_counts(class_labels_with_discretized_data)
|
167
171
|
num_bins = np.bincount(np.nonzero(cnts)[1])
|
168
172
|
diversity_index = diversity_fn(cnts, num_bins)
|
169
173
|
|
170
|
-
class_lbl = metadata.class_labels
|
171
|
-
|
172
174
|
u_classes = np.unique(class_lbl)
|
173
|
-
num_factors = len(
|
175
|
+
num_factors = len(factor_names)
|
174
176
|
classwise_div = np.full((len(u_classes), num_factors), np.nan)
|
175
177
|
for idx, cls in enumerate(u_classes):
|
176
178
|
subset_mask = class_lbl == cls
|
177
|
-
cls_cnts = get_counts(
|
179
|
+
cls_cnts = get_counts(discretized_data[subset_mask], min_num_bins=cnts.shape[0])
|
178
180
|
classwise_div[idx, :] = diversity_fn(cls_cnts, num_bins[1:])
|
179
181
|
|
180
|
-
return DiversityOutput(diversity_index, classwise_div,
|
182
|
+
return DiversityOutput(diversity_index, classwise_div, factor_names, metadata.class_names)
|
dataeval/metrics/bias/_parity.py
CHANGED
@@ -3,6 +3,7 @@ from __future__ import annotations
|
|
3
3
|
__all__ = []
|
4
4
|
|
5
5
|
import warnings
|
6
|
+
from collections import defaultdict
|
6
7
|
from typing import Any
|
7
8
|
|
8
9
|
import numpy as np
|
@@ -241,13 +242,13 @@ def parity(metadata: Metadata) -> ParityOutput:
|
|
241
242
|
>>> parity(metadata)
|
242
243
|
ParityOutput(score=array([7.357, 5.467, 0.515]), p_value=array([0.289, 0.243, 0.773]), factor_names=['age', 'income', 'gender'], insufficient_data={'age': {3: {'artist': 4}, 4: {'artist': 4, 'teacher': 3}}, 'income': {1: {'artist': 3}}})
|
243
244
|
""" # noqa: E501
|
244
|
-
if not metadata.
|
245
|
+
if not metadata.factor_names:
|
245
246
|
raise ValueError("No factors found in provided metadata.")
|
246
247
|
|
247
|
-
chi_scores = np.zeros(metadata.
|
248
|
+
chi_scores = np.zeros(metadata.discretized_data.shape[1])
|
248
249
|
p_values = np.zeros_like(chi_scores)
|
249
|
-
insufficient_data =
|
250
|
-
for i, col_data in enumerate(metadata.
|
250
|
+
insufficient_data: defaultdict[str, defaultdict[int, dict[str, int]]] = defaultdict(lambda: defaultdict(dict))
|
251
|
+
for i, col_data in enumerate(metadata.discretized_data.T):
|
251
252
|
# Builds a contingency matrix where entry at index (r,c) represents
|
252
253
|
# the frequency of current_factor_name achieving value unique_factor_values[r]
|
253
254
|
# at a data point with class c.
|
@@ -257,30 +258,29 @@ def parity(metadata: Metadata) -> ParityOutput:
|
|
257
258
|
# Determines if any frequencies are too low
|
258
259
|
counts = np.nonzero(contingency_matrix < 5)
|
259
260
|
unique_factor_values = np.unique(col_data)
|
260
|
-
current_factor_name = metadata.
|
261
|
+
current_factor_name = metadata.factor_names[i]
|
261
262
|
for int_factor, int_class in zip(counts[0], counts[1]):
|
262
263
|
if contingency_matrix[int_factor, int_class] > 0:
|
263
264
|
factor_category = unique_factor_values[int_factor].item()
|
264
|
-
if current_factor_name not in insufficient_data:
|
265
|
-
insufficient_data[current_factor_name] = {}
|
266
|
-
if factor_category not in insufficient_data[current_factor_name]:
|
267
|
-
insufficient_data[current_factor_name][factor_category] = {}
|
268
265
|
class_name = metadata.class_names[int_class]
|
269
266
|
class_count = contingency_matrix[int_factor, int_class].item()
|
270
267
|
insufficient_data[current_factor_name][factor_category][class_name] = class_count
|
271
268
|
|
272
269
|
# This deletes rows containing only zeros,
|
273
270
|
# because scipy.stats.chi2_contingency fails when there are rows containing only zeros.
|
274
|
-
|
275
|
-
rowmask = np.nonzero(rowsums)[0]
|
276
|
-
contingency_matrix = contingency_matrix[rowmask]
|
271
|
+
contingency_matrix = contingency_matrix[np.any(contingency_matrix, axis=1)]
|
277
272
|
|
278
|
-
|
279
|
-
|
280
|
-
chi_scores[i] = chi2
|
281
|
-
p_values[i] = p
|
273
|
+
chi_scores[i], p_values[i] = chi2_contingency(contingency_matrix)[:2]
|
282
274
|
|
283
275
|
if insufficient_data:
|
284
|
-
warnings.warn(
|
276
|
+
warnings.warn(
|
277
|
+
f"Factors {list(insufficient_data)} did not meet the recommended "
|
278
|
+
"5 occurrences for each value-label combination."
|
279
|
+
)
|
285
280
|
|
286
|
-
return ParityOutput(
|
281
|
+
return ParityOutput(
|
282
|
+
score=chi_scores,
|
283
|
+
p_value=p_values,
|
284
|
+
factor_names=metadata.factor_names,
|
285
|
+
insufficient_data={k: dict(v) for k, v in insufficient_data.items()},
|
286
|
+
)
|
@@ -38,8 +38,7 @@ def divergence_mst(data: NDArray[np.float64], labels: NDArray[np.int_]) -> int:
|
|
38
38
|
"""
|
39
39
|
mst = minimum_spanning_tree(data).toarray()
|
40
40
|
edgelist = np.transpose(np.nonzero(mst))
|
41
|
-
|
42
|
-
return errors
|
41
|
+
return np.sum(labels[edgelist[:, 0]] != labels[edgelist[:, 1]])
|
43
42
|
|
44
43
|
|
45
44
|
def divergence_fnn(data: NDArray[np.float64], labels: NDArray[np.int_]) -> int:
|
@@ -59,8 +58,7 @@ def divergence_fnn(data: NDArray[np.float64], labels: NDArray[np.int_]) -> int:
|
|
59
58
|
Number of label errors when finding nearest neighbors
|
60
59
|
"""
|
61
60
|
nn_indices = compute_neighbors(data, data)
|
62
|
-
|
63
|
-
return errors
|
61
|
+
return np.sum(np.abs(labels[nn_indices] - labels))
|
64
62
|
|
65
63
|
|
66
64
|
_DIVERGENCE_FN_MAP = {"FNN": divergence_fnn, "MST": divergence_mst}
|