dataeval 0.84.0__py3-none-any.whl → 0.84.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (40) hide show
  1. dataeval/__init__.py +1 -1
  2. dataeval/detectors/drift/__init__.py +2 -2
  3. dataeval/detectors/drift/_base.py +55 -203
  4. dataeval/detectors/drift/_cvm.py +19 -30
  5. dataeval/detectors/drift/_ks.py +18 -30
  6. dataeval/detectors/drift/_mmd.py +189 -53
  7. dataeval/detectors/drift/_uncertainty.py +52 -56
  8. dataeval/detectors/drift/updates.py +13 -12
  9. dataeval/detectors/linters/duplicates.py +5 -3
  10. dataeval/detectors/linters/outliers.py +2 -2
  11. dataeval/detectors/ood/ae.py +1 -1
  12. dataeval/metrics/stats/_base.py +7 -7
  13. dataeval/metrics/stats/_dimensionstats.py +2 -2
  14. dataeval/metrics/stats/_hashstats.py +2 -2
  15. dataeval/metrics/stats/_imagestats.py +4 -4
  16. dataeval/metrics/stats/_pixelstats.py +2 -2
  17. dataeval/metrics/stats/_visualstats.py +2 -2
  18. dataeval/typing.py +22 -19
  19. dataeval/utils/_array.py +18 -7
  20. dataeval/utils/data/_dataset.py +6 -4
  21. dataeval/utils/data/_embeddings.py +46 -7
  22. dataeval/utils/data/_images.py +2 -2
  23. dataeval/utils/data/_metadata.py +5 -4
  24. dataeval/utils/data/datasets/_base.py +7 -4
  25. dataeval/utils/data/datasets/_cifar10.py +9 -9
  26. dataeval/utils/data/datasets/_milco.py +42 -14
  27. dataeval/utils/data/datasets/_mnist.py +9 -5
  28. dataeval/utils/data/datasets/_ships.py +8 -4
  29. dataeval/utils/data/datasets/_voc.py +40 -19
  30. dataeval/utils/data/selections/__init__.py +2 -0
  31. dataeval/utils/data/selections/_classbalance.py +38 -0
  32. dataeval/utils/data/selections/_classfilter.py +14 -29
  33. dataeval/utils/data/selections/_prioritize.py +1 -1
  34. dataeval/utils/data/selections/_shuffle.py +2 -2
  35. dataeval/utils/torch/_internal.py +12 -35
  36. {dataeval-0.84.0.dist-info → dataeval-0.84.1.dist-info}/METADATA +2 -3
  37. {dataeval-0.84.0.dist-info → dataeval-0.84.1.dist-info}/RECORD +39 -39
  38. dataeval/detectors/drift/_torch.py +0 -222
  39. {dataeval-0.84.0.dist-info → dataeval-0.84.1.dist-info}/LICENSE.txt +0 -0
  40. {dataeval-0.84.0.dist-info → dataeval-0.84.1.dist-info}/WHEEL +0 -0
@@ -1,20 +1,19 @@
1
- dataeval/__init__.py,sha256=VczdyekiNdqvi2dEUf7xXBu3Aw-MoVTvq-k6c2zjeBM,1636
1
+ dataeval/__init__.py,sha256=QzrctVrymZuLN8tnHcF1wp0RTXYM3WSWMozX3NOzIos,1636
2
2
  dataeval/_log.py,sha256=Mn5bRWO0cgtAYd5VGYSFiPgu57ta3zoktrtHAZ1m3dU,357
3
3
  dataeval/config.py,sha256=lD1YDH8HosFeRU5rQEYRBcmXMZy-csWaMlJTRZGd9iU,3582
4
4
  dataeval/detectors/__init__.py,sha256=3Sg-XWlwr75zEEH3hZKA4nWMtGvaRlnfzTWvZG_Ak6U,189
5
- dataeval/detectors/drift/__init__.py,sha256=6is_XBtG1d-vUbhHvqXGOdnAwxJ7NA5yRfURn7pCeIw,651
6
- dataeval/detectors/drift/_base.py,sha256=mJdKvyROgWvz-p1VlAIJqUI6BAj9ss8riUvR5An5wIw,13459
7
- dataeval/detectors/drift/_cvm.py,sha256=H2w-I0eMD7yP-CSmpdodeJ0-TYznJT7w_H7JuobESow,3859
8
- dataeval/detectors/drift/_ks.py,sha256=-5k3RBPA3kadX7oD14Wc52rAqQf1udwFeW7Qf3Sv4Tw,4058
9
- dataeval/detectors/drift/_mmd.py,sha256=NEXowx9UHIvmEKS8sqssw6PMLJMh0BZPhRNX1hYlkz4,7239
10
- dataeval/detectors/drift/_torch.py,sha256=VrFCyTaRrUslFPy_mYZ4UL70LZ8faH4eHwLurZ9qqNE,7628
11
- dataeval/detectors/drift/_uncertainty.py,sha256=O5h6_bJbeQEE660SLLP8k-EHqImmKegIcxzcnUKI7X4,5714
12
- dataeval/detectors/drift/updates.py,sha256=Btu2iaZW7fbO59G1w5v3ykFot0YPzy2U6VjF0d440VE,2195
5
+ dataeval/detectors/drift/__init__.py,sha256=gD8aY5PotS-S2ot7iB_z_zzSOjIbQLw5znFBNj0jtHE,646
6
+ dataeval/detectors/drift/_base.py,sha256=PdWyEuYqExFdyxvyOh7Q8yXnjNm0D3KfpDUn0bUixtY,7580
7
+ dataeval/detectors/drift/_cvm.py,sha256=CSEyNN9u1MzUI6QmCSlexTUSlHzK1kYh36Nv2L72WbY,3016
8
+ dataeval/detectors/drift/_ks.py,sha256=ifFb_0JcyykJyF9DAVkQqWCXc-3aA0AC8c8to_oOPKo,3198
9
+ dataeval/detectors/drift/_mmd.py,sha256=DOHBNyNNxosR67yM9HTxbvqp1IZ8_KSvTVlX-JtKtjM,11601
10
+ dataeval/detectors/drift/_uncertainty.py,sha256=BHlykJ-r7TGLJxdPfoazXnoAJ1qVDzbk5HjAMdsnHz8,5847
11
+ dataeval/detectors/drift/updates.py,sha256=L1PnrPlIE1x6ujCc5mCwjcAZwadVTn-Zjb6MnTDvzJQ,2251
13
12
  dataeval/detectors/linters/__init__.py,sha256=xn2zPwUcmsuf-Jd9uw6AVI11C9z1b1Y9fYtuFnXenZ0,404
14
- dataeval/detectors/linters/duplicates.py,sha256=tcxniL8rRZkDdQqfuS502UmfKxS3a7iRA22Dtt_vQIk,4935
15
- dataeval/detectors/linters/outliers.py,sha256=Hln2dPQZjF_uV2QYptA_o6ZF3ugyCImVT-XLDB2-q3A,9042
13
+ dataeval/detectors/linters/duplicates.py,sha256=x36-0EAlO_AuOttvElJOZCa0R3VzrlII0NxjwhdkrpE,4969
14
+ dataeval/detectors/linters/outliers.py,sha256=Z0Sbtluu2im1IRGsjhXF2AhrShKDrEkF8BWzAZyPwlA,9054
16
15
  dataeval/detectors/ood/__init__.py,sha256=juCYBDs7CQEAtMhnEpPqF6uTrOIH9kTBSuQ_GRw6a8o,283
17
- dataeval/detectors/ood/ae.py,sha256=YQfhB1ShQLjM1V4uCz9Oo2tCZpOfAZ_-SBCAl4Ac67Y,2921
16
+ dataeval/detectors/ood/ae.py,sha256=fTrUfFxv6xUqzKpwMC8rW3JrizA16M_bgzqLuBKMrS0,2944
18
17
  dataeval/detectors/ood/base.py,sha256=9b-Ljznf0lB1SXF4F_Aj3eJ4Y3ijGEDPMjucUsWOGJM,3051
19
18
  dataeval/detectors/ood/mixin.py,sha256=0_o-1HPvgf3-Lf1MSOIfjj5UB8LTLEBGYtJJfyCCzwc,5431
20
19
  dataeval/detectors/ood/vae.py,sha256=Fcq0-WbLhzYCgYOAJPBklHm7yuXmFJuEpBkhgwM5kiA,2291
@@ -35,14 +34,14 @@ dataeval/metrics/estimators/_clusterer.py,sha256=1HrpihGTJ63IkNSOy4Ibw633Gllkm1R
35
34
  dataeval/metrics/estimators/_divergence.py,sha256=QDWl1lyAYoO9D3Ho7qOHSk6ud8Gi2MGuXEsYwO1HxvA,4043
36
35
  dataeval/metrics/estimators/_uap.py,sha256=BULEBbJ9BQ1IcTeZf0x7iI60QHAWCccBOM97FIu9VXA,1928
37
36
  dataeval/metrics/stats/__init__.py,sha256=6tA_9nbbM5ObJ6cds8Y1VBtTQiTOxrpGQSFLu_lWGGA,1098
38
- dataeval/metrics/stats/_base.py,sha256=rA-Xt9slf2DOR5ky9gGR5s1pmzTb47DykovDp5EWEP0,10672
37
+ dataeval/metrics/stats/_base.py,sha256=YIfOVGd7E19B4dpAnzDYRQkaikvRRyJIpznJNfVtPdw,10750
39
38
  dataeval/metrics/stats/_boxratiostats.py,sha256=8Kd2FTZ5PLNYZfdAjU_R385gb0Z16JY0L9H_d5ZhgQs,6341
40
- dataeval/metrics/stats/_dimensionstats.py,sha256=h2wCLn4UuW7-GV6tM5E1SqSeGa_-4ie9oaEXpSC7EKI,2690
41
- dataeval/metrics/stats/_hashstats.py,sha256=yD6cXMvOo10-xtwUr7ftBRbCqMhReNfQJMInEWV_8Mk,4757
42
- dataeval/metrics/stats/_imagestats.py,sha256=hyjijPXAfUIJ1lwWiIyYK9VSLiq7Vg2-YhJ5Q8s1rkY,2979
39
+ dataeval/metrics/stats/_dimensionstats.py,sha256=73mFP-Myxne0peFliwvTntc0kk4cpq0krzMvSLDSIMM,2702
40
+ dataeval/metrics/stats/_hashstats.py,sha256=gp9X_pnTT3mPH9YNrWLdn2LQPK_epJ3dQRoyOCwmKlg,4758
41
+ dataeval/metrics/stats/_imagestats.py,sha256=gUPNgN5Zwzdr7WnSwbve1NXNsyxd5dy3cSnlR_7guCg,3007
43
42
  dataeval/metrics/stats/_labelstats.py,sha256=WbvXZ831a5BDfm58HF8Z8i5JUV1tgw7tcfzUh8pOXSo,2825
44
- dataeval/metrics/stats/_pixelstats.py,sha256=Q0-ldG-znDYBP_qTqm6S4qYm0ZV5FTTHf8MlyGHSYEc,3235
45
- dataeval/metrics/stats/_visualstats.py,sha256=ZxBDTerZ8ixibY2pGl7mwwcIz3DWl-k_Jb4YwBjHLNw,3686
43
+ dataeval/metrics/stats/_pixelstats.py,sha256=SfergRbjNJE4h0xqe-0c8RnKtZmEkZ9MwExdipLSGvg,3247
44
+ dataeval/metrics/stats/_visualstats.py,sha256=cq4AbF2B50Ihbzb86FphcnKQ1TSwNnP3PsnbpiPQZWw,3698
46
45
  dataeval/outputs/__init__.py,sha256=ciK-RdXgtn_s7MSCUW1UXvrXltMbltqbpfe9_V7xGrI,1701
47
46
  dataeval/outputs/_base.py,sha256=aZFbgybnZSQ3ws7QYRLTbDFqUfBFRVtIwX2LZfeGFUA,5703
48
47
  dataeval/outputs/_bias.py,sha256=GwbjLdppUODOeudYb_7ki2ejDmAYthlRKGijVwgVePE,12407
@@ -55,9 +54,9 @@ dataeval/outputs/_stats.py,sha256=c73Yc3Kkrl-MN6BGKe1V0Yr6Ix2Yp_DZZfFSp8fZMZ0,13
55
54
  dataeval/outputs/_utils.py,sha256=HHlGC7sk416m_3Bgn075Qdblz_aPup_UOafJpB0RuXY,893
56
55
  dataeval/outputs/_workflows.py,sha256=MkRD6ubI4NCBXb9v3kjXy64cUGs3G-JKkBdOpRD9XVE,10750
57
56
  dataeval/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
58
- dataeval/typing.py,sha256=h1wNoWovasIEAWwNWvICToJ1bE1DRba68HifP14zjZc,6598
57
+ dataeval/typing.py,sha256=zn6smomSdcO7EeZpeeSP5-8sknTdgUuU7TKe-3UFVrg,6550
59
58
  dataeval/utils/__init__.py,sha256=T8F8zJh4ZAeu0wDzfpld92I2zJg9mWBmkGCHrDPU7gk,264
60
- dataeval/utils/_array.py,sha256=B44LbVi4g7XntVvnPyi5KeyQIxRQSvYhTVD4Th76u5g,5577
59
+ dataeval/utils/_array.py,sha256=KqAdXEMjcXYvdWdYEEoEbigwQJ4S9VYxQS3sRFeY5XY,5929
61
60
  dataeval/utils/_bin.py,sha256=nylthmsC3vzLHLhlUMACvZs--h7xvAh9Pt75InaQJW8,7322
62
61
  dataeval/utils/_clusterer.py,sha256=fw5x-2QN0TIbiodDKHZxRgxKHINedpPcOklzce0Rbjg,5436
63
62
  dataeval/utils/_fast_mst.py,sha256=4_7ykVihCL5jWtxcGnrecIsDQo65kUml9SZ1JxgBZYY,7172
@@ -66,41 +65,42 @@ dataeval/utils/_method.py,sha256=9B9JQbgqWJBRhQJb7glajUtWaQzUTIUuvrZ9_bisxsM,394
66
65
  dataeval/utils/_mst.py,sha256=f0vXytTUjlOS6AyL7c6PkXmaHuuGUK-vMLpq-5xMgxk,2183
67
66
  dataeval/utils/_plot.py,sha256=mTRQNbJsA42QMiOwZbJaH8sNYgP996QFDEGVVE9HSgY,7076
68
67
  dataeval/utils/data/__init__.py,sha256=vldQ2ZXl8gnI3s4vAGqUUVi6dc_R58F3JMSpbCOyFRI,820
69
- dataeval/utils/data/_dataset.py,sha256=RZT05cfXkiPJPNCG6SVf8zmsk0pQUBViyrSo2l1_G5w,7852
70
- dataeval/utils/data/_embeddings.py,sha256=NK87PfzpQUagwU1aBknsEEihAPNR3BIqHnHkpeKEgVs,7028
71
- dataeval/utils/data/_images.py,sha256=Zn2um-oZjypwVTdpQNw7DsjxKyujswe2jLIgxmUPQ7Q,2626
72
- dataeval/utils/data/_metadata.py,sha256=VqeePp7NtoFFWzmIhH4fn-cjrnATpgzgzs-d73cnBXM,14370
68
+ dataeval/utils/data/_dataset.py,sha256=MHY582yRm4FxQkkLWUhKZBb7ZyvWypM6ldUG89vd3uE,7936
69
+ dataeval/utils/data/_embeddings.py,sha256=iDtfLJY1uHoTP4UdQoOt-3wopc6kSOXH_4CVNnmXXA4,8356
70
+ dataeval/utils/data/_images.py,sha256=WF9XJRka8ohUdyI2IKBMAy3JoJhOm1iC-8tbYl8woRM,2642
71
+ dataeval/utils/data/_metadata.py,sha256=62z_qHjoGjiMdpuT36QpNhbWy2UClHWUcjHHlIWT470,14464
73
72
  dataeval/utils/data/_selection.py,sha256=2c6DjyeDIJapbI7xL36eBxFnJHIP8Yxt3oU3rBGMqLk,3948
74
73
  dataeval/utils/data/_split.py,sha256=q-2RwllJgazwuyxB_GoBqK_nLkqIjyTVr2SQKj_7lhw,16767
75
74
  dataeval/utils/data/_targets.py,sha256=ws5d9wRiDkIuOV7GSAKNxzgSm6AWTgb0BFroQK5nAmM,3057
76
75
  dataeval/utils/data/collate.py,sha256=Z5nmBnWV_IoJzMp_tj8RCKjMJA9sSCY_zZITqISGixc,3865
77
76
  dataeval/utils/data/datasets/__init__.py,sha256=jBrswiERrvBx4pJQJZIq_B5UE-Wy8a2_SBfM2crG8R8,511
78
- dataeval/utils/data/datasets/_base.py,sha256=CZ-hb-yWPLdnTQ3pURJMcityQ42ZNYj_Lbb1P5Junn4,8793
79
- dataeval/utils/data/datasets/_cifar10.py,sha256=I6HKksE2escos1aTdiZJObtiVXChBlez5BDa0eBfJ_Y,5449
77
+ dataeval/utils/data/datasets/_base.py,sha256=827nSVhZ-tqeHw1HQ7Qj060CSDd90fEWZomN6FaWnQA,8872
78
+ dataeval/utils/data/datasets/_cifar10.py,sha256=R7QgcCHowAkqhEXOvUhybXTmMlA4BJXkTuAeV9uDgfU,5449
80
79
  dataeval/utils/data/datasets/_fileio.py,sha256=SixIk5nIlIwJdX9zjNXS10vHA3hL8aaYbqHsDg1xSpY,6447
81
- dataeval/utils/data/datasets/_milco.py,sha256=ScBe7Ux-J9Kxs33jeKffhWKeSb8GCrWznTyEUt95Vt4,6369
80
+ dataeval/utils/data/datasets/_milco.py,sha256=bVVDl5W8TdTPU2RiwoPXrfFDM1rKyb-LslwTThBXEr0,7583
82
81
  dataeval/utils/data/datasets/_mixin.py,sha256=FJgZP_cpJkgAHA3j3ai_j3Wt7aFSEjIMVmt9NpvVXzg,1757
83
- dataeval/utils/data/datasets/_mnist.py,sha256=iWWI9mq6TbZm7eTL9btzqjCNMhgXrLHQeMKENr7USsk,7988
84
- dataeval/utils/data/datasets/_ships.py,sha256=p3fScYLW2f1wUEPOroCX5nOFti0vMOSjeYltj6ox53U,4777
82
+ dataeval/utils/data/datasets/_mnist.py,sha256=kNDJw0oyqa6QgU1y9lg-3AzStavK1BB8iHnDOdv9nyE,8112
83
+ dataeval/utils/data/datasets/_ships.py,sha256=rsyIoRAIk40liFgaEb2dg0lYB7__bAGd9zh9ouzjFKg,4880
85
84
  dataeval/utils/data/datasets/_types.py,sha256=iSKyHXRlGuomXs0FHK6md8lXLQrQQ4fxgVOwr4o81bo,1089
86
- dataeval/utils/data/datasets/_voc.py,sha256=4poEer_G_mUBcz6eAro0Tc29CjdgjEAlms0Eu0tLBzE,14842
87
- dataeval/utils/data/selections/__init__.py,sha256=k86OpqGPkjT1MrOir5fOZ3AIq5UR81Az9ek7l1-GdIM,565
88
- dataeval/utils/data/selections/_classfilter.py,sha256=opSF8CGv4x1hUMe-GTQOu3UwJK80DzT0nJOV0l2uaW4,2404
85
+ dataeval/utils/data/datasets/_voc.py,sha256=QUtpbh2EpiBoicsmOo-YIfwRwPXyHj-zB2hFn7tlz0Y,15580
86
+ dataeval/utils/data/selections/__init__.py,sha256=iUbMZRDuBXwY3SNAtZTdCVu7SI4zbCyaL6ItXnnq1yI,655
87
+ dataeval/utils/data/selections/_classbalance.py,sha256=hHq9frdwzFLCUmfeJq977Sot_SXhuGANlSsetokhRDc,1465
88
+ dataeval/utils/data/selections/_classfilter.py,sha256=xdR5uX7W5Yivf-mE_CikbRi2fGrZLFrPYun3TeQHTA0,1267
89
89
  dataeval/utils/data/selections/_indices.py,sha256=QdLgXN7GABCvGPYe28PV1RAc_RSP_nZOyCvEpKRBdWg,636
90
90
  dataeval/utils/data/selections/_limit.py,sha256=ECvHRsp7OF4LZw2tE4sGqqJ085kjC-hd2c7QDMfvXr8,518
91
- dataeval/utils/data/selections/_prioritize.py,sha256=EAA4_uFVV7MmemhhufGmP7eunnbtyTc-TzgcnvRK5OE,11333
91
+ dataeval/utils/data/selections/_prioritize.py,sha256=uRQjeQiAc-vvwHMH4CQtXTGJCfjj_h5dgGlhQYFMz1c,11318
92
92
  dataeval/utils/data/selections/_reverse.py,sha256=6SWpELC9Wgx-kPqzhDrPNn4NKU6FqDJveLrxV4D2Ypk,374
93
- dataeval/utils/data/selections/_shuffle.py,sha256=kY3xJvVbBArdrJu_u6mXmxk1HdNmmDE4w7MmxbevUmU,1178
93
+ dataeval/utils/data/selections/_shuffle.py,sha256=_jwms0qcwrknf2Fx84cCXyNOJyhE_V8rcnDOTDn1S2A,1179
94
94
  dataeval/utils/metadata.py,sha256=1XeGYj_e97-nJ_IrWEHPhWICmouYU5qbXWbp7uhZrIE,14171
95
95
  dataeval/utils/torch/__init__.py,sha256=dn5mjCrFp0b1aL_UEURhONU0Ag0cmXoTOBSGagpkTiA,325
96
96
  dataeval/utils/torch/_blocks.py,sha256=HVhBTMMD5NA4qheMUgyol1KWiKZDIuc8k5j4RcMKmhk,1466
97
97
  dataeval/utils/torch/_gmm.py,sha256=XM68GNEP97EjaB1U49-ZXRb81d0CEFnPS910alrcB3g,3740
98
- dataeval/utils/torch/_internal.py,sha256=23DCnF7C7N3tZgZUpT2nyH7mMb8Pi4GcnQyjK0BKHpg,5735
98
+ dataeval/utils/torch/_internal.py,sha256=vHy-DzPhmvE8h3wmWc3aciBJ8nDGzQ1z1jTZgGjmDyM,4154
99
99
  dataeval/utils/torch/models.py,sha256=hmroEs6C6jQ5tAoZa71RFeIvXLxfXrTJSFH_jG2LGQU,9749
100
100
  dataeval/utils/torch/trainer.py,sha256=iUotX4OdirH8-ZtjdpU8gbJavkYW9YY9qpA2mAlFy1Y,5520
101
101
  dataeval/workflows/__init__.py,sha256=ou8y0KO-d6W5lgmcyLjKlf-J_ckP3vilW7wHkgiDlZ4,255
102
102
  dataeval/workflows/sufficiency.py,sha256=mjKmfRrAjShLUFIARv5o8yT5fnFvDsS5Qu6ujIPUgQg,8497
103
- dataeval-0.84.0.dist-info/LICENSE.txt,sha256=uAooygKWvX6NbU9Ran9oG2msttoG8aeTeHSTe5JeCnY,1061
104
- dataeval-0.84.0.dist-info/METADATA,sha256=oEAANNRg8RUIWn9AdrQEV7OUnX5mJbgf4NqXr5QY8AY,5320
105
- dataeval-0.84.0.dist-info/WHEEL,sha256=Nq82e9rUAnEjt98J6MlVmMCZb-t9cYE2Ir1kpBmnWfs,88
106
- dataeval-0.84.0.dist-info/RECORD,,
103
+ dataeval-0.84.1.dist-info/LICENSE.txt,sha256=uAooygKWvX6NbU9Ran9oG2msttoG8aeTeHSTe5JeCnY,1061
104
+ dataeval-0.84.1.dist-info/METADATA,sha256=F7L5PSWHV3z0_4pwA-JSgucW2A4bEv_dtvIMzCTGLZ8,5308
105
+ dataeval-0.84.1.dist-info/WHEEL,sha256=Nq82e9rUAnEjt98J6MlVmMCZb-t9cYE2Ir1kpBmnWfs,88
106
+ dataeval-0.84.1.dist-info/RECORD,,
@@ -1,222 +0,0 @@
1
- """
2
- Source code derived from Alibi-Detect 0.11.4
3
- https://github.com/SeldonIO/alibi-detect/tree/v0.11.4
4
-
5
- Original code Copyright (c) 2023 Seldon Technologies Ltd
6
- Licensed under Apache Software License (Apache 2.0)
7
- """
8
-
9
- from __future__ import annotations
10
-
11
- __all__ = []
12
-
13
- from typing import Any, Callable
14
-
15
- import numpy as np
16
- import torch
17
- import torch.nn as nn
18
- from numpy.typing import NDArray
19
-
20
- from dataeval.config import DeviceLike, get_device
21
- from dataeval.utils.torch._internal import predict_batch
22
-
23
-
24
- def mmd2_from_kernel_matrix(
25
- kernel_mat: torch.Tensor, m: int, permute: bool = False, zero_diag: bool = True
26
- ) -> torch.Tensor:
27
- """
28
- Compute maximum mean discrepancy (MMD^2) between 2 samples x and y from the
29
- full kernel matrix between the samples.
30
-
31
- Parameters
32
- ----------
33
- kernel_mat : torch.Tensor
34
- Kernel matrix between samples x and y.
35
- m : int
36
- Number of instances in y.
37
- permute : bool, default False
38
- Whether to permute the row indices. Used for permutation tests.
39
- zero_diag : bool, default True
40
- Whether to zero out the diagonal of the kernel matrix.
41
-
42
- Returns
43
- -------
44
- torch.Tensor
45
- MMD^2 between the samples from the kernel matrix.
46
- """
47
- n = kernel_mat.shape[0] - m
48
- if zero_diag:
49
- kernel_mat = kernel_mat - torch.diag(kernel_mat.diag())
50
- if permute:
51
- idx = torch.randperm(kernel_mat.shape[0])
52
- kernel_mat = kernel_mat[idx][:, idx]
53
- k_xx, k_yy, k_xy = kernel_mat[:-m, :-m], kernel_mat[-m:, -m:], kernel_mat[-m:, :-m]
54
- c_xx, c_yy = 1 / (n * (n - 1)), 1 / (m * (m - 1))
55
- mmd2 = c_xx * k_xx.sum() + c_yy * k_yy.sum() - 2.0 * k_xy.mean()
56
- return mmd2
57
-
58
-
59
- def preprocess_drift(
60
- x: NDArray[Any],
61
- model: nn.Module,
62
- device: DeviceLike | None = None,
63
- preprocess_batch_fn: Callable | None = None,
64
- batch_size: int = int(1e10),
65
- dtype: type[np.generic] | torch.dtype = np.float32,
66
- ) -> NDArray[Any] | torch.Tensor | tuple[Any, ...]:
67
- """
68
- Prediction function used for preprocessing step of drift detector.
69
-
70
- Parameters
71
- ----------
72
- x : NDArray
73
- Batch of instances.
74
- model : nn.Module
75
- Model used for preprocessing.
76
- device : DeviceLike or None, default None
77
- The hardware device to use if specified, otherwise uses the DataEval
78
- default or torch default.
79
- preprocess_batch_fn : Callable or None, default None
80
- Optional batch preprocessing function. For example to convert a list of objects
81
- to a batch which can be processed by the PyTorch model.
82
- batch_size : int, default 1e10
83
- Batch size used during prediction.
84
- dtype : np.dtype or torch.dtype, default np.float32
85
- Model output type, either a :term:`NumPy` or torch dtype, e.g. np.float32 or torch.float32.
86
-
87
- Returns
88
- -------
89
- NDArray | torch.Tensor | tuple
90
- Numpy array, torch tensor or tuples of those with model outputs.
91
- """
92
- return predict_batch(
93
- x,
94
- model,
95
- device=get_device(device),
96
- batch_size=batch_size,
97
- preprocess_fn=preprocess_batch_fn,
98
- dtype=dtype,
99
- )
100
-
101
-
102
- @torch.jit.script
103
- def _squared_pairwise_distance(
104
- x: torch.Tensor, y: torch.Tensor, a_min: float = 1e-30
105
- ) -> torch.Tensor: # pragma: no cover - torch.jit.script code is compiled and copied
106
- """
107
- PyTorch pairwise squared Euclidean distance between samples x and y.
108
-
109
- Parameters
110
- ----------
111
- x : torch.Tensor
112
- Batch of instances of shape [Nx, features].
113
- y : torch.Tensor
114
- Batch of instances of shape [Ny, features].
115
- a_min : float
116
- Lower bound to clip distance values.
117
-
118
- Returns
119
- -------
120
- torch.Tensor
121
- Pairwise squared Euclidean distance [Nx, Ny].
122
- """
123
- x2 = x.pow(2).sum(dim=-1, keepdim=True)
124
- y2 = y.pow(2).sum(dim=-1, keepdim=True)
125
- dist = torch.addmm(y2.transpose(-2, -1), x, y.transpose(-2, -1), alpha=-2).add_(x2)
126
- return dist.clamp_min_(a_min)
127
-
128
-
129
- def sigma_median(x: torch.Tensor, y: torch.Tensor, dist: torch.Tensor) -> torch.Tensor:
130
- """
131
- Bandwidth estimation using the median heuristic `Gretton2012`
132
-
133
- Parameters
134
- ----------
135
- x : torch.Tensor
136
- Tensor of instances with dimension [Nx, features].
137
- y : torch.Tensor
138
- Tensor of instances with dimension [Ny, features].
139
- dist : torch.Tensor
140
- Tensor with dimensions [Nx, Ny], containing the pairwise distances
141
- between `x` and `y`.
142
-
143
- Returns
144
- -------
145
- torch.Tensor
146
- The computed bandwidth, `sigma`.
147
- """
148
- n = min(x.shape[0], y.shape[0])
149
- n = n if (x[:n] == y[:n]).all() and x.shape == y.shape else 0
150
- n_median = n + (np.prod(dist.shape) - n) // 2 - 1
151
- sigma = (0.5 * dist.flatten().sort().values[int(n_median)].unsqueeze(dim=-1)) ** 0.5
152
- return sigma
153
-
154
-
155
- class GaussianRBF(nn.Module):
156
- """
157
- Gaussian RBF kernel: k(x,y) = exp(-(1/(2*sigma^2)||x-y||^2).
158
-
159
- A forward pass takes a batch of instances x [Nx, features] and
160
- y [Ny, features] and returns the kernel matrix [Nx, Ny].
161
-
162
- Parameters
163
- ----------
164
- sigma : torch.Tensor | None, default None
165
- Bandwidth used for the kernel. Needn't be specified if being inferred or
166
- trained. Can pass multiple values to eval kernel with and then average.
167
- init_sigma_fn : Callable | None, default None
168
- Function used to compute the bandwidth ``sigma``. Used when ``sigma`` is to be
169
- inferred. The function's signature should take in the tensors ``x``, ``y`` and
170
- ``dist`` and return ``sigma``. If ``None``, it is set to ``sigma_median``.
171
- trainable : bool, default False
172
- Whether or not to track gradients w.r.t. `sigma` to allow it to be trained.
173
- """
174
-
175
- def __init__(
176
- self,
177
- sigma: torch.Tensor | None = None,
178
- init_sigma_fn: Callable | None = None,
179
- trainable: bool = False,
180
- ) -> None:
181
- super().__init__()
182
- init_sigma_fn = sigma_median if init_sigma_fn is None else init_sigma_fn
183
- self.config: dict[str, Any] = {
184
- "sigma": sigma,
185
- "trainable": trainable,
186
- "init_sigma_fn": init_sigma_fn,
187
- }
188
- if sigma is None:
189
- self.log_sigma: nn.Parameter = nn.Parameter(torch.empty(1), requires_grad=trainable)
190
- self.init_required: bool = True
191
- else:
192
- sigma = sigma.reshape(-1) # [Ns,]
193
- self.log_sigma: nn.Parameter = nn.Parameter(sigma.log(), requires_grad=trainable)
194
- self.init_required: bool = False
195
- self.init_sigma_fn = init_sigma_fn
196
- self.trainable = trainable
197
-
198
- @property
199
- def sigma(self) -> torch.Tensor:
200
- return self.log_sigma.exp()
201
-
202
- def forward(
203
- self,
204
- x: np.ndarray[Any, Any] | torch.Tensor,
205
- y: np.ndarray[Any, Any] | torch.Tensor,
206
- infer_sigma: bool = False,
207
- ) -> torch.Tensor:
208
- x, y = torch.as_tensor(x), torch.as_tensor(y)
209
- dist = _squared_pairwise_distance(x.flatten(1), y.flatten(1)) # [Nx, Ny]
210
-
211
- if infer_sigma or self.init_required:
212
- if self.trainable and infer_sigma:
213
- raise ValueError("Gradients cannot be computed w.r.t. an inferred sigma value")
214
- sigma = self.init_sigma_fn(x, y, dist)
215
- with torch.no_grad():
216
- self.log_sigma.copy_(sigma.log().clone())
217
- self.init_required: bool = False
218
-
219
- gamma = 1.0 / (2.0 * self.sigma**2) # [Ns,]
220
- # TODO: do matrix multiplication after all?
221
- kernel_mat = torch.exp(-torch.cat([(g * dist)[None, :, :] for g in gamma], dim=0)) # [Ns, Nx, Ny]
222
- return kernel_mat.mean(dim=0) # [Nx, Ny]