dataeval 0.73.0__py3-none-any.whl → 0.74.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dataeval/__init__.py +3 -3
- dataeval/detectors/__init__.py +1 -1
- dataeval/detectors/drift/__init__.py +1 -1
- dataeval/detectors/drift/base.py +2 -2
- dataeval/detectors/drift/torch.py +1 -101
- dataeval/detectors/linters/clusterer.py +1 -1
- dataeval/detectors/ood/__init__.py +11 -4
- dataeval/detectors/ood/ae.py +2 -1
- dataeval/detectors/ood/ae_torch.py +70 -0
- dataeval/detectors/ood/aegmm.py +4 -3
- dataeval/detectors/ood/base.py +58 -108
- dataeval/detectors/ood/base_tf.py +109 -0
- dataeval/detectors/ood/base_torch.py +109 -0
- dataeval/detectors/ood/llr.py +2 -2
- dataeval/detectors/ood/metadata_ks_compare.py +53 -14
- dataeval/detectors/ood/vae.py +3 -2
- dataeval/detectors/ood/vaegmm.py +5 -4
- dataeval/metrics/bias/__init__.py +3 -0
- dataeval/metrics/bias/balance.py +77 -64
- dataeval/metrics/bias/coverage.py +12 -12
- dataeval/metrics/bias/diversity.py +74 -114
- dataeval/metrics/bias/metadata_preprocessing.py +285 -0
- dataeval/metrics/bias/metadata_utils.py +229 -0
- dataeval/metrics/bias/parity.py +54 -158
- dataeval/utils/__init__.py +2 -2
- dataeval/utils/gmm.py +26 -0
- dataeval/utils/metadata.py +29 -9
- dataeval/utils/shared.py +1 -1
- dataeval/utils/split_dataset.py +12 -6
- dataeval/utils/tensorflow/_internal/gmm.py +4 -24
- dataeval/utils/torch/datasets.py +2 -2
- dataeval/utils/torch/gmm.py +98 -0
- dataeval/utils/torch/models.py +192 -0
- dataeval/utils/torch/trainer.py +84 -5
- dataeval/utils/torch/utils.py +107 -1
- dataeval/workflows/__init__.py +1 -1
- {dataeval-0.73.0.dist-info → dataeval-0.74.0.dist-info}/METADATA +1 -2
- {dataeval-0.73.0.dist-info → dataeval-0.74.0.dist-info}/RECORD +40 -34
- dataeval/metrics/bias/metadata.py +0 -358
- {dataeval-0.73.0.dist-info → dataeval-0.74.0.dist-info}/LICENSE.txt +0 -0
- {dataeval-0.73.0.dist-info → dataeval-0.74.0.dist-info}/WHEEL +0 -0
@@ -1,36 +1,40 @@
|
|
1
|
-
dataeval/__init__.py,sha256=
|
2
|
-
dataeval/detectors/__init__.py,sha256=
|
3
|
-
dataeval/detectors/drift/__init__.py,sha256=
|
4
|
-
dataeval/detectors/drift/base.py,sha256=
|
1
|
+
dataeval/__init__.py,sha256=bwKFegCsdGFydqDvza_wSvJgRGr-0pQ59UpcePQ1mNs,601
|
2
|
+
dataeval/detectors/__init__.py,sha256=mwAyY54Hvp6N4D57cde3_besOinK8jVF43k0Mw4XZi8,363
|
3
|
+
dataeval/detectors/drift/__init__.py,sha256=BSXm21y7cAawHep-ZldCJ5HOvzYjPzYGKGrmoEs3i0E,737
|
4
|
+
dataeval/detectors/drift/base.py,sha256=xwI6C-PEH0ZjpSqP6No6WDZp42DnE16OHi_mXe2JSvI,14499
|
5
5
|
dataeval/detectors/drift/cvm.py,sha256=kc59w2_wtxFGNnLcaJRvX5v_38gPXiebSGNiFVdunEQ,4142
|
6
6
|
dataeval/detectors/drift/ks.py,sha256=gcpe1WIQeNeZdLYkdMZCFLXUp1bHMQUxwJE6-RLVOXs,4229
|
7
7
|
dataeval/detectors/drift/mmd.py,sha256=TqGOnUNYKwpS0GQPV3dSl-_qRa0g2flmoQ-dxzW_JfY,7586
|
8
|
-
dataeval/detectors/drift/torch.py,sha256=
|
8
|
+
dataeval/detectors/drift/torch.py,sha256=igEQ2DV9JmcpTdUKCOHBi5LxtoNeCAslJS2Ldulg1hw,7585
|
9
9
|
dataeval/detectors/drift/uncertainty.py,sha256=Xz2yzJjtJfw1vLag234jwRvaa_HK36nMajGx8bQaNRs,5322
|
10
10
|
dataeval/detectors/drift/updates.py,sha256=UJ0z5hlunRi7twnkLABfdJG3tT2EqX4y9IGx8_USYvo,1780
|
11
11
|
dataeval/detectors/linters/__init__.py,sha256=BvpaB1RUpkEhhXk3Mqi5NYoOcJKZRFSBOJCmQOIfYRU,483
|
12
|
-
dataeval/detectors/linters/clusterer.py,sha256=
|
12
|
+
dataeval/detectors/linters/clusterer.py,sha256=sau5A9YcQ6VDjbZGOIaCaRHW_63opaA31pqHo5Rm-hQ,21018
|
13
13
|
dataeval/detectors/linters/duplicates.py,sha256=tOD43rJkvheIA3mznbUqHhft2yD3xRZQdCt61daIca4,5665
|
14
14
|
dataeval/detectors/linters/merged_stats.py,sha256=X-bDTwjyR8RuVmzxLaHZmQ5nI3oOWvsqVlitdSncapk,1355
|
15
15
|
dataeval/detectors/linters/outliers.py,sha256=BUVvtbKHo04KnRmrgb84MBr0l1gtcY3-xNCHjetFrEQ,10117
|
16
|
-
dataeval/detectors/ood/__init__.py,sha256=
|
17
|
-
dataeval/detectors/ood/ae.py,sha256=
|
18
|
-
dataeval/detectors/ood/
|
19
|
-
dataeval/detectors/ood/
|
20
|
-
dataeval/detectors/ood/
|
21
|
-
dataeval/detectors/ood/
|
16
|
+
dataeval/detectors/ood/__init__.py,sha256=XckkWVhYbbg9iWVsCPEQN-t7FFSt2a4jmCwAAempkM4,793
|
17
|
+
dataeval/detectors/ood/ae.py,sha256=km7buF8LbMmwsyfu1xMOI5CJDnQX1x8_-c04zTGMXRI,2389
|
18
|
+
dataeval/detectors/ood/ae_torch.py,sha256=pO9w5221bXR9lEBkE7oakXeE7PXUUR--xcTpmHvOCSk,2142
|
19
|
+
dataeval/detectors/ood/aegmm.py,sha256=CI2HEkRMJSEFTVLZEhz4CStkaS7i66yTPtnbkbCqTes,2084
|
20
|
+
dataeval/detectors/ood/base.py,sha256=u9S7z7zJ8wuPqrtn63ePdAa8DdI579EbCy8Tn0M3XI8,6983
|
21
|
+
dataeval/detectors/ood/base_tf.py,sha256=ppj8rAjXjHEab2oGfQO2olXyN4aGZH8_QHIEghOoeFQ,3297
|
22
|
+
dataeval/detectors/ood/base_torch.py,sha256=yFbSfQsBMwZeVf8mrixmkZYBGChhV5oAHtkgzWnMzsA,3405
|
23
|
+
dataeval/detectors/ood/llr.py,sha256=IrOam-kqUU4bftolR3MvhcEq-NNj2euyI-lYvMuXYn8,10645
|
24
|
+
dataeval/detectors/ood/metadata_ks_compare.py,sha256=Ka6MABdJH5ZlHF66mENpSOLCE8H9xdQ_wWNwMYVO_Q0,5352
|
22
25
|
dataeval/detectors/ood/metadata_least_likely.py,sha256=nxMCXUOjOfWHDTGT2SLE7OYBCydRq8zHLd8t17k7hMM,5193
|
23
26
|
dataeval/detectors/ood/metadata_ood_mi.py,sha256=KLay2BmgHrStBV92VpIs_B1yEfQKllsMTgzOQEng01I,4065
|
24
|
-
dataeval/detectors/ood/vae.py,sha256=
|
25
|
-
dataeval/detectors/ood/vaegmm.py,sha256=
|
27
|
+
dataeval/detectors/ood/vae.py,sha256=yjK4p-XYhnH3wWPiwAclb3eyZE0wpTazLLuKhzurcWY,3203
|
28
|
+
dataeval/detectors/ood/vaegmm.py,sha256=FhPJBzs7wyEPQUUMxOMsdPpCdAZwN82vztjt05cSrds,2459
|
26
29
|
dataeval/interop.py,sha256=TZCkZo844DvzHoxuRo-YsBhT6GvKmyQTHtUEQZPly1M,1728
|
27
30
|
dataeval/metrics/__init__.py,sha256=fPBNLd-T6mCErZBBJrxWmXIL0jCk7fNUYIcNEBkMa80,238
|
28
|
-
dataeval/metrics/bias/__init__.py,sha256=
|
29
|
-
dataeval/metrics/bias/balance.py,sha256=
|
30
|
-
dataeval/metrics/bias/coverage.py,sha256=
|
31
|
-
dataeval/metrics/bias/diversity.py,sha256=
|
32
|
-
dataeval/metrics/bias/
|
33
|
-
dataeval/metrics/bias/
|
31
|
+
dataeval/metrics/bias/__init__.py,sha256=dYiPHenS8J7pgRMMW2jNkTBmTbPoYTxT04fZu9PFats,747
|
32
|
+
dataeval/metrics/bias/balance.py,sha256=BH644D_xN7rRUdJMNgVcGHWq3TTnehYjSBhSMhmAFyY,9154
|
33
|
+
dataeval/metrics/bias/coverage.py,sha256=LBrNG6GIrvMJjZckr72heyCTMCke_p5BT8NJWi-noEY,4546
|
34
|
+
dataeval/metrics/bias/diversity.py,sha256=__7I934sVoymXqgHoneXglJhIU5iHRIuklFwC2ks84w,7719
|
35
|
+
dataeval/metrics/bias/metadata_preprocessing.py,sha256=DbtzsiHjkCxs411okb6s2B_H2TqfvwJ4xyt9m_OsqJo,12266
|
36
|
+
dataeval/metrics/bias/metadata_utils.py,sha256=HmTjlRRTdM9566oKUDDdVMJ8luss4DYykFOiS2FQzhM,6558
|
37
|
+
dataeval/metrics/bias/parity.py,sha256=lLa2zN0AK-zWzlXmvLCbMxTZFodAKLs8wSGl_YZdNFo,12765
|
34
38
|
dataeval/metrics/estimators/__init__.py,sha256=O6ocxJq8XDkfJWwXeJnnnzbOyRnFPKF4kTIVTTZYOA8,380
|
35
39
|
dataeval/metrics/estimators/ber.py,sha256=SVT-BIC_GLs0l2l2NhWu4OpRbgn96w-OwTSoPHTnQbE,5037
|
36
40
|
dataeval/metrics/estimators/divergence.py,sha256=pImaa216-YYTgGWDCSTcpJrC-dfl7150yVrPfW_TyGc,4293
|
@@ -46,14 +50,15 @@ dataeval/metrics/stats/pixelstats.py,sha256=x90O10IqVjEORtYwueFLvJnVYTxhPBOOx5HM
|
|
46
50
|
dataeval/metrics/stats/visualstats.py,sha256=y0xIvst7epcajk8vz2jngiAiz0T7DZC-M97Rs1-vV9I,4950
|
47
51
|
dataeval/output.py,sha256=jWXXNxFNBEaY1rN7Z-6LZl6bQT-I7z_wqr91Rhrdt_0,3061
|
48
52
|
dataeval/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
49
|
-
dataeval/utils/__init__.py,sha256=
|
53
|
+
dataeval/utils/__init__.py,sha256=FZLWDA7nMbHOcdg3701cVJpQmUp1Wxxk8h_qIrUQQjY,713
|
54
|
+
dataeval/utils/gmm.py,sha256=YuLsJKsVWgH_wHr1u_hSRH5Yeexdj8exht8h99L7bLo,561
|
50
55
|
dataeval/utils/image.py,sha256=KgC_1nW__nGN5q6bVZNvG4U_qIBdjcPATz9qe8f2XuA,1928
|
51
56
|
dataeval/utils/lazy.py,sha256=M0iBHuJh4UPrSJPHZ0jhFwRSZhyjHJQx_KEf1OCkHD8,588
|
52
|
-
dataeval/utils/metadata.py,sha256=
|
53
|
-
dataeval/utils/shared.py,sha256=
|
54
|
-
dataeval/utils/split_dataset.py,sha256=
|
57
|
+
dataeval/utils/metadata.py,sha256=0A--iru0zEmi044mKz5P35q69KrI30yoiRSlvs7TSdQ,9418
|
58
|
+
dataeval/utils/shared.py,sha256=xvF3VLfyheVwJtdtDrneOobkKf7t-JTmf_w91FWXmqo,3616
|
59
|
+
dataeval/utils/split_dataset.py,sha256=Ot1ZJhbIhVfcShYXF9MkWXak5odBXyuBdRh-noXh-MI,19555
|
55
60
|
dataeval/utils/tensorflow/__init__.py,sha256=l4OjIA75JJXeNWDCkST1xtDMVYsw97lZ-9JXFBlyuYg,539
|
56
|
-
dataeval/utils/tensorflow/_internal/gmm.py,sha256=
|
61
|
+
dataeval/utils/tensorflow/_internal/gmm.py,sha256=XvjhWM3ppP-R9nCZGs80WphmQR3u7wb-VtoCQYeXZlQ,3404
|
57
62
|
dataeval/utils/tensorflow/_internal/loss.py,sha256=TFhoNPgqeJtdpIHYobZPyzMpeWjzlFqzu5LCtthEUi4,4463
|
58
63
|
dataeval/utils/tensorflow/_internal/models.py,sha256=TzQYRrFe5XomhnPw05v-HBODQdFIqWg21WH1xS0XBlg,59868
|
59
64
|
dataeval/utils/tensorflow/_internal/trainer.py,sha256=uBFTnAy9o2T_FoT3RSX-AA7T-2FScyOdYEg9_7Dpd28,4314
|
@@ -61,13 +66,14 @@ dataeval/utils/tensorflow/_internal/utils.py,sha256=lr5hKkAPbjMCUNIzMUIqbEddwbWQ
|
|
61
66
|
dataeval/utils/tensorflow/loss/__init__.py,sha256=Q-66vt91Oe1ByYfo28tW32zXDq2MqQ2gngWgmIVmof8,227
|
62
67
|
dataeval/utils/torch/__init__.py,sha256=lpkqfgyARUxgrV94cZESQv8PIP2p-UnwItZ_wIr0XzQ,675
|
63
68
|
dataeval/utils/torch/blocks.py,sha256=HVhBTMMD5NA4qheMUgyol1KWiKZDIuc8k5j4RcMKmhk,1466
|
64
|
-
dataeval/utils/torch/datasets.py,sha256=
|
65
|
-
dataeval/utils/torch/
|
66
|
-
dataeval/utils/torch/
|
67
|
-
dataeval/utils/torch/
|
68
|
-
dataeval/
|
69
|
+
dataeval/utils/torch/datasets.py,sha256=10elNgLuH_FDX_CHE3y2Z215JN4-PQovQm5brcIJOeM,15021
|
70
|
+
dataeval/utils/torch/gmm.py,sha256=VbLlUQohwToApT493_tjQBWy2UM5R-3ppS9Dp-eP7BA,3240
|
71
|
+
dataeval/utils/torch/models.py,sha256=sdGeo7a8vshCTGA4lYyVxxb_aDWUlxdtIVxrddS-_ls,8542
|
72
|
+
dataeval/utils/torch/trainer.py,sha256=8BEXr6xtk-CHJTcNxOBnWgkFWfJUAiBy28cEdBhLMRU,7883
|
73
|
+
dataeval/utils/torch/utils.py,sha256=nWRcT6z6DbFVrL1RyxCOX3DPoCrv9G0B-VI_9LdGCQQ,5784
|
74
|
+
dataeval/workflows/__init__.py,sha256=ef1MiVL5IuhlDXXbwsiAfafhnr7tD3TXF9GRusy9_O8,290
|
69
75
|
dataeval/workflows/sufficiency.py,sha256=1jSYhH9i4oesmJYs5PZvWS1LGXf8ekOgNhpFtMPLPXk,18552
|
70
|
-
dataeval-0.
|
71
|
-
dataeval-0.
|
72
|
-
dataeval-0.
|
73
|
-
dataeval-0.
|
76
|
+
dataeval-0.74.0.dist-info/LICENSE.txt,sha256=Kpzcfobf1HlqafF-EX6dQLw9TlJiaJzfgvLQFukyXYw,1060
|
77
|
+
dataeval-0.74.0.dist-info/METADATA,sha256=OPnkHZTm8R1LHqLxcSnOHjqj5GuHmjUVI3dddTVsBAc,4680
|
78
|
+
dataeval-0.74.0.dist-info/WHEEL,sha256=Nq82e9rUAnEjt98J6MlVmMCZb-t9cYE2Ir1kpBmnWfs,88
|
79
|
+
dataeval-0.74.0.dist-info/RECORD,,
|
@@ -1,358 +0,0 @@
|
|
1
|
-
from __future__ import annotations
|
2
|
-
|
3
|
-
__all__ = []
|
4
|
-
|
5
|
-
import contextlib
|
6
|
-
from typing import Any, Mapping
|
7
|
-
|
8
|
-
import numpy as np
|
9
|
-
from numpy.typing import ArrayLike, NDArray
|
10
|
-
from scipy.stats import entropy as sp_entropy
|
11
|
-
|
12
|
-
from dataeval.interop import to_numpy
|
13
|
-
|
14
|
-
with contextlib.suppress(ImportError):
|
15
|
-
from matplotlib.figure import Figure
|
16
|
-
|
17
|
-
CLASS_LABEL = "class_label"
|
18
|
-
|
19
|
-
|
20
|
-
def get_counts(
|
21
|
-
data: NDArray[np.int_], names: list[str], is_categorical: list[bool], subset_mask: NDArray[np.bool_] | None = None
|
22
|
-
) -> tuple[dict[str, NDArray[np.int_]], dict[str, NDArray[np.int_]]]:
|
23
|
-
"""
|
24
|
-
Initialize dictionary of histogram counts --- treat categorical values
|
25
|
-
as histogram bins.
|
26
|
-
|
27
|
-
Parameters
|
28
|
-
----------
|
29
|
-
subset_mask: NDArray[np.bool_] | None
|
30
|
-
Boolean mask of samples to bin (e.g. when computing per class). True -> include in histogram counts
|
31
|
-
|
32
|
-
Returns
|
33
|
-
-------
|
34
|
-
counts: Dict
|
35
|
-
histogram counts per metadata factor in `factors`. Each
|
36
|
-
factor will have a different number of bins. Counts get reused
|
37
|
-
across metrics, so hist_counts are cached but only if computed
|
38
|
-
globally, i.e. without masked samples.
|
39
|
-
"""
|
40
|
-
|
41
|
-
hist_counts, hist_bins = {}, {}
|
42
|
-
# np.where needed to satisfy linter
|
43
|
-
mask = np.where(subset_mask if subset_mask is not None else np.ones(data.shape[0], dtype=bool))
|
44
|
-
|
45
|
-
for cdx, fn in enumerate(names):
|
46
|
-
# linter doesn't like double indexing
|
47
|
-
col_data = data[mask, cdx].squeeze()
|
48
|
-
if is_categorical[cdx]:
|
49
|
-
# if discrete, use unique values as bins
|
50
|
-
bins, cnts = np.unique(col_data, return_counts=True)
|
51
|
-
else:
|
52
|
-
bins = hist_bins.get(fn, "auto")
|
53
|
-
cnts, bins = np.histogram(col_data, bins=bins, density=True)
|
54
|
-
|
55
|
-
hist_counts[fn] = cnts
|
56
|
-
hist_bins[fn] = bins
|
57
|
-
|
58
|
-
return hist_counts, hist_bins
|
59
|
-
|
60
|
-
|
61
|
-
def entropy(
|
62
|
-
data: NDArray[Any],
|
63
|
-
names: list[str],
|
64
|
-
is_categorical: list[bool],
|
65
|
-
normalized: bool = False,
|
66
|
-
subset_mask: NDArray[np.bool_] | None = None,
|
67
|
-
) -> NDArray[np.float64]:
|
68
|
-
"""
|
69
|
-
Meant for use with :term:`bias<Bias>` metrics, :term:`balance<Balance>`, :term:`diversity<Diversity>`,
|
70
|
-
ClasswiseBalance, and Classwise Diversity.
|
71
|
-
|
72
|
-
Compute entropy for discrete/categorical variables and for continuous variables through standard
|
73
|
-
histogram binning.
|
74
|
-
|
75
|
-
Parameters
|
76
|
-
----------
|
77
|
-
normalized: bool
|
78
|
-
Flag that determines whether or not to normalize entropy by log(num_bins)
|
79
|
-
subset_mask: NDArray[np.bool_] | None
|
80
|
-
Boolean mask of samples to bin (e.g. when computing per class). True -> include in histogram counts
|
81
|
-
|
82
|
-
Note
|
83
|
-
----
|
84
|
-
For continuous variables, histogram bins are chosen automatically. See
|
85
|
-
numpy.histogram for details.
|
86
|
-
|
87
|
-
Returns
|
88
|
-
-------
|
89
|
-
ent: NDArray[np.float64]
|
90
|
-
Entropy estimate per column of X
|
91
|
-
|
92
|
-
See Also
|
93
|
-
--------
|
94
|
-
numpy.histogram
|
95
|
-
scipy.stats.entropy
|
96
|
-
"""
|
97
|
-
|
98
|
-
num_factors = len(names)
|
99
|
-
hist_counts, _ = get_counts(data, names, is_categorical, subset_mask)
|
100
|
-
|
101
|
-
ev_index = np.empty(num_factors)
|
102
|
-
for col, cnts in enumerate(hist_counts.values()):
|
103
|
-
# entropy in nats, normalizes counts
|
104
|
-
ev_index[col] = sp_entropy(cnts)
|
105
|
-
if normalized:
|
106
|
-
if len(cnts) == 1:
|
107
|
-
# log(0)
|
108
|
-
ev_index[col] = 0
|
109
|
-
else:
|
110
|
-
ev_index[col] /= np.log(len(cnts))
|
111
|
-
return ev_index
|
112
|
-
|
113
|
-
|
114
|
-
def get_num_bins(
|
115
|
-
data: NDArray[Any], names: list[str], is_categorical: list[bool], subset_mask: NDArray[np.bool_] | None = None
|
116
|
-
) -> NDArray[np.float64]:
|
117
|
-
"""
|
118
|
-
Number of bins or unique values for each metadata factor, used to
|
119
|
-
normalize entropy/:term:`diversity<Diversity>`.
|
120
|
-
|
121
|
-
Parameters
|
122
|
-
----------
|
123
|
-
subset_mask: NDArray[np.bool_] | None
|
124
|
-
Boolean mask of samples to bin (e.g. when computing per class). True -> include in histogram counts
|
125
|
-
|
126
|
-
Returns
|
127
|
-
-------
|
128
|
-
NDArray[np.float64]
|
129
|
-
"""
|
130
|
-
# likely cached
|
131
|
-
hist_counts, _ = get_counts(data, names, is_categorical, subset_mask)
|
132
|
-
num_bins = np.empty(len(hist_counts))
|
133
|
-
for idx, cnts in enumerate(hist_counts.values()):
|
134
|
-
num_bins[idx] = len(cnts)
|
135
|
-
|
136
|
-
return num_bins
|
137
|
-
|
138
|
-
|
139
|
-
def infer_categorical(arr: NDArray[Any], threshold: float = 0.2) -> NDArray[Any]:
|
140
|
-
"""
|
141
|
-
Compute fraction of feature values that are unique --- intended to be used
|
142
|
-
for inferring whether variables are categorical.
|
143
|
-
"""
|
144
|
-
if arr.ndim == 1:
|
145
|
-
arr = np.expand_dims(arr, axis=1)
|
146
|
-
num_samples = arr.shape[0]
|
147
|
-
pct_unique = np.empty(arr.shape[1])
|
148
|
-
for col in range(arr.shape[1]): # type: ignore
|
149
|
-
uvals = np.unique(arr[:, col], axis=0)
|
150
|
-
pct_unique[col] = len(uvals) / num_samples
|
151
|
-
return pct_unique < threshold
|
152
|
-
|
153
|
-
|
154
|
-
def preprocess_metadata(
|
155
|
-
class_labels: ArrayLike, metadata: Mapping[str, ArrayLike], cat_thresh: float = 0.2
|
156
|
-
) -> tuple[NDArray[Any], list[str], list[bool], NDArray[np.str_]]:
|
157
|
-
# if class_labels is not numeric
|
158
|
-
class_array = to_numpy(class_labels)
|
159
|
-
if not np.issubdtype(class_array.dtype, np.number):
|
160
|
-
unique_classes, numerical_labels = np.unique(class_array, return_inverse=True)
|
161
|
-
else:
|
162
|
-
numerical_labels = np.asarray(class_array, dtype=int)
|
163
|
-
unique_classes = np.unique(class_array)
|
164
|
-
|
165
|
-
# convert class_labels and dict of lists to matrix of metadata values
|
166
|
-
preprocessed_metadata = {CLASS_LABEL: numerical_labels}
|
167
|
-
|
168
|
-
# map columns of dict that are not numeric (e.g. string) to numeric values
|
169
|
-
# that mutual information and diversity functions can accommodate. Each
|
170
|
-
# unique string receives a unique integer value.
|
171
|
-
for k, v in metadata.items():
|
172
|
-
if k == CLASS_LABEL:
|
173
|
-
k = "label_class"
|
174
|
-
# if not numeric
|
175
|
-
v = to_numpy(v)
|
176
|
-
if not np.issubdtype(v.dtype, np.number):
|
177
|
-
_, mapped_vals = np.unique(v, return_inverse=True)
|
178
|
-
preprocessed_metadata[k] = mapped_vals
|
179
|
-
else:
|
180
|
-
preprocessed_metadata[k] = v
|
181
|
-
|
182
|
-
data = np.stack(list(preprocessed_metadata.values()), axis=-1)
|
183
|
-
names = list(preprocessed_metadata.keys())
|
184
|
-
is_categorical = [infer_categorical(preprocessed_metadata[var], cat_thresh)[0] for var in names]
|
185
|
-
|
186
|
-
return data, names, is_categorical, unique_classes
|
187
|
-
|
188
|
-
|
189
|
-
def heatmap(
|
190
|
-
data: NDArray[Any],
|
191
|
-
row_labels: list[str] | NDArray[Any],
|
192
|
-
col_labels: list[str] | NDArray[Any],
|
193
|
-
xlabel: str = "",
|
194
|
-
ylabel: str = "",
|
195
|
-
cbarlabel: str = "",
|
196
|
-
) -> Figure:
|
197
|
-
"""
|
198
|
-
Plots a formatted heatmap
|
199
|
-
|
200
|
-
Parameters
|
201
|
-
----------
|
202
|
-
data : NDArray
|
203
|
-
Array containing numerical values for factors to plot
|
204
|
-
row_labels : ArrayLike
|
205
|
-
List/Array containing the labels for rows in the histogram
|
206
|
-
col_labels : ArrayLike
|
207
|
-
List/Array containing the labels for columns in the histogram
|
208
|
-
xlabel : str, default ""
|
209
|
-
X-axis label
|
210
|
-
ylabel : str, default ""
|
211
|
-
Y-axis label
|
212
|
-
cbarlabel : str, default ""
|
213
|
-
Label for the colorbar
|
214
|
-
"""
|
215
|
-
import matplotlib
|
216
|
-
import matplotlib.pyplot as plt
|
217
|
-
|
218
|
-
fig, ax = plt.subplots(figsize=(10, 10))
|
219
|
-
|
220
|
-
# Plot the heatmap
|
221
|
-
im = ax.imshow(data, vmin=0, vmax=1.0)
|
222
|
-
|
223
|
-
# Create colorbar
|
224
|
-
cbar = fig.colorbar(im, shrink=0.5)
|
225
|
-
cbar.set_ticks([0.0, 0.25, 0.5, 0.75, 1.0])
|
226
|
-
cbar.set_ticklabels(["0.0", "0.25", "0.5", "0.75", "1.0"])
|
227
|
-
cbar.set_label(cbarlabel, loc="center")
|
228
|
-
|
229
|
-
# Show all ticks and label them with the respective list entries.
|
230
|
-
ax.set_xticks(np.arange(data.shape[1]), labels=col_labels)
|
231
|
-
ax.set_yticks(np.arange(data.shape[0]), labels=row_labels)
|
232
|
-
|
233
|
-
ax.tick_params(top=False, bottom=True, labeltop=False, labelbottom=True)
|
234
|
-
# Rotate the tick labels and set their alignment.
|
235
|
-
plt.setp(ax.get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor")
|
236
|
-
|
237
|
-
# Turn spines off and create white grid.
|
238
|
-
ax.spines[:].set_visible(False)
|
239
|
-
|
240
|
-
ax.set_xticks(np.arange(data.shape[1] + 1) - 0.5, minor=True)
|
241
|
-
ax.set_yticks(np.arange(data.shape[0] + 1) - 0.5, minor=True)
|
242
|
-
ax.grid(which="minor", color="w", linestyle="-", linewidth=3)
|
243
|
-
ax.tick_params(which="minor", bottom=False, left=False)
|
244
|
-
|
245
|
-
if xlabel:
|
246
|
-
ax.set_xlabel(xlabel)
|
247
|
-
if ylabel:
|
248
|
-
ax.set_ylabel(ylabel)
|
249
|
-
|
250
|
-
valfmt = matplotlib.ticker.FuncFormatter(format_text) # type: ignore
|
251
|
-
|
252
|
-
# Normalize the threshold to the images color range.
|
253
|
-
threshold = im.norm(1.0) / 2.0
|
254
|
-
|
255
|
-
# Set default alignment to center, but allow it to be
|
256
|
-
# overwritten by textkw.
|
257
|
-
kw = {"horizontalalignment": "center", "verticalalignment": "center"}
|
258
|
-
|
259
|
-
# Loop over the data and create a `Text` for each "pixel".
|
260
|
-
# Change the text's color depending on the data.
|
261
|
-
textcolors = ("white", "black")
|
262
|
-
texts = []
|
263
|
-
for i in range(data.shape[0]):
|
264
|
-
for j in range(data.shape[1]):
|
265
|
-
kw.update(color=textcolors[int(im.norm(data[i, j]) > threshold)])
|
266
|
-
text = im.axes.text(j, i, valfmt(data[i, j], None), **kw) # type: ignore
|
267
|
-
texts.append(text)
|
268
|
-
|
269
|
-
fig.tight_layout()
|
270
|
-
return fig
|
271
|
-
|
272
|
-
|
273
|
-
# Function to define how the text is displayed in the heatmap
|
274
|
-
def format_text(*args: str) -> str:
|
275
|
-
"""
|
276
|
-
Helper function to format text for heatmap()
|
277
|
-
|
278
|
-
Parameters
|
279
|
-
----------
|
280
|
-
*args: Tuple (str, str)
|
281
|
-
Text to be formatted. Second element is ignored, but is a
|
282
|
-
mandatory pass-through argument as per matplotlib.ticket.FuncFormatter
|
283
|
-
|
284
|
-
Returns
|
285
|
-
-------
|
286
|
-
str
|
287
|
-
Formatted text
|
288
|
-
"""
|
289
|
-
x = args[0]
|
290
|
-
return f"{x:.2f}".replace("0.00", "0").replace("0.", ".").replace("nan", "")
|
291
|
-
|
292
|
-
|
293
|
-
def diversity_bar_plot(labels: NDArray[Any], bar_heights: NDArray[Any]) -> Figure:
|
294
|
-
"""
|
295
|
-
Plots a formatted bar plot
|
296
|
-
|
297
|
-
Parameters
|
298
|
-
----------
|
299
|
-
labels : NDArray
|
300
|
-
Array containing the labels for each bar
|
301
|
-
bar_heights : NDArray
|
302
|
-
Array containing the values for each bar
|
303
|
-
"""
|
304
|
-
import matplotlib.pyplot as plt
|
305
|
-
|
306
|
-
fig, ax = plt.subplots(figsize=(10, 10))
|
307
|
-
|
308
|
-
ax.bar(labels, bar_heights)
|
309
|
-
ax.set_xlabel("Factors")
|
310
|
-
|
311
|
-
plt.setp(ax.get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor")
|
312
|
-
|
313
|
-
fig.tight_layout()
|
314
|
-
return fig
|
315
|
-
|
316
|
-
|
317
|
-
def coverage_plot(images: NDArray[Any], num_images: int) -> Figure:
|
318
|
-
"""
|
319
|
-
Creates a single plot of all of the provided images
|
320
|
-
|
321
|
-
Parameters
|
322
|
-
----------
|
323
|
-
images : NDArray
|
324
|
-
Array containing only the desired images to plot
|
325
|
-
"""
|
326
|
-
import matplotlib.pyplot as plt
|
327
|
-
|
328
|
-
num_images = min(num_images, len(images))
|
329
|
-
|
330
|
-
if images.ndim == 4:
|
331
|
-
images = np.moveaxis(images, 1, -1)
|
332
|
-
elif images.ndim == 3:
|
333
|
-
images = np.repeat(images[:, :, :, np.newaxis], 3, axis=-1)
|
334
|
-
else:
|
335
|
-
raise ValueError(
|
336
|
-
f"Expected a (N,C,H,W) or a (N, H, W) set of images, but got a {images.ndim}-dimensional set of images."
|
337
|
-
)
|
338
|
-
|
339
|
-
rows = np.ceil(num_images / 3).astype(int)
|
340
|
-
fig, axs = plt.subplots(rows, 3, figsize=(9, 3 * rows))
|
341
|
-
|
342
|
-
if rows == 1:
|
343
|
-
for j in range(3):
|
344
|
-
if j >= len(images):
|
345
|
-
continue
|
346
|
-
axs[j].imshow(images[j])
|
347
|
-
axs[j].axis("off")
|
348
|
-
else:
|
349
|
-
for i in range(rows):
|
350
|
-
for j in range(3):
|
351
|
-
i_j = i * 3 + j
|
352
|
-
if i_j >= len(images):
|
353
|
-
continue
|
354
|
-
axs[i, j].imshow(images[i_j])
|
355
|
-
axs[i, j].axis("off")
|
356
|
-
|
357
|
-
fig.tight_layout()
|
358
|
-
return fig
|
File without changes
|
File without changes
|