dataeval 0.72.1__py3-none-any.whl → 0.72.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dataeval/__init__.py +4 -4
- dataeval/detectors/__init__.py +4 -3
- dataeval/detectors/drift/__init__.py +9 -10
- dataeval/{_internal/detectors → detectors}/drift/base.py +39 -91
- dataeval/{_internal/detectors → detectors}/drift/cvm.py +4 -3
- dataeval/{_internal/detectors → detectors}/drift/ks.py +4 -3
- dataeval/{_internal/detectors → detectors}/drift/mmd.py +23 -25
- dataeval/{_internal/detectors → detectors}/drift/torch.py +13 -11
- dataeval/{_internal/detectors → detectors}/drift/uncertainty.py +7 -5
- dataeval/detectors/drift/updates.py +61 -0
- dataeval/detectors/linters/__init__.py +3 -3
- dataeval/{_internal/detectors → detectors/linters}/clusterer.py +41 -39
- dataeval/{_internal/detectors → detectors/linters}/duplicates.py +19 -9
- dataeval/{_internal/detectors → detectors/linters}/merged_stats.py +3 -1
- dataeval/{_internal/detectors → detectors/linters}/outliers.py +14 -21
- dataeval/detectors/ood/__init__.py +6 -6
- dataeval/{_internal/detectors → detectors}/ood/ae.py +7 -7
- dataeval/{_internal/detectors → detectors}/ood/aegmm.py +9 -29
- dataeval/{_internal/detectors → detectors}/ood/base.py +24 -18
- dataeval/{_internal/detectors → detectors}/ood/llr.py +24 -20
- dataeval/detectors/ood/metadata_ks_compare.py +99 -0
- dataeval/detectors/ood/metadata_least_likely.py +119 -0
- dataeval/detectors/ood/metadata_ood_mi.py +92 -0
- dataeval/{_internal/detectors → detectors}/ood/vae.py +10 -12
- dataeval/{_internal/detectors → detectors}/ood/vaegmm.py +10 -32
- dataeval/{_internal/interop.py → interop.py} +12 -7
- dataeval/metrics/__init__.py +1 -1
- dataeval/metrics/bias/__init__.py +4 -4
- dataeval/{_internal/metrics → metrics/bias}/balance.py +75 -9
- dataeval/{_internal/metrics → metrics/bias}/coverage.py +6 -4
- dataeval/{_internal/metrics → metrics/bias}/diversity.py +48 -14
- dataeval/metrics/bias/metadata.py +275 -0
- dataeval/{_internal/metrics → metrics/bias}/parity.py +12 -10
- dataeval/metrics/estimators/__init__.py +3 -3
- dataeval/{_internal/metrics → metrics/estimators}/ber.py +25 -22
- dataeval/{_internal/metrics → metrics/estimators}/divergence.py +11 -12
- dataeval/{_internal/metrics → metrics/estimators}/uap.py +5 -3
- dataeval/metrics/stats/__init__.py +7 -7
- dataeval/{_internal/metrics → metrics}/stats/base.py +59 -35
- dataeval/{_internal/metrics → metrics}/stats/boxratiostats.py +18 -14
- dataeval/{_internal/metrics → metrics}/stats/datasetstats.py +18 -16
- dataeval/{_internal/metrics → metrics}/stats/dimensionstats.py +9 -7
- dataeval/metrics/stats/hashstats.py +156 -0
- dataeval/{_internal/metrics → metrics}/stats/labelstats.py +5 -3
- dataeval/{_internal/metrics → metrics}/stats/pixelstats.py +9 -8
- dataeval/{_internal/metrics → metrics}/stats/visualstats.py +10 -9
- dataeval/{_internal/output.py → output.py} +26 -6
- dataeval/utils/__init__.py +7 -3
- dataeval/utils/image.py +71 -0
- dataeval/utils/shared.py +151 -0
- dataeval/{_internal → utils}/split_dataset.py +98 -33
- dataeval/utils/tensorflow/__init__.py +7 -6
- dataeval/{_internal/models/tensorflow → utils/tensorflow/_internal}/autoencoder.py +60 -64
- dataeval/{_internal/models/tensorflow/losses.py → utils/tensorflow/_internal/loss.py} +9 -8
- dataeval/{_internal/models/tensorflow → utils/tensorflow/_internal}/pixelcnn.py +16 -20
- dataeval/{_internal/models/tensorflow → utils/tensorflow/_internal}/trainer.py +3 -1
- dataeval/{_internal/models/tensorflow → utils/tensorflow/_internal}/utils.py +17 -17
- dataeval/utils/tensorflow/loss/__init__.py +6 -2
- dataeval/utils/torch/__init__.py +7 -3
- dataeval/{_internal/models/pytorch → utils/torch}/blocks.py +19 -14
- dataeval/{_internal → utils/torch}/datasets.py +48 -42
- dataeval/utils/torch/models.py +138 -0
- dataeval/{_internal/models/pytorch/autoencoder.py → utils/torch/trainer.py} +7 -136
- dataeval/{_internal → utils/torch}/utils.py +3 -1
- dataeval/workflows/__init__.py +1 -1
- dataeval/{_internal/workflows → workflows}/sufficiency.py +39 -34
- {dataeval-0.72.1.dist-info → dataeval-0.72.2.dist-info}/METADATA +2 -1
- dataeval-0.72.2.dist-info/RECORD +72 -0
- dataeval/_internal/detectors/__init__.py +0 -0
- dataeval/_internal/detectors/drift/__init__.py +0 -0
- dataeval/_internal/detectors/ood/__init__.py +0 -0
- dataeval/_internal/metrics/__init__.py +0 -0
- dataeval/_internal/metrics/stats/hashstats.py +0 -75
- dataeval/_internal/metrics/utils.py +0 -447
- dataeval/_internal/models/__init__.py +0 -0
- dataeval/_internal/models/pytorch/__init__.py +0 -0
- dataeval/_internal/models/pytorch/utils.py +0 -67
- dataeval/_internal/models/tensorflow/__init__.py +0 -0
- dataeval/_internal/workflows/__init__.py +0 -0
- dataeval/detectors/drift/kernels/__init__.py +0 -10
- dataeval/detectors/drift/updates/__init__.py +0 -8
- dataeval/utils/tensorflow/models/__init__.py +0 -9
- dataeval/utils/tensorflow/recon/__init__.py +0 -3
- dataeval/utils/torch/datasets/__init__.py +0 -12
- dataeval/utils/torch/models/__init__.py +0 -11
- dataeval/utils/torch/trainer/__init__.py +0 -7
- dataeval-0.72.1.dist-info/RECORD +0 -81
- /dataeval/{_internal/models/tensorflow → utils/tensorflow/_internal}/gmm.py +0 -0
- {dataeval-0.72.1.dist-info → dataeval-0.72.2.dist-info}/LICENSE.txt +0 -0
- {dataeval-0.72.1.dist-info → dataeval-0.72.2.dist-info}/WHEEL +0 -0
@@ -1,36 +1,99 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
|
+
__all__ = ["BalanceOutput", "balance"]
|
4
|
+
|
3
5
|
import warnings
|
4
6
|
from dataclasses import dataclass
|
5
|
-
from typing import Mapping
|
7
|
+
from typing import Any, Mapping
|
6
8
|
|
7
9
|
import numpy as np
|
8
10
|
from numpy.typing import ArrayLike, NDArray
|
9
11
|
from sklearn.feature_selection import mutual_info_classif, mutual_info_regression
|
10
12
|
|
11
|
-
from dataeval.
|
12
|
-
from dataeval.
|
13
|
+
from dataeval.metrics.bias.metadata import entropy, heatmap, preprocess_metadata
|
14
|
+
from dataeval.output import OutputMetadata, set_metadata
|
13
15
|
|
14
16
|
|
15
17
|
@dataclass(frozen=True)
|
16
18
|
class BalanceOutput(OutputMetadata):
|
17
19
|
"""
|
18
|
-
Output class for :func:`balance`
|
20
|
+
Output class for :func:`balance` bias metric
|
19
21
|
|
20
22
|
Attributes
|
21
23
|
----------
|
22
24
|
balance : NDArray[np.float64]
|
23
|
-
Estimate of
|
25
|
+
Estimate of mutual information between metadata factors and class label
|
24
26
|
factors : NDArray[np.float64]
|
25
27
|
Estimate of inter/intra-factor mutual information
|
26
28
|
classwise : NDArray[np.float64]
|
27
29
|
Estimate of mutual information between metadata factors and individual class labels
|
30
|
+
class_list: NDArray[np.int64]
|
31
|
+
Class labels for each value in the dataset
|
32
|
+
metadata_names: list[str]
|
33
|
+
Names of each metadata factor
|
28
34
|
"""
|
29
35
|
|
30
36
|
balance: NDArray[np.float64]
|
31
37
|
factors: NDArray[np.float64]
|
32
38
|
classwise: NDArray[np.float64]
|
33
39
|
|
40
|
+
class_list: NDArray[np.int64]
|
41
|
+
metadata_names: list[str]
|
42
|
+
|
43
|
+
def plot(
|
44
|
+
self,
|
45
|
+
row_labels: NDArray[Any] | None = None,
|
46
|
+
col_labels: NDArray[Any] | None = None,
|
47
|
+
plot_classwise: bool = False,
|
48
|
+
) -> None:
|
49
|
+
"""
|
50
|
+
Plot a heatmap of balance information
|
51
|
+
|
52
|
+
Parameters
|
53
|
+
----------
|
54
|
+
row_labels: NDArray | None, default None
|
55
|
+
Array containing the labels for rows in the histogram
|
56
|
+
col_labels: NDArray | None, default None
|
57
|
+
Array containing the labels for columns in the histogram
|
58
|
+
plot_classwise: bool, default False
|
59
|
+
Whether to plot per-class balance instead of global balance
|
60
|
+
|
61
|
+
"""
|
62
|
+
if plot_classwise:
|
63
|
+
if row_labels is None:
|
64
|
+
row_labels = np.unique(self.class_list)
|
65
|
+
if col_labels is None:
|
66
|
+
col_labels = np.concatenate((["class"], self.metadata_names))
|
67
|
+
|
68
|
+
heatmap(
|
69
|
+
self.classwise,
|
70
|
+
row_labels,
|
71
|
+
col_labels,
|
72
|
+
xlabel="Factors",
|
73
|
+
ylabel="Class",
|
74
|
+
cbarlabel="Normalized Mutual Information",
|
75
|
+
)
|
76
|
+
else:
|
77
|
+
data = np.concatenate([self.balance[np.newaxis, 1:], self.factors], axis=0)
|
78
|
+
# Create a mask for the upper triangle of the symmetrical array, ignoring the diagonal
|
79
|
+
mask = np.triu(data + 1, k=0) < 1
|
80
|
+
# Finalize the data for the plot, last row is last factor x last factor so it gets dropped
|
81
|
+
heat_data = np.where(mask, np.nan, data)[:-1]
|
82
|
+
# Creating label array for heat map axes
|
83
|
+
heat_labels = np.concatenate((["class"], self.metadata_names))
|
84
|
+
|
85
|
+
if row_labels is None:
|
86
|
+
row_labels = heat_labels[:-1]
|
87
|
+
if col_labels is None:
|
88
|
+
col_labels = heat_labels[1:]
|
89
|
+
|
90
|
+
heatmap(
|
91
|
+
heat_data,
|
92
|
+
row_labels,
|
93
|
+
col_labels,
|
94
|
+
cbarlabel="Normalized Mutual Information",
|
95
|
+
)
|
96
|
+
|
34
97
|
|
35
98
|
def validate_num_neighbors(num_neighbors: int) -> int:
|
36
99
|
if not isinstance(num_neighbors, (int, float)):
|
@@ -55,7 +118,7 @@ def validate_num_neighbors(num_neighbors: int) -> int:
|
|
55
118
|
@set_metadata("dataeval.metrics")
|
56
119
|
def balance(class_labels: ArrayLike, metadata: Mapping[str, ArrayLike], num_neighbors: int = 5) -> BalanceOutput:
|
57
120
|
"""
|
58
|
-
|
121
|
+
Mutual information (MI) between factors (class label, metadata, label/image properties)
|
59
122
|
|
60
123
|
Parameters
|
61
124
|
----------
|
@@ -70,7 +133,7 @@ def balance(class_labels: ArrayLike, metadata: Mapping[str, ArrayLike], num_neig
|
|
70
133
|
Returns
|
71
134
|
-------
|
72
135
|
BalanceOutput
|
73
|
-
(num_factors+1) x (num_factors+1) estimate of
|
136
|
+
(num_factors+1) x (num_factors+1) estimate of mutual information
|
74
137
|
between num_factors metadata factors and class label. Symmetry is enforced.
|
75
138
|
|
76
139
|
Note
|
@@ -83,7 +146,7 @@ def balance(class_labels: ArrayLike, metadata: Mapping[str, ArrayLike], num_neig
|
|
83
146
|
|
84
147
|
Example
|
85
148
|
-------
|
86
|
-
Return
|
149
|
+
Return balance (mutual information) of factors with class_labels
|
87
150
|
|
88
151
|
>>> bal = balance(class_labels, metadata)
|
89
152
|
>>> bal.balance
|
@@ -114,6 +177,9 @@ def balance(class_labels: ArrayLike, metadata: Mapping[str, ArrayLike], num_neig
|
|
114
177
|
mi = np.empty((num_factors, num_factors))
|
115
178
|
mi[:] = np.nan
|
116
179
|
|
180
|
+
class_idx = names.index("class_label")
|
181
|
+
class_lbl = np.array(data[:, class_idx], dtype=int)
|
182
|
+
|
117
183
|
for idx in range(num_factors):
|
118
184
|
tgt = data[:, idx].astype(int)
|
119
185
|
|
@@ -174,4 +240,4 @@ def balance(class_labels: ArrayLike, metadata: Mapping[str, ArrayLike], num_neig
|
|
174
240
|
norm_factor = 0.5 * np.add.outer(ent_tgt_bin, ent_all) + 1e-6
|
175
241
|
classwise = classwise_mi / norm_factor
|
176
242
|
|
177
|
-
return BalanceOutput(balance, factors, classwise)
|
243
|
+
return BalanceOutput(balance, factors, classwise, class_lbl, list(metadata.keys()))
|
@@ -1,5 +1,7 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
|
+
__all__ = ["CoverageOutput", "coverage"]
|
4
|
+
|
3
5
|
import math
|
4
6
|
from dataclasses import dataclass
|
5
7
|
from typing import Literal
|
@@ -8,9 +10,9 @@ import numpy as np
|
|
8
10
|
from numpy.typing import ArrayLike, NDArray
|
9
11
|
from scipy.spatial.distance import pdist, squareform
|
10
12
|
|
11
|
-
from dataeval.
|
12
|
-
from dataeval.
|
13
|
-
from dataeval.
|
13
|
+
from dataeval.interop import to_numpy
|
14
|
+
from dataeval.output import OutputMetadata, set_metadata
|
15
|
+
from dataeval.utils.shared import flatten
|
14
16
|
|
15
17
|
|
16
18
|
@dataclass(frozen=True)
|
@@ -33,7 +35,7 @@ class CoverageOutput(OutputMetadata):
|
|
33
35
|
critical_value: float
|
34
36
|
|
35
37
|
|
36
|
-
@set_metadata(
|
38
|
+
@set_metadata()
|
37
39
|
def coverage(
|
38
40
|
embeddings: ArrayLike,
|
39
41
|
radius_type: Literal["adaptive", "naive"] = "adaptive",
|
@@ -1,13 +1,16 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
|
+
__all__ = ["DiversityOutput", "diversity"]
|
4
|
+
|
3
5
|
from dataclasses import dataclass
|
4
|
-
from typing import Literal, Mapping
|
6
|
+
from typing import Any, Literal, Mapping
|
5
7
|
|
6
8
|
import numpy as np
|
7
9
|
from numpy.typing import ArrayLike, NDArray
|
8
10
|
|
9
|
-
from dataeval.
|
10
|
-
from dataeval.
|
11
|
+
from dataeval.metrics.bias.metadata import entropy, get_counts, get_num_bins, heatmap, preprocess_metadata
|
12
|
+
from dataeval.output import OutputMetadata, set_metadata
|
13
|
+
from dataeval.utils.shared import get_method
|
11
14
|
|
12
15
|
|
13
16
|
@dataclass(frozen=True)
|
@@ -21,18 +24,52 @@ class DiversityOutput(OutputMetadata):
|
|
21
24
|
:term:`Diversity` index for classes and factors
|
22
25
|
classwise : NDArray[np.float64]
|
23
26
|
Classwise diversity index [n_class x n_factor]
|
27
|
+
class_list: NDArray[np.int64]
|
28
|
+
Class labels for each value in the dataset
|
29
|
+
metadata_names: list[str]
|
30
|
+
Names of each metadata factor
|
24
31
|
"""
|
25
32
|
|
26
33
|
diversity_index: NDArray[np.float64]
|
27
34
|
classwise: NDArray[np.float64]
|
28
35
|
|
36
|
+
class_list: NDArray[np.int64]
|
37
|
+
metadata_names: list[str]
|
38
|
+
|
39
|
+
method: Literal["shannon", "simpson"]
|
40
|
+
|
41
|
+
def plot(self, row_labels: NDArray[Any] | None = None, col_labels: NDArray[Any] | None = None) -> None:
|
42
|
+
"""
|
43
|
+
Plot a heatmap of diversity information
|
44
|
+
|
45
|
+
Parameters
|
46
|
+
----------
|
47
|
+
row_labels: NDArray | None, default None
|
48
|
+
Array containing the labels for rows in the histogram
|
49
|
+
col_labels: NDArray | None, default None
|
50
|
+
Array containing the labels for columns in the histogram
|
51
|
+
"""
|
52
|
+
if row_labels is None:
|
53
|
+
row_labels = np.unique(self.class_list)
|
54
|
+
if col_labels is None:
|
55
|
+
col_labels = np.array(self.metadata_names)
|
56
|
+
|
57
|
+
heatmap(
|
58
|
+
self.classwise,
|
59
|
+
row_labels,
|
60
|
+
col_labels,
|
61
|
+
xlabel="Factors",
|
62
|
+
ylabel="Class",
|
63
|
+
cbarlabel=f"Normalized {self.method.title()} Index",
|
64
|
+
)
|
65
|
+
|
29
66
|
|
30
67
|
def diversity_shannon(
|
31
|
-
data: NDArray,
|
68
|
+
data: NDArray[Any],
|
32
69
|
names: list[str],
|
33
70
|
is_categorical: list[bool],
|
34
71
|
subset_mask: NDArray[np.bool_] | None = None,
|
35
|
-
) -> NDArray:
|
72
|
+
) -> NDArray[np.float64]:
|
36
73
|
"""
|
37
74
|
Compute :term:`diversity<Diversity>` for discrete/categorical variables and, through standard
|
38
75
|
histogram binning, for continuous variables.
|
@@ -79,11 +116,11 @@ def diversity_shannon(
|
|
79
116
|
|
80
117
|
|
81
118
|
def diversity_simpson(
|
82
|
-
data: NDArray,
|
119
|
+
data: NDArray[Any],
|
83
120
|
names: list[str],
|
84
121
|
is_categorical: list[bool],
|
85
122
|
subset_mask: NDArray[np.bool_] | None = None,
|
86
|
-
) -> NDArray:
|
123
|
+
) -> NDArray[np.float64]:
|
87
124
|
"""
|
88
125
|
Compute :term:`diversity<Diversity>` for discrete/categorical variables and, through standard
|
89
126
|
histogram binning, for continuous variables.
|
@@ -139,10 +176,7 @@ def diversity_simpson(
|
|
139
176
|
return ev_index
|
140
177
|
|
141
178
|
|
142
|
-
|
143
|
-
|
144
|
-
|
145
|
-
@set_metadata("dataeval.metrics")
|
179
|
+
@set_metadata()
|
146
180
|
def diversity(
|
147
181
|
class_labels: ArrayLike, metadata: Mapping[str, ArrayLike], method: Literal["shannon", "simpson"] = "simpson"
|
148
182
|
) -> DiversityOutput:
|
@@ -202,12 +236,12 @@ def diversity(
|
|
202
236
|
--------
|
203
237
|
numpy.histogram
|
204
238
|
"""
|
205
|
-
diversity_fn = get_method(
|
239
|
+
diversity_fn = get_method({"simpson": diversity_simpson, "shannon": diversity_shannon}, method)
|
206
240
|
data, names, is_categorical = preprocess_metadata(class_labels, metadata)
|
207
241
|
diversity_index = diversity_fn(data, names, is_categorical, None).astype(np.float64)
|
208
242
|
|
209
243
|
class_idx = names.index("class_label")
|
210
|
-
class_lbl = data[:, class_idx]
|
244
|
+
class_lbl = np.array(data[:, class_idx], dtype=int)
|
211
245
|
|
212
246
|
u_classes = np.unique(class_lbl)
|
213
247
|
num_factors = len(names)
|
@@ -218,4 +252,4 @@ def diversity(
|
|
218
252
|
diversity[idx, :] = diversity_fn(data, names, is_categorical, subset_mask)
|
219
253
|
div_no_class = np.concatenate((diversity[:, :class_idx], diversity[:, (class_idx + 1) :]), axis=1)
|
220
254
|
|
221
|
-
return DiversityOutput(diversity_index, div_no_class)
|
255
|
+
return DiversityOutput(diversity_index, div_no_class, class_lbl, list(metadata.keys()), method)
|
@@ -0,0 +1,275 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
__all__ = []
|
4
|
+
|
5
|
+
from typing import Any, Mapping
|
6
|
+
|
7
|
+
import numpy as np
|
8
|
+
from numpy.typing import ArrayLike, NDArray
|
9
|
+
from scipy.stats import entropy as sp_entropy
|
10
|
+
|
11
|
+
from dataeval.interop import to_numpy
|
12
|
+
|
13
|
+
|
14
|
+
def get_counts(
|
15
|
+
data: NDArray[np.int_], names: list[str], is_categorical: list[bool], subset_mask: NDArray[np.bool_] | None = None
|
16
|
+
) -> tuple[dict[str, NDArray[np.int_]], dict[str, NDArray[np.int_]]]:
|
17
|
+
"""
|
18
|
+
Initialize dictionary of histogram counts --- treat categorical values
|
19
|
+
as histogram bins.
|
20
|
+
|
21
|
+
Parameters
|
22
|
+
----------
|
23
|
+
subset_mask: NDArray[np.bool_] | None
|
24
|
+
Boolean mask of samples to bin (e.g. when computing per class). True -> include in histogram counts
|
25
|
+
|
26
|
+
Returns
|
27
|
+
-------
|
28
|
+
counts: Dict
|
29
|
+
histogram counts per metadata factor in `factors`. Each
|
30
|
+
factor will have a different number of bins. Counts get reused
|
31
|
+
across metrics, so hist_counts are cached but only if computed
|
32
|
+
globally, i.e. without masked samples.
|
33
|
+
"""
|
34
|
+
|
35
|
+
hist_counts, hist_bins = {}, {}
|
36
|
+
# np.where needed to satisfy linter
|
37
|
+
mask = np.where(subset_mask if subset_mask is not None else np.ones(data.shape[0], dtype=bool))
|
38
|
+
|
39
|
+
for cdx, fn in enumerate(names):
|
40
|
+
# linter doesn't like double indexing
|
41
|
+
col_data = data[mask, cdx].squeeze()
|
42
|
+
if is_categorical[cdx]:
|
43
|
+
# if discrete, use unique values as bins
|
44
|
+
bins, cnts = np.unique(col_data, return_counts=True)
|
45
|
+
else:
|
46
|
+
bins = hist_bins.get(fn, "auto")
|
47
|
+
cnts, bins = np.histogram(col_data, bins=bins, density=True)
|
48
|
+
|
49
|
+
hist_counts[fn] = cnts
|
50
|
+
hist_bins[fn] = bins
|
51
|
+
|
52
|
+
return hist_counts, hist_bins
|
53
|
+
|
54
|
+
|
55
|
+
def entropy(
|
56
|
+
data: NDArray[Any],
|
57
|
+
names: list[str],
|
58
|
+
is_categorical: list[bool],
|
59
|
+
normalized: bool = False,
|
60
|
+
subset_mask: NDArray[np.bool_] | None = None,
|
61
|
+
) -> NDArray[np.float64]:
|
62
|
+
"""
|
63
|
+
Meant for use with :term:`bias<Bias>` metrics, :term:`balance<Balance>`, :term:`diversity<Diversity>`,
|
64
|
+
ClasswiseBalance, and Classwise Diversity.
|
65
|
+
|
66
|
+
Compute entropy for discrete/categorical variables and for continuous variables through standard
|
67
|
+
histogram binning.
|
68
|
+
|
69
|
+
Parameters
|
70
|
+
----------
|
71
|
+
normalized: bool
|
72
|
+
Flag that determines whether or not to normalize entropy by log(num_bins)
|
73
|
+
subset_mask: NDArray[np.bool_] | None
|
74
|
+
Boolean mask of samples to bin (e.g. when computing per class). True -> include in histogram counts
|
75
|
+
|
76
|
+
Note
|
77
|
+
----
|
78
|
+
For continuous variables, histogram bins are chosen automatically. See
|
79
|
+
numpy.histogram for details.
|
80
|
+
|
81
|
+
Returns
|
82
|
+
-------
|
83
|
+
ent: NDArray[np.float64]
|
84
|
+
Entropy estimate per column of X
|
85
|
+
|
86
|
+
See Also
|
87
|
+
--------
|
88
|
+
numpy.histogram
|
89
|
+
scipy.stats.entropy
|
90
|
+
"""
|
91
|
+
|
92
|
+
num_factors = len(names)
|
93
|
+
hist_counts, _ = get_counts(data, names, is_categorical, subset_mask)
|
94
|
+
|
95
|
+
ev_index = np.empty(num_factors)
|
96
|
+
for col, cnts in enumerate(hist_counts.values()):
|
97
|
+
# entropy in nats, normalizes counts
|
98
|
+
ev_index[col] = sp_entropy(cnts)
|
99
|
+
if normalized:
|
100
|
+
if len(cnts) == 1:
|
101
|
+
# log(0)
|
102
|
+
ev_index[col] = 0
|
103
|
+
else:
|
104
|
+
ev_index[col] /= np.log(len(cnts))
|
105
|
+
return ev_index
|
106
|
+
|
107
|
+
|
108
|
+
def get_num_bins(
|
109
|
+
data: NDArray[Any], names: list[str], is_categorical: list[bool], subset_mask: NDArray[np.bool_] | None = None
|
110
|
+
) -> NDArray[np.float64]:
|
111
|
+
"""
|
112
|
+
Number of bins or unique values for each metadata factor, used to
|
113
|
+
normalize entropy/:term:`diversity<Diversity>`.
|
114
|
+
|
115
|
+
Parameters
|
116
|
+
----------
|
117
|
+
subset_mask: NDArray[np.bool_] | None
|
118
|
+
Boolean mask of samples to bin (e.g. when computing per class). True -> include in histogram counts
|
119
|
+
|
120
|
+
Returns
|
121
|
+
-------
|
122
|
+
NDArray[np.float64]
|
123
|
+
"""
|
124
|
+
# likely cached
|
125
|
+
hist_counts, _ = get_counts(data, names, is_categorical, subset_mask)
|
126
|
+
num_bins = np.empty(len(hist_counts))
|
127
|
+
for idx, cnts in enumerate(hist_counts.values()):
|
128
|
+
num_bins[idx] = len(cnts)
|
129
|
+
|
130
|
+
return num_bins
|
131
|
+
|
132
|
+
|
133
|
+
def infer_categorical(arr: NDArray[Any], threshold: float = 0.2) -> NDArray[Any]:
|
134
|
+
"""
|
135
|
+
Compute fraction of feature values that are unique --- intended to be used
|
136
|
+
for inferring whether variables are categorical.
|
137
|
+
"""
|
138
|
+
if arr.ndim == 1:
|
139
|
+
arr = np.expand_dims(arr, axis=1)
|
140
|
+
num_samples = arr.shape[0]
|
141
|
+
pct_unique = np.empty(arr.shape[1])
|
142
|
+
for col in range(arr.shape[1]): # type: ignore
|
143
|
+
uvals = np.unique(arr[:, col], axis=0)
|
144
|
+
pct_unique[col] = len(uvals) / num_samples
|
145
|
+
return pct_unique < threshold
|
146
|
+
|
147
|
+
|
148
|
+
def preprocess_metadata(
|
149
|
+
class_labels: ArrayLike, metadata: Mapping[str, ArrayLike], cat_thresh: float = 0.2
|
150
|
+
) -> tuple[NDArray[Any], list[str], list[bool]]:
|
151
|
+
# convert class_labels and dict of lists to matrix of metadata values
|
152
|
+
preprocessed_metadata = {"class_label": np.asarray(class_labels, dtype=int)}
|
153
|
+
|
154
|
+
# map columns of dict that are not numeric (e.g. string) to numeric values
|
155
|
+
# that mutual information and diversity functions can accommodate. Each
|
156
|
+
# unique string receives a unique integer value.
|
157
|
+
for k, v in metadata.items():
|
158
|
+
# if not numeric
|
159
|
+
v = to_numpy(v)
|
160
|
+
if not np.issubdtype(v.dtype, np.number):
|
161
|
+
_, mapped_vals = np.unique(v, return_inverse=True)
|
162
|
+
preprocessed_metadata[k] = mapped_vals
|
163
|
+
else:
|
164
|
+
preprocessed_metadata[k] = v
|
165
|
+
|
166
|
+
data = np.stack(list(preprocessed_metadata.values()), axis=-1)
|
167
|
+
names = list(preprocessed_metadata.keys())
|
168
|
+
is_categorical = [infer_categorical(preprocessed_metadata[var], cat_thresh)[0] for var in names]
|
169
|
+
|
170
|
+
return data, names, is_categorical
|
171
|
+
|
172
|
+
|
173
|
+
def heatmap(
|
174
|
+
data: NDArray[Any],
|
175
|
+
row_labels: NDArray[Any],
|
176
|
+
col_labels: NDArray[Any],
|
177
|
+
xlabel: str = "",
|
178
|
+
ylabel: str = "",
|
179
|
+
cbarlabel: str = "",
|
180
|
+
) -> None:
|
181
|
+
"""
|
182
|
+
Plots a formatted heatmap
|
183
|
+
|
184
|
+
Parameters
|
185
|
+
----------
|
186
|
+
data: NDArray
|
187
|
+
Array containing numerical values for factors to plot
|
188
|
+
row_labels: NDArray
|
189
|
+
Array containing the labels for rows in the histogram
|
190
|
+
col_labels: NDArray
|
191
|
+
Array containing the labels for columns in the histogram
|
192
|
+
xlabel: str, default ""
|
193
|
+
X-axis label
|
194
|
+
ylabel: str, default ""
|
195
|
+
Y-axis label
|
196
|
+
cbarlabel: str, default ""
|
197
|
+
Label for the colorbar
|
198
|
+
|
199
|
+
"""
|
200
|
+
import matplotlib
|
201
|
+
import matplotlib.pyplot as plt
|
202
|
+
|
203
|
+
fig, ax = plt.subplots(figsize=(10, 10))
|
204
|
+
|
205
|
+
# Plot the heatmap
|
206
|
+
im = ax.imshow(data, vmin=0, vmax=1.0)
|
207
|
+
|
208
|
+
# Create colorbar
|
209
|
+
cbar = fig.colorbar(im, shrink=0.5)
|
210
|
+
cbar.set_ticks([0.0, 0.25, 0.5, 0.75, 1.0])
|
211
|
+
cbar.set_ticklabels(["0.0", "0.25", "0.5", "0.75", "1.0"])
|
212
|
+
cbar.set_label(cbarlabel, loc="center")
|
213
|
+
|
214
|
+
# Show all ticks and label them with the respective list entries.
|
215
|
+
ax.set_xticks(np.arange(data.shape[1]), labels=col_labels)
|
216
|
+
ax.set_yticks(np.arange(data.shape[0]), labels=row_labels)
|
217
|
+
|
218
|
+
ax.tick_params(top=False, bottom=True, labeltop=False, labelbottom=True)
|
219
|
+
# Rotate the tick labels and set their alignment.
|
220
|
+
plt.setp(ax.get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor")
|
221
|
+
|
222
|
+
# Turn spines off and create white grid.
|
223
|
+
ax.spines[:].set_visible(False)
|
224
|
+
|
225
|
+
ax.set_xticks(np.arange(data.shape[1] + 1) - 0.5, minor=True)
|
226
|
+
ax.set_yticks(np.arange(data.shape[0] + 1) - 0.5, minor=True)
|
227
|
+
ax.grid(which="minor", color="w", linestyle="-", linewidth=3)
|
228
|
+
ax.tick_params(which="minor", bottom=False, left=False)
|
229
|
+
|
230
|
+
if xlabel:
|
231
|
+
ax.set_xlabel(xlabel)
|
232
|
+
if ylabel:
|
233
|
+
ax.set_ylabel(ylabel)
|
234
|
+
|
235
|
+
valfmt = matplotlib.ticker.FuncFormatter(format_text) # type: ignore
|
236
|
+
|
237
|
+
# Normalize the threshold to the images color range.
|
238
|
+
threshold = im.norm(1.0) / 2.0
|
239
|
+
|
240
|
+
# Set default alignment to center, but allow it to be
|
241
|
+
# overwritten by textkw.
|
242
|
+
kw = {"horizontalalignment": "center", "verticalalignment": "center"}
|
243
|
+
|
244
|
+
# Loop over the data and create a `Text` for each "pixel".
|
245
|
+
# Change the text's color depending on the data.
|
246
|
+
textcolors = ("white", "black")
|
247
|
+
texts = []
|
248
|
+
for i in range(data.shape[0]):
|
249
|
+
for j in range(data.shape[1]):
|
250
|
+
kw.update(color=textcolors[int(im.norm(data[i, j]) > threshold)])
|
251
|
+
text = im.axes.text(j, i, valfmt(data[i, j], None), **kw) # type: ignore
|
252
|
+
texts.append(text)
|
253
|
+
|
254
|
+
fig.tight_layout()
|
255
|
+
plt.show()
|
256
|
+
|
257
|
+
|
258
|
+
# Function to define how the text is displayed in the heatmap
|
259
|
+
def format_text(*args: str) -> str:
|
260
|
+
"""
|
261
|
+
Helper function to format text for heatmap()
|
262
|
+
|
263
|
+
Parameters
|
264
|
+
----------
|
265
|
+
*args: Tuple (str, str)
|
266
|
+
Text to be formatted. Second element is ignored, but is a
|
267
|
+
mandatory pass-through argument as per matplotlib.ticket.FuncFormatter
|
268
|
+
|
269
|
+
Returns
|
270
|
+
-------
|
271
|
+
str
|
272
|
+
Formatted text
|
273
|
+
"""
|
274
|
+
x = args[0]
|
275
|
+
return f"{x:.2f}".replace("0.00", "0").replace("0.", ".").replace("nan", "")
|
@@ -1,15 +1,17 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
|
+
__all__ = ["ParityOutput", "parity", "label_parity"]
|
4
|
+
|
3
5
|
import warnings
|
4
6
|
from dataclasses import dataclass
|
5
|
-
from typing import Generic, Mapping, TypeVar
|
7
|
+
from typing import Any, Generic, Mapping, TypeVar
|
6
8
|
|
7
9
|
import numpy as np
|
8
10
|
from numpy.typing import ArrayLike, NDArray
|
9
11
|
from scipy.stats import chi2_contingency, chisquare
|
10
12
|
|
11
|
-
from dataeval.
|
12
|
-
from dataeval.
|
13
|
+
from dataeval.interop import to_numpy
|
14
|
+
from dataeval.output import OutputMetadata, set_metadata
|
13
15
|
|
14
16
|
TData = TypeVar("TData", np.float64, NDArray[np.float64])
|
15
17
|
|
@@ -31,7 +33,7 @@ class ParityOutput(Generic[TData], OutputMetadata):
|
|
31
33
|
p_value: TData
|
32
34
|
|
33
35
|
|
34
|
-
def digitize_factor_bins(continuous_values: NDArray, bins: int, factor_name: str) -> NDArray:
|
36
|
+
def digitize_factor_bins(continuous_values: NDArray[Any], bins: int, factor_name: str) -> NDArray[np.intp]:
|
35
37
|
"""
|
36
38
|
Digitizes a list of values into a given number of bins.
|
37
39
|
|
@@ -64,8 +66,8 @@ def digitize_factor_bins(continuous_values: NDArray, bins: int, factor_name: str
|
|
64
66
|
|
65
67
|
|
66
68
|
def format_discretize_factors(
|
67
|
-
data_factors: Mapping[str, NDArray], continuous_factor_bincounts: Mapping[str, int]
|
68
|
-
) -> dict[str, NDArray]:
|
69
|
+
data_factors: Mapping[str, NDArray[Any]], continuous_factor_bincounts: Mapping[str, int]
|
70
|
+
) -> dict[str, NDArray[Any]]:
|
69
71
|
"""
|
70
72
|
Sets up the internal list of metadata factors.
|
71
73
|
|
@@ -115,7 +117,7 @@ def format_discretize_factors(
|
|
115
117
|
return metadata_factors
|
116
118
|
|
117
119
|
|
118
|
-
def normalize_expected_dist(expected_dist: NDArray, observed_dist: NDArray) -> NDArray:
|
120
|
+
def normalize_expected_dist(expected_dist: NDArray[Any], observed_dist: NDArray[Any]) -> NDArray[Any]:
|
119
121
|
"""
|
120
122
|
Normalize the expected label distribution to match the total number of labels in the observed distribution.
|
121
123
|
|
@@ -162,7 +164,7 @@ def normalize_expected_dist(expected_dist: NDArray, observed_dist: NDArray) -> N
|
|
162
164
|
return expected_dist
|
163
165
|
|
164
166
|
|
165
|
-
def validate_dist(label_dist: NDArray, label_name: str):
|
167
|
+
def validate_dist(label_dist: NDArray[Any], label_name: str) -> None:
|
166
168
|
"""
|
167
169
|
Verifies that the given label distribution has labels and checks if
|
168
170
|
any labels have frequencies less than 5.
|
@@ -191,7 +193,7 @@ def validate_dist(label_dist: NDArray, label_name: str):
|
|
191
193
|
)
|
192
194
|
|
193
195
|
|
194
|
-
@set_metadata(
|
196
|
+
@set_metadata()
|
195
197
|
def label_parity(
|
196
198
|
expected_labels: ArrayLike,
|
197
199
|
observed_labels: ArrayLike,
|
@@ -279,7 +281,7 @@ def label_parity(
|
|
279
281
|
return ParityOutput(cs, p)
|
280
282
|
|
281
283
|
|
282
|
-
@set_metadata(
|
284
|
+
@set_metadata()
|
283
285
|
def parity(
|
284
286
|
class_labels: ArrayLike,
|
285
287
|
data_factors: Mapping[str, ArrayLike],
|
@@ -2,8 +2,8 @@
|
|
2
2
|
Estimators calculate performance bounds and the statistical distance between datasets.
|
3
3
|
"""
|
4
4
|
|
5
|
-
from dataeval.
|
6
|
-
from dataeval.
|
7
|
-
from dataeval.
|
5
|
+
from dataeval.metrics.estimators.ber import BEROutput, ber
|
6
|
+
from dataeval.metrics.estimators.divergence import DivergenceOutput, divergence
|
7
|
+
from dataeval.metrics.estimators.uap import UAPOutput, uap
|
8
8
|
|
9
9
|
__all__ = ["ber", "divergence", "uap", "BEROutput", "DivergenceOutput", "UAPOutput"]
|