braindecode 1.2.0.dev184328194__py3-none-any.whl → 1.3.0.dev171478045__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of braindecode might be problematic. Click here for more details.

Files changed (47) hide show
  1. braindecode/augmentation/functional.py +154 -54
  2. braindecode/augmentation/transforms.py +2 -2
  3. braindecode/datasets/base.py +1 -1
  4. braindecode/datasets/experimental.py +218 -0
  5. braindecode/datasets/sleep_physio_challe_18.py +2 -1
  6. braindecode/datautil/serialization.py +11 -6
  7. braindecode/eegneuralnet.py +2 -0
  8. braindecode/models/__init__.py +12 -8
  9. braindecode/models/atcnet.py +157 -17
  10. braindecode/models/attentionbasenet.py +153 -26
  11. braindecode/models/{sleep_stager_eldele_2021.py → attn_sleep.py} +12 -2
  12. braindecode/models/base.py +280 -2
  13. braindecode/models/bendr.py +469 -0
  14. braindecode/models/biot.py +3 -1
  15. braindecode/models/ctnet.py +1 -1
  16. braindecode/models/deep4.py +6 -2
  17. braindecode/models/deepsleepnet.py +118 -5
  18. braindecode/models/eegconformer.py +114 -15
  19. braindecode/models/eeginception_erp.py +76 -7
  20. braindecode/models/eeginception_mi.py +2 -0
  21. braindecode/models/eegnet.py +64 -177
  22. braindecode/models/eegnex.py +113 -6
  23. braindecode/models/eegsimpleconv.py +2 -0
  24. braindecode/models/eegtcnet.py +1 -1
  25. braindecode/models/labram.py +170 -69
  26. braindecode/models/patchedtransformer.py +640 -0
  27. braindecode/models/sccnet.py +81 -8
  28. braindecode/models/shallow_fbcsp.py +2 -0
  29. braindecode/models/signal_jepa.py +109 -27
  30. braindecode/models/sleep_stager_blanco_2020.py +2 -0
  31. braindecode/models/sleep_stager_chambon_2018.py +2 -0
  32. braindecode/models/sparcnet.py +2 -0
  33. braindecode/models/sstdpn.py +869 -0
  34. braindecode/models/summary.csv +42 -41
  35. braindecode/models/tidnet.py +2 -0
  36. braindecode/models/tsinception.py +15 -3
  37. braindecode/models/usleep.py +103 -9
  38. braindecode/models/util.py +8 -5
  39. braindecode/preprocessing/preprocess.py +31 -28
  40. braindecode/version.py +1 -1
  41. {braindecode-1.2.0.dev184328194.dist-info → braindecode-1.3.0.dev171478045.dist-info}/METADATA +10 -3
  42. {braindecode-1.2.0.dev184328194.dist-info → braindecode-1.3.0.dev171478045.dist-info}/RECORD +46 -43
  43. braindecode/models/eegresnet.py +0 -362
  44. {braindecode-1.2.0.dev184328194.dist-info → braindecode-1.3.0.dev171478045.dist-info}/WHEEL +0 -0
  45. {braindecode-1.2.0.dev184328194.dist-info → braindecode-1.3.0.dev171478045.dist-info}/licenses/LICENSE.txt +0 -0
  46. {braindecode-1.2.0.dev184328194.dist-info → braindecode-1.3.0.dev171478045.dist-info}/licenses/NOTICE.txt +0 -0
  47. {braindecode-1.2.0.dev184328194.dist-info → braindecode-1.3.0.dev171478045.dist-info}/top_level.txt +0 -0
@@ -1,72 +1,75 @@
1
1
  braindecode/__init__.py,sha256=Ac3LEEyIHWFY_fFh3eAY1GZUqXcUxVSJwOSUCwGEDvQ,182
2
2
  braindecode/classifier.py,sha256=k9vSCtfQbld0YVleDi5rrrmk6k_k5JYEPPBYcNxYjZ8,9807
3
- braindecode/eegneuralnet.py,sha256=dz8k_-2jV7WqkaX4bQG-dmr-vRT7ZtOwJqomXyC9PTw,15287
3
+ braindecode/eegneuralnet.py,sha256=U6kRdT2u8A2Ca0axMTR8IAESBsvgjLMusAbYappKAOk,15368
4
4
  braindecode/regressor.py,sha256=VLfrpiXklwI4onkwue3QmzlBWcvspu0tlrLo9RT1Oiw,9375
5
5
  braindecode/util.py,sha256=J-tBcDJNlMTIFW2mfOy6Ko0nsgdP4obRoEVDeg2rFH0,12686
6
- braindecode/version.py,sha256=Adl2q0noMgIED1dlngWz_nvDbzU6GpgOYSGTS9Fs6io,35
6
+ braindecode/version.py,sha256=g5gWQg7AVLR1mAmDgIYLLB-iuLq1Y2n9miSSXRiLYNY,35
7
7
  braindecode/augmentation/__init__.py,sha256=LG7ONqCufYAF9NZt8POIp10lYXb8iSueYkF-CWGK2Ls,1001
8
8
  braindecode/augmentation/base.py,sha256=gg7wYsVfa9jfqBddtE03B5ZrPHFFmPl2sa3LOrRnGfo,7325
9
- braindecode/augmentation/functional.py,sha256=ygkMNEFHaUdRQfk7meMML19FnM406Uf34h-ztKXdJwM,37978
10
- braindecode/augmentation/transforms.py,sha256=QgLoX6MFaiBH8WoVBgB8eY4x9jZNPMvj20zlwUM8AOs,44245
9
+ braindecode/augmentation/functional.py,sha256=lPhGpZcVtgfQ3oV6p6IQLBCWM_Psa60TwxH3Wj1WyOQ,41133
10
+ braindecode/augmentation/transforms.py,sha256=Ur05yLdROm5pfKTsS2opCWI--X6JwWjP7YMa2KTTZTw,44243
11
11
  braindecode/datasets/__init__.py,sha256=CTl8ucbG948ZJqntEBELb-Pn8GsZLfFZLgVcB-fhw4k,891
12
- braindecode/datasets/base.py,sha256=ED8RQWusMyWf0T7b_HXwouR2Ax47qppEc506AlSzBt0,32155
12
+ braindecode/datasets/base.py,sha256=_qUuMripcBrc04R7j5wOqW41myo5Y_Ku3OqKa4uRqx4,32176
13
13
  braindecode/datasets/bbci.py,sha256=BC9o1thEyYBREAo930O7zZz3xZB-l4Odt5j8E_1huXI,19277
14
14
  braindecode/datasets/bcicomp.py,sha256=ER_XmqxhpoO2FWELMesQXQ40OTe7BXoy7nYDSiZG9kE,7556
15
15
  braindecode/datasets/bids.py,sha256=4asq1HyQHgJjwW7w-GMlvTVQhi-hR2HWLJ8Z__UrUS4,8846
16
+ braindecode/datasets/experimental.py,sha256=Z_uzMNA875-l878LAv7bWiWYJX3QAefmb5quBkcPp7M,8514
16
17
  braindecode/datasets/mne.py,sha256=Dg6RZAAwd8TVGrvLOPF5B_JrbyGUWg52vWmn6fLMOQM,6135
17
18
  braindecode/datasets/moabb.py,sha256=JmBcFV7QJT8GCgLNNKWgxJVnEVnO5wd9U_uiIqTIxDM,7091
18
19
  braindecode/datasets/nmt.py,sha256=E4T8OYBEwWRSjh7VFzmyxaZbf5ufFVEBYYmQEd1ghUU,10430
19
- braindecode/datasets/sleep_physio_challe_18.py,sha256=KTvUtuarOOYu6PHN6H1vcy4W9xilwtZE08n7JSrk8Cs,15414
20
+ braindecode/datasets/sleep_physio_challe_18.py,sha256=66A86_9VssszKrVXowb0oFyL3xbF1VRqQK5FtW33QlM,15427
20
21
  braindecode/datasets/sleep_physionet.py,sha256=jieRx6u-MQ4jn_5Zox_pVV8WjBwXKLv9uq4GXRAZ_58,4087
21
22
  braindecode/datasets/tuh.py,sha256=iG1hOtdevzKGEVpeuRFDBOnsW_rWa5zEmMFJfYR1hqg,22867
22
23
  braindecode/datasets/xy.py,sha256=xT-nS_5jpuVKJ0SGqc7Ia0FVpqj86UfuzcYQdEGZdp0,2986
23
24
  braindecode/datautil/__init__.py,sha256=GB9xOudUhJGDyG08PBrnotw6HnWoWIXAHfRNFO-pxSk,1797
24
- braindecode/datautil/serialization.py,sha256=gLIm9bcuR-XfVdII-RTplUWFRms9qVvVZ0-M6gTucNc,13028
25
+ braindecode/datautil/serialization.py,sha256=g_EVg3oTieqFRattw9OdwMaYjfjANVG-uCS3xVkuHjg,13293
25
26
  braindecode/datautil/util.py,sha256=ZfDoxLieKsgI8xcWQqebV-vJ5pJYRvRRHkEwhwpgoKU,674
26
27
  braindecode/functional/__init__.py,sha256=JPUDFeKtfogEzfrwPaZRBmxexPjBw7AglYMlImaAnWc,413
27
28
  braindecode/functional/functions.py,sha256=CoEweM6YLhigx0tNmmz6yAc8iQ078sTFY2GeCjK5fFs,8622
28
29
  braindecode/functional/initialization.py,sha256=BUSC7y2TMsfShpMYBVwm3xg3ODFqWp-STH7yD4sn8zk,1388
29
- braindecode/models/__init__.py,sha256=xv1QPELZxocPgbc_mz-eYM5w08ZDNOsDV4pOnIFhUww,2551
30
- braindecode/models/atcnet.py,sha256=PhDJl6nBChButabjsmLz_heRcGFCCMKoeUt7k7neNzs,24483
31
- braindecode/models/attentionbasenet.py,sha256=1uwrtsdEGiBwokkO8A_2SR5zapOTQUBZd4q7hIpR0cw,23359
32
- braindecode/models/base.py,sha256=9icrWNZBGbh_VLyB9m8g_K1QyK7s3mh8X-hJ29gEbWs,10802
33
- braindecode/models/biot.py,sha256=T4PymX3penMJcrdfb5Nq6B3P-jyP2laAIu_R9o3uCXo,17512
30
+ braindecode/models/__init__.py,sha256=ovF_WX8ZkXEkleRwYsMMS7ldLPh8_2NzTeYGVqH9ilg,2581
31
+ braindecode/models/atcnet.py,sha256=H2IWMscm3IM4PH8DA_iLkUaeMXgA120DmVld4jBFOCM,32242
32
+ braindecode/models/attentionbasenet.py,sha256=_bml0Ofy7yB12X19a026EYkcLuzZIab0v3sQTqZ5HGQ,30485
33
+ braindecode/models/attn_sleep.py,sha256=m6sdFfD4en2hHf_TpotLPC1hVweJcYZvjgf12bV5FZg,17822
34
+ braindecode/models/base.py,sha256=KjsHVQDdUCAJB4nS-a6eze-H7ayvU4565tsFUcDVxVQ,20212
35
+ braindecode/models/bendr.py,sha256=MZQdYFERVeBJnynEXDlCLdn_I0mJtgzzFuMhCXkbMkg,21591
36
+ braindecode/models/biot.py,sha256=LpJ8tXqQL2Zh_vcQnpUHEpAGQrPHtn2cBSTUPFCW8jQ,17546
34
37
  braindecode/models/contrawr.py,sha256=eeR_ik4gNZ3rJLM6Mw9gJ2gTMkZ8CU8C4rN_GQMQTAE,10044
35
- braindecode/models/ctnet.py,sha256=-J9QtUM8kcntz_xinfuBBvwDMECHiMPMcr2MS4GDPEY,17308
36
- braindecode/models/deep4.py,sha256=YJQUw-0EuFUi4qjm8caJGB8wRM_aeJa5X_d8jrGaQAI,14588
37
- braindecode/models/deepsleepnet.py,sha256=RrciuVJtZ-fhiUl-yLPfK2FP-G29V5Wor6pPlrMHQWQ,9218
38
- braindecode/models/eegconformer.py,sha256=_Y0SXprBD74zD8nKPcS9HQ6PoWzfpu-VCY7Tj6R7Xrs,11612
39
- braindecode/models/eeginception_erp.py,sha256=mwh3rGSHAJVvnbOlYTuWWkKxlmFAdAXBNCrq4IPgOS4,11408
40
- braindecode/models/eeginception_mi.py,sha256=aKJRFuYrpbcRbmmT2xVghKbK8pnl7fzu5hrV0ybRKso,12424
38
+ braindecode/models/ctnet.py,sha256=ce5F31q2weBKvg7PL80iDm7za9fhGaCFvNfHoJW_dtg,17315
39
+ braindecode/models/deep4.py,sha256=-s-R3H7so2xlSiPsU226eSwscv1X9xJMYLm3LhZ3mSU,14645
40
+ braindecode/models/deepsleepnet.py,sha256=wGSAXW73Ga1-HFbn7kXiLeGsJceiqZyMLZnX2UZZXWw,15207
41
+ braindecode/models/eegconformer.py,sha256=rxMAmqErDVLq7nS77CnTtpcC3C2OR_EoZ8-jG-dKP9I,17433
42
+ braindecode/models/eeginception_erp.py,sha256=FYXoM-u4kOodMzGgvKDn7IwJwHl9Z0iiWx9bVHiO9EY,16324
43
+ braindecode/models/eeginception_mi.py,sha256=VoWtsaWj1xQ4FlrvCbnPvo8eosufYUmTrL4uvFtqKcg,12456
41
44
  braindecode/models/eegitnet.py,sha256=feXFmPCd-Ejxt7jgWPen1Ag0-oSclDVQai0Atwu9d_A,9827
42
45
  braindecode/models/eegminer.py,sha256=ouKZah9Q7_sxT7DJJMcPObwVxNQE87sEljJg6QwiQNw,9847
43
- braindecode/models/eegnet.py,sha256=1ZAG0KLDedkodDfqgnGGsoZj6iuU55kGmBlyQo1b47w,16284
44
- braindecode/models/eegnex.py,sha256=KNJIh8pFNhY087Bey2OPzDD4Uqw9pS6UkwMjnOngBzg,8497
45
- braindecode/models/eegresnet.py,sha256=cqWOSGqfJN_dNYUU9l8nYd_S3T1N-UX5-encKQzfBlg,12057
46
- braindecode/models/eegsimpleconv.py,sha256=sHpK-7ZGOCMuXsdkSVuarFTd1T0jMJUP_xwXP3gxQwc,7268
47
- braindecode/models/eegtcnet.py,sha256=np-93Ttctp2uaEYpMrfXfH5bJmCOUZZHLjv8GJEEym4,10830
46
+ braindecode/models/eegnet.py,sha256=i5HzBKTd82fTlKDfB42uc14HpDYxN29SGPfCa4ON5gk,13686
47
+ braindecode/models/eegnex.py,sha256=eahHolFl15LwNWeC5qjQqUGqURibQZIV425rI1p-dG8,13604
48
+ braindecode/models/eegsimpleconv.py,sha256=6V5ZQNWijmd3-2wv7lJB_HGBS3wHWWVrKoNIeWTXu-w,7300
49
+ braindecode/models/eegtcnet.py,sha256=Y53uJEX_hoB6eHCew9SIfzNxCYea8UhljDARJTk-Tq8,10837
48
50
  braindecode/models/fbcnet.py,sha256=RBCLOaiUvivfsT2mq6FN0Kp1-rR3iB0ElzVpHxRl4oI,7486
49
51
  braindecode/models/fblightconvnet.py,sha256=d5MwhawhkjilAMo0ckaYMxJhdGMEuorWgHX-TBgwv6s,11041
50
52
  braindecode/models/fbmsnet.py,sha256=9bZn2_n1dTrI1Qh3Sz9zMZnH_a-Yq-13UHYSmF6r_UE,11659
51
53
  braindecode/models/hybrid.py,sha256=hA8jwD3_3LL71BxUjRM1dkhqlHU9E9hjuDokh-jBq-4,4024
52
54
  braindecode/models/ifnet.py,sha256=Y2bwfko3SDjD74AzgUEzgMhKJFGCCw_Q_Noh5VONEjQ,15137
53
- braindecode/models/labram.py,sha256=vcrpwiu4F-djtIPscFbtP2Y0jTosyR_cXnOMQQRGPLw,41798
55
+ braindecode/models/labram.py,sha256=1BVGJpPNqXtM_XQ__20p1_9KWVCgvxJ1ICU4HqZW3d8,46600
54
56
  braindecode/models/msvtnet.py,sha256=hxeCLkHS6w2w89YlLfEPCyQ4XQQpt45bEYPiQJ9SFzY,12642
55
- braindecode/models/sccnet.py,sha256=baGsNpVRdyWzbkTizOthJoJGejLb8BxMpN9ODwZinio,7919
56
- braindecode/models/shallow_fbcsp.py,sha256=-sL6XCmCUZVhKKrC84-KWgwhWKQQvev1oNSmH_d6FA4,7499
57
- braindecode/models/signal_jepa.py,sha256=UeSkeAM3Qmx8bbAqHCj5nP-PtZM00_5SGA8ibo9mptc,37079
57
+ braindecode/models/patchedtransformer.py,sha256=9TY9l2X4EoCuE9IoOObjubKFRdmsN5lbrVQLnmr66VY,23444
58
+ braindecode/models/sccnet.py,sha256=C7vdwIR5cI6wJCl5f8TnGQG6qinq21y4HG6l-D5AwbY,11971
59
+ braindecode/models/shallow_fbcsp.py,sha256=7U07DJBrm2JHV8v5ja-xuE5-IH5tfmryhJtrfO1n4jk,7531
60
+ braindecode/models/signal_jepa.py,sha256=bBujhM9ItIJisKvbxEi5e1yuV-0mBb41GlyMeEs_TkA,41124
58
61
  braindecode/models/sinc_shallow.py,sha256=Ilv8K1XhMGiRTBtQdq7L595i6cEFYOBe0_UDv-LqL7s,11907
59
- braindecode/models/sleep_stager_blanco_2020.py,sha256=qPKMDLuv4J7et4dZHyTe-j0oB6ESYn9mA_aW7RMC-rU,6002
60
- braindecode/models/sleep_stager_chambon_2018.py,sha256=62x2Rdjd5UZDX8YlnfAtdRCrjLsPvPpnUweGElZLdkw,5213
61
- braindecode/models/sleep_stager_eldele_2021.py,sha256=-4ISuznykDy9ZFzUM-OeiGCwmgM3U-LuyoDSrhPbRDw,17555
62
- braindecode/models/sparcnet.py,sha256=eZMoJOxlcIyHPdQiX7KXUKuUBlAWkTwsXNWmNma_KAI,13941
63
- braindecode/models/summary.csv,sha256=l7HYYwv3Z69JRPVIhVq-wr_nC1J1KIz6IGw_zeRSk58,6110
62
+ braindecode/models/sleep_stager_blanco_2020.py,sha256=vXulnDYutEFLM0UPXyAI0YIj5QImUMVEmYZb78j34H8,6034
63
+ braindecode/models/sleep_stager_chambon_2018.py,sha256=8w8IR2PsfG0jSc3o0YVopgHpOvCHNIuMi7-QRJOYEW4,5245
64
+ braindecode/models/sparcnet.py,sha256=MG1OB91guI7ssKRk8GvWlzUvaxo_otaYnbEGzNUZVyg,13973
65
+ braindecode/models/sstdpn.py,sha256=wJv-UYP1q8cMGp2wU1efzIZiigRmkJ8uY22rNB2D7Wc,35077
66
+ braindecode/models/summary.csv,sha256=vFmhpCGFZlxC9Zm8KLBaGRHvZZfdRY85NAGj1Wyv1yU,7209
64
67
  braindecode/models/syncnet.py,sha256=nrWJC5ijCSWKVZyRn-dmOuc1t5vk2C6tx8U3U4j5d5Y,8362
65
68
  braindecode/models/tcn.py,sha256=SQu56H9zdbcbbDIXZVgZtJg7es8CRAJ7z-IBnmf4UWM,8158
66
- braindecode/models/tidnet.py,sha256=k7Q0yAnEBmq1sqhsvoV4-g8wfYSUQ-C3iYxfLp5m8xQ,11805
67
- braindecode/models/tsinception.py,sha256=EcfLDDJXZloh_vrKRuxAHYRZ1EVWlEKHNXqybTRrTbQ,10116
68
- braindecode/models/usleep.py,sha256=dFh3KiZITu13gMxcbPGoK4hq2ySDWzVSCQXkj1006w0,11605
69
- braindecode/models/util.py,sha256=VrhwG1YBGwKohCej6TmhrNAIoleQHRu3YdiBPuHFY_E,5302
69
+ braindecode/models/tidnet.py,sha256=HSUL1al6gaRbJ-BRYAAs4KDvLuKEvh0NnBfAsPeWMpM,11837
70
+ braindecode/models/tsinception.py,sha256=nnQxzpqRy9FPuN5xgh9fNQ386VbreQ_nZBSFNkSfal0,10356
71
+ braindecode/models/usleep.py,sha256=5uztUHX70T_LurqRob_XmVnKkZDwt74x2Iz181M7s54,17233
72
+ braindecode/models/util.py,sha256=nrYBdd0FTCoYxgg21oz1UlW-PACx-0-_EyvMQua0QI8,5414
70
73
  braindecode/modules/__init__.py,sha256=PD2LpeSHWW_MgEef7-G8ief5gheGObzsIoacchxWuyA,1756
71
74
  braindecode/modules/activation.py,sha256=lTO2IjZWBDeXZ4ZVDgLmTDmxHdqyAny3Fsy07HY9tmQ,1466
72
75
  braindecode/modules/attention.py,sha256=ISE11jXAvMqKpawZilg8i7lDX5mkuvpEplrh_CtGEkk,24102
@@ -81,7 +84,7 @@ braindecode/modules/util.py,sha256=tVXEhzeTsYrr_wZ5CiXaq3VYGtC5TmGEEW2hMYjTQAE,2
81
84
  braindecode/modules/wrapper.py,sha256=Z-aZ4wxA0psYefMOfj03r7D1XjD4az6GpZpaQoDPJv0,2421
82
85
  braindecode/preprocessing/__init__.py,sha256=V0iwdzb6DzpUaCabA7I6HmOqXK_XvTbpP5HaEduSJ4s,776
83
86
  braindecode/preprocessing/mne_preprocess.py,sha256=_Jczaitqbx16utsUOhnonEcoExf6jPsWNwVOVvoKFfU,2210
84
- braindecode/preprocessing/preprocess.py,sha256=-9IKjb0THq36m54TK-YRzV18wIkxmVgTcGO2sEH6q98,17665
87
+ braindecode/preprocessing/preprocess.py,sha256=da_-Tn1NLPunsZC2-uzzgCYgdm_Xj-CIJjwf_CTMuFs,17899
85
88
  braindecode/preprocessing/windowers.py,sha256=6w6mOnroGWnV7tS23UagZZepswaxaL00S45Jr5AViRE,36551
86
89
  braindecode/samplers/__init__.py,sha256=TLuO6gXv2WioJdX671MI_CHVSsOfbjnly1Xv9K3_WdA,452
87
90
  braindecode/samplers/base.py,sha256=z_Txp9cEwUmIBL0J6FPJbx1cMSsU9l9mxymRCGqNss0,15111
@@ -93,9 +96,9 @@ braindecode/training/scoring.py,sha256=WRkwqbitA3m_dzRnGp2ZIZPge5Nhx9gAEQhIHzeH4
93
96
  braindecode/visualization/__init__.py,sha256=4EER_xHqZIDzEvmgUEm7K1bgNKpyZAIClR9ZCkMuY4M,240
94
97
  braindecode/visualization/confusion_matrices.py,sha256=qIWMLEHow5CJ7PhGggD8mnD55Le6xhma9HSzt4R33fc,9509
95
98
  braindecode/visualization/gradients.py,sha256=KZo-GA0uwiwty2_94j2IjmCR2SKcfPb1Bi3sQq7vpTk,2170
96
- braindecode-1.2.0.dev184328194.dist-info/licenses/LICENSE.txt,sha256=7rg7k6hyj8m9whQ7dpKbqnCssoOEx_Mbtqb4uSOjljE,1525
97
- braindecode-1.2.0.dev184328194.dist-info/licenses/NOTICE.txt,sha256=sOxuTbalPxTM8H6VqtvGbXCt_BoOF7JevEYG_knqbm4,620
98
- braindecode-1.2.0.dev184328194.dist-info/METADATA,sha256=PgPq5CmBC6TDByTBtGn3Gtf6yaAJW96CZ_3J5BgGhDc,6883
99
- braindecode-1.2.0.dev184328194.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
100
- braindecode-1.2.0.dev184328194.dist-info/top_level.txt,sha256=pHsWQmSy0uhIez62-HA9j0iaXKvSbUL39ifFRkFnChA,12
101
- braindecode-1.2.0.dev184328194.dist-info/RECORD,,
99
+ braindecode-1.3.0.dev171478045.dist-info/licenses/LICENSE.txt,sha256=7rg7k6hyj8m9whQ7dpKbqnCssoOEx_Mbtqb4uSOjljE,1525
100
+ braindecode-1.3.0.dev171478045.dist-info/licenses/NOTICE.txt,sha256=sOxuTbalPxTM8H6VqtvGbXCt_BoOF7JevEYG_knqbm4,620
101
+ braindecode-1.3.0.dev171478045.dist-info/METADATA,sha256=mwyspEBAyIpNfEvNLbm8yY2DePcO_EuJFLYuVI5dePk,7215
102
+ braindecode-1.3.0.dev171478045.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
103
+ braindecode-1.3.0.dev171478045.dist-info/top_level.txt,sha256=pHsWQmSy0uhIez62-HA9j0iaXKvSbUL39ifFRkFnChA,12
104
+ braindecode-1.3.0.dev171478045.dist-info/RECORD,,
@@ -1,362 +0,0 @@
1
- # Authors: Robin Tibor Schirrmeister <robintibor@gmail.com>
2
- # Tonio Ball
3
- #
4
- # License: BSD-3
5
-
6
- import numpy as np
7
- import torch
8
- from einops.layers.torch import Rearrange
9
- from torch import nn
10
- from torch.nn import init
11
-
12
- from braindecode.models.base import EEGModuleMixin
13
- from braindecode.modules import (
14
- AvgPool2dWithConv,
15
- Ensure4d,
16
- SqueezeFinalOutput,
17
- )
18
-
19
-
20
- class EEGResNet(EEGModuleMixin, nn.Sequential):
21
- """EEGResNet from Schirrmeister et al. 2017 [Schirrmeister2017]_.
22
-
23
- .. figure:: https://onlinelibrary.wiley.com/cms/asset/bed1b768-809f-4bc6-b942-b36970d81271/hbm23730-fig-0003-m.jpg
24
- :align: center
25
- :alt: EEGResNet Architecture
26
-
27
- Model described in [Schirrmeister2017]_.
28
-
29
- Parameters
30
- ----------
31
- in_chans :
32
- Alias for ``n_chans``.
33
- n_classes :
34
- Alias for ``n_outputs``.
35
- input_window_samples :
36
- Alias for ``n_times``.
37
- activation: nn.Module, default=nn.ELU
38
- Activation function class to apply. Should be a PyTorch activation
39
- module class like ``nn.ReLU`` or ``nn.ELU``. Default is ``nn.ELU``.
40
-
41
- References
42
- ----------
43
- .. [Schirrmeister2017] Schirrmeister, R. T., Springenberg, J. T., Fiederer,
44
- L. D. J., Glasstetter, M., Eggensperger, K., Tangermann, M., Hutter, F.
45
- & Ball, T. (2017). Deep learning with convolutional neural networks for ,
46
- EEG decoding and visualization. Human Brain Mapping, Aug. 2017.
47
- Online: http://dx.doi.org/10.1002/hbm.23730
48
- """
49
-
50
- def __init__(
51
- self,
52
- n_chans=None,
53
- n_outputs=None,
54
- n_times=None,
55
- final_pool_length="auto",
56
- n_first_filters=20,
57
- n_layers_per_block=2,
58
- first_filter_length=3,
59
- activation=nn.ELU,
60
- split_first_layer=True,
61
- batch_norm_alpha=0.1,
62
- batch_norm_epsilon=1e-4,
63
- conv_weight_init_fn=lambda w: init.kaiming_normal_(w, a=0),
64
- chs_info=None,
65
- input_window_seconds=None,
66
- sfreq=250,
67
- ):
68
- super().__init__(
69
- n_outputs=n_outputs,
70
- n_chans=n_chans,
71
- chs_info=chs_info,
72
- n_times=n_times,
73
- input_window_seconds=input_window_seconds,
74
- sfreq=sfreq,
75
- )
76
- del n_outputs, n_chans, chs_info, n_times, input_window_seconds, sfreq
77
-
78
- if final_pool_length == "auto":
79
- assert self.n_times is not None
80
- assert first_filter_length % 2 == 1
81
- self.final_pool_length = final_pool_length
82
- self.n_first_filters = n_first_filters
83
- self.n_layers_per_block = n_layers_per_block
84
- self.first_filter_length = first_filter_length
85
- self.nonlinearity = activation
86
- self.split_first_layer = split_first_layer
87
- self.batch_norm_alpha = batch_norm_alpha
88
- self.batch_norm_epsilon = batch_norm_epsilon
89
- self.conv_weight_init_fn = conv_weight_init_fn
90
-
91
- self.mapping = {
92
- "conv_classifier.weight": "final_layer.conv_classifier.weight",
93
- "conv_classifier.bias": "final_layer.conv_classifier.bias",
94
- }
95
-
96
- self.add_module("ensuredims", Ensure4d())
97
- if self.split_first_layer:
98
- self.add_module("dimshuffle", Rearrange("batch C T 1 -> batch 1 T C"))
99
- self.add_module(
100
- "conv_time",
101
- nn.Conv2d(
102
- 1,
103
- self.n_first_filters,
104
- (self.first_filter_length, 1),
105
- stride=1,
106
- padding=(self.first_filter_length // 2, 0),
107
- ),
108
- )
109
- self.add_module(
110
- "conv_spat",
111
- nn.Conv2d(
112
- self.n_first_filters,
113
- self.n_first_filters,
114
- (1, self.n_chans),
115
- stride=(1, 1),
116
- bias=False,
117
- ),
118
- )
119
- else:
120
- self.add_module(
121
- "conv_time",
122
- nn.Conv2d(
123
- self.n_chans,
124
- self.n_first_filters,
125
- (self.first_filter_length, 1),
126
- stride=(1, 1),
127
- padding=(self.first_filter_length // 2, 0),
128
- bias=False,
129
- ),
130
- )
131
- n_filters_conv = self.n_first_filters
132
- self.add_module(
133
- "bnorm",
134
- nn.BatchNorm2d(
135
- n_filters_conv, momentum=self.batch_norm_alpha, affine=True, eps=1e-5
136
- ),
137
- )
138
- self.add_module("conv_nonlin", self.nonlinearity())
139
- cur_dilation = np.array([1, 1])
140
- n_cur_filters = n_filters_conv
141
- i_block = 1
142
- for i_layer in range(self.n_layers_per_block):
143
- self.add_module(
144
- "res_{:d}_{:d}".format(i_block, i_layer),
145
- _ResidualBlock(n_cur_filters, n_cur_filters, dilation=cur_dilation),
146
- )
147
- i_block += 1
148
- cur_dilation[0] *= 2
149
- n_out_filters = int(2 * n_cur_filters)
150
- self.add_module(
151
- "res_{:d}_{:d}".format(i_block, 0),
152
- _ResidualBlock(
153
- n_cur_filters,
154
- n_out_filters,
155
- dilation=cur_dilation,
156
- ),
157
- )
158
- n_cur_filters = n_out_filters
159
- for i_layer in range(1, self.n_layers_per_block):
160
- self.add_module(
161
- "res_{:d}_{:d}".format(i_block, i_layer),
162
- _ResidualBlock(n_cur_filters, n_cur_filters, dilation=cur_dilation),
163
- )
164
-
165
- i_block += 1
166
- cur_dilation[0] *= 2
167
- n_out_filters = int(1.5 * n_cur_filters)
168
- self.add_module(
169
- "res_{:d}_{:d}".format(i_block, 0),
170
- _ResidualBlock(
171
- n_cur_filters,
172
- n_out_filters,
173
- dilation=cur_dilation,
174
- ),
175
- )
176
- n_cur_filters = n_out_filters
177
- for i_layer in range(1, self.n_layers_per_block):
178
- self.add_module(
179
- "res_{:d}_{:d}".format(i_block, i_layer),
180
- _ResidualBlock(n_cur_filters, n_cur_filters, dilation=cur_dilation),
181
- )
182
-
183
- i_block += 1
184
- cur_dilation[0] *= 2
185
- self.add_module(
186
- "res_{:d}_{:d}".format(i_block, 0),
187
- _ResidualBlock(
188
- n_cur_filters,
189
- n_cur_filters,
190
- dilation=cur_dilation,
191
- ),
192
- )
193
- for i_layer in range(1, self.n_layers_per_block):
194
- self.add_module(
195
- "res_{:d}_{:d}".format(i_block, i_layer),
196
- _ResidualBlock(n_cur_filters, n_cur_filters, dilation=cur_dilation),
197
- )
198
-
199
- i_block += 1
200
- cur_dilation[0] *= 2
201
- self.add_module(
202
- "res_{:d}_{:d}".format(i_block, 0),
203
- _ResidualBlock(
204
- n_cur_filters,
205
- n_cur_filters,
206
- dilation=cur_dilation,
207
- ),
208
- )
209
- for i_layer in range(1, self.n_layers_per_block):
210
- self.add_module(
211
- "res_{:d}_{:d}".format(i_block, i_layer),
212
- _ResidualBlock(n_cur_filters, n_cur_filters, dilation=cur_dilation),
213
- )
214
-
215
- i_block += 1
216
- cur_dilation[0] *= 2
217
- self.add_module(
218
- "res_{:d}_{:d}".format(i_block, 0),
219
- _ResidualBlock(
220
- n_cur_filters,
221
- n_cur_filters,
222
- dilation=cur_dilation,
223
- ),
224
- )
225
- for i_layer in range(1, self.n_layers_per_block):
226
- self.add_module(
227
- "res_{:d}_{:d}".format(i_block, i_layer),
228
- _ResidualBlock(n_cur_filters, n_cur_filters, dilation=cur_dilation),
229
- )
230
- i_block += 1
231
- cur_dilation[0] *= 2
232
- self.add_module(
233
- "res_{:d}_{:d}".format(i_block, 0),
234
- _ResidualBlock(
235
- n_cur_filters,
236
- n_cur_filters,
237
- dilation=cur_dilation,
238
- ),
239
- )
240
- for i_layer in range(1, self.n_layers_per_block):
241
- self.add_module(
242
- "res_{:d}_{:d}".format(i_block, i_layer),
243
- _ResidualBlock(n_cur_filters, n_cur_filters, dilation=cur_dilation),
244
- )
245
-
246
- self.eval()
247
- if self.final_pool_length == "auto":
248
- self.add_module("mean_pool", nn.AdaptiveAvgPool2d((1, 1)))
249
- else:
250
- pool_dilation = int(cur_dilation[0]), int(cur_dilation[1])
251
- self.add_module(
252
- "mean_pool",
253
- AvgPool2dWithConv(
254
- (self.final_pool_length, 1), (1, 1), dilation=pool_dilation
255
- ),
256
- )
257
-
258
- # Incorporating classification module and subsequent ones in one final layer
259
- module = nn.Sequential()
260
-
261
- module.add_module(
262
- "conv_classifier",
263
- nn.Conv2d(
264
- n_cur_filters,
265
- self.n_outputs,
266
- (1, 1),
267
- bias=True,
268
- ),
269
- )
270
-
271
- module.add_module("squeeze", SqueezeFinalOutput())
272
-
273
- self.add_module("final_layer", module)
274
-
275
- # Initialize all weights
276
- self.apply(lambda module: self._weights_init(module, self.conv_weight_init_fn))
277
-
278
- # Start in train mode
279
- self.train()
280
-
281
- @staticmethod
282
- def _weights_init(module, conv_weight_init_fn):
283
- """
284
- initialize weights
285
- """
286
- classname = module.__class__.__name__
287
- if "Conv" in classname and classname != "AvgPool2dWithConv":
288
- conv_weight_init_fn(module.weight)
289
- if module.bias is not None:
290
- init.constant_(module.bias, 0)
291
- elif "BatchNorm" in classname:
292
- init.constant_(module.weight, 1)
293
- init.constant_(module.bias, 0)
294
-
295
-
296
- class _ResidualBlock(nn.Module):
297
- """
298
- create a residual learning building block with two stacked 3x3 convlayers as in paper
299
- """
300
-
301
- def __init__(
302
- self,
303
- in_filters,
304
- out_num_filters,
305
- dilation,
306
- filter_time_length=3,
307
- nonlinearity: nn.Module = nn.ELU,
308
- batch_norm_alpha=0.1,
309
- batch_norm_epsilon=1e-4,
310
- ):
311
- super(_ResidualBlock, self).__init__()
312
- time_padding = int((filter_time_length - 1) * dilation[0])
313
- assert time_padding % 2 == 0
314
- time_padding = int(time_padding // 2)
315
- dilation = (int(dilation[0]), int(dilation[1]))
316
- assert (out_num_filters - in_filters) % 2 == 0, (
317
- "Need even number of extra channels in order to be able to pad correctly"
318
- )
319
- self.n_pad_chans = out_num_filters - in_filters
320
-
321
- self.conv_1 = nn.Conv2d(
322
- in_filters,
323
- out_num_filters,
324
- (filter_time_length, 1),
325
- stride=(1, 1),
326
- dilation=dilation,
327
- padding=(time_padding, 0),
328
- )
329
- self.bn1 = nn.BatchNorm2d(
330
- out_num_filters,
331
- momentum=batch_norm_alpha,
332
- affine=True,
333
- eps=batch_norm_epsilon,
334
- )
335
- self.conv_2 = nn.Conv2d(
336
- out_num_filters,
337
- out_num_filters,
338
- (filter_time_length, 1),
339
- stride=(1, 1),
340
- dilation=dilation,
341
- padding=(time_padding, 0),
342
- )
343
- self.bn2 = nn.BatchNorm2d(
344
- out_num_filters,
345
- momentum=batch_norm_alpha,
346
- affine=True,
347
- eps=batch_norm_epsilon,
348
- )
349
- # also see https://mail.google.com/mail/u/0/#search/ilya+joos/1576137dd34c3127
350
- # for resnet options as ilya used them
351
- self.nonlinearity = nonlinearity()
352
-
353
- def forward(self, x):
354
- stack_1 = self.nonlinearity(self.bn1(self.conv_1(x)))
355
- stack_2 = self.bn2(self.conv_2(stack_1)) # next nonlin after sum
356
- if self.n_pad_chans != 0:
357
- zeros_for_padding = x.new_zeros(
358
- (x.shape[0], self.n_pad_chans // 2, x.shape[2], x.shape[3])
359
- )
360
- x = torch.cat((zeros_for_padding, x, zeros_for_padding), dim=1)
361
- out = self.nonlinearity(x + stack_2)
362
- return out