dataeval 0.72.0__py3-none-any.whl → 0.72.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dataeval/__init__.py +4 -4
- dataeval/detectors/__init__.py +4 -3
- dataeval/detectors/drift/__init__.py +10 -11
- dataeval/{_internal/detectors → detectors}/drift/base.py +51 -102
- dataeval/{_internal/detectors → detectors}/drift/cvm.py +9 -8
- dataeval/{_internal/detectors → detectors}/drift/ks.py +11 -10
- dataeval/{_internal/detectors → detectors}/drift/mmd.py +33 -34
- dataeval/{_internal/detectors → detectors}/drift/torch.py +15 -13
- dataeval/{_internal/detectors → detectors}/drift/uncertainty.py +12 -9
- dataeval/detectors/drift/updates.py +61 -0
- dataeval/detectors/linters/__init__.py +3 -3
- dataeval/{_internal/detectors → detectors/linters}/clusterer.py +47 -45
- dataeval/{_internal/detectors → detectors/linters}/duplicates.py +20 -10
- dataeval/{_internal/detectors → detectors/linters}/merged_stats.py +3 -1
- dataeval/{_internal/detectors → detectors/linters}/outliers.py +19 -26
- dataeval/detectors/ood/__init__.py +8 -16
- dataeval/{_internal/detectors → detectors}/ood/ae.py +9 -9
- dataeval/{_internal/detectors → detectors}/ood/aegmm.py +10 -30
- dataeval/{_internal/detectors → detectors}/ood/base.py +27 -21
- dataeval/{_internal/detectors → detectors}/ood/llr.py +27 -23
- dataeval/detectors/ood/metadata_ks_compare.py +99 -0
- dataeval/detectors/ood/metadata_least_likely.py +119 -0
- dataeval/detectors/ood/metadata_ood_mi.py +92 -0
- dataeval/{_internal/detectors → detectors}/ood/vae.py +11 -13
- dataeval/{_internal/detectors → detectors}/ood/vaegmm.py +10 -32
- dataeval/{_internal/interop.py → interop.py} +12 -7
- dataeval/metrics/__init__.py +1 -1
- dataeval/metrics/bias/__init__.py +4 -4
- dataeval/{_internal/metrics → metrics/bias}/balance.py +70 -4
- dataeval/{_internal/metrics → metrics/bias}/coverage.py +10 -8
- dataeval/{_internal/metrics → metrics/bias}/diversity.py +54 -20
- dataeval/metrics/bias/metadata.py +275 -0
- dataeval/{_internal/metrics → metrics/bias}/parity.py +21 -17
- dataeval/metrics/estimators/__init__.py +3 -3
- dataeval/{_internal/metrics → metrics/estimators}/ber.py +31 -28
- dataeval/{_internal/metrics → metrics/estimators}/divergence.py +15 -16
- dataeval/{_internal/metrics → metrics/estimators}/uap.py +8 -6
- dataeval/metrics/stats/__init__.py +7 -7
- dataeval/{_internal/metrics → metrics}/stats/base.py +66 -40
- dataeval/{_internal/metrics → metrics}/stats/boxratiostats.py +19 -15
- dataeval/{_internal/metrics → metrics}/stats/datasetstats.py +19 -17
- dataeval/{_internal/metrics → metrics}/stats/dimensionstats.py +12 -10
- dataeval/metrics/stats/hashstats.py +156 -0
- dataeval/{_internal/metrics → metrics}/stats/labelstats.py +8 -6
- dataeval/{_internal/metrics → metrics}/stats/pixelstats.py +12 -11
- dataeval/{_internal/metrics → metrics}/stats/visualstats.py +14 -13
- dataeval/{_internal/output.py → output.py} +26 -6
- dataeval/utils/__init__.py +8 -4
- dataeval/utils/image.py +71 -0
- dataeval/utils/shared.py +151 -0
- dataeval/utils/split_dataset.py +486 -0
- dataeval/utils/tensorflow/__init__.py +9 -7
- dataeval/{_internal/models/tensorflow → utils/tensorflow/_internal}/autoencoder.py +64 -68
- dataeval/{_internal/models/tensorflow/losses.py → utils/tensorflow/_internal/loss.py} +10 -9
- dataeval/{_internal/models/tensorflow → utils/tensorflow/_internal}/pixelcnn.py +18 -22
- dataeval/{_internal/models/tensorflow → utils/tensorflow/_internal}/trainer.py +3 -1
- dataeval/{_internal/models/tensorflow → utils/tensorflow/_internal}/utils.py +18 -18
- dataeval/utils/tensorflow/loss/__init__.py +6 -2
- dataeval/utils/torch/__init__.py +7 -3
- dataeval/{_internal/models/pytorch → utils/torch}/blocks.py +19 -14
- dataeval/{_internal → utils/torch}/datasets.py +49 -43
- dataeval/utils/torch/models.py +138 -0
- dataeval/{_internal/models/pytorch/autoencoder.py → utils/torch/trainer.py} +12 -141
- dataeval/{_internal → utils/torch}/utils.py +3 -1
- dataeval/workflows/__init__.py +1 -1
- dataeval/{_internal/workflows → workflows}/sufficiency.py +42 -37
- {dataeval-0.72.0.dist-info → dataeval-0.72.2.dist-info}/METADATA +7 -5
- dataeval-0.72.2.dist-info/RECORD +72 -0
- dataeval/_internal/detectors/__init__.py +0 -0
- dataeval/_internal/detectors/drift/__init__.py +0 -0
- dataeval/_internal/detectors/ood/__init__.py +0 -0
- dataeval/_internal/metrics/__init__.py +0 -0
- dataeval/_internal/metrics/stats/hashstats.py +0 -75
- dataeval/_internal/metrics/utils.py +0 -447
- dataeval/_internal/models/__init__.py +0 -0
- dataeval/_internal/models/pytorch/__init__.py +0 -0
- dataeval/_internal/models/pytorch/utils.py +0 -67
- dataeval/_internal/models/tensorflow/__init__.py +0 -0
- dataeval/_internal/workflows/__init__.py +0 -0
- dataeval/detectors/drift/kernels/__init__.py +0 -10
- dataeval/detectors/drift/updates/__init__.py +0 -7
- dataeval/utils/tensorflow/models/__init__.py +0 -9
- dataeval/utils/tensorflow/recon/__init__.py +0 -3
- dataeval/utils/torch/datasets/__init__.py +0 -12
- dataeval/utils/torch/models/__init__.py +0 -11
- dataeval/utils/torch/trainer/__init__.py +0 -7
- dataeval-0.72.0.dist-info/RECORD +0 -80
- /dataeval/{_internal/models/tensorflow → utils/tensorflow/_internal}/gmm.py +0 -0
- {dataeval-0.72.0.dist-info → dataeval-0.72.2.dist-info}/LICENSE.txt +0 -0
- {dataeval-0.72.0.dist-info → dataeval-0.72.2.dist-info}/WHEEL +0 -0
@@ -1,40 +1,77 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
|
+
__all__ = ["DiversityOutput", "diversity"]
|
4
|
+
|
3
5
|
from dataclasses import dataclass
|
4
|
-
from typing import Literal, Mapping
|
6
|
+
from typing import Any, Literal, Mapping
|
5
7
|
|
6
8
|
import numpy as np
|
7
9
|
from numpy.typing import ArrayLike, NDArray
|
8
10
|
|
9
|
-
from dataeval.
|
10
|
-
from dataeval.
|
11
|
+
from dataeval.metrics.bias.metadata import entropy, get_counts, get_num_bins, heatmap, preprocess_metadata
|
12
|
+
from dataeval.output import OutputMetadata, set_metadata
|
13
|
+
from dataeval.utils.shared import get_method
|
11
14
|
|
12
15
|
|
13
16
|
@dataclass(frozen=True)
|
14
17
|
class DiversityOutput(OutputMetadata):
|
15
18
|
"""
|
16
|
-
Output class for :func:`diversity` bias metric
|
19
|
+
Output class for :func:`diversity` :term:`bias<Bias>` metric
|
17
20
|
|
18
21
|
Attributes
|
19
22
|
----------
|
20
23
|
diversity_index : NDArray[np.float64]
|
21
|
-
Diversity index for classes and factors
|
24
|
+
:term:`Diversity` index for classes and factors
|
22
25
|
classwise : NDArray[np.float64]
|
23
26
|
Classwise diversity index [n_class x n_factor]
|
27
|
+
class_list: NDArray[np.int64]
|
28
|
+
Class labels for each value in the dataset
|
29
|
+
metadata_names: list[str]
|
30
|
+
Names of each metadata factor
|
24
31
|
"""
|
25
32
|
|
26
33
|
diversity_index: NDArray[np.float64]
|
27
34
|
classwise: NDArray[np.float64]
|
28
35
|
|
36
|
+
class_list: NDArray[np.int64]
|
37
|
+
metadata_names: list[str]
|
38
|
+
|
39
|
+
method: Literal["shannon", "simpson"]
|
40
|
+
|
41
|
+
def plot(self, row_labels: NDArray[Any] | None = None, col_labels: NDArray[Any] | None = None) -> None:
|
42
|
+
"""
|
43
|
+
Plot a heatmap of diversity information
|
44
|
+
|
45
|
+
Parameters
|
46
|
+
----------
|
47
|
+
row_labels: NDArray | None, default None
|
48
|
+
Array containing the labels for rows in the histogram
|
49
|
+
col_labels: NDArray | None, default None
|
50
|
+
Array containing the labels for columns in the histogram
|
51
|
+
"""
|
52
|
+
if row_labels is None:
|
53
|
+
row_labels = np.unique(self.class_list)
|
54
|
+
if col_labels is None:
|
55
|
+
col_labels = np.array(self.metadata_names)
|
56
|
+
|
57
|
+
heatmap(
|
58
|
+
self.classwise,
|
59
|
+
row_labels,
|
60
|
+
col_labels,
|
61
|
+
xlabel="Factors",
|
62
|
+
ylabel="Class",
|
63
|
+
cbarlabel=f"Normalized {self.method.title()} Index",
|
64
|
+
)
|
65
|
+
|
29
66
|
|
30
67
|
def diversity_shannon(
|
31
|
-
data: NDArray,
|
68
|
+
data: NDArray[Any],
|
32
69
|
names: list[str],
|
33
70
|
is_categorical: list[bool],
|
34
71
|
subset_mask: NDArray[np.bool_] | None = None,
|
35
|
-
) -> NDArray:
|
72
|
+
) -> NDArray[np.float64]:
|
36
73
|
"""
|
37
|
-
Compute diversity for discrete/categorical variables and, through standard
|
74
|
+
Compute :term:`diversity<Diversity>` for discrete/categorical variables and, through standard
|
38
75
|
histogram binning, for continuous variables.
|
39
76
|
|
40
77
|
We define diversity as a normalized form of the Shannon entropy.
|
@@ -79,13 +116,13 @@ def diversity_shannon(
|
|
79
116
|
|
80
117
|
|
81
118
|
def diversity_simpson(
|
82
|
-
data: NDArray,
|
119
|
+
data: NDArray[Any],
|
83
120
|
names: list[str],
|
84
121
|
is_categorical: list[bool],
|
85
122
|
subset_mask: NDArray[np.bool_] | None = None,
|
86
|
-
) -> NDArray:
|
123
|
+
) -> NDArray[np.float64]:
|
87
124
|
"""
|
88
|
-
Compute diversity for discrete/categorical variables and, through standard
|
125
|
+
Compute :term:`diversity<Diversity>` for discrete/categorical variables and, through standard
|
89
126
|
histogram binning, for continuous variables.
|
90
127
|
|
91
128
|
We define diversity as the inverse Simpson diversity index linearly rescaled to the unit interval.
|
@@ -139,16 +176,13 @@ def diversity_simpson(
|
|
139
176
|
return ev_index
|
140
177
|
|
141
178
|
|
142
|
-
|
143
|
-
|
144
|
-
|
145
|
-
@set_metadata("dataeval.metrics")
|
179
|
+
@set_metadata()
|
146
180
|
def diversity(
|
147
181
|
class_labels: ArrayLike, metadata: Mapping[str, ArrayLike], method: Literal["shannon", "simpson"] = "simpson"
|
148
182
|
) -> DiversityOutput:
|
149
183
|
"""
|
150
|
-
Compute diversity and classwise diversity for discrete/categorical variables and,
|
151
|
-
histogram binning, for continuous variables.
|
184
|
+
Compute :term:`diversity<Diversity>` and classwise diversity for discrete/categorical variables and,
|
185
|
+
through standard histogram binning, for continuous variables.
|
152
186
|
|
153
187
|
We define diversity as a normalized form of the inverse Simpson diversity index.
|
154
188
|
|
@@ -202,12 +236,12 @@ def diversity(
|
|
202
236
|
--------
|
203
237
|
numpy.histogram
|
204
238
|
"""
|
205
|
-
diversity_fn = get_method(
|
239
|
+
diversity_fn = get_method({"simpson": diversity_simpson, "shannon": diversity_shannon}, method)
|
206
240
|
data, names, is_categorical = preprocess_metadata(class_labels, metadata)
|
207
241
|
diversity_index = diversity_fn(data, names, is_categorical, None).astype(np.float64)
|
208
242
|
|
209
243
|
class_idx = names.index("class_label")
|
210
|
-
class_lbl = data[:, class_idx]
|
244
|
+
class_lbl = np.array(data[:, class_idx], dtype=int)
|
211
245
|
|
212
246
|
u_classes = np.unique(class_lbl)
|
213
247
|
num_factors = len(names)
|
@@ -218,4 +252,4 @@ def diversity(
|
|
218
252
|
diversity[idx, :] = diversity_fn(data, names, is_categorical, subset_mask)
|
219
253
|
div_no_class = np.concatenate((diversity[:, :class_idx], diversity[:, (class_idx + 1) :]), axis=1)
|
220
254
|
|
221
|
-
return DiversityOutput(diversity_index, div_no_class)
|
255
|
+
return DiversityOutput(diversity_index, div_no_class, class_lbl, list(metadata.keys()), method)
|
@@ -0,0 +1,275 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
__all__ = []
|
4
|
+
|
5
|
+
from typing import Any, Mapping
|
6
|
+
|
7
|
+
import numpy as np
|
8
|
+
from numpy.typing import ArrayLike, NDArray
|
9
|
+
from scipy.stats import entropy as sp_entropy
|
10
|
+
|
11
|
+
from dataeval.interop import to_numpy
|
12
|
+
|
13
|
+
|
14
|
+
def get_counts(
|
15
|
+
data: NDArray[np.int_], names: list[str], is_categorical: list[bool], subset_mask: NDArray[np.bool_] | None = None
|
16
|
+
) -> tuple[dict[str, NDArray[np.int_]], dict[str, NDArray[np.int_]]]:
|
17
|
+
"""
|
18
|
+
Initialize dictionary of histogram counts --- treat categorical values
|
19
|
+
as histogram bins.
|
20
|
+
|
21
|
+
Parameters
|
22
|
+
----------
|
23
|
+
subset_mask: NDArray[np.bool_] | None
|
24
|
+
Boolean mask of samples to bin (e.g. when computing per class). True -> include in histogram counts
|
25
|
+
|
26
|
+
Returns
|
27
|
+
-------
|
28
|
+
counts: Dict
|
29
|
+
histogram counts per metadata factor in `factors`. Each
|
30
|
+
factor will have a different number of bins. Counts get reused
|
31
|
+
across metrics, so hist_counts are cached but only if computed
|
32
|
+
globally, i.e. without masked samples.
|
33
|
+
"""
|
34
|
+
|
35
|
+
hist_counts, hist_bins = {}, {}
|
36
|
+
# np.where needed to satisfy linter
|
37
|
+
mask = np.where(subset_mask if subset_mask is not None else np.ones(data.shape[0], dtype=bool))
|
38
|
+
|
39
|
+
for cdx, fn in enumerate(names):
|
40
|
+
# linter doesn't like double indexing
|
41
|
+
col_data = data[mask, cdx].squeeze()
|
42
|
+
if is_categorical[cdx]:
|
43
|
+
# if discrete, use unique values as bins
|
44
|
+
bins, cnts = np.unique(col_data, return_counts=True)
|
45
|
+
else:
|
46
|
+
bins = hist_bins.get(fn, "auto")
|
47
|
+
cnts, bins = np.histogram(col_data, bins=bins, density=True)
|
48
|
+
|
49
|
+
hist_counts[fn] = cnts
|
50
|
+
hist_bins[fn] = bins
|
51
|
+
|
52
|
+
return hist_counts, hist_bins
|
53
|
+
|
54
|
+
|
55
|
+
def entropy(
|
56
|
+
data: NDArray[Any],
|
57
|
+
names: list[str],
|
58
|
+
is_categorical: list[bool],
|
59
|
+
normalized: bool = False,
|
60
|
+
subset_mask: NDArray[np.bool_] | None = None,
|
61
|
+
) -> NDArray[np.float64]:
|
62
|
+
"""
|
63
|
+
Meant for use with :term:`bias<Bias>` metrics, :term:`balance<Balance>`, :term:`diversity<Diversity>`,
|
64
|
+
ClasswiseBalance, and Classwise Diversity.
|
65
|
+
|
66
|
+
Compute entropy for discrete/categorical variables and for continuous variables through standard
|
67
|
+
histogram binning.
|
68
|
+
|
69
|
+
Parameters
|
70
|
+
----------
|
71
|
+
normalized: bool
|
72
|
+
Flag that determines whether or not to normalize entropy by log(num_bins)
|
73
|
+
subset_mask: NDArray[np.bool_] | None
|
74
|
+
Boolean mask of samples to bin (e.g. when computing per class). True -> include in histogram counts
|
75
|
+
|
76
|
+
Note
|
77
|
+
----
|
78
|
+
For continuous variables, histogram bins are chosen automatically. See
|
79
|
+
numpy.histogram for details.
|
80
|
+
|
81
|
+
Returns
|
82
|
+
-------
|
83
|
+
ent: NDArray[np.float64]
|
84
|
+
Entropy estimate per column of X
|
85
|
+
|
86
|
+
See Also
|
87
|
+
--------
|
88
|
+
numpy.histogram
|
89
|
+
scipy.stats.entropy
|
90
|
+
"""
|
91
|
+
|
92
|
+
num_factors = len(names)
|
93
|
+
hist_counts, _ = get_counts(data, names, is_categorical, subset_mask)
|
94
|
+
|
95
|
+
ev_index = np.empty(num_factors)
|
96
|
+
for col, cnts in enumerate(hist_counts.values()):
|
97
|
+
# entropy in nats, normalizes counts
|
98
|
+
ev_index[col] = sp_entropy(cnts)
|
99
|
+
if normalized:
|
100
|
+
if len(cnts) == 1:
|
101
|
+
# log(0)
|
102
|
+
ev_index[col] = 0
|
103
|
+
else:
|
104
|
+
ev_index[col] /= np.log(len(cnts))
|
105
|
+
return ev_index
|
106
|
+
|
107
|
+
|
108
|
+
def get_num_bins(
|
109
|
+
data: NDArray[Any], names: list[str], is_categorical: list[bool], subset_mask: NDArray[np.bool_] | None = None
|
110
|
+
) -> NDArray[np.float64]:
|
111
|
+
"""
|
112
|
+
Number of bins or unique values for each metadata factor, used to
|
113
|
+
normalize entropy/:term:`diversity<Diversity>`.
|
114
|
+
|
115
|
+
Parameters
|
116
|
+
----------
|
117
|
+
subset_mask: NDArray[np.bool_] | None
|
118
|
+
Boolean mask of samples to bin (e.g. when computing per class). True -> include in histogram counts
|
119
|
+
|
120
|
+
Returns
|
121
|
+
-------
|
122
|
+
NDArray[np.float64]
|
123
|
+
"""
|
124
|
+
# likely cached
|
125
|
+
hist_counts, _ = get_counts(data, names, is_categorical, subset_mask)
|
126
|
+
num_bins = np.empty(len(hist_counts))
|
127
|
+
for idx, cnts in enumerate(hist_counts.values()):
|
128
|
+
num_bins[idx] = len(cnts)
|
129
|
+
|
130
|
+
return num_bins
|
131
|
+
|
132
|
+
|
133
|
+
def infer_categorical(arr: NDArray[Any], threshold: float = 0.2) -> NDArray[Any]:
|
134
|
+
"""
|
135
|
+
Compute fraction of feature values that are unique --- intended to be used
|
136
|
+
for inferring whether variables are categorical.
|
137
|
+
"""
|
138
|
+
if arr.ndim == 1:
|
139
|
+
arr = np.expand_dims(arr, axis=1)
|
140
|
+
num_samples = arr.shape[0]
|
141
|
+
pct_unique = np.empty(arr.shape[1])
|
142
|
+
for col in range(arr.shape[1]): # type: ignore
|
143
|
+
uvals = np.unique(arr[:, col], axis=0)
|
144
|
+
pct_unique[col] = len(uvals) / num_samples
|
145
|
+
return pct_unique < threshold
|
146
|
+
|
147
|
+
|
148
|
+
def preprocess_metadata(
|
149
|
+
class_labels: ArrayLike, metadata: Mapping[str, ArrayLike], cat_thresh: float = 0.2
|
150
|
+
) -> tuple[NDArray[Any], list[str], list[bool]]:
|
151
|
+
# convert class_labels and dict of lists to matrix of metadata values
|
152
|
+
preprocessed_metadata = {"class_label": np.asarray(class_labels, dtype=int)}
|
153
|
+
|
154
|
+
# map columns of dict that are not numeric (e.g. string) to numeric values
|
155
|
+
# that mutual information and diversity functions can accommodate. Each
|
156
|
+
# unique string receives a unique integer value.
|
157
|
+
for k, v in metadata.items():
|
158
|
+
# if not numeric
|
159
|
+
v = to_numpy(v)
|
160
|
+
if not np.issubdtype(v.dtype, np.number):
|
161
|
+
_, mapped_vals = np.unique(v, return_inverse=True)
|
162
|
+
preprocessed_metadata[k] = mapped_vals
|
163
|
+
else:
|
164
|
+
preprocessed_metadata[k] = v
|
165
|
+
|
166
|
+
data = np.stack(list(preprocessed_metadata.values()), axis=-1)
|
167
|
+
names = list(preprocessed_metadata.keys())
|
168
|
+
is_categorical = [infer_categorical(preprocessed_metadata[var], cat_thresh)[0] for var in names]
|
169
|
+
|
170
|
+
return data, names, is_categorical
|
171
|
+
|
172
|
+
|
173
|
+
def heatmap(
|
174
|
+
data: NDArray[Any],
|
175
|
+
row_labels: NDArray[Any],
|
176
|
+
col_labels: NDArray[Any],
|
177
|
+
xlabel: str = "",
|
178
|
+
ylabel: str = "",
|
179
|
+
cbarlabel: str = "",
|
180
|
+
) -> None:
|
181
|
+
"""
|
182
|
+
Plots a formatted heatmap
|
183
|
+
|
184
|
+
Parameters
|
185
|
+
----------
|
186
|
+
data: NDArray
|
187
|
+
Array containing numerical values for factors to plot
|
188
|
+
row_labels: NDArray
|
189
|
+
Array containing the labels for rows in the histogram
|
190
|
+
col_labels: NDArray
|
191
|
+
Array containing the labels for columns in the histogram
|
192
|
+
xlabel: str, default ""
|
193
|
+
X-axis label
|
194
|
+
ylabel: str, default ""
|
195
|
+
Y-axis label
|
196
|
+
cbarlabel: str, default ""
|
197
|
+
Label for the colorbar
|
198
|
+
|
199
|
+
"""
|
200
|
+
import matplotlib
|
201
|
+
import matplotlib.pyplot as plt
|
202
|
+
|
203
|
+
fig, ax = plt.subplots(figsize=(10, 10))
|
204
|
+
|
205
|
+
# Plot the heatmap
|
206
|
+
im = ax.imshow(data, vmin=0, vmax=1.0)
|
207
|
+
|
208
|
+
# Create colorbar
|
209
|
+
cbar = fig.colorbar(im, shrink=0.5)
|
210
|
+
cbar.set_ticks([0.0, 0.25, 0.5, 0.75, 1.0])
|
211
|
+
cbar.set_ticklabels(["0.0", "0.25", "0.5", "0.75", "1.0"])
|
212
|
+
cbar.set_label(cbarlabel, loc="center")
|
213
|
+
|
214
|
+
# Show all ticks and label them with the respective list entries.
|
215
|
+
ax.set_xticks(np.arange(data.shape[1]), labels=col_labels)
|
216
|
+
ax.set_yticks(np.arange(data.shape[0]), labels=row_labels)
|
217
|
+
|
218
|
+
ax.tick_params(top=False, bottom=True, labeltop=False, labelbottom=True)
|
219
|
+
# Rotate the tick labels and set their alignment.
|
220
|
+
plt.setp(ax.get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor")
|
221
|
+
|
222
|
+
# Turn spines off and create white grid.
|
223
|
+
ax.spines[:].set_visible(False)
|
224
|
+
|
225
|
+
ax.set_xticks(np.arange(data.shape[1] + 1) - 0.5, minor=True)
|
226
|
+
ax.set_yticks(np.arange(data.shape[0] + 1) - 0.5, minor=True)
|
227
|
+
ax.grid(which="minor", color="w", linestyle="-", linewidth=3)
|
228
|
+
ax.tick_params(which="minor", bottom=False, left=False)
|
229
|
+
|
230
|
+
if xlabel:
|
231
|
+
ax.set_xlabel(xlabel)
|
232
|
+
if ylabel:
|
233
|
+
ax.set_ylabel(ylabel)
|
234
|
+
|
235
|
+
valfmt = matplotlib.ticker.FuncFormatter(format_text) # type: ignore
|
236
|
+
|
237
|
+
# Normalize the threshold to the images color range.
|
238
|
+
threshold = im.norm(1.0) / 2.0
|
239
|
+
|
240
|
+
# Set default alignment to center, but allow it to be
|
241
|
+
# overwritten by textkw.
|
242
|
+
kw = {"horizontalalignment": "center", "verticalalignment": "center"}
|
243
|
+
|
244
|
+
# Loop over the data and create a `Text` for each "pixel".
|
245
|
+
# Change the text's color depending on the data.
|
246
|
+
textcolors = ("white", "black")
|
247
|
+
texts = []
|
248
|
+
for i in range(data.shape[0]):
|
249
|
+
for j in range(data.shape[1]):
|
250
|
+
kw.update(color=textcolors[int(im.norm(data[i, j]) > threshold)])
|
251
|
+
text = im.axes.text(j, i, valfmt(data[i, j], None), **kw) # type: ignore
|
252
|
+
texts.append(text)
|
253
|
+
|
254
|
+
fig.tight_layout()
|
255
|
+
plt.show()
|
256
|
+
|
257
|
+
|
258
|
+
# Function to define how the text is displayed in the heatmap
|
259
|
+
def format_text(*args: str) -> str:
|
260
|
+
"""
|
261
|
+
Helper function to format text for heatmap()
|
262
|
+
|
263
|
+
Parameters
|
264
|
+
----------
|
265
|
+
*args: Tuple (str, str)
|
266
|
+
Text to be formatted. Second element is ignored, but is a
|
267
|
+
mandatory pass-through argument as per matplotlib.ticket.FuncFormatter
|
268
|
+
|
269
|
+
Returns
|
270
|
+
-------
|
271
|
+
str
|
272
|
+
Formatted text
|
273
|
+
"""
|
274
|
+
x = args[0]
|
275
|
+
return f"{x:.2f}".replace("0.00", "0").replace("0.", ".").replace("nan", "")
|
@@ -1,15 +1,17 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
|
+
__all__ = ["ParityOutput", "parity", "label_parity"]
|
4
|
+
|
3
5
|
import warnings
|
4
6
|
from dataclasses import dataclass
|
5
|
-
from typing import Generic, Mapping, TypeVar
|
7
|
+
from typing import Any, Generic, Mapping, TypeVar
|
6
8
|
|
7
9
|
import numpy as np
|
8
10
|
from numpy.typing import ArrayLike, NDArray
|
9
11
|
from scipy.stats import chi2_contingency, chisquare
|
10
12
|
|
11
|
-
from dataeval.
|
12
|
-
from dataeval.
|
13
|
+
from dataeval.interop import to_numpy
|
14
|
+
from dataeval.output import OutputMetadata, set_metadata
|
13
15
|
|
14
16
|
TData = TypeVar("TData", np.float64, NDArray[np.float64])
|
15
17
|
|
@@ -17,7 +19,7 @@ TData = TypeVar("TData", np.float64, NDArray[np.float64])
|
|
17
19
|
@dataclass(frozen=True)
|
18
20
|
class ParityOutput(Generic[TData], OutputMetadata):
|
19
21
|
"""
|
20
|
-
Output class for :func:`parity` and :func:`label_parity` bias metrics
|
22
|
+
Output class for :func:`parity` and :func:`label_parity` :term:`bias<Bias>` metrics
|
21
23
|
|
22
24
|
Attributes
|
23
25
|
----------
|
@@ -31,7 +33,7 @@ class ParityOutput(Generic[TData], OutputMetadata):
|
|
31
33
|
p_value: TData
|
32
34
|
|
33
35
|
|
34
|
-
def digitize_factor_bins(continuous_values: NDArray, bins: int, factor_name: str) -> NDArray:
|
36
|
+
def digitize_factor_bins(continuous_values: NDArray[Any], bins: int, factor_name: str) -> NDArray[np.intp]:
|
35
37
|
"""
|
36
38
|
Digitizes a list of values into a given number of bins.
|
37
39
|
|
@@ -64,8 +66,8 @@ def digitize_factor_bins(continuous_values: NDArray, bins: int, factor_name: str
|
|
64
66
|
|
65
67
|
|
66
68
|
def format_discretize_factors(
|
67
|
-
data_factors: Mapping[str, NDArray], continuous_factor_bincounts: Mapping[str, int]
|
68
|
-
) -> dict[str, NDArray]:
|
69
|
+
data_factors: Mapping[str, NDArray[Any]], continuous_factor_bincounts: Mapping[str, int]
|
70
|
+
) -> dict[str, NDArray[Any]]:
|
69
71
|
"""
|
70
72
|
Sets up the internal list of metadata factors.
|
71
73
|
|
@@ -115,7 +117,7 @@ def format_discretize_factors(
|
|
115
117
|
return metadata_factors
|
116
118
|
|
117
119
|
|
118
|
-
def normalize_expected_dist(expected_dist: NDArray, observed_dist: NDArray) -> NDArray:
|
120
|
+
def normalize_expected_dist(expected_dist: NDArray[Any], observed_dist: NDArray[Any]) -> NDArray[Any]:
|
119
121
|
"""
|
120
122
|
Normalize the expected label distribution to match the total number of labels in the observed distribution.
|
121
123
|
|
@@ -162,7 +164,7 @@ def normalize_expected_dist(expected_dist: NDArray, observed_dist: NDArray) -> N
|
|
162
164
|
return expected_dist
|
163
165
|
|
164
166
|
|
165
|
-
def validate_dist(label_dist: NDArray, label_name: str):
|
167
|
+
def validate_dist(label_dist: NDArray[Any], label_name: str) -> None:
|
166
168
|
"""
|
167
169
|
Verifies that the given label distribution has labels and checks if
|
168
170
|
any labels have frequencies less than 5.
|
@@ -191,14 +193,15 @@ def validate_dist(label_dist: NDArray, label_name: str):
|
|
191
193
|
)
|
192
194
|
|
193
195
|
|
194
|
-
@set_metadata(
|
196
|
+
@set_metadata()
|
195
197
|
def label_parity(
|
196
198
|
expected_labels: ArrayLike,
|
197
199
|
observed_labels: ArrayLike,
|
198
200
|
num_classes: int | None = None,
|
199
201
|
) -> ParityOutput[np.float64]:
|
200
202
|
"""
|
201
|
-
Calculate the chi-square statistic to assess the parity between expected and
|
203
|
+
Calculate the chi-square statistic to assess the :term:`parity<Parity>` between expected and
|
204
|
+
observed label distributions.
|
202
205
|
|
203
206
|
This function computes the frequency distribution of classes in both expected and observed labels, normalizes
|
204
207
|
the expected distribution to match the total number of observed labels, and then calculates the chi-square
|
@@ -217,7 +220,7 @@ def label_parity(
|
|
217
220
|
Returns
|
218
221
|
-------
|
219
222
|
ParityOutput[np.float64]
|
220
|
-
chi-squared score and
|
223
|
+
chi-squared score and :term`P-Value` of the test
|
221
224
|
|
222
225
|
Raises
|
223
226
|
------
|
@@ -231,8 +234,8 @@ def label_parity(
|
|
231
234
|
- Providing ``num_classes`` can be helpful if there are classes with zero instances in one of the distributions.
|
232
235
|
- The function first validates the observed distribution and normalizes the expected distribution so that it
|
233
236
|
has the same total number of labels as the observed distribution.
|
234
|
-
- It then performs a
|
235
|
-
the observed and expected label distributions.
|
237
|
+
- It then performs a :term:`Chi-Square Test of Independence` to determine if there is a statistically significant
|
238
|
+
difference between the observed and expected label distributions.
|
236
239
|
- This function acts as an interface to the scipy.stats.chisquare method, which is documented at
|
237
240
|
https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.chisquare.html
|
238
241
|
|
@@ -278,14 +281,15 @@ def label_parity(
|
|
278
281
|
return ParityOutput(cs, p)
|
279
282
|
|
280
283
|
|
281
|
-
@set_metadata(
|
284
|
+
@set_metadata()
|
282
285
|
def parity(
|
283
286
|
class_labels: ArrayLike,
|
284
287
|
data_factors: Mapping[str, ArrayLike],
|
285
288
|
continuous_factor_bincounts: Mapping[str, int] | None = None,
|
286
289
|
) -> ParityOutput[NDArray[np.float64]]:
|
287
290
|
"""
|
288
|
-
Calculate chi-square statistics to assess the relationship between multiple factors
|
291
|
+
Calculate chi-square statistics to assess the relationship between multiple factors
|
292
|
+
and class labels.
|
289
293
|
|
290
294
|
This function computes the chi-square statistic for each metadata factor to determine if there is
|
291
295
|
a significant relationship between the factor values and class labels. The function handles both categorical
|
@@ -308,7 +312,7 @@ def parity(
|
|
308
312
|
-------
|
309
313
|
ParityOutput[NDArray[np.float64]]
|
310
314
|
Arrays of length (num_factors) whose (i)th element corresponds to the
|
311
|
-
chi-square score and p-value for the relationship between factor i and
|
315
|
+
chi-square score and :term:`p-value<P-Value>` for the relationship between factor i and
|
312
316
|
the class labels in the dataset.
|
313
317
|
|
314
318
|
Raises
|
@@ -2,8 +2,8 @@
|
|
2
2
|
Estimators calculate performance bounds and the statistical distance between datasets.
|
3
3
|
"""
|
4
4
|
|
5
|
-
from dataeval.
|
6
|
-
from dataeval.
|
7
|
-
from dataeval.
|
5
|
+
from dataeval.metrics.estimators.ber import BEROutput, ber
|
6
|
+
from dataeval.metrics.estimators.divergence import DivergenceOutput, divergence
|
7
|
+
from dataeval.metrics.estimators.uap import UAPOutput, uap
|
8
8
|
|
9
9
|
__all__ = ["ber", "divergence", "uap", "BEROutput", "DivergenceOutput", "UAPOutput"]
|