braindecode 1.2.0.dev168908090__py3-none-any.whl → 1.2.0.dev175267687__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of braindecode might be problematic. Click here for more details.

Files changed (33) hide show
  1. braindecode/datasets/experimental.py +218 -0
  2. braindecode/models/__init__.py +6 -8
  3. braindecode/models/atcnet.py +1 -1
  4. braindecode/models/attentionbasenet.py +151 -26
  5. braindecode/models/{sleep_stager_eldele_2021.py → attn_sleep.py} +12 -2
  6. braindecode/models/ctnet.py +1 -1
  7. braindecode/models/deep4.py +6 -2
  8. braindecode/models/deepsleepnet.py +118 -5
  9. braindecode/models/eegconformer.py +15 -12
  10. braindecode/models/eeginception_erp.py +76 -7
  11. braindecode/models/eeginception_mi.py +2 -0
  12. braindecode/models/eegnet.py +25 -189
  13. braindecode/models/eegnex.py +113 -6
  14. braindecode/models/eegsimpleconv.py +2 -0
  15. braindecode/models/eegtcnet.py +1 -1
  16. braindecode/models/sccnet.py +79 -8
  17. braindecode/models/shallow_fbcsp.py +2 -0
  18. braindecode/models/sleep_stager_blanco_2020.py +2 -0
  19. braindecode/models/sleep_stager_chambon_2018.py +2 -0
  20. braindecode/models/sparcnet.py +2 -0
  21. braindecode/models/summary.csv +39 -41
  22. braindecode/models/tidnet.py +2 -0
  23. braindecode/models/tsinception.py +15 -3
  24. braindecode/models/usleep.py +103 -9
  25. braindecode/models/util.py +5 -5
  26. braindecode/version.py +1 -1
  27. {braindecode-1.2.0.dev168908090.dist-info → braindecode-1.2.0.dev175267687.dist-info}/METADATA +7 -2
  28. {braindecode-1.2.0.dev168908090.dist-info → braindecode-1.2.0.dev175267687.dist-info}/RECORD +32 -32
  29. braindecode/models/eegresnet.py +0 -362
  30. {braindecode-1.2.0.dev168908090.dist-info → braindecode-1.2.0.dev175267687.dist-info}/WHEEL +0 -0
  31. {braindecode-1.2.0.dev168908090.dist-info → braindecode-1.2.0.dev175267687.dist-info}/licenses/LICENSE.txt +0 -0
  32. {braindecode-1.2.0.dev168908090.dist-info → braindecode-1.2.0.dev175267687.dist-info}/licenses/NOTICE.txt +0 -0
  33. {braindecode-1.2.0.dev168908090.dist-info → braindecode-1.2.0.dev175267687.dist-info}/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: braindecode
3
- Version: 1.2.0.dev168908090
3
+ Version: 1.2.0.dev175267687
4
4
  Summary: Deep learning software to decode EEG, ECG or MEG signals
5
5
  Author-email: Robin Tibor Schirrmeister <robintibor@gmail.com>
6
6
  Maintainer-email: Alexandre Gramfort <agramfort@meta.com>, Bruno Aristimunha Pinto <b.aristimunha@gmail.com>, Robin Tibor Schirrmeister <robintibor@gmail.com>
@@ -17,7 +17,7 @@ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
17
17
  Classifier: Programming Language :: Python :: 3.10
18
18
  Classifier: Programming Language :: Python :: 3.11
19
19
  Classifier: Programming Language :: Python :: 3.12
20
- Requires-Python: >3.10
20
+ Requires-Python: >=3.10
21
21
  Description-Content-Type: text/x-rst
22
22
  License-File: LICENSE.txt
23
23
  License-File: NOTICE.txt
@@ -49,6 +49,10 @@ Requires-Dist: mypy; extra == "tests"
49
49
  Provides-Extra: docs
50
50
  Requires-Dist: sphinx_gallery; extra == "docs"
51
51
  Requires-Dist: sphinx_rtd_theme; extra == "docs"
52
+ Requires-Dist: sphinx-autodoc-typehints; extra == "docs"
53
+ Requires-Dist: sphinx-autobuild; extra == "docs"
54
+ Requires-Dist: sphinxcontrib-bibtex; extra == "docs"
55
+ Requires-Dist: sphinx_sitemap; extra == "docs"
52
56
  Requires-Dist: pydata_sphinx_theme; extra == "docs"
53
57
  Requires-Dist: numpydoc; extra == "docs"
54
58
  Requires-Dist: memory_profiler; extra == "docs"
@@ -59,6 +63,7 @@ Requires-Dist: lightning; extra == "docs"
59
63
  Requires-Dist: seaborn; extra == "docs"
60
64
  Requires-Dist: pre-commit; extra == "docs"
61
65
  Requires-Dist: openneuro-py; extra == "docs"
66
+ Requires-Dist: plotly; extra == "docs"
62
67
  Provides-Extra: all
63
68
  Requires-Dist: braindecode[docs,moabb,tests]; extra == "all"
64
69
  Dynamic: license-file
@@ -3,7 +3,7 @@ braindecode/classifier.py,sha256=k9vSCtfQbld0YVleDi5rrrmk6k_k5JYEPPBYcNxYjZ8,980
3
3
  braindecode/eegneuralnet.py,sha256=dz8k_-2jV7WqkaX4bQG-dmr-vRT7ZtOwJqomXyC9PTw,15287
4
4
  braindecode/regressor.py,sha256=VLfrpiXklwI4onkwue3QmzlBWcvspu0tlrLo9RT1Oiw,9375
5
5
  braindecode/util.py,sha256=J-tBcDJNlMTIFW2mfOy6Ko0nsgdP4obRoEVDeg2rFH0,12686
6
- braindecode/version.py,sha256=b8BdgKb6YgI65s6lSGpGOwj1zjSE3ObTFYtA1t6_lrU,35
6
+ braindecode/version.py,sha256=DjAGynn_F-_dyPfKBBkITwv_h_JYqDSa0IISxiwpVAE,35
7
7
  braindecode/augmentation/__init__.py,sha256=LG7ONqCufYAF9NZt8POIp10lYXb8iSueYkF-CWGK2Ls,1001
8
8
  braindecode/augmentation/base.py,sha256=gg7wYsVfa9jfqBddtE03B5ZrPHFFmPl2sa3LOrRnGfo,7325
9
9
  braindecode/augmentation/functional.py,sha256=ygkMNEFHaUdRQfk7meMML19FnM406Uf34h-ztKXdJwM,37978
@@ -13,6 +13,7 @@ braindecode/datasets/base.py,sha256=ED8RQWusMyWf0T7b_HXwouR2Ax47qppEc506AlSzBt0,
13
13
  braindecode/datasets/bbci.py,sha256=BC9o1thEyYBREAo930O7zZz3xZB-l4Odt5j8E_1huXI,19277
14
14
  braindecode/datasets/bcicomp.py,sha256=ER_XmqxhpoO2FWELMesQXQ40OTe7BXoy7nYDSiZG9kE,7556
15
15
  braindecode/datasets/bids.py,sha256=4asq1HyQHgJjwW7w-GMlvTVQhi-hR2HWLJ8Z__UrUS4,8846
16
+ braindecode/datasets/experimental.py,sha256=Z_uzMNA875-l878LAv7bWiWYJX3QAefmb5quBkcPp7M,8514
16
17
  braindecode/datasets/mne.py,sha256=Dg6RZAAwd8TVGrvLOPF5B_JrbyGUWg52vWmn6fLMOQM,6135
17
18
  braindecode/datasets/moabb.py,sha256=JmBcFV7QJT8GCgLNNKWgxJVnEVnO5wd9U_uiIqTIxDM,7091
18
19
  braindecode/datasets/nmt.py,sha256=E4T8OYBEwWRSjh7VFzmyxaZbf5ufFVEBYYmQEd1ghUU,10430
@@ -26,25 +27,25 @@ braindecode/datautil/util.py,sha256=ZfDoxLieKsgI8xcWQqebV-vJ5pJYRvRRHkEwhwpgoKU,
26
27
  braindecode/functional/__init__.py,sha256=JPUDFeKtfogEzfrwPaZRBmxexPjBw7AglYMlImaAnWc,413
27
28
  braindecode/functional/functions.py,sha256=CoEweM6YLhigx0tNmmz6yAc8iQ078sTFY2GeCjK5fFs,8622
28
29
  braindecode/functional/initialization.py,sha256=BUSC7y2TMsfShpMYBVwm3xg3ODFqWp-STH7yD4sn8zk,1388
29
- braindecode/models/__init__.py,sha256=xv1QPELZxocPgbc_mz-eYM5w08ZDNOsDV4pOnIFhUww,2551
30
- braindecode/models/atcnet.py,sha256=Pn5KzQjv7YxSNDr_CY6O_Yg9K4m9XJ7btCIqyzkcPxc,32102
31
- braindecode/models/attentionbasenet.py,sha256=1uwrtsdEGiBwokkO8A_2SR5zapOTQUBZd4q7hIpR0cw,23359
30
+ braindecode/models/__init__.py,sha256=v2Pn0H-rM_9xr1EEoKIFygmhbS9r52qh8XwFzXuhK70,2455
31
+ braindecode/models/atcnet.py,sha256=jA_18BOaasmiqGbLJOvfBY5q2xHtKdoRFKzN_aqpDoQ,32107
32
+ braindecode/models/attentionbasenet.py,sha256=AK78VvwrZXyJY20zadzDUHl17C-5zcWCd5xPRN7Lr4o,30385
33
+ braindecode/models/attn_sleep.py,sha256=m6sdFfD4en2hHf_TpotLPC1hVweJcYZvjgf12bV5FZg,17822
32
34
  braindecode/models/base.py,sha256=9icrWNZBGbh_VLyB9m8g_K1QyK7s3mh8X-hJ29gEbWs,10802
33
35
  braindecode/models/biot.py,sha256=T4PymX3penMJcrdfb5Nq6B3P-jyP2laAIu_R9o3uCXo,17512
34
36
  braindecode/models/contrawr.py,sha256=eeR_ik4gNZ3rJLM6Mw9gJ2gTMkZ8CU8C4rN_GQMQTAE,10044
35
- braindecode/models/ctnet.py,sha256=-J9QtUM8kcntz_xinfuBBvwDMECHiMPMcr2MS4GDPEY,17308
36
- braindecode/models/deep4.py,sha256=YJQUw-0EuFUi4qjm8caJGB8wRM_aeJa5X_d8jrGaQAI,14588
37
- braindecode/models/deepsleepnet.py,sha256=RrciuVJtZ-fhiUl-yLPfK2FP-G29V5Wor6pPlrMHQWQ,9218
38
- braindecode/models/eegconformer.py,sha256=6scz0Axm97JV7-4u5yd6HGE7PldAMR39x5qSNzjSqxQ,17404
39
- braindecode/models/eeginception_erp.py,sha256=mwh3rGSHAJVvnbOlYTuWWkKxlmFAdAXBNCrq4IPgOS4,11408
40
- braindecode/models/eeginception_mi.py,sha256=aKJRFuYrpbcRbmmT2xVghKbK8pnl7fzu5hrV0ybRKso,12424
37
+ braindecode/models/ctnet.py,sha256=ce5F31q2weBKvg7PL80iDm7za9fhGaCFvNfHoJW_dtg,17315
38
+ braindecode/models/deep4.py,sha256=-s-R3H7so2xlSiPsU226eSwscv1X9xJMYLm3LhZ3mSU,14645
39
+ braindecode/models/deepsleepnet.py,sha256=wGSAXW73Ga1-HFbn7kXiLeGsJceiqZyMLZnX2UZZXWw,15207
40
+ braindecode/models/eegconformer.py,sha256=rxMAmqErDVLq7nS77CnTtpcC3C2OR_EoZ8-jG-dKP9I,17433
41
+ braindecode/models/eeginception_erp.py,sha256=FYXoM-u4kOodMzGgvKDn7IwJwHl9Z0iiWx9bVHiO9EY,16324
42
+ braindecode/models/eeginception_mi.py,sha256=VoWtsaWj1xQ4FlrvCbnPvo8eosufYUmTrL4uvFtqKcg,12456
41
43
  braindecode/models/eegitnet.py,sha256=feXFmPCd-Ejxt7jgWPen1Ag0-oSclDVQai0Atwu9d_A,9827
42
44
  braindecode/models/eegminer.py,sha256=ouKZah9Q7_sxT7DJJMcPObwVxNQE87sEljJg6QwiQNw,9847
43
- braindecode/models/eegnet.py,sha256=YeBCmU6Al9FDS4MZQTOLd0MCUfPbM6tmVlGWpb59Qzg,19256
44
- braindecode/models/eegnex.py,sha256=KNJIh8pFNhY087Bey2OPzDD4Uqw9pS6UkwMjnOngBzg,8497
45
- braindecode/models/eegresnet.py,sha256=cqWOSGqfJN_dNYUU9l8nYd_S3T1N-UX5-encKQzfBlg,12057
46
- braindecode/models/eegsimpleconv.py,sha256=sHpK-7ZGOCMuXsdkSVuarFTd1T0jMJUP_xwXP3gxQwc,7268
47
- braindecode/models/eegtcnet.py,sha256=np-93Ttctp2uaEYpMrfXfH5bJmCOUZZHLjv8GJEEym4,10830
45
+ braindecode/models/eegnet.py,sha256=dIaHZoz7xMII1qKrS0___IWdy1xg2QrMMiqUgTJM9E8,13682
46
+ braindecode/models/eegnex.py,sha256=eahHolFl15LwNWeC5qjQqUGqURibQZIV425rI1p-dG8,13604
47
+ braindecode/models/eegsimpleconv.py,sha256=6V5ZQNWijmd3-2wv7lJB_HGBS3wHWWVrKoNIeWTXu-w,7300
48
+ braindecode/models/eegtcnet.py,sha256=Y53uJEX_hoB6eHCew9SIfzNxCYea8UhljDARJTk-Tq8,10837
48
49
  braindecode/models/fbcnet.py,sha256=RBCLOaiUvivfsT2mq6FN0Kp1-rR3iB0ElzVpHxRl4oI,7486
49
50
  braindecode/models/fblightconvnet.py,sha256=d5MwhawhkjilAMo0ckaYMxJhdGMEuorWgHX-TBgwv6s,11041
50
51
  braindecode/models/fbmsnet.py,sha256=9bZn2_n1dTrI1Qh3Sz9zMZnH_a-Yq-13UHYSmF6r_UE,11659
@@ -52,21 +53,20 @@ braindecode/models/hybrid.py,sha256=hA8jwD3_3LL71BxUjRM1dkhqlHU9E9hjuDokh-jBq-4,
52
53
  braindecode/models/ifnet.py,sha256=Y2bwfko3SDjD74AzgUEzgMhKJFGCCw_Q_Noh5VONEjQ,15137
53
54
  braindecode/models/labram.py,sha256=vcrpwiu4F-djtIPscFbtP2Y0jTosyR_cXnOMQQRGPLw,41798
54
55
  braindecode/models/msvtnet.py,sha256=hxeCLkHS6w2w89YlLfEPCyQ4XQQpt45bEYPiQJ9SFzY,12642
55
- braindecode/models/sccnet.py,sha256=baGsNpVRdyWzbkTizOthJoJGejLb8BxMpN9ODwZinio,7919
56
- braindecode/models/shallow_fbcsp.py,sha256=-sL6XCmCUZVhKKrC84-KWgwhWKQQvev1oNSmH_d6FA4,7499
56
+ braindecode/models/sccnet.py,sha256=ragbEzNrua0S84H4JR_j2QGLZWrFKGQ4CfIS2epIYEk,11919
57
+ braindecode/models/shallow_fbcsp.py,sha256=7U07DJBrm2JHV8v5ja-xuE5-IH5tfmryhJtrfO1n4jk,7531
57
58
  braindecode/models/signal_jepa.py,sha256=UeSkeAM3Qmx8bbAqHCj5nP-PtZM00_5SGA8ibo9mptc,37079
58
59
  braindecode/models/sinc_shallow.py,sha256=Ilv8K1XhMGiRTBtQdq7L595i6cEFYOBe0_UDv-LqL7s,11907
59
- braindecode/models/sleep_stager_blanco_2020.py,sha256=qPKMDLuv4J7et4dZHyTe-j0oB6ESYn9mA_aW7RMC-rU,6002
60
- braindecode/models/sleep_stager_chambon_2018.py,sha256=62x2Rdjd5UZDX8YlnfAtdRCrjLsPvPpnUweGElZLdkw,5213
61
- braindecode/models/sleep_stager_eldele_2021.py,sha256=-4ISuznykDy9ZFzUM-OeiGCwmgM3U-LuyoDSrhPbRDw,17555
62
- braindecode/models/sparcnet.py,sha256=eZMoJOxlcIyHPdQiX7KXUKuUBlAWkTwsXNWmNma_KAI,13941
63
- braindecode/models/summary.csv,sha256=l7HYYwv3Z69JRPVIhVq-wr_nC1J1KIz6IGw_zeRSk58,6110
60
+ braindecode/models/sleep_stager_blanco_2020.py,sha256=vXulnDYutEFLM0UPXyAI0YIj5QImUMVEmYZb78j34H8,6034
61
+ braindecode/models/sleep_stager_chambon_2018.py,sha256=8w8IR2PsfG0jSc3o0YVopgHpOvCHNIuMi7-QRJOYEW4,5245
62
+ braindecode/models/sparcnet.py,sha256=MG1OB91guI7ssKRk8GvWlzUvaxo_otaYnbEGzNUZVyg,13973
63
+ braindecode/models/summary.csv,sha256=NfrmnjyfDmWVe2zyNqgczEQcLI910BOS4sICtcKS3gc,6765
64
64
  braindecode/models/syncnet.py,sha256=nrWJC5ijCSWKVZyRn-dmOuc1t5vk2C6tx8U3U4j5d5Y,8362
65
65
  braindecode/models/tcn.py,sha256=SQu56H9zdbcbbDIXZVgZtJg7es8CRAJ7z-IBnmf4UWM,8158
66
- braindecode/models/tidnet.py,sha256=k7Q0yAnEBmq1sqhsvoV4-g8wfYSUQ-C3iYxfLp5m8xQ,11805
67
- braindecode/models/tsinception.py,sha256=EcfLDDJXZloh_vrKRuxAHYRZ1EVWlEKHNXqybTRrTbQ,10116
68
- braindecode/models/usleep.py,sha256=dFh3KiZITu13gMxcbPGoK4hq2ySDWzVSCQXkj1006w0,11605
69
- braindecode/models/util.py,sha256=VrhwG1YBGwKohCej6TmhrNAIoleQHRu3YdiBPuHFY_E,5302
66
+ braindecode/models/tidnet.py,sha256=HSUL1al6gaRbJ-BRYAAs4KDvLuKEvh0NnBfAsPeWMpM,11837
67
+ braindecode/models/tsinception.py,sha256=nnQxzpqRy9FPuN5xgh9fNQ386VbreQ_nZBSFNkSfal0,10356
68
+ braindecode/models/usleep.py,sha256=5uztUHX70T_LurqRob_XmVnKkZDwt74x2Iz181M7s54,17233
69
+ braindecode/models/util.py,sha256=VZGVPhUSsoP47pta0_UhC2-g5n5-EFZAW93ZVccrEHU,5232
70
70
  braindecode/modules/__init__.py,sha256=PD2LpeSHWW_MgEef7-G8ief5gheGObzsIoacchxWuyA,1756
71
71
  braindecode/modules/activation.py,sha256=lTO2IjZWBDeXZ4ZVDgLmTDmxHdqyAny3Fsy07HY9tmQ,1466
72
72
  braindecode/modules/attention.py,sha256=ISE11jXAvMqKpawZilg8i7lDX5mkuvpEplrh_CtGEkk,24102
@@ -93,9 +93,9 @@ braindecode/training/scoring.py,sha256=WRkwqbitA3m_dzRnGp2ZIZPge5Nhx9gAEQhIHzeH4
93
93
  braindecode/visualization/__init__.py,sha256=4EER_xHqZIDzEvmgUEm7K1bgNKpyZAIClR9ZCkMuY4M,240
94
94
  braindecode/visualization/confusion_matrices.py,sha256=qIWMLEHow5CJ7PhGggD8mnD55Le6xhma9HSzt4R33fc,9509
95
95
  braindecode/visualization/gradients.py,sha256=KZo-GA0uwiwty2_94j2IjmCR2SKcfPb1Bi3sQq7vpTk,2170
96
- braindecode-1.2.0.dev168908090.dist-info/licenses/LICENSE.txt,sha256=7rg7k6hyj8m9whQ7dpKbqnCssoOEx_Mbtqb4uSOjljE,1525
97
- braindecode-1.2.0.dev168908090.dist-info/licenses/NOTICE.txt,sha256=sOxuTbalPxTM8H6VqtvGbXCt_BoOF7JevEYG_knqbm4,620
98
- braindecode-1.2.0.dev168908090.dist-info/METADATA,sha256=k8t6kilsNd1r7kcjt85QWco9RF9s4XVwFDXSv3wqv8M,6883
99
- braindecode-1.2.0.dev168908090.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
100
- braindecode-1.2.0.dev168908090.dist-info/top_level.txt,sha256=pHsWQmSy0uhIez62-HA9j0iaXKvSbUL39ifFRkFnChA,12
101
- braindecode-1.2.0.dev168908090.dist-info/RECORD,,
96
+ braindecode-1.2.0.dev175267687.dist-info/licenses/LICENSE.txt,sha256=7rg7k6hyj8m9whQ7dpKbqnCssoOEx_Mbtqb4uSOjljE,1525
97
+ braindecode-1.2.0.dev175267687.dist-info/licenses/NOTICE.txt,sha256=sOxuTbalPxTM8H6VqtvGbXCt_BoOF7JevEYG_knqbm4,620
98
+ braindecode-1.2.0.dev175267687.dist-info/METADATA,sha256=oQss190zuseCAh-RsSLYE_Ak5oLrSIQzh0G2PrdZwNA,7129
99
+ braindecode-1.2.0.dev175267687.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
100
+ braindecode-1.2.0.dev175267687.dist-info/top_level.txt,sha256=pHsWQmSy0uhIez62-HA9j0iaXKvSbUL39ifFRkFnChA,12
101
+ braindecode-1.2.0.dev175267687.dist-info/RECORD,,
@@ -1,362 +0,0 @@
1
- # Authors: Robin Tibor Schirrmeister <robintibor@gmail.com>
2
- # Tonio Ball
3
- #
4
- # License: BSD-3
5
-
6
- import numpy as np
7
- import torch
8
- from einops.layers.torch import Rearrange
9
- from torch import nn
10
- from torch.nn import init
11
-
12
- from braindecode.models.base import EEGModuleMixin
13
- from braindecode.modules import (
14
- AvgPool2dWithConv,
15
- Ensure4d,
16
- SqueezeFinalOutput,
17
- )
18
-
19
-
20
- class EEGResNet(EEGModuleMixin, nn.Sequential):
21
- """EEGResNet from Schirrmeister et al. 2017 [Schirrmeister2017]_.
22
-
23
- .. figure:: https://onlinelibrary.wiley.com/cms/asset/bed1b768-809f-4bc6-b942-b36970d81271/hbm23730-fig-0003-m.jpg
24
- :align: center
25
- :alt: EEGResNet Architecture
26
-
27
- Model described in [Schirrmeister2017]_.
28
-
29
- Parameters
30
- ----------
31
- in_chans :
32
- Alias for ``n_chans``.
33
- n_classes :
34
- Alias for ``n_outputs``.
35
- input_window_samples :
36
- Alias for ``n_times``.
37
- activation: nn.Module, default=nn.ELU
38
- Activation function class to apply. Should be a PyTorch activation
39
- module class like ``nn.ReLU`` or ``nn.ELU``. Default is ``nn.ELU``.
40
-
41
- References
42
- ----------
43
- .. [Schirrmeister2017] Schirrmeister, R. T., Springenberg, J. T., Fiederer,
44
- L. D. J., Glasstetter, M., Eggensperger, K., Tangermann, M., Hutter, F.
45
- & Ball, T. (2017). Deep learning with convolutional neural networks for ,
46
- EEG decoding and visualization. Human Brain Mapping, Aug. 2017.
47
- Online: http://dx.doi.org/10.1002/hbm.23730
48
- """
49
-
50
- def __init__(
51
- self,
52
- n_chans=None,
53
- n_outputs=None,
54
- n_times=None,
55
- final_pool_length="auto",
56
- n_first_filters=20,
57
- n_layers_per_block=2,
58
- first_filter_length=3,
59
- activation=nn.ELU,
60
- split_first_layer=True,
61
- batch_norm_alpha=0.1,
62
- batch_norm_epsilon=1e-4,
63
- conv_weight_init_fn=lambda w: init.kaiming_normal_(w, a=0),
64
- chs_info=None,
65
- input_window_seconds=None,
66
- sfreq=250,
67
- ):
68
- super().__init__(
69
- n_outputs=n_outputs,
70
- n_chans=n_chans,
71
- chs_info=chs_info,
72
- n_times=n_times,
73
- input_window_seconds=input_window_seconds,
74
- sfreq=sfreq,
75
- )
76
- del n_outputs, n_chans, chs_info, n_times, input_window_seconds, sfreq
77
-
78
- if final_pool_length == "auto":
79
- assert self.n_times is not None
80
- assert first_filter_length % 2 == 1
81
- self.final_pool_length = final_pool_length
82
- self.n_first_filters = n_first_filters
83
- self.n_layers_per_block = n_layers_per_block
84
- self.first_filter_length = first_filter_length
85
- self.nonlinearity = activation
86
- self.split_first_layer = split_first_layer
87
- self.batch_norm_alpha = batch_norm_alpha
88
- self.batch_norm_epsilon = batch_norm_epsilon
89
- self.conv_weight_init_fn = conv_weight_init_fn
90
-
91
- self.mapping = {
92
- "conv_classifier.weight": "final_layer.conv_classifier.weight",
93
- "conv_classifier.bias": "final_layer.conv_classifier.bias",
94
- }
95
-
96
- self.add_module("ensuredims", Ensure4d())
97
- if self.split_first_layer:
98
- self.add_module("dimshuffle", Rearrange("batch C T 1 -> batch 1 T C"))
99
- self.add_module(
100
- "conv_time",
101
- nn.Conv2d(
102
- 1,
103
- self.n_first_filters,
104
- (self.first_filter_length, 1),
105
- stride=1,
106
- padding=(self.first_filter_length // 2, 0),
107
- ),
108
- )
109
- self.add_module(
110
- "conv_spat",
111
- nn.Conv2d(
112
- self.n_first_filters,
113
- self.n_first_filters,
114
- (1, self.n_chans),
115
- stride=(1, 1),
116
- bias=False,
117
- ),
118
- )
119
- else:
120
- self.add_module(
121
- "conv_time",
122
- nn.Conv2d(
123
- self.n_chans,
124
- self.n_first_filters,
125
- (self.first_filter_length, 1),
126
- stride=(1, 1),
127
- padding=(self.first_filter_length // 2, 0),
128
- bias=False,
129
- ),
130
- )
131
- n_filters_conv = self.n_first_filters
132
- self.add_module(
133
- "bnorm",
134
- nn.BatchNorm2d(
135
- n_filters_conv, momentum=self.batch_norm_alpha, affine=True, eps=1e-5
136
- ),
137
- )
138
- self.add_module("conv_nonlin", self.nonlinearity())
139
- cur_dilation = np.array([1, 1])
140
- n_cur_filters = n_filters_conv
141
- i_block = 1
142
- for i_layer in range(self.n_layers_per_block):
143
- self.add_module(
144
- "res_{:d}_{:d}".format(i_block, i_layer),
145
- _ResidualBlock(n_cur_filters, n_cur_filters, dilation=cur_dilation),
146
- )
147
- i_block += 1
148
- cur_dilation[0] *= 2
149
- n_out_filters = int(2 * n_cur_filters)
150
- self.add_module(
151
- "res_{:d}_{:d}".format(i_block, 0),
152
- _ResidualBlock(
153
- n_cur_filters,
154
- n_out_filters,
155
- dilation=cur_dilation,
156
- ),
157
- )
158
- n_cur_filters = n_out_filters
159
- for i_layer in range(1, self.n_layers_per_block):
160
- self.add_module(
161
- "res_{:d}_{:d}".format(i_block, i_layer),
162
- _ResidualBlock(n_cur_filters, n_cur_filters, dilation=cur_dilation),
163
- )
164
-
165
- i_block += 1
166
- cur_dilation[0] *= 2
167
- n_out_filters = int(1.5 * n_cur_filters)
168
- self.add_module(
169
- "res_{:d}_{:d}".format(i_block, 0),
170
- _ResidualBlock(
171
- n_cur_filters,
172
- n_out_filters,
173
- dilation=cur_dilation,
174
- ),
175
- )
176
- n_cur_filters = n_out_filters
177
- for i_layer in range(1, self.n_layers_per_block):
178
- self.add_module(
179
- "res_{:d}_{:d}".format(i_block, i_layer),
180
- _ResidualBlock(n_cur_filters, n_cur_filters, dilation=cur_dilation),
181
- )
182
-
183
- i_block += 1
184
- cur_dilation[0] *= 2
185
- self.add_module(
186
- "res_{:d}_{:d}".format(i_block, 0),
187
- _ResidualBlock(
188
- n_cur_filters,
189
- n_cur_filters,
190
- dilation=cur_dilation,
191
- ),
192
- )
193
- for i_layer in range(1, self.n_layers_per_block):
194
- self.add_module(
195
- "res_{:d}_{:d}".format(i_block, i_layer),
196
- _ResidualBlock(n_cur_filters, n_cur_filters, dilation=cur_dilation),
197
- )
198
-
199
- i_block += 1
200
- cur_dilation[0] *= 2
201
- self.add_module(
202
- "res_{:d}_{:d}".format(i_block, 0),
203
- _ResidualBlock(
204
- n_cur_filters,
205
- n_cur_filters,
206
- dilation=cur_dilation,
207
- ),
208
- )
209
- for i_layer in range(1, self.n_layers_per_block):
210
- self.add_module(
211
- "res_{:d}_{:d}".format(i_block, i_layer),
212
- _ResidualBlock(n_cur_filters, n_cur_filters, dilation=cur_dilation),
213
- )
214
-
215
- i_block += 1
216
- cur_dilation[0] *= 2
217
- self.add_module(
218
- "res_{:d}_{:d}".format(i_block, 0),
219
- _ResidualBlock(
220
- n_cur_filters,
221
- n_cur_filters,
222
- dilation=cur_dilation,
223
- ),
224
- )
225
- for i_layer in range(1, self.n_layers_per_block):
226
- self.add_module(
227
- "res_{:d}_{:d}".format(i_block, i_layer),
228
- _ResidualBlock(n_cur_filters, n_cur_filters, dilation=cur_dilation),
229
- )
230
- i_block += 1
231
- cur_dilation[0] *= 2
232
- self.add_module(
233
- "res_{:d}_{:d}".format(i_block, 0),
234
- _ResidualBlock(
235
- n_cur_filters,
236
- n_cur_filters,
237
- dilation=cur_dilation,
238
- ),
239
- )
240
- for i_layer in range(1, self.n_layers_per_block):
241
- self.add_module(
242
- "res_{:d}_{:d}".format(i_block, i_layer),
243
- _ResidualBlock(n_cur_filters, n_cur_filters, dilation=cur_dilation),
244
- )
245
-
246
- self.eval()
247
- if self.final_pool_length == "auto":
248
- self.add_module("mean_pool", nn.AdaptiveAvgPool2d((1, 1)))
249
- else:
250
- pool_dilation = int(cur_dilation[0]), int(cur_dilation[1])
251
- self.add_module(
252
- "mean_pool",
253
- AvgPool2dWithConv(
254
- (self.final_pool_length, 1), (1, 1), dilation=pool_dilation
255
- ),
256
- )
257
-
258
- # Incorporating classification module and subsequent ones in one final layer
259
- module = nn.Sequential()
260
-
261
- module.add_module(
262
- "conv_classifier",
263
- nn.Conv2d(
264
- n_cur_filters,
265
- self.n_outputs,
266
- (1, 1),
267
- bias=True,
268
- ),
269
- )
270
-
271
- module.add_module("squeeze", SqueezeFinalOutput())
272
-
273
- self.add_module("final_layer", module)
274
-
275
- # Initialize all weights
276
- self.apply(lambda module: self._weights_init(module, self.conv_weight_init_fn))
277
-
278
- # Start in train mode
279
- self.train()
280
-
281
- @staticmethod
282
- def _weights_init(module, conv_weight_init_fn):
283
- """
284
- initialize weights
285
- """
286
- classname = module.__class__.__name__
287
- if "Conv" in classname and classname != "AvgPool2dWithConv":
288
- conv_weight_init_fn(module.weight)
289
- if module.bias is not None:
290
- init.constant_(module.bias, 0)
291
- elif "BatchNorm" in classname:
292
- init.constant_(module.weight, 1)
293
- init.constant_(module.bias, 0)
294
-
295
-
296
- class _ResidualBlock(nn.Module):
297
- """
298
- create a residual learning building block with two stacked 3x3 convlayers as in paper
299
- """
300
-
301
- def __init__(
302
- self,
303
- in_filters,
304
- out_num_filters,
305
- dilation,
306
- filter_time_length=3,
307
- nonlinearity: nn.Module = nn.ELU,
308
- batch_norm_alpha=0.1,
309
- batch_norm_epsilon=1e-4,
310
- ):
311
- super(_ResidualBlock, self).__init__()
312
- time_padding = int((filter_time_length - 1) * dilation[0])
313
- assert time_padding % 2 == 0
314
- time_padding = int(time_padding // 2)
315
- dilation = (int(dilation[0]), int(dilation[1]))
316
- assert (out_num_filters - in_filters) % 2 == 0, (
317
- "Need even number of extra channels in order to be able to pad correctly"
318
- )
319
- self.n_pad_chans = out_num_filters - in_filters
320
-
321
- self.conv_1 = nn.Conv2d(
322
- in_filters,
323
- out_num_filters,
324
- (filter_time_length, 1),
325
- stride=(1, 1),
326
- dilation=dilation,
327
- padding=(time_padding, 0),
328
- )
329
- self.bn1 = nn.BatchNorm2d(
330
- out_num_filters,
331
- momentum=batch_norm_alpha,
332
- affine=True,
333
- eps=batch_norm_epsilon,
334
- )
335
- self.conv_2 = nn.Conv2d(
336
- out_num_filters,
337
- out_num_filters,
338
- (filter_time_length, 1),
339
- stride=(1, 1),
340
- dilation=dilation,
341
- padding=(time_padding, 0),
342
- )
343
- self.bn2 = nn.BatchNorm2d(
344
- out_num_filters,
345
- momentum=batch_norm_alpha,
346
- affine=True,
347
- eps=batch_norm_epsilon,
348
- )
349
- # also see https://mail.google.com/mail/u/0/#search/ilya+joos/1576137dd34c3127
350
- # for resnet options as ilya used them
351
- self.nonlinearity = nonlinearity()
352
-
353
- def forward(self, x):
354
- stack_1 = self.nonlinearity(self.bn1(self.conv_1(x)))
355
- stack_2 = self.bn2(self.conv_2(stack_1)) # next nonlin after sum
356
- if self.n_pad_chans != 0:
357
- zeros_for_padding = x.new_zeros(
358
- (x.shape[0], self.n_pad_chans // 2, x.shape[2], x.shape[3])
359
- )
360
- x = torch.cat((zeros_for_padding, x, zeros_for_padding), dim=1)
361
- out = self.nonlinearity(x + stack_2)
362
- return out