dataeval 0.82.1__py3-none-any.whl → 0.83.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dataeval/__init__.py +7 -2
- dataeval/config.py +10 -0
- dataeval/metadata/__init__.py +2 -2
- dataeval/metadata/_ood.py +144 -27
- dataeval/metrics/bias/_balance.py +3 -3
- dataeval/metrics/estimators/_ber.py +2 -1
- dataeval/metrics/stats/_base.py +17 -18
- dataeval/metrics/stats/_dimensionstats.py +2 -2
- dataeval/metrics/stats/_hashstats.py +2 -2
- dataeval/metrics/stats/_imagestats.py +4 -4
- dataeval/metrics/stats/_pixelstats.py +2 -2
- dataeval/metrics/stats/_visualstats.py +2 -2
- dataeval/outputs/__init__.py +2 -1
- dataeval/outputs/_metadata.py +7 -0
- dataeval/typing.py +40 -9
- dataeval/utils/_mst.py +1 -2
- dataeval/utils/data/_embeddings.py +15 -10
- dataeval/utils/data/_selection.py +22 -11
- dataeval/utils/data/datasets/_base.py +4 -2
- dataeval/utils/data/datasets/_cifar10.py +17 -9
- dataeval/utils/data/datasets/_milco.py +18 -12
- dataeval/utils/data/datasets/_mnist.py +24 -8
- dataeval/utils/data/datasets/_ships.py +18 -8
- dataeval/utils/data/datasets/_types.py +1 -5
- dataeval/utils/data/datasets/_voc.py +47 -24
- dataeval/utils/data/selections/__init__.py +2 -0
- dataeval/utils/data/selections/_classfilter.py +1 -1
- dataeval/utils/data/selections/_prioritize.py +296 -0
- dataeval/utils/data/selections/_shuffle.py +13 -4
- dataeval/utils/torch/_gmm.py +3 -2
- {dataeval-0.82.1.dist-info → dataeval-0.83.0.dist-info}/METADATA +4 -4
- {dataeval-0.82.1.dist-info → dataeval-0.83.0.dist-info}/RECORD +34 -34
- dataeval/detectors/ood/metadata_ood_mi.py +0 -91
- {dataeval-0.82.1.dist-info → dataeval-0.83.0.dist-info}/LICENSE.txt +0 -0
- {dataeval-0.82.1.dist-info → dataeval-0.83.0.dist-info}/WHEEL +0 -0
@@ -1,6 +1,6 @@
|
|
1
|
-
dataeval/__init__.py,sha256=
|
1
|
+
dataeval/__init__.py,sha256=uL-JSd_dKVJpGx4H8f6aOiQVpli46zeTLFqjb4Pa69c,1636
|
2
2
|
dataeval/_log.py,sha256=Mn5bRWO0cgtAYd5VGYSFiPgu57ta3zoktrtHAZ1m3dU,357
|
3
|
-
dataeval/config.py,sha256=
|
3
|
+
dataeval/config.py,sha256=oQ0XQsgIF4_z4n1j0Di6B-JCRUFzzPgJgpQUm3ZlYhs,3539
|
4
4
|
dataeval/detectors/__init__.py,sha256=3Sg-XWlwr75zEEH3hZKA4nWMtGvaRlnfzTWvZG_Ak6U,189
|
5
5
|
dataeval/detectors/drift/__init__.py,sha256=6is_XBtG1d-vUbhHvqXGOdnAwxJ7NA5yRfURn7pCeIw,651
|
6
6
|
dataeval/detectors/drift/_base.py,sha256=mJdKvyROgWvz-p1VlAIJqUI6BAj9ss8riUvR5An5wIw,13459
|
@@ -16,46 +16,45 @@ dataeval/detectors/linters/outliers.py,sha256=Hln2dPQZjF_uV2QYptA_o6ZF3ugyCImVT-
|
|
16
16
|
dataeval/detectors/ood/__init__.py,sha256=juCYBDs7CQEAtMhnEpPqF6uTrOIH9kTBSuQ_GRw6a8o,283
|
17
17
|
dataeval/detectors/ood/ae.py,sha256=YQfhB1ShQLjM1V4uCz9Oo2tCZpOfAZ_-SBCAl4Ac67Y,2921
|
18
18
|
dataeval/detectors/ood/base.py,sha256=9b-Ljznf0lB1SXF4F_Aj3eJ4Y3ijGEDPMjucUsWOGJM,3051
|
19
|
-
dataeval/detectors/ood/metadata_ood_mi.py,sha256=aMSP3zh5EwIWqM7w135ZAuTVnpqYI4dN3tEOrx41lsk,3837
|
20
19
|
dataeval/detectors/ood/mixin.py,sha256=0_o-1HPvgf3-Lf1MSOIfjj5UB8LTLEBGYtJJfyCCzwc,5431
|
21
20
|
dataeval/detectors/ood/vae.py,sha256=Fcq0-WbLhzYCgYOAJPBklHm7yuXmFJuEpBkhgwM5kiA,2291
|
22
|
-
dataeval/metadata/__init__.py,sha256=
|
21
|
+
dataeval/metadata/__init__.py,sha256=XDDmJbOZBNM6pL0r6Nbu6oMRoyAh22IDkPYGndNlkZU,316
|
23
22
|
dataeval/metadata/_distance.py,sha256=xsXMMg1pJkHcEZ-KIlqv9YOGYVID3ELjt3-fr1QVnOs,4082
|
24
|
-
dataeval/metadata/_ood.py,sha256=
|
23
|
+
dataeval/metadata/_ood.py,sha256=HbS5MusWl62hjixUAd-xaaT0KXkYY1M-MlnUaAI_-8M,12751
|
25
24
|
dataeval/metadata/_utils.py,sha256=r8qBJT83RblobD5W5zyTVi6vYi51Dwkqswizdbzss-M,1169
|
26
25
|
dataeval/metrics/__init__.py,sha256=8VC8q3HuJN3o_WN51Ae2_wXznl3RMXIvA5GYVcy7vr8,225
|
27
26
|
dataeval/metrics/bias/__init__.py,sha256=1yTLmgiu1kwT_7ZWcjOUbj8R0NJ0DjGoCuWdA0_T7kc,683
|
28
|
-
dataeval/metrics/bias/_balance.py,sha256=
|
27
|
+
dataeval/metrics/bias/_balance.py,sha256=UnUgbPk2ybFfS5qxv8e_uim7RxamWj0UQP71x3omGs0,6158
|
29
28
|
dataeval/metrics/bias/_coverage.py,sha256=PeUoOiaghUEdn6Ov8z2-am7-fnBVIPcFbJK7Ty5JObA,3647
|
30
29
|
dataeval/metrics/bias/_diversity.py,sha256=U_l4oYjH39rON2Io0BdCIwJxxob0cKTW8bZNufG0CWs,5820
|
31
30
|
dataeval/metrics/bias/_parity.py,sha256=8JRZv4wLpxN9zTvMDlcpKgz-2nO-9eVjqccODcf2nbw,11535
|
32
31
|
dataeval/metrics/estimators/__init__.py,sha256=Pnds8uIyAovt2fKqZjiHCIP_kVoBWlVllekYuK5UmmU,568
|
33
|
-
dataeval/metrics/estimators/_ber.py,sha256=
|
32
|
+
dataeval/metrics/estimators/_ber.py,sha256=C30E5LiGGTAfo31zWFYDptDg0R7CTJGJ-a60YgzSkYY,5382
|
34
33
|
dataeval/metrics/estimators/_clusterer.py,sha256=1HrpihGTJ63IkNSOy4Ibw633Gllkm1RxKmoKT5MOgt0,1434
|
35
34
|
dataeval/metrics/estimators/_divergence.py,sha256=QDWl1lyAYoO9D3Ho7qOHSk6ud8Gi2MGuXEsYwO1HxvA,4043
|
36
35
|
dataeval/metrics/estimators/_uap.py,sha256=BULEBbJ9BQ1IcTeZf0x7iI60QHAWCccBOM97FIu9VXA,1928
|
37
36
|
dataeval/metrics/stats/__init__.py,sha256=6tA_9nbbM5ObJ6cds8Y1VBtTQiTOxrpGQSFLu_lWGGA,1098
|
38
|
-
dataeval/metrics/stats/_base.py,sha256=
|
37
|
+
dataeval/metrics/stats/_base.py,sha256=rn0CrRCvVh3QLDEi_JlOFVUoQ-xtclnOoHt_o1E26J4,10656
|
39
38
|
dataeval/metrics/stats/_boxratiostats.py,sha256=8Kd2FTZ5PLNYZfdAjU_R385gb0Z16JY0L9H_d5ZhgQs,6341
|
40
|
-
dataeval/metrics/stats/_dimensionstats.py,sha256=
|
41
|
-
dataeval/metrics/stats/_hashstats.py,sha256=
|
42
|
-
dataeval/metrics/stats/_imagestats.py,sha256=
|
39
|
+
dataeval/metrics/stats/_dimensionstats.py,sha256=h2wCLn4UuW7-GV6tM5E1SqSeGa_-4ie9oaEXpSC7EKI,2690
|
40
|
+
dataeval/metrics/stats/_hashstats.py,sha256=yD6cXMvOo10-xtwUr7ftBRbCqMhReNfQJMInEWV_8Mk,4757
|
41
|
+
dataeval/metrics/stats/_imagestats.py,sha256=hyjijPXAfUIJ1lwWiIyYK9VSLiq7Vg2-YhJ5Q8s1rkY,2979
|
43
42
|
dataeval/metrics/stats/_labelstats.py,sha256=PtGyqj4RHw0cyLAWAR9FzZGqgA81AtxLGHZiuMAL2h0,4100
|
44
|
-
dataeval/metrics/stats/_pixelstats.py,sha256=
|
45
|
-
dataeval/metrics/stats/_visualstats.py,sha256=
|
46
|
-
dataeval/outputs/__init__.py,sha256=
|
43
|
+
dataeval/metrics/stats/_pixelstats.py,sha256=Q0-ldG-znDYBP_qTqm6S4qYm0ZV5FTTHf8MlyGHSYEc,3235
|
44
|
+
dataeval/metrics/stats/_visualstats.py,sha256=ZxBDTerZ8ixibY2pGl7mwwcIz3DWl-k_Jb4YwBjHLNw,3686
|
45
|
+
dataeval/outputs/__init__.py,sha256=uxTAr1Kn0QNwC7zn1U_5WBAgwZxupM3JGgD25DyO6yI,1655
|
47
46
|
dataeval/outputs/_base.py,sha256=aZFbgybnZSQ3ws7QYRLTbDFqUfBFRVtIwX2LZfeGFUA,5703
|
48
47
|
dataeval/outputs/_bias.py,sha256=O5RHbTUJDwkwJfz2-YoOfRb4eDl5Tg1UFVtvs025wfA,12173
|
49
48
|
dataeval/outputs/_drift.py,sha256=gOiu2C-ERTWiRqlP0auMYxPBGdm9HecWPqWfg7I4tZg,2015
|
50
49
|
dataeval/outputs/_estimators.py,sha256=a2oAIxxEDZ9WLGfMWH8KD-BVUS_SnULRPR-iI9hFPoQ,3047
|
51
50
|
dataeval/outputs/_linters.py,sha256=YOdjrfm8ypdRrqYOaPM9nc6wVJI3-ita3Haj7LHDNaw,6416
|
52
|
-
dataeval/outputs/_metadata.py,sha256=
|
51
|
+
dataeval/outputs/_metadata.py,sha256=ffZgpX8KWURPHXpOWjbvJ2KRqWQkS2nWuIjKUzoHhMI,1710
|
53
52
|
dataeval/outputs/_ood.py,sha256=suLKVXULGtXH0rq9eXHI1d3d2jhGmItJtz4QiQd47A4,1718
|
54
53
|
dataeval/outputs/_stats.py,sha256=PhRdyWWZxewzenFx0MxK9y9ZLE2MnMA-a4-JeSJ_Bs8,13180
|
55
54
|
dataeval/outputs/_utils.py,sha256=HHlGC7sk416m_3Bgn075Qdblz_aPup_UOafJpB0RuXY,893
|
56
55
|
dataeval/outputs/_workflows.py,sha256=MkRD6ubI4NCBXb9v3kjXy64cUGs3G-JKkBdOpRD9XVE,10750
|
57
56
|
dataeval/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
58
|
-
dataeval/typing.py,sha256=
|
57
|
+
dataeval/typing.py,sha256=YQ1KteeK1zf2mcWwngWwQP8EC3pI4WsvAzp_x179b4g,6568
|
59
58
|
dataeval/utils/__init__.py,sha256=T8F8zJh4ZAeu0wDzfpld92I2zJg9mWBmkGCHrDPU7gk,264
|
60
59
|
dataeval/utils/_array.py,sha256=fc04sYShIdsRS4qtG1UCnlGGk-yVRxlOHTNAmW7NpDY,4990
|
61
60
|
dataeval/utils/_bin.py,sha256=nylthmsC3vzLHLhlUMACvZs--h7xvAh9Pt75InaQJW8,7322
|
@@ -63,43 +62,44 @@ dataeval/utils/_clusterer.py,sha256=fw5x-2QN0TIbiodDKHZxRgxKHINedpPcOklzce0Rbjg,
|
|
63
62
|
dataeval/utils/_fast_mst.py,sha256=4_7ykVihCL5jWtxcGnrecIsDQo65kUml9SZ1JxgBZYY,7172
|
64
63
|
dataeval/utils/_image.py,sha256=capzF_X5H0jy0PmTP3Hf52GFgLqrnfU6gS4tiwck9jo,1939
|
65
64
|
dataeval/utils/_method.py,sha256=9B9JQbgqWJBRhQJb7glajUtWaQzUTIUuvrZ9_bisxsM,394
|
66
|
-
dataeval/utils/_mst.py,sha256=
|
65
|
+
dataeval/utils/_mst.py,sha256=f0vXytTUjlOS6AyL7c6PkXmaHuuGUK-vMLpq-5xMgxk,2183
|
67
66
|
dataeval/utils/_plot.py,sha256=mTRQNbJsA42QMiOwZbJaH8sNYgP996QFDEGVVE9HSgY,7076
|
68
67
|
dataeval/utils/data/__init__.py,sha256=vldQ2ZXl8gnI3s4vAGqUUVi6dc_R58F3JMSpbCOyFRI,820
|
69
68
|
dataeval/utils/data/_dataset.py,sha256=tjZUJnxj9IY71GKqdKltrwufkn0EC0S3a6ylrW5Bc2s,7756
|
70
|
-
dataeval/utils/data/_embeddings.py,sha256=
|
69
|
+
dataeval/utils/data/_embeddings.py,sha256=fKGFJXhb4ajnBE3jrKxIvBAhBQ6HpcYYkpO_sAk3jTE,3669
|
71
70
|
dataeval/utils/data/_images.py,sha256=pv_vvpH8hWxPgLvjeVC2mZiyZivZFNLARNIOXam5ceY,1984
|
72
71
|
dataeval/utils/data/_metadata.py,sha256=VqeePp7NtoFFWzmIhH4fn-cjrnATpgzgzs-d73cnBXM,14370
|
73
|
-
dataeval/utils/data/_selection.py,sha256=
|
72
|
+
dataeval/utils/data/_selection.py,sha256=nlslafwAfoZ5d5K_v9bIIvij-UP0NcalKqH4Nw7A-S4,4553
|
74
73
|
dataeval/utils/data/_split.py,sha256=YdsqTRjKbdSfg8w0f4XgX7j0uOSdtfzvvyObAzyqgI0,18433
|
75
74
|
dataeval/utils/data/_targets.py,sha256=ws5d9wRiDkIuOV7GSAKNxzgSm6AWTgb0BFroQK5nAmM,3057
|
76
75
|
dataeval/utils/data/collate.py,sha256=Z5nmBnWV_IoJzMp_tj8RCKjMJA9sSCY_zZITqISGixc,3865
|
77
76
|
dataeval/utils/data/datasets/__init__.py,sha256=jBrswiERrvBx4pJQJZIq_B5UE-Wy8a2_SBfM2crG8R8,511
|
78
|
-
dataeval/utils/data/datasets/_base.py,sha256=
|
79
|
-
dataeval/utils/data/datasets/_cifar10.py,sha256=
|
77
|
+
dataeval/utils/data/datasets/_base.py,sha256=CZ-hb-yWPLdnTQ3pURJMcityQ42ZNYj_Lbb1P5Junn4,8793
|
78
|
+
dataeval/utils/data/datasets/_cifar10.py,sha256=I6HKksE2escos1aTdiZJObtiVXChBlez5BDa0eBfJ_Y,5449
|
80
79
|
dataeval/utils/data/datasets/_fileio.py,sha256=SixIk5nIlIwJdX9zjNXS10vHA3hL8aaYbqHsDg1xSpY,6447
|
81
|
-
dataeval/utils/data/datasets/_milco.py,sha256=
|
80
|
+
dataeval/utils/data/datasets/_milco.py,sha256=ScBe7Ux-J9Kxs33jeKffhWKeSb8GCrWznTyEUt95Vt4,6369
|
82
81
|
dataeval/utils/data/datasets/_mixin.py,sha256=FJgZP_cpJkgAHA3j3ai_j3Wt7aFSEjIMVmt9NpvVXzg,1757
|
83
|
-
dataeval/utils/data/datasets/_mnist.py,sha256=
|
84
|
-
dataeval/utils/data/datasets/_ships.py,sha256=
|
85
|
-
dataeval/utils/data/datasets/_types.py,sha256=
|
86
|
-
dataeval/utils/data/datasets/_voc.py,sha256=
|
87
|
-
dataeval/utils/data/selections/__init__.py,sha256=
|
88
|
-
dataeval/utils/data/selections/_classfilter.py,sha256=
|
82
|
+
dataeval/utils/data/datasets/_mnist.py,sha256=iWWI9mq6TbZm7eTL9btzqjCNMhgXrLHQeMKENr7USsk,7988
|
83
|
+
dataeval/utils/data/datasets/_ships.py,sha256=p3fScYLW2f1wUEPOroCX5nOFti0vMOSjeYltj6ox53U,4777
|
84
|
+
dataeval/utils/data/datasets/_types.py,sha256=iSKyHXRlGuomXs0FHK6md8lXLQrQQ4fxgVOwr4o81bo,1089
|
85
|
+
dataeval/utils/data/datasets/_voc.py,sha256=4poEer_G_mUBcz6eAro0Tc29CjdgjEAlms0Eu0tLBzE,14842
|
86
|
+
dataeval/utils/data/selections/__init__.py,sha256=k86OpqGPkjT1MrOir5fOZ3AIq5UR81Az9ek7l1-GdIM,565
|
87
|
+
dataeval/utils/data/selections/_classfilter.py,sha256=opSF8CGv4x1hUMe-GTQOu3UwJK80DzT0nJOV0l2uaW4,2404
|
89
88
|
dataeval/utils/data/selections/_indices.py,sha256=QdLgXN7GABCvGPYe28PV1RAc_RSP_nZOyCvEpKRBdWg,636
|
90
89
|
dataeval/utils/data/selections/_limit.py,sha256=ECvHRsp7OF4LZw2tE4sGqqJ085kjC-hd2c7QDMfvXr8,518
|
90
|
+
dataeval/utils/data/selections/_prioritize.py,sha256=EAA4_uFVV7MmemhhufGmP7eunnbtyTc-TzgcnvRK5OE,11333
|
91
91
|
dataeval/utils/data/selections/_reverse.py,sha256=6SWpELC9Wgx-kPqzhDrPNn4NKU6FqDJveLrxV4D2Ypk,374
|
92
|
-
dataeval/utils/data/selections/_shuffle.py,sha256=
|
92
|
+
dataeval/utils/data/selections/_shuffle.py,sha256=kY3xJvVbBArdrJu_u6mXmxk1HdNmmDE4w7MmxbevUmU,1178
|
93
93
|
dataeval/utils/metadata.py,sha256=X8Hu4LdCzAaE9uk1hI4BflmFve_VOQCqK9lXq0sk9ow,14196
|
94
94
|
dataeval/utils/torch/__init__.py,sha256=dn5mjCrFp0b1aL_UEURhONU0Ag0cmXoTOBSGagpkTiA,325
|
95
95
|
dataeval/utils/torch/_blocks.py,sha256=HVhBTMMD5NA4qheMUgyol1KWiKZDIuc8k5j4RcMKmhk,1466
|
96
|
-
dataeval/utils/torch/_gmm.py,sha256=
|
96
|
+
dataeval/utils/torch/_gmm.py,sha256=XM68GNEP97EjaB1U49-ZXRb81d0CEFnPS910alrcB3g,3740
|
97
97
|
dataeval/utils/torch/_internal.py,sha256=23DCnF7C7N3tZgZUpT2nyH7mMb8Pi4GcnQyjK0BKHpg,5735
|
98
98
|
dataeval/utils/torch/models.py,sha256=hmroEs6C6jQ5tAoZa71RFeIvXLxfXrTJSFH_jG2LGQU,9749
|
99
99
|
dataeval/utils/torch/trainer.py,sha256=iUotX4OdirH8-ZtjdpU8gbJavkYW9YY9qpA2mAlFy1Y,5520
|
100
100
|
dataeval/workflows/__init__.py,sha256=ou8y0KO-d6W5lgmcyLjKlf-J_ckP3vilW7wHkgiDlZ4,255
|
101
101
|
dataeval/workflows/sufficiency.py,sha256=mjKmfRrAjShLUFIARv5o8yT5fnFvDsS5Qu6ujIPUgQg,8497
|
102
|
-
dataeval-0.
|
103
|
-
dataeval-0.
|
104
|
-
dataeval-0.
|
105
|
-
dataeval-0.
|
102
|
+
dataeval-0.83.0.dist-info/LICENSE.txt,sha256=uAooygKWvX6NbU9Ran9oG2msttoG8aeTeHSTe5JeCnY,1061
|
103
|
+
dataeval-0.83.0.dist-info/METADATA,sha256=lVRLNQcl2DYQDo7GHpFv_z133aD5hn-uOCkXXltGK5s,5320
|
104
|
+
dataeval-0.83.0.dist-info/WHEEL,sha256=Nq82e9rUAnEjt98J6MlVmMCZb-t9cYE2Ir1kpBmnWfs,88
|
105
|
+
dataeval-0.83.0.dist-info/RECORD,,
|
@@ -1,91 +0,0 @@
|
|
1
|
-
from __future__ import annotations
|
2
|
-
|
3
|
-
__all__ = []
|
4
|
-
|
5
|
-
import numbers
|
6
|
-
import warnings
|
7
|
-
from typing import Any
|
8
|
-
|
9
|
-
import numpy as np
|
10
|
-
from numpy.typing import NDArray
|
11
|
-
from sklearn.feature_selection import mutual_info_classif
|
12
|
-
|
13
|
-
from dataeval.config import get_seed
|
14
|
-
|
15
|
-
# NATS2BITS is the reciprocal of natural log of 2. If you have an information/entropy-type quantity measured in nats,
|
16
|
-
# which is what many library functions return, multiply it by NATS2BITS to get it in bits.
|
17
|
-
NATS2BITS = 1.442695
|
18
|
-
|
19
|
-
|
20
|
-
def get_metadata_ood_mi(
|
21
|
-
metadata: dict[str, list[Any] | NDArray[Any]],
|
22
|
-
is_ood: NDArray[np.bool_],
|
23
|
-
discrete_features: str | bool | NDArray[np.bool_] = False,
|
24
|
-
) -> dict[str, float]:
|
25
|
-
"""Computes mutual information between a set of metadata features and an out-of-distribution flag.
|
26
|
-
|
27
|
-
Given a metadata dictionary `metadata` (where each key maps to one scalar metadata feature per example), and a
|
28
|
-
corresponding boolean flag `is_ood` indicating whether each example falls out-of-distribution (OOD) relative to a
|
29
|
-
reference dataset, this function finds the strength of association between each metadata feature and `is_ood` by
|
30
|
-
computing their mutual information. Metadata features may be either discrete or continuous; set the
|
31
|
-
`discrete_features` keyword to a bool array set to True for each feature that is discrete, or pass one bool to apply
|
32
|
-
to all features. Returns a dict indicating the strength of association between each individual feature and the OOD
|
33
|
-
flag, measured in bits.
|
34
|
-
|
35
|
-
Parameters
|
36
|
-
----------
|
37
|
-
metadata : dict[str, list[Any] | NDArray[Any]]
|
38
|
-
A set of arrays of values, indexed by metadata feature names, with one value per data example per feature.
|
39
|
-
is_ood : NDArray[np.bool_]
|
40
|
-
A boolean array, with one value per example, that indicates which examples are OOD.
|
41
|
-
discrete_features : str | bool | NDArray[np.bool_]
|
42
|
-
Either a boolean array or a single boolean value, indicate which features take on discrete values.
|
43
|
-
|
44
|
-
Returns
|
45
|
-
-------
|
46
|
-
dict[str, float]
|
47
|
-
A dictionary with keys corresponding to metadata feature names, and values indicating the strength of
|
48
|
-
association between each named feature and the OOD flag, as mutual information measured in bits.
|
49
|
-
|
50
|
-
Examples
|
51
|
-
--------
|
52
|
-
Imagine we have 3 data examples, and that the corresponding metadata contains 2 features called time and altitude.
|
53
|
-
|
54
|
-
>>> metadata = {"time": np.linspace(0, 10, 100), "altitude": np.linspace(0, 16, 100) ** 2}
|
55
|
-
>>> is_ood = metadata["altitude"] > 100
|
56
|
-
>>> get_metadata_ood_mi(metadata, is_ood, discrete_features=False)
|
57
|
-
{'time': 0.9359596758173668, 'altitude': 0.9407686591507002}
|
58
|
-
"""
|
59
|
-
numerical_keys = [k for k, v in metadata.items() if all(isinstance(vi, numbers.Number) for vi in v)]
|
60
|
-
if len(numerical_keys) < len(metadata):
|
61
|
-
warnings.warn(
|
62
|
-
f"Processing {numerical_keys}, others are non-numerical and will be skipped.",
|
63
|
-
UserWarning,
|
64
|
-
)
|
65
|
-
|
66
|
-
md_lengths = {len(np.atleast_1d(v)) for v in metadata.values()}
|
67
|
-
if len(md_lengths) > 1:
|
68
|
-
raise ValueError(f"Metadata features have differing sizes: {md_lengths}")
|
69
|
-
|
70
|
-
if len(is_ood) != (mdl := md_lengths.pop()):
|
71
|
-
raise ValueError(
|
72
|
-
f"OOD flag and metadata features need to be same size, but are different sizes: {len(is_ood)} and {mdl}."
|
73
|
-
)
|
74
|
-
|
75
|
-
X = np.array([metadata[k] for k in numerical_keys]).T
|
76
|
-
|
77
|
-
X0, dX = np.mean(X, axis=0), np.std(X, axis=0, ddof=1)
|
78
|
-
Xscl = (X - X0) / dX
|
79
|
-
|
80
|
-
mutual_info_values = (
|
81
|
-
mutual_info_classif(
|
82
|
-
Xscl,
|
83
|
-
is_ood,
|
84
|
-
discrete_features=discrete_features, # type: ignore
|
85
|
-
random_state=get_seed(),
|
86
|
-
)
|
87
|
-
* NATS2BITS
|
88
|
-
)
|
89
|
-
|
90
|
-
mi_dict = {k: mutual_info_values[i] for i, k in enumerate(numerical_keys)}
|
91
|
-
return mi_dict
|
File without changes
|
File without changes
|