dataeval 0.73.0__py3-none-any.whl → 0.74.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (41) hide show
  1. dataeval/__init__.py +3 -3
  2. dataeval/detectors/__init__.py +1 -1
  3. dataeval/detectors/drift/__init__.py +1 -1
  4. dataeval/detectors/drift/base.py +2 -2
  5. dataeval/detectors/drift/torch.py +1 -101
  6. dataeval/detectors/linters/clusterer.py +1 -1
  7. dataeval/detectors/ood/__init__.py +11 -4
  8. dataeval/detectors/ood/ae.py +2 -1
  9. dataeval/detectors/ood/ae_torch.py +70 -0
  10. dataeval/detectors/ood/aegmm.py +4 -3
  11. dataeval/detectors/ood/base.py +58 -108
  12. dataeval/detectors/ood/base_tf.py +109 -0
  13. dataeval/detectors/ood/base_torch.py +109 -0
  14. dataeval/detectors/ood/llr.py +2 -2
  15. dataeval/detectors/ood/metadata_ks_compare.py +53 -14
  16. dataeval/detectors/ood/vae.py +3 -2
  17. dataeval/detectors/ood/vaegmm.py +5 -4
  18. dataeval/metrics/bias/__init__.py +3 -0
  19. dataeval/metrics/bias/balance.py +77 -64
  20. dataeval/metrics/bias/coverage.py +12 -12
  21. dataeval/metrics/bias/diversity.py +74 -114
  22. dataeval/metrics/bias/metadata_preprocessing.py +285 -0
  23. dataeval/metrics/bias/metadata_utils.py +229 -0
  24. dataeval/metrics/bias/parity.py +54 -158
  25. dataeval/utils/__init__.py +2 -2
  26. dataeval/utils/gmm.py +26 -0
  27. dataeval/utils/metadata.py +29 -9
  28. dataeval/utils/shared.py +1 -1
  29. dataeval/utils/split_dataset.py +12 -6
  30. dataeval/utils/tensorflow/_internal/gmm.py +4 -24
  31. dataeval/utils/torch/datasets.py +2 -2
  32. dataeval/utils/torch/gmm.py +98 -0
  33. dataeval/utils/torch/models.py +192 -0
  34. dataeval/utils/torch/trainer.py +84 -5
  35. dataeval/utils/torch/utils.py +107 -1
  36. dataeval/workflows/__init__.py +1 -1
  37. {dataeval-0.73.0.dist-info → dataeval-0.74.0.dist-info}/METADATA +1 -2
  38. {dataeval-0.73.0.dist-info → dataeval-0.74.0.dist-info}/RECORD +40 -34
  39. dataeval/metrics/bias/metadata.py +0 -358
  40. {dataeval-0.73.0.dist-info → dataeval-0.74.0.dist-info}/LICENSE.txt +0 -0
  41. {dataeval-0.73.0.dist-info → dataeval-0.74.0.dist-info}/WHEEL +0 -0
@@ -1,36 +1,40 @@
1
- dataeval/__init__.py,sha256=cAgMAbawI3EC6HdLfV_g_mMpH5Y-zy-n2qzrRKBH_6s,641
2
- dataeval/detectors/__init__.py,sha256=xdp8LYOFjV5tVbAwu0Y03KU9EajHkSFy_M3raqbxpDc,383
3
- dataeval/detectors/drift/__init__.py,sha256=MRPWFOaoVoqAHW36nA5F3wk7QXJU4oecND2RbtgG9oY,757
4
- dataeval/detectors/drift/base.py,sha256=0S-0MFpIFaJ4_8IGreFKSmyna2L50FBn7DVaoNWmw8E,14509
1
+ dataeval/__init__.py,sha256=bwKFegCsdGFydqDvza_wSvJgRGr-0pQ59UpcePQ1mNs,601
2
+ dataeval/detectors/__init__.py,sha256=mwAyY54Hvp6N4D57cde3_besOinK8jVF43k0Mw4XZi8,363
3
+ dataeval/detectors/drift/__init__.py,sha256=BSXm21y7cAawHep-ZldCJ5HOvzYjPzYGKGrmoEs3i0E,737
4
+ dataeval/detectors/drift/base.py,sha256=xwI6C-PEH0ZjpSqP6No6WDZp42DnE16OHi_mXe2JSvI,14499
5
5
  dataeval/detectors/drift/cvm.py,sha256=kc59w2_wtxFGNnLcaJRvX5v_38gPXiebSGNiFVdunEQ,4142
6
6
  dataeval/detectors/drift/ks.py,sha256=gcpe1WIQeNeZdLYkdMZCFLXUp1bHMQUxwJE6-RLVOXs,4229
7
7
  dataeval/detectors/drift/mmd.py,sha256=TqGOnUNYKwpS0GQPV3dSl-_qRa0g2flmoQ-dxzW_JfY,7586
8
- dataeval/detectors/drift/torch.py,sha256=D46J72OPW8-PpP3w9ODMBfcDSdailIgVjgHVFpbYfws,11649
8
+ dataeval/detectors/drift/torch.py,sha256=igEQ2DV9JmcpTdUKCOHBi5LxtoNeCAslJS2Ldulg1hw,7585
9
9
  dataeval/detectors/drift/uncertainty.py,sha256=Xz2yzJjtJfw1vLag234jwRvaa_HK36nMajGx8bQaNRs,5322
10
10
  dataeval/detectors/drift/updates.py,sha256=UJ0z5hlunRi7twnkLABfdJG3tT2EqX4y9IGx8_USYvo,1780
11
11
  dataeval/detectors/linters/__init__.py,sha256=BvpaB1RUpkEhhXk3Mqi5NYoOcJKZRFSBOJCmQOIfYRU,483
12
- dataeval/detectors/linters/clusterer.py,sha256=OtBE5rglAGdTTQRmKUHP6J-uWmnh2E3lZxeqJCnc87U,21014
12
+ dataeval/detectors/linters/clusterer.py,sha256=sau5A9YcQ6VDjbZGOIaCaRHW_63opaA31pqHo5Rm-hQ,21018
13
13
  dataeval/detectors/linters/duplicates.py,sha256=tOD43rJkvheIA3mznbUqHhft2yD3xRZQdCt61daIca4,5665
14
14
  dataeval/detectors/linters/merged_stats.py,sha256=X-bDTwjyR8RuVmzxLaHZmQ5nI3oOWvsqVlitdSncapk,1355
15
15
  dataeval/detectors/linters/outliers.py,sha256=BUVvtbKHo04KnRmrgb84MBr0l1gtcY3-xNCHjetFrEQ,10117
16
- dataeval/detectors/ood/__init__.py,sha256=FVyVuaxVKAOgSTaaBf-j2OXXDarSBFcJ7CTlMV6w88s,661
17
- dataeval/detectors/ood/ae.py,sha256=XQ_rCsf0VWg_2YXt33XGe6ZgxEud1PfIl7TmBVP1GkM,2347
18
- dataeval/detectors/ood/aegmm.py,sha256=6UKv0uJYWAzu1F-cITFGly4w9y_t7wqg3OmVyCN365o,2041
19
- dataeval/detectors/ood/base.py,sha256=a_d52pJMWVmduSt8OvUWYwHE8mpCaI6pIAE4_ib_GOs,8841
20
- dataeval/detectors/ood/llr.py,sha256=TwUk1RsZhnM5tUssGVMBhWggCW2izs_Asy9QPHkTJaU,10615
21
- dataeval/detectors/ood/metadata_ks_compare.py,sha256=jH7uDwyyBIIcTrRhQEdnLAdrwf7LfNczKBw0CpJyF5c,4282
16
+ dataeval/detectors/ood/__init__.py,sha256=XckkWVhYbbg9iWVsCPEQN-t7FFSt2a4jmCwAAempkM4,793
17
+ dataeval/detectors/ood/ae.py,sha256=km7buF8LbMmwsyfu1xMOI5CJDnQX1x8_-c04zTGMXRI,2389
18
+ dataeval/detectors/ood/ae_torch.py,sha256=pO9w5221bXR9lEBkE7oakXeE7PXUUR--xcTpmHvOCSk,2142
19
+ dataeval/detectors/ood/aegmm.py,sha256=CI2HEkRMJSEFTVLZEhz4CStkaS7i66yTPtnbkbCqTes,2084
20
+ dataeval/detectors/ood/base.py,sha256=u9S7z7zJ8wuPqrtn63ePdAa8DdI579EbCy8Tn0M3XI8,6983
21
+ dataeval/detectors/ood/base_tf.py,sha256=ppj8rAjXjHEab2oGfQO2olXyN4aGZH8_QHIEghOoeFQ,3297
22
+ dataeval/detectors/ood/base_torch.py,sha256=yFbSfQsBMwZeVf8mrixmkZYBGChhV5oAHtkgzWnMzsA,3405
23
+ dataeval/detectors/ood/llr.py,sha256=IrOam-kqUU4bftolR3MvhcEq-NNj2euyI-lYvMuXYn8,10645
24
+ dataeval/detectors/ood/metadata_ks_compare.py,sha256=Ka6MABdJH5ZlHF66mENpSOLCE8H9xdQ_wWNwMYVO_Q0,5352
22
25
  dataeval/detectors/ood/metadata_least_likely.py,sha256=nxMCXUOjOfWHDTGT2SLE7OYBCydRq8zHLd8t17k7hMM,5193
23
26
  dataeval/detectors/ood/metadata_ood_mi.py,sha256=KLay2BmgHrStBV92VpIs_B1yEfQKllsMTgzOQEng01I,4065
24
- dataeval/detectors/ood/vae.py,sha256=UKrQNFdHcnxAY0fAFbLrXasY8Z6qg138BXxqwc1hlts,3154
25
- dataeval/detectors/ood/vaegmm.py,sha256=_wwmT37URs0MyhbORk91XJExClv-4e15LH_Bj60Pw1w,2409
27
+ dataeval/detectors/ood/vae.py,sha256=yjK4p-XYhnH3wWPiwAclb3eyZE0wpTazLLuKhzurcWY,3203
28
+ dataeval/detectors/ood/vaegmm.py,sha256=FhPJBzs7wyEPQUUMxOMsdPpCdAZwN82vztjt05cSrds,2459
26
29
  dataeval/interop.py,sha256=TZCkZo844DvzHoxuRo-YsBhT6GvKmyQTHtUEQZPly1M,1728
27
30
  dataeval/metrics/__init__.py,sha256=fPBNLd-T6mCErZBBJrxWmXIL0jCk7fNUYIcNEBkMa80,238
28
- dataeval/metrics/bias/__init__.py,sha256=puf645-hAO5hFHNHlZ239TPopqWIoN-uLGXFB8-hA_o,599
29
- dataeval/metrics/bias/balance.py,sha256=Uz7RHf3UuiAxfYlZpKMg4jMzXwXcEfYj7BUnUjzgkw0,8579
30
- dataeval/metrics/bias/coverage.py,sha256=eB8PacN_uJ19pMd5SVI3N98NC2KJMgE3tgI-DJFNHYs,4497
31
- dataeval/metrics/bias/diversity.py,sha256=v9fiuySovMajW9Re0EH_FdbuJryAAdVKkvOuNngO5nc,9618
32
- dataeval/metrics/bias/metadata.py,sha256=OZB9BzPW6JMq2kTp_a9ucqRNcPpfqOexINax1jH5vVQ,11318
33
- dataeval/metrics/bias/parity.py,sha256=vfGnt_GoGMjMfWgY1FjqNV-gjqVq13tsTTmVkNtRfDM,17120
31
+ dataeval/metrics/bias/__init__.py,sha256=dYiPHenS8J7pgRMMW2jNkTBmTbPoYTxT04fZu9PFats,747
32
+ dataeval/metrics/bias/balance.py,sha256=BH644D_xN7rRUdJMNgVcGHWq3TTnehYjSBhSMhmAFyY,9154
33
+ dataeval/metrics/bias/coverage.py,sha256=LBrNG6GIrvMJjZckr72heyCTMCke_p5BT8NJWi-noEY,4546
34
+ dataeval/metrics/bias/diversity.py,sha256=__7I934sVoymXqgHoneXglJhIU5iHRIuklFwC2ks84w,7719
35
+ dataeval/metrics/bias/metadata_preprocessing.py,sha256=DbtzsiHjkCxs411okb6s2B_H2TqfvwJ4xyt9m_OsqJo,12266
36
+ dataeval/metrics/bias/metadata_utils.py,sha256=HmTjlRRTdM9566oKUDDdVMJ8luss4DYykFOiS2FQzhM,6558
37
+ dataeval/metrics/bias/parity.py,sha256=lLa2zN0AK-zWzlXmvLCbMxTZFodAKLs8wSGl_YZdNFo,12765
34
38
  dataeval/metrics/estimators/__init__.py,sha256=O6ocxJq8XDkfJWwXeJnnnzbOyRnFPKF4kTIVTTZYOA8,380
35
39
  dataeval/metrics/estimators/ber.py,sha256=SVT-BIC_GLs0l2l2NhWu4OpRbgn96w-OwTSoPHTnQbE,5037
36
40
  dataeval/metrics/estimators/divergence.py,sha256=pImaa216-YYTgGWDCSTcpJrC-dfl7150yVrPfW_TyGc,4293
@@ -46,14 +50,15 @@ dataeval/metrics/stats/pixelstats.py,sha256=x90O10IqVjEORtYwueFLvJnVYTxhPBOOx5HM
46
50
  dataeval/metrics/stats/visualstats.py,sha256=y0xIvst7epcajk8vz2jngiAiz0T7DZC-M97Rs1-vV9I,4950
47
51
  dataeval/output.py,sha256=jWXXNxFNBEaY1rN7Z-6LZl6bQT-I7z_wqr91Rhrdt_0,3061
48
52
  dataeval/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
49
- dataeval/utils/__init__.py,sha256=Qr-D0yHnDE8qit0-Wf6xmdMX9Wle2p_mXKgTueTy5GA,753
53
+ dataeval/utils/__init__.py,sha256=FZLWDA7nMbHOcdg3701cVJpQmUp1Wxxk8h_qIrUQQjY,713
54
+ dataeval/utils/gmm.py,sha256=YuLsJKsVWgH_wHr1u_hSRH5Yeexdj8exht8h99L7bLo,561
50
55
  dataeval/utils/image.py,sha256=KgC_1nW__nGN5q6bVZNvG4U_qIBdjcPATz9qe8f2XuA,1928
51
56
  dataeval/utils/lazy.py,sha256=M0iBHuJh4UPrSJPHZ0jhFwRSZhyjHJQx_KEf1OCkHD8,588
52
- dataeval/utils/metadata.py,sha256=A6VN7KbdiOA6rUQvUGKwDcvtOyjBer8bRW_wFxNhmW0,8556
53
- dataeval/utils/shared.py,sha256=BvEeYPMNQTmx4LSaImGeC0VkvcbEY3Byqtxa-jQ3xgc,3623
54
- dataeval/utils/split_dataset.py,sha256=IopyxwC3FaZwgVriW4OXze-mDMpOlvRr83OADA5Jydk,19454
57
+ dataeval/utils/metadata.py,sha256=0A--iru0zEmi044mKz5P35q69KrI30yoiRSlvs7TSdQ,9418
58
+ dataeval/utils/shared.py,sha256=xvF3VLfyheVwJtdtDrneOobkKf7t-JTmf_w91FWXmqo,3616
59
+ dataeval/utils/split_dataset.py,sha256=Ot1ZJhbIhVfcShYXF9MkWXak5odBXyuBdRh-noXh-MI,19555
55
60
  dataeval/utils/tensorflow/__init__.py,sha256=l4OjIA75JJXeNWDCkST1xtDMVYsw97lZ-9JXFBlyuYg,539
56
- dataeval/utils/tensorflow/_internal/gmm.py,sha256=RIFx8asEpi2kMf8JVzq9M3aAvNe9fjpJPf3BzWE-aeE,3787
61
+ dataeval/utils/tensorflow/_internal/gmm.py,sha256=XvjhWM3ppP-R9nCZGs80WphmQR3u7wb-VtoCQYeXZlQ,3404
57
62
  dataeval/utils/tensorflow/_internal/loss.py,sha256=TFhoNPgqeJtdpIHYobZPyzMpeWjzlFqzu5LCtthEUi4,4463
58
63
  dataeval/utils/tensorflow/_internal/models.py,sha256=TzQYRrFe5XomhnPw05v-HBODQdFIqWg21WH1xS0XBlg,59868
59
64
  dataeval/utils/tensorflow/_internal/trainer.py,sha256=uBFTnAy9o2T_FoT3RSX-AA7T-2FScyOdYEg9_7Dpd28,4314
@@ -61,13 +66,14 @@ dataeval/utils/tensorflow/_internal/utils.py,sha256=lr5hKkAPbjMCUNIzMUIqbEddwbWQ
61
66
  dataeval/utils/tensorflow/loss/__init__.py,sha256=Q-66vt91Oe1ByYfo28tW32zXDq2MqQ2gngWgmIVmof8,227
62
67
  dataeval/utils/torch/__init__.py,sha256=lpkqfgyARUxgrV94cZESQv8PIP2p-UnwItZ_wIr0XzQ,675
63
68
  dataeval/utils/torch/blocks.py,sha256=HVhBTMMD5NA4qheMUgyol1KWiKZDIuc8k5j4RcMKmhk,1466
64
- dataeval/utils/torch/datasets.py,sha256=9YV9-Uhq6NCMuu1hPhMnQXjmeI-Ld8ve1z_haxre88o,15023
65
- dataeval/utils/torch/models.py,sha256=0BsXmLK8W1OZ8nnEGb1f9LzIeCgtevQC37dvKS1v1vA,3236
66
- dataeval/utils/torch/trainer.py,sha256=EraOKiXxiMNiycStZNMR5yRz3ehgp87d9ewR9a9dV4w,5559
67
- dataeval/utils/torch/utils.py,sha256=FI4LJ6DvXFQJVff8fxSCP7LRkp8H9BIUgYX0kk7_Cuo,1537
68
- dataeval/workflows/__init__.py,sha256=x2JnOoKmLUCZOsB6RNPqMdVvxEb6Hpda5GPJnD_k0v0,310
69
+ dataeval/utils/torch/datasets.py,sha256=10elNgLuH_FDX_CHE3y2Z215JN4-PQovQm5brcIJOeM,15021
70
+ dataeval/utils/torch/gmm.py,sha256=VbLlUQohwToApT493_tjQBWy2UM5R-3ppS9Dp-eP7BA,3240
71
+ dataeval/utils/torch/models.py,sha256=sdGeo7a8vshCTGA4lYyVxxb_aDWUlxdtIVxrddS-_ls,8542
72
+ dataeval/utils/torch/trainer.py,sha256=8BEXr6xtk-CHJTcNxOBnWgkFWfJUAiBy28cEdBhLMRU,7883
73
+ dataeval/utils/torch/utils.py,sha256=nWRcT6z6DbFVrL1RyxCOX3DPoCrv9G0B-VI_9LdGCQQ,5784
74
+ dataeval/workflows/__init__.py,sha256=ef1MiVL5IuhlDXXbwsiAfafhnr7tD3TXF9GRusy9_O8,290
69
75
  dataeval/workflows/sufficiency.py,sha256=1jSYhH9i4oesmJYs5PZvWS1LGXf8ekOgNhpFtMPLPXk,18552
70
- dataeval-0.73.0.dist-info/LICENSE.txt,sha256=Kpzcfobf1HlqafF-EX6dQLw9TlJiaJzfgvLQFukyXYw,1060
71
- dataeval-0.73.0.dist-info/METADATA,sha256=YVw0z5C5BZs-9gCxCmbo4aNIN7Ph3rZsel7FofFrMKY,4714
72
- dataeval-0.73.0.dist-info/WHEEL,sha256=Nq82e9rUAnEjt98J6MlVmMCZb-t9cYE2Ir1kpBmnWfs,88
73
- dataeval-0.73.0.dist-info/RECORD,,
76
+ dataeval-0.74.0.dist-info/LICENSE.txt,sha256=Kpzcfobf1HlqafF-EX6dQLw9TlJiaJzfgvLQFukyXYw,1060
77
+ dataeval-0.74.0.dist-info/METADATA,sha256=OPnkHZTm8R1LHqLxcSnOHjqj5GuHmjUVI3dddTVsBAc,4680
78
+ dataeval-0.74.0.dist-info/WHEEL,sha256=Nq82e9rUAnEjt98J6MlVmMCZb-t9cYE2Ir1kpBmnWfs,88
79
+ dataeval-0.74.0.dist-info/RECORD,,
@@ -1,358 +0,0 @@
1
- from __future__ import annotations
2
-
3
- __all__ = []
4
-
5
- import contextlib
6
- from typing import Any, Mapping
7
-
8
- import numpy as np
9
- from numpy.typing import ArrayLike, NDArray
10
- from scipy.stats import entropy as sp_entropy
11
-
12
- from dataeval.interop import to_numpy
13
-
14
- with contextlib.suppress(ImportError):
15
- from matplotlib.figure import Figure
16
-
17
- CLASS_LABEL = "class_label"
18
-
19
-
20
- def get_counts(
21
- data: NDArray[np.int_], names: list[str], is_categorical: list[bool], subset_mask: NDArray[np.bool_] | None = None
22
- ) -> tuple[dict[str, NDArray[np.int_]], dict[str, NDArray[np.int_]]]:
23
- """
24
- Initialize dictionary of histogram counts --- treat categorical values
25
- as histogram bins.
26
-
27
- Parameters
28
- ----------
29
- subset_mask: NDArray[np.bool_] | None
30
- Boolean mask of samples to bin (e.g. when computing per class). True -> include in histogram counts
31
-
32
- Returns
33
- -------
34
- counts: Dict
35
- histogram counts per metadata factor in `factors`. Each
36
- factor will have a different number of bins. Counts get reused
37
- across metrics, so hist_counts are cached but only if computed
38
- globally, i.e. without masked samples.
39
- """
40
-
41
- hist_counts, hist_bins = {}, {}
42
- # np.where needed to satisfy linter
43
- mask = np.where(subset_mask if subset_mask is not None else np.ones(data.shape[0], dtype=bool))
44
-
45
- for cdx, fn in enumerate(names):
46
- # linter doesn't like double indexing
47
- col_data = data[mask, cdx].squeeze()
48
- if is_categorical[cdx]:
49
- # if discrete, use unique values as bins
50
- bins, cnts = np.unique(col_data, return_counts=True)
51
- else:
52
- bins = hist_bins.get(fn, "auto")
53
- cnts, bins = np.histogram(col_data, bins=bins, density=True)
54
-
55
- hist_counts[fn] = cnts
56
- hist_bins[fn] = bins
57
-
58
- return hist_counts, hist_bins
59
-
60
-
61
- def entropy(
62
- data: NDArray[Any],
63
- names: list[str],
64
- is_categorical: list[bool],
65
- normalized: bool = False,
66
- subset_mask: NDArray[np.bool_] | None = None,
67
- ) -> NDArray[np.float64]:
68
- """
69
- Meant for use with :term:`bias<Bias>` metrics, :term:`balance<Balance>`, :term:`diversity<Diversity>`,
70
- ClasswiseBalance, and Classwise Diversity.
71
-
72
- Compute entropy for discrete/categorical variables and for continuous variables through standard
73
- histogram binning.
74
-
75
- Parameters
76
- ----------
77
- normalized: bool
78
- Flag that determines whether or not to normalize entropy by log(num_bins)
79
- subset_mask: NDArray[np.bool_] | None
80
- Boolean mask of samples to bin (e.g. when computing per class). True -> include in histogram counts
81
-
82
- Note
83
- ----
84
- For continuous variables, histogram bins are chosen automatically. See
85
- numpy.histogram for details.
86
-
87
- Returns
88
- -------
89
- ent: NDArray[np.float64]
90
- Entropy estimate per column of X
91
-
92
- See Also
93
- --------
94
- numpy.histogram
95
- scipy.stats.entropy
96
- """
97
-
98
- num_factors = len(names)
99
- hist_counts, _ = get_counts(data, names, is_categorical, subset_mask)
100
-
101
- ev_index = np.empty(num_factors)
102
- for col, cnts in enumerate(hist_counts.values()):
103
- # entropy in nats, normalizes counts
104
- ev_index[col] = sp_entropy(cnts)
105
- if normalized:
106
- if len(cnts) == 1:
107
- # log(0)
108
- ev_index[col] = 0
109
- else:
110
- ev_index[col] /= np.log(len(cnts))
111
- return ev_index
112
-
113
-
114
- def get_num_bins(
115
- data: NDArray[Any], names: list[str], is_categorical: list[bool], subset_mask: NDArray[np.bool_] | None = None
116
- ) -> NDArray[np.float64]:
117
- """
118
- Number of bins or unique values for each metadata factor, used to
119
- normalize entropy/:term:`diversity<Diversity>`.
120
-
121
- Parameters
122
- ----------
123
- subset_mask: NDArray[np.bool_] | None
124
- Boolean mask of samples to bin (e.g. when computing per class). True -> include in histogram counts
125
-
126
- Returns
127
- -------
128
- NDArray[np.float64]
129
- """
130
- # likely cached
131
- hist_counts, _ = get_counts(data, names, is_categorical, subset_mask)
132
- num_bins = np.empty(len(hist_counts))
133
- for idx, cnts in enumerate(hist_counts.values()):
134
- num_bins[idx] = len(cnts)
135
-
136
- return num_bins
137
-
138
-
139
- def infer_categorical(arr: NDArray[Any], threshold: float = 0.2) -> NDArray[Any]:
140
- """
141
- Compute fraction of feature values that are unique --- intended to be used
142
- for inferring whether variables are categorical.
143
- """
144
- if arr.ndim == 1:
145
- arr = np.expand_dims(arr, axis=1)
146
- num_samples = arr.shape[0]
147
- pct_unique = np.empty(arr.shape[1])
148
- for col in range(arr.shape[1]): # type: ignore
149
- uvals = np.unique(arr[:, col], axis=0)
150
- pct_unique[col] = len(uvals) / num_samples
151
- return pct_unique < threshold
152
-
153
-
154
- def preprocess_metadata(
155
- class_labels: ArrayLike, metadata: Mapping[str, ArrayLike], cat_thresh: float = 0.2
156
- ) -> tuple[NDArray[Any], list[str], list[bool], NDArray[np.str_]]:
157
- # if class_labels is not numeric
158
- class_array = to_numpy(class_labels)
159
- if not np.issubdtype(class_array.dtype, np.number):
160
- unique_classes, numerical_labels = np.unique(class_array, return_inverse=True)
161
- else:
162
- numerical_labels = np.asarray(class_array, dtype=int)
163
- unique_classes = np.unique(class_array)
164
-
165
- # convert class_labels and dict of lists to matrix of metadata values
166
- preprocessed_metadata = {CLASS_LABEL: numerical_labels}
167
-
168
- # map columns of dict that are not numeric (e.g. string) to numeric values
169
- # that mutual information and diversity functions can accommodate. Each
170
- # unique string receives a unique integer value.
171
- for k, v in metadata.items():
172
- if k == CLASS_LABEL:
173
- k = "label_class"
174
- # if not numeric
175
- v = to_numpy(v)
176
- if not np.issubdtype(v.dtype, np.number):
177
- _, mapped_vals = np.unique(v, return_inverse=True)
178
- preprocessed_metadata[k] = mapped_vals
179
- else:
180
- preprocessed_metadata[k] = v
181
-
182
- data = np.stack(list(preprocessed_metadata.values()), axis=-1)
183
- names = list(preprocessed_metadata.keys())
184
- is_categorical = [infer_categorical(preprocessed_metadata[var], cat_thresh)[0] for var in names]
185
-
186
- return data, names, is_categorical, unique_classes
187
-
188
-
189
- def heatmap(
190
- data: NDArray[Any],
191
- row_labels: list[str] | NDArray[Any],
192
- col_labels: list[str] | NDArray[Any],
193
- xlabel: str = "",
194
- ylabel: str = "",
195
- cbarlabel: str = "",
196
- ) -> Figure:
197
- """
198
- Plots a formatted heatmap
199
-
200
- Parameters
201
- ----------
202
- data : NDArray
203
- Array containing numerical values for factors to plot
204
- row_labels : ArrayLike
205
- List/Array containing the labels for rows in the histogram
206
- col_labels : ArrayLike
207
- List/Array containing the labels for columns in the histogram
208
- xlabel : str, default ""
209
- X-axis label
210
- ylabel : str, default ""
211
- Y-axis label
212
- cbarlabel : str, default ""
213
- Label for the colorbar
214
- """
215
- import matplotlib
216
- import matplotlib.pyplot as plt
217
-
218
- fig, ax = plt.subplots(figsize=(10, 10))
219
-
220
- # Plot the heatmap
221
- im = ax.imshow(data, vmin=0, vmax=1.0)
222
-
223
- # Create colorbar
224
- cbar = fig.colorbar(im, shrink=0.5)
225
- cbar.set_ticks([0.0, 0.25, 0.5, 0.75, 1.0])
226
- cbar.set_ticklabels(["0.0", "0.25", "0.5", "0.75", "1.0"])
227
- cbar.set_label(cbarlabel, loc="center")
228
-
229
- # Show all ticks and label them with the respective list entries.
230
- ax.set_xticks(np.arange(data.shape[1]), labels=col_labels)
231
- ax.set_yticks(np.arange(data.shape[0]), labels=row_labels)
232
-
233
- ax.tick_params(top=False, bottom=True, labeltop=False, labelbottom=True)
234
- # Rotate the tick labels and set their alignment.
235
- plt.setp(ax.get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor")
236
-
237
- # Turn spines off and create white grid.
238
- ax.spines[:].set_visible(False)
239
-
240
- ax.set_xticks(np.arange(data.shape[1] + 1) - 0.5, minor=True)
241
- ax.set_yticks(np.arange(data.shape[0] + 1) - 0.5, minor=True)
242
- ax.grid(which="minor", color="w", linestyle="-", linewidth=3)
243
- ax.tick_params(which="minor", bottom=False, left=False)
244
-
245
- if xlabel:
246
- ax.set_xlabel(xlabel)
247
- if ylabel:
248
- ax.set_ylabel(ylabel)
249
-
250
- valfmt = matplotlib.ticker.FuncFormatter(format_text) # type: ignore
251
-
252
- # Normalize the threshold to the images color range.
253
- threshold = im.norm(1.0) / 2.0
254
-
255
- # Set default alignment to center, but allow it to be
256
- # overwritten by textkw.
257
- kw = {"horizontalalignment": "center", "verticalalignment": "center"}
258
-
259
- # Loop over the data and create a `Text` for each "pixel".
260
- # Change the text's color depending on the data.
261
- textcolors = ("white", "black")
262
- texts = []
263
- for i in range(data.shape[0]):
264
- for j in range(data.shape[1]):
265
- kw.update(color=textcolors[int(im.norm(data[i, j]) > threshold)])
266
- text = im.axes.text(j, i, valfmt(data[i, j], None), **kw) # type: ignore
267
- texts.append(text)
268
-
269
- fig.tight_layout()
270
- return fig
271
-
272
-
273
- # Function to define how the text is displayed in the heatmap
274
- def format_text(*args: str) -> str:
275
- """
276
- Helper function to format text for heatmap()
277
-
278
- Parameters
279
- ----------
280
- *args: Tuple (str, str)
281
- Text to be formatted. Second element is ignored, but is a
282
- mandatory pass-through argument as per matplotlib.ticket.FuncFormatter
283
-
284
- Returns
285
- -------
286
- str
287
- Formatted text
288
- """
289
- x = args[0]
290
- return f"{x:.2f}".replace("0.00", "0").replace("0.", ".").replace("nan", "")
291
-
292
-
293
- def diversity_bar_plot(labels: NDArray[Any], bar_heights: NDArray[Any]) -> Figure:
294
- """
295
- Plots a formatted bar plot
296
-
297
- Parameters
298
- ----------
299
- labels : NDArray
300
- Array containing the labels for each bar
301
- bar_heights : NDArray
302
- Array containing the values for each bar
303
- """
304
- import matplotlib.pyplot as plt
305
-
306
- fig, ax = plt.subplots(figsize=(10, 10))
307
-
308
- ax.bar(labels, bar_heights)
309
- ax.set_xlabel("Factors")
310
-
311
- plt.setp(ax.get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor")
312
-
313
- fig.tight_layout()
314
- return fig
315
-
316
-
317
- def coverage_plot(images: NDArray[Any], num_images: int) -> Figure:
318
- """
319
- Creates a single plot of all of the provided images
320
-
321
- Parameters
322
- ----------
323
- images : NDArray
324
- Array containing only the desired images to plot
325
- """
326
- import matplotlib.pyplot as plt
327
-
328
- num_images = min(num_images, len(images))
329
-
330
- if images.ndim == 4:
331
- images = np.moveaxis(images, 1, -1)
332
- elif images.ndim == 3:
333
- images = np.repeat(images[:, :, :, np.newaxis], 3, axis=-1)
334
- else:
335
- raise ValueError(
336
- f"Expected a (N,C,H,W) or a (N, H, W) set of images, but got a {images.ndim}-dimensional set of images."
337
- )
338
-
339
- rows = np.ceil(num_images / 3).astype(int)
340
- fig, axs = plt.subplots(rows, 3, figsize=(9, 3 * rows))
341
-
342
- if rows == 1:
343
- for j in range(3):
344
- if j >= len(images):
345
- continue
346
- axs[j].imshow(images[j])
347
- axs[j].axis("off")
348
- else:
349
- for i in range(rows):
350
- for j in range(3):
351
- i_j = i * 3 + j
352
- if i_j >= len(images):
353
- continue
354
- axs[i, j].imshow(images[i_j])
355
- axs[i, j].axis("off")
356
-
357
- fig.tight_layout()
358
- return fig