braindecode 1.2.0.dev168908090__py3-none-any.whl → 1.2.0.dev175267687__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of braindecode might be problematic. Click here for more details.
- braindecode/datasets/experimental.py +218 -0
- braindecode/models/__init__.py +6 -8
- braindecode/models/atcnet.py +1 -1
- braindecode/models/attentionbasenet.py +151 -26
- braindecode/models/{sleep_stager_eldele_2021.py → attn_sleep.py} +12 -2
- braindecode/models/ctnet.py +1 -1
- braindecode/models/deep4.py +6 -2
- braindecode/models/deepsleepnet.py +118 -5
- braindecode/models/eegconformer.py +15 -12
- braindecode/models/eeginception_erp.py +76 -7
- braindecode/models/eeginception_mi.py +2 -0
- braindecode/models/eegnet.py +25 -189
- braindecode/models/eegnex.py +113 -6
- braindecode/models/eegsimpleconv.py +2 -0
- braindecode/models/eegtcnet.py +1 -1
- braindecode/models/sccnet.py +79 -8
- braindecode/models/shallow_fbcsp.py +2 -0
- braindecode/models/sleep_stager_blanco_2020.py +2 -0
- braindecode/models/sleep_stager_chambon_2018.py +2 -0
- braindecode/models/sparcnet.py +2 -0
- braindecode/models/summary.csv +39 -41
- braindecode/models/tidnet.py +2 -0
- braindecode/models/tsinception.py +15 -3
- braindecode/models/usleep.py +103 -9
- braindecode/models/util.py +5 -5
- braindecode/version.py +1 -1
- {braindecode-1.2.0.dev168908090.dist-info → braindecode-1.2.0.dev175267687.dist-info}/METADATA +7 -2
- {braindecode-1.2.0.dev168908090.dist-info → braindecode-1.2.0.dev175267687.dist-info}/RECORD +32 -32
- braindecode/models/eegresnet.py +0 -362
- {braindecode-1.2.0.dev168908090.dist-info → braindecode-1.2.0.dev175267687.dist-info}/WHEEL +0 -0
- {braindecode-1.2.0.dev168908090.dist-info → braindecode-1.2.0.dev175267687.dist-info}/licenses/LICENSE.txt +0 -0
- {braindecode-1.2.0.dev168908090.dist-info → braindecode-1.2.0.dev175267687.dist-info}/licenses/NOTICE.txt +0 -0
- {braindecode-1.2.0.dev168908090.dist-info → braindecode-1.2.0.dev175267687.dist-info}/top_level.txt +0 -0
{braindecode-1.2.0.dev168908090.dist-info → braindecode-1.2.0.dev175267687.dist-info}/METADATA
RENAMED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: braindecode
|
|
3
|
-
Version: 1.2.0.
|
|
3
|
+
Version: 1.2.0.dev175267687
|
|
4
4
|
Summary: Deep learning software to decode EEG, ECG or MEG signals
|
|
5
5
|
Author-email: Robin Tibor Schirrmeister <robintibor@gmail.com>
|
|
6
6
|
Maintainer-email: Alexandre Gramfort <agramfort@meta.com>, Bruno Aristimunha Pinto <b.aristimunha@gmail.com>, Robin Tibor Schirrmeister <robintibor@gmail.com>
|
|
@@ -17,7 +17,7 @@ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
|
17
17
|
Classifier: Programming Language :: Python :: 3.10
|
|
18
18
|
Classifier: Programming Language :: Python :: 3.11
|
|
19
19
|
Classifier: Programming Language :: Python :: 3.12
|
|
20
|
-
Requires-Python:
|
|
20
|
+
Requires-Python: >=3.10
|
|
21
21
|
Description-Content-Type: text/x-rst
|
|
22
22
|
License-File: LICENSE.txt
|
|
23
23
|
License-File: NOTICE.txt
|
|
@@ -49,6 +49,10 @@ Requires-Dist: mypy; extra == "tests"
|
|
|
49
49
|
Provides-Extra: docs
|
|
50
50
|
Requires-Dist: sphinx_gallery; extra == "docs"
|
|
51
51
|
Requires-Dist: sphinx_rtd_theme; extra == "docs"
|
|
52
|
+
Requires-Dist: sphinx-autodoc-typehints; extra == "docs"
|
|
53
|
+
Requires-Dist: sphinx-autobuild; extra == "docs"
|
|
54
|
+
Requires-Dist: sphinxcontrib-bibtex; extra == "docs"
|
|
55
|
+
Requires-Dist: sphinx_sitemap; extra == "docs"
|
|
52
56
|
Requires-Dist: pydata_sphinx_theme; extra == "docs"
|
|
53
57
|
Requires-Dist: numpydoc; extra == "docs"
|
|
54
58
|
Requires-Dist: memory_profiler; extra == "docs"
|
|
@@ -59,6 +63,7 @@ Requires-Dist: lightning; extra == "docs"
|
|
|
59
63
|
Requires-Dist: seaborn; extra == "docs"
|
|
60
64
|
Requires-Dist: pre-commit; extra == "docs"
|
|
61
65
|
Requires-Dist: openneuro-py; extra == "docs"
|
|
66
|
+
Requires-Dist: plotly; extra == "docs"
|
|
62
67
|
Provides-Extra: all
|
|
63
68
|
Requires-Dist: braindecode[docs,moabb,tests]; extra == "all"
|
|
64
69
|
Dynamic: license-file
|
{braindecode-1.2.0.dev168908090.dist-info → braindecode-1.2.0.dev175267687.dist-info}/RECORD
RENAMED
|
@@ -3,7 +3,7 @@ braindecode/classifier.py,sha256=k9vSCtfQbld0YVleDi5rrrmk6k_k5JYEPPBYcNxYjZ8,980
|
|
|
3
3
|
braindecode/eegneuralnet.py,sha256=dz8k_-2jV7WqkaX4bQG-dmr-vRT7ZtOwJqomXyC9PTw,15287
|
|
4
4
|
braindecode/regressor.py,sha256=VLfrpiXklwI4onkwue3QmzlBWcvspu0tlrLo9RT1Oiw,9375
|
|
5
5
|
braindecode/util.py,sha256=J-tBcDJNlMTIFW2mfOy6Ko0nsgdP4obRoEVDeg2rFH0,12686
|
|
6
|
-
braindecode/version.py,sha256=
|
|
6
|
+
braindecode/version.py,sha256=DjAGynn_F-_dyPfKBBkITwv_h_JYqDSa0IISxiwpVAE,35
|
|
7
7
|
braindecode/augmentation/__init__.py,sha256=LG7ONqCufYAF9NZt8POIp10lYXb8iSueYkF-CWGK2Ls,1001
|
|
8
8
|
braindecode/augmentation/base.py,sha256=gg7wYsVfa9jfqBddtE03B5ZrPHFFmPl2sa3LOrRnGfo,7325
|
|
9
9
|
braindecode/augmentation/functional.py,sha256=ygkMNEFHaUdRQfk7meMML19FnM406Uf34h-ztKXdJwM,37978
|
|
@@ -13,6 +13,7 @@ braindecode/datasets/base.py,sha256=ED8RQWusMyWf0T7b_HXwouR2Ax47qppEc506AlSzBt0,
|
|
|
13
13
|
braindecode/datasets/bbci.py,sha256=BC9o1thEyYBREAo930O7zZz3xZB-l4Odt5j8E_1huXI,19277
|
|
14
14
|
braindecode/datasets/bcicomp.py,sha256=ER_XmqxhpoO2FWELMesQXQ40OTe7BXoy7nYDSiZG9kE,7556
|
|
15
15
|
braindecode/datasets/bids.py,sha256=4asq1HyQHgJjwW7w-GMlvTVQhi-hR2HWLJ8Z__UrUS4,8846
|
|
16
|
+
braindecode/datasets/experimental.py,sha256=Z_uzMNA875-l878LAv7bWiWYJX3QAefmb5quBkcPp7M,8514
|
|
16
17
|
braindecode/datasets/mne.py,sha256=Dg6RZAAwd8TVGrvLOPF5B_JrbyGUWg52vWmn6fLMOQM,6135
|
|
17
18
|
braindecode/datasets/moabb.py,sha256=JmBcFV7QJT8GCgLNNKWgxJVnEVnO5wd9U_uiIqTIxDM,7091
|
|
18
19
|
braindecode/datasets/nmt.py,sha256=E4T8OYBEwWRSjh7VFzmyxaZbf5ufFVEBYYmQEd1ghUU,10430
|
|
@@ -26,25 +27,25 @@ braindecode/datautil/util.py,sha256=ZfDoxLieKsgI8xcWQqebV-vJ5pJYRvRRHkEwhwpgoKU,
|
|
|
26
27
|
braindecode/functional/__init__.py,sha256=JPUDFeKtfogEzfrwPaZRBmxexPjBw7AglYMlImaAnWc,413
|
|
27
28
|
braindecode/functional/functions.py,sha256=CoEweM6YLhigx0tNmmz6yAc8iQ078sTFY2GeCjK5fFs,8622
|
|
28
29
|
braindecode/functional/initialization.py,sha256=BUSC7y2TMsfShpMYBVwm3xg3ODFqWp-STH7yD4sn8zk,1388
|
|
29
|
-
braindecode/models/__init__.py,sha256=
|
|
30
|
-
braindecode/models/atcnet.py,sha256=
|
|
31
|
-
braindecode/models/attentionbasenet.py,sha256=
|
|
30
|
+
braindecode/models/__init__.py,sha256=v2Pn0H-rM_9xr1EEoKIFygmhbS9r52qh8XwFzXuhK70,2455
|
|
31
|
+
braindecode/models/atcnet.py,sha256=jA_18BOaasmiqGbLJOvfBY5q2xHtKdoRFKzN_aqpDoQ,32107
|
|
32
|
+
braindecode/models/attentionbasenet.py,sha256=AK78VvwrZXyJY20zadzDUHl17C-5zcWCd5xPRN7Lr4o,30385
|
|
33
|
+
braindecode/models/attn_sleep.py,sha256=m6sdFfD4en2hHf_TpotLPC1hVweJcYZvjgf12bV5FZg,17822
|
|
32
34
|
braindecode/models/base.py,sha256=9icrWNZBGbh_VLyB9m8g_K1QyK7s3mh8X-hJ29gEbWs,10802
|
|
33
35
|
braindecode/models/biot.py,sha256=T4PymX3penMJcrdfb5Nq6B3P-jyP2laAIu_R9o3uCXo,17512
|
|
34
36
|
braindecode/models/contrawr.py,sha256=eeR_ik4gNZ3rJLM6Mw9gJ2gTMkZ8CU8C4rN_GQMQTAE,10044
|
|
35
|
-
braindecode/models/ctnet.py,sha256
|
|
36
|
-
braindecode/models/deep4.py,sha256
|
|
37
|
-
braindecode/models/deepsleepnet.py,sha256=
|
|
38
|
-
braindecode/models/eegconformer.py,sha256=
|
|
39
|
-
braindecode/models/eeginception_erp.py,sha256=
|
|
40
|
-
braindecode/models/eeginception_mi.py,sha256=
|
|
37
|
+
braindecode/models/ctnet.py,sha256=ce5F31q2weBKvg7PL80iDm7za9fhGaCFvNfHoJW_dtg,17315
|
|
38
|
+
braindecode/models/deep4.py,sha256=-s-R3H7so2xlSiPsU226eSwscv1X9xJMYLm3LhZ3mSU,14645
|
|
39
|
+
braindecode/models/deepsleepnet.py,sha256=wGSAXW73Ga1-HFbn7kXiLeGsJceiqZyMLZnX2UZZXWw,15207
|
|
40
|
+
braindecode/models/eegconformer.py,sha256=rxMAmqErDVLq7nS77CnTtpcC3C2OR_EoZ8-jG-dKP9I,17433
|
|
41
|
+
braindecode/models/eeginception_erp.py,sha256=FYXoM-u4kOodMzGgvKDn7IwJwHl9Z0iiWx9bVHiO9EY,16324
|
|
42
|
+
braindecode/models/eeginception_mi.py,sha256=VoWtsaWj1xQ4FlrvCbnPvo8eosufYUmTrL4uvFtqKcg,12456
|
|
41
43
|
braindecode/models/eegitnet.py,sha256=feXFmPCd-Ejxt7jgWPen1Ag0-oSclDVQai0Atwu9d_A,9827
|
|
42
44
|
braindecode/models/eegminer.py,sha256=ouKZah9Q7_sxT7DJJMcPObwVxNQE87sEljJg6QwiQNw,9847
|
|
43
|
-
braindecode/models/eegnet.py,sha256=
|
|
44
|
-
braindecode/models/eegnex.py,sha256=
|
|
45
|
-
braindecode/models/
|
|
46
|
-
braindecode/models/
|
|
47
|
-
braindecode/models/eegtcnet.py,sha256=np-93Ttctp2uaEYpMrfXfH5bJmCOUZZHLjv8GJEEym4,10830
|
|
45
|
+
braindecode/models/eegnet.py,sha256=dIaHZoz7xMII1qKrS0___IWdy1xg2QrMMiqUgTJM9E8,13682
|
|
46
|
+
braindecode/models/eegnex.py,sha256=eahHolFl15LwNWeC5qjQqUGqURibQZIV425rI1p-dG8,13604
|
|
47
|
+
braindecode/models/eegsimpleconv.py,sha256=6V5ZQNWijmd3-2wv7lJB_HGBS3wHWWVrKoNIeWTXu-w,7300
|
|
48
|
+
braindecode/models/eegtcnet.py,sha256=Y53uJEX_hoB6eHCew9SIfzNxCYea8UhljDARJTk-Tq8,10837
|
|
48
49
|
braindecode/models/fbcnet.py,sha256=RBCLOaiUvivfsT2mq6FN0Kp1-rR3iB0ElzVpHxRl4oI,7486
|
|
49
50
|
braindecode/models/fblightconvnet.py,sha256=d5MwhawhkjilAMo0ckaYMxJhdGMEuorWgHX-TBgwv6s,11041
|
|
50
51
|
braindecode/models/fbmsnet.py,sha256=9bZn2_n1dTrI1Qh3Sz9zMZnH_a-Yq-13UHYSmF6r_UE,11659
|
|
@@ -52,21 +53,20 @@ braindecode/models/hybrid.py,sha256=hA8jwD3_3LL71BxUjRM1dkhqlHU9E9hjuDokh-jBq-4,
|
|
|
52
53
|
braindecode/models/ifnet.py,sha256=Y2bwfko3SDjD74AzgUEzgMhKJFGCCw_Q_Noh5VONEjQ,15137
|
|
53
54
|
braindecode/models/labram.py,sha256=vcrpwiu4F-djtIPscFbtP2Y0jTosyR_cXnOMQQRGPLw,41798
|
|
54
55
|
braindecode/models/msvtnet.py,sha256=hxeCLkHS6w2w89YlLfEPCyQ4XQQpt45bEYPiQJ9SFzY,12642
|
|
55
|
-
braindecode/models/sccnet.py,sha256=
|
|
56
|
-
braindecode/models/shallow_fbcsp.py,sha256
|
|
56
|
+
braindecode/models/sccnet.py,sha256=ragbEzNrua0S84H4JR_j2QGLZWrFKGQ4CfIS2epIYEk,11919
|
|
57
|
+
braindecode/models/shallow_fbcsp.py,sha256=7U07DJBrm2JHV8v5ja-xuE5-IH5tfmryhJtrfO1n4jk,7531
|
|
57
58
|
braindecode/models/signal_jepa.py,sha256=UeSkeAM3Qmx8bbAqHCj5nP-PtZM00_5SGA8ibo9mptc,37079
|
|
58
59
|
braindecode/models/sinc_shallow.py,sha256=Ilv8K1XhMGiRTBtQdq7L595i6cEFYOBe0_UDv-LqL7s,11907
|
|
59
|
-
braindecode/models/sleep_stager_blanco_2020.py,sha256=
|
|
60
|
-
braindecode/models/sleep_stager_chambon_2018.py,sha256=
|
|
61
|
-
braindecode/models/
|
|
62
|
-
braindecode/models/
|
|
63
|
-
braindecode/models/summary.csv,sha256=l7HYYwv3Z69JRPVIhVq-wr_nC1J1KIz6IGw_zeRSk58,6110
|
|
60
|
+
braindecode/models/sleep_stager_blanco_2020.py,sha256=vXulnDYutEFLM0UPXyAI0YIj5QImUMVEmYZb78j34H8,6034
|
|
61
|
+
braindecode/models/sleep_stager_chambon_2018.py,sha256=8w8IR2PsfG0jSc3o0YVopgHpOvCHNIuMi7-QRJOYEW4,5245
|
|
62
|
+
braindecode/models/sparcnet.py,sha256=MG1OB91guI7ssKRk8GvWlzUvaxo_otaYnbEGzNUZVyg,13973
|
|
63
|
+
braindecode/models/summary.csv,sha256=NfrmnjyfDmWVe2zyNqgczEQcLI910BOS4sICtcKS3gc,6765
|
|
64
64
|
braindecode/models/syncnet.py,sha256=nrWJC5ijCSWKVZyRn-dmOuc1t5vk2C6tx8U3U4j5d5Y,8362
|
|
65
65
|
braindecode/models/tcn.py,sha256=SQu56H9zdbcbbDIXZVgZtJg7es8CRAJ7z-IBnmf4UWM,8158
|
|
66
|
-
braindecode/models/tidnet.py,sha256=
|
|
67
|
-
braindecode/models/tsinception.py,sha256=
|
|
68
|
-
braindecode/models/usleep.py,sha256=
|
|
69
|
-
braindecode/models/util.py,sha256=
|
|
66
|
+
braindecode/models/tidnet.py,sha256=HSUL1al6gaRbJ-BRYAAs4KDvLuKEvh0NnBfAsPeWMpM,11837
|
|
67
|
+
braindecode/models/tsinception.py,sha256=nnQxzpqRy9FPuN5xgh9fNQ386VbreQ_nZBSFNkSfal0,10356
|
|
68
|
+
braindecode/models/usleep.py,sha256=5uztUHX70T_LurqRob_XmVnKkZDwt74x2Iz181M7s54,17233
|
|
69
|
+
braindecode/models/util.py,sha256=VZGVPhUSsoP47pta0_UhC2-g5n5-EFZAW93ZVccrEHU,5232
|
|
70
70
|
braindecode/modules/__init__.py,sha256=PD2LpeSHWW_MgEef7-G8ief5gheGObzsIoacchxWuyA,1756
|
|
71
71
|
braindecode/modules/activation.py,sha256=lTO2IjZWBDeXZ4ZVDgLmTDmxHdqyAny3Fsy07HY9tmQ,1466
|
|
72
72
|
braindecode/modules/attention.py,sha256=ISE11jXAvMqKpawZilg8i7lDX5mkuvpEplrh_CtGEkk,24102
|
|
@@ -93,9 +93,9 @@ braindecode/training/scoring.py,sha256=WRkwqbitA3m_dzRnGp2ZIZPge5Nhx9gAEQhIHzeH4
|
|
|
93
93
|
braindecode/visualization/__init__.py,sha256=4EER_xHqZIDzEvmgUEm7K1bgNKpyZAIClR9ZCkMuY4M,240
|
|
94
94
|
braindecode/visualization/confusion_matrices.py,sha256=qIWMLEHow5CJ7PhGggD8mnD55Le6xhma9HSzt4R33fc,9509
|
|
95
95
|
braindecode/visualization/gradients.py,sha256=KZo-GA0uwiwty2_94j2IjmCR2SKcfPb1Bi3sQq7vpTk,2170
|
|
96
|
-
braindecode-1.2.0.
|
|
97
|
-
braindecode-1.2.0.
|
|
98
|
-
braindecode-1.2.0.
|
|
99
|
-
braindecode-1.2.0.
|
|
100
|
-
braindecode-1.2.0.
|
|
101
|
-
braindecode-1.2.0.
|
|
96
|
+
braindecode-1.2.0.dev175267687.dist-info/licenses/LICENSE.txt,sha256=7rg7k6hyj8m9whQ7dpKbqnCssoOEx_Mbtqb4uSOjljE,1525
|
|
97
|
+
braindecode-1.2.0.dev175267687.dist-info/licenses/NOTICE.txt,sha256=sOxuTbalPxTM8H6VqtvGbXCt_BoOF7JevEYG_knqbm4,620
|
|
98
|
+
braindecode-1.2.0.dev175267687.dist-info/METADATA,sha256=oQss190zuseCAh-RsSLYE_Ak5oLrSIQzh0G2PrdZwNA,7129
|
|
99
|
+
braindecode-1.2.0.dev175267687.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
100
|
+
braindecode-1.2.0.dev175267687.dist-info/top_level.txt,sha256=pHsWQmSy0uhIez62-HA9j0iaXKvSbUL39ifFRkFnChA,12
|
|
101
|
+
braindecode-1.2.0.dev175267687.dist-info/RECORD,,
|
braindecode/models/eegresnet.py
DELETED
|
@@ -1,362 +0,0 @@
|
|
|
1
|
-
# Authors: Robin Tibor Schirrmeister <robintibor@gmail.com>
|
|
2
|
-
# Tonio Ball
|
|
3
|
-
#
|
|
4
|
-
# License: BSD-3
|
|
5
|
-
|
|
6
|
-
import numpy as np
|
|
7
|
-
import torch
|
|
8
|
-
from einops.layers.torch import Rearrange
|
|
9
|
-
from torch import nn
|
|
10
|
-
from torch.nn import init
|
|
11
|
-
|
|
12
|
-
from braindecode.models.base import EEGModuleMixin
|
|
13
|
-
from braindecode.modules import (
|
|
14
|
-
AvgPool2dWithConv,
|
|
15
|
-
Ensure4d,
|
|
16
|
-
SqueezeFinalOutput,
|
|
17
|
-
)
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
class EEGResNet(EEGModuleMixin, nn.Sequential):
|
|
21
|
-
"""EEGResNet from Schirrmeister et al. 2017 [Schirrmeister2017]_.
|
|
22
|
-
|
|
23
|
-
.. figure:: https://onlinelibrary.wiley.com/cms/asset/bed1b768-809f-4bc6-b942-b36970d81271/hbm23730-fig-0003-m.jpg
|
|
24
|
-
:align: center
|
|
25
|
-
:alt: EEGResNet Architecture
|
|
26
|
-
|
|
27
|
-
Model described in [Schirrmeister2017]_.
|
|
28
|
-
|
|
29
|
-
Parameters
|
|
30
|
-
----------
|
|
31
|
-
in_chans :
|
|
32
|
-
Alias for ``n_chans``.
|
|
33
|
-
n_classes :
|
|
34
|
-
Alias for ``n_outputs``.
|
|
35
|
-
input_window_samples :
|
|
36
|
-
Alias for ``n_times``.
|
|
37
|
-
activation: nn.Module, default=nn.ELU
|
|
38
|
-
Activation function class to apply. Should be a PyTorch activation
|
|
39
|
-
module class like ``nn.ReLU`` or ``nn.ELU``. Default is ``nn.ELU``.
|
|
40
|
-
|
|
41
|
-
References
|
|
42
|
-
----------
|
|
43
|
-
.. [Schirrmeister2017] Schirrmeister, R. T., Springenberg, J. T., Fiederer,
|
|
44
|
-
L. D. J., Glasstetter, M., Eggensperger, K., Tangermann, M., Hutter, F.
|
|
45
|
-
& Ball, T. (2017). Deep learning with convolutional neural networks for ,
|
|
46
|
-
EEG decoding and visualization. Human Brain Mapping, Aug. 2017.
|
|
47
|
-
Online: http://dx.doi.org/10.1002/hbm.23730
|
|
48
|
-
"""
|
|
49
|
-
|
|
50
|
-
def __init__(
|
|
51
|
-
self,
|
|
52
|
-
n_chans=None,
|
|
53
|
-
n_outputs=None,
|
|
54
|
-
n_times=None,
|
|
55
|
-
final_pool_length="auto",
|
|
56
|
-
n_first_filters=20,
|
|
57
|
-
n_layers_per_block=2,
|
|
58
|
-
first_filter_length=3,
|
|
59
|
-
activation=nn.ELU,
|
|
60
|
-
split_first_layer=True,
|
|
61
|
-
batch_norm_alpha=0.1,
|
|
62
|
-
batch_norm_epsilon=1e-4,
|
|
63
|
-
conv_weight_init_fn=lambda w: init.kaiming_normal_(w, a=0),
|
|
64
|
-
chs_info=None,
|
|
65
|
-
input_window_seconds=None,
|
|
66
|
-
sfreq=250,
|
|
67
|
-
):
|
|
68
|
-
super().__init__(
|
|
69
|
-
n_outputs=n_outputs,
|
|
70
|
-
n_chans=n_chans,
|
|
71
|
-
chs_info=chs_info,
|
|
72
|
-
n_times=n_times,
|
|
73
|
-
input_window_seconds=input_window_seconds,
|
|
74
|
-
sfreq=sfreq,
|
|
75
|
-
)
|
|
76
|
-
del n_outputs, n_chans, chs_info, n_times, input_window_seconds, sfreq
|
|
77
|
-
|
|
78
|
-
if final_pool_length == "auto":
|
|
79
|
-
assert self.n_times is not None
|
|
80
|
-
assert first_filter_length % 2 == 1
|
|
81
|
-
self.final_pool_length = final_pool_length
|
|
82
|
-
self.n_first_filters = n_first_filters
|
|
83
|
-
self.n_layers_per_block = n_layers_per_block
|
|
84
|
-
self.first_filter_length = first_filter_length
|
|
85
|
-
self.nonlinearity = activation
|
|
86
|
-
self.split_first_layer = split_first_layer
|
|
87
|
-
self.batch_norm_alpha = batch_norm_alpha
|
|
88
|
-
self.batch_norm_epsilon = batch_norm_epsilon
|
|
89
|
-
self.conv_weight_init_fn = conv_weight_init_fn
|
|
90
|
-
|
|
91
|
-
self.mapping = {
|
|
92
|
-
"conv_classifier.weight": "final_layer.conv_classifier.weight",
|
|
93
|
-
"conv_classifier.bias": "final_layer.conv_classifier.bias",
|
|
94
|
-
}
|
|
95
|
-
|
|
96
|
-
self.add_module("ensuredims", Ensure4d())
|
|
97
|
-
if self.split_first_layer:
|
|
98
|
-
self.add_module("dimshuffle", Rearrange("batch C T 1 -> batch 1 T C"))
|
|
99
|
-
self.add_module(
|
|
100
|
-
"conv_time",
|
|
101
|
-
nn.Conv2d(
|
|
102
|
-
1,
|
|
103
|
-
self.n_first_filters,
|
|
104
|
-
(self.first_filter_length, 1),
|
|
105
|
-
stride=1,
|
|
106
|
-
padding=(self.first_filter_length // 2, 0),
|
|
107
|
-
),
|
|
108
|
-
)
|
|
109
|
-
self.add_module(
|
|
110
|
-
"conv_spat",
|
|
111
|
-
nn.Conv2d(
|
|
112
|
-
self.n_first_filters,
|
|
113
|
-
self.n_first_filters,
|
|
114
|
-
(1, self.n_chans),
|
|
115
|
-
stride=(1, 1),
|
|
116
|
-
bias=False,
|
|
117
|
-
),
|
|
118
|
-
)
|
|
119
|
-
else:
|
|
120
|
-
self.add_module(
|
|
121
|
-
"conv_time",
|
|
122
|
-
nn.Conv2d(
|
|
123
|
-
self.n_chans,
|
|
124
|
-
self.n_first_filters,
|
|
125
|
-
(self.first_filter_length, 1),
|
|
126
|
-
stride=(1, 1),
|
|
127
|
-
padding=(self.first_filter_length // 2, 0),
|
|
128
|
-
bias=False,
|
|
129
|
-
),
|
|
130
|
-
)
|
|
131
|
-
n_filters_conv = self.n_first_filters
|
|
132
|
-
self.add_module(
|
|
133
|
-
"bnorm",
|
|
134
|
-
nn.BatchNorm2d(
|
|
135
|
-
n_filters_conv, momentum=self.batch_norm_alpha, affine=True, eps=1e-5
|
|
136
|
-
),
|
|
137
|
-
)
|
|
138
|
-
self.add_module("conv_nonlin", self.nonlinearity())
|
|
139
|
-
cur_dilation = np.array([1, 1])
|
|
140
|
-
n_cur_filters = n_filters_conv
|
|
141
|
-
i_block = 1
|
|
142
|
-
for i_layer in range(self.n_layers_per_block):
|
|
143
|
-
self.add_module(
|
|
144
|
-
"res_{:d}_{:d}".format(i_block, i_layer),
|
|
145
|
-
_ResidualBlock(n_cur_filters, n_cur_filters, dilation=cur_dilation),
|
|
146
|
-
)
|
|
147
|
-
i_block += 1
|
|
148
|
-
cur_dilation[0] *= 2
|
|
149
|
-
n_out_filters = int(2 * n_cur_filters)
|
|
150
|
-
self.add_module(
|
|
151
|
-
"res_{:d}_{:d}".format(i_block, 0),
|
|
152
|
-
_ResidualBlock(
|
|
153
|
-
n_cur_filters,
|
|
154
|
-
n_out_filters,
|
|
155
|
-
dilation=cur_dilation,
|
|
156
|
-
),
|
|
157
|
-
)
|
|
158
|
-
n_cur_filters = n_out_filters
|
|
159
|
-
for i_layer in range(1, self.n_layers_per_block):
|
|
160
|
-
self.add_module(
|
|
161
|
-
"res_{:d}_{:d}".format(i_block, i_layer),
|
|
162
|
-
_ResidualBlock(n_cur_filters, n_cur_filters, dilation=cur_dilation),
|
|
163
|
-
)
|
|
164
|
-
|
|
165
|
-
i_block += 1
|
|
166
|
-
cur_dilation[0] *= 2
|
|
167
|
-
n_out_filters = int(1.5 * n_cur_filters)
|
|
168
|
-
self.add_module(
|
|
169
|
-
"res_{:d}_{:d}".format(i_block, 0),
|
|
170
|
-
_ResidualBlock(
|
|
171
|
-
n_cur_filters,
|
|
172
|
-
n_out_filters,
|
|
173
|
-
dilation=cur_dilation,
|
|
174
|
-
),
|
|
175
|
-
)
|
|
176
|
-
n_cur_filters = n_out_filters
|
|
177
|
-
for i_layer in range(1, self.n_layers_per_block):
|
|
178
|
-
self.add_module(
|
|
179
|
-
"res_{:d}_{:d}".format(i_block, i_layer),
|
|
180
|
-
_ResidualBlock(n_cur_filters, n_cur_filters, dilation=cur_dilation),
|
|
181
|
-
)
|
|
182
|
-
|
|
183
|
-
i_block += 1
|
|
184
|
-
cur_dilation[0] *= 2
|
|
185
|
-
self.add_module(
|
|
186
|
-
"res_{:d}_{:d}".format(i_block, 0),
|
|
187
|
-
_ResidualBlock(
|
|
188
|
-
n_cur_filters,
|
|
189
|
-
n_cur_filters,
|
|
190
|
-
dilation=cur_dilation,
|
|
191
|
-
),
|
|
192
|
-
)
|
|
193
|
-
for i_layer in range(1, self.n_layers_per_block):
|
|
194
|
-
self.add_module(
|
|
195
|
-
"res_{:d}_{:d}".format(i_block, i_layer),
|
|
196
|
-
_ResidualBlock(n_cur_filters, n_cur_filters, dilation=cur_dilation),
|
|
197
|
-
)
|
|
198
|
-
|
|
199
|
-
i_block += 1
|
|
200
|
-
cur_dilation[0] *= 2
|
|
201
|
-
self.add_module(
|
|
202
|
-
"res_{:d}_{:d}".format(i_block, 0),
|
|
203
|
-
_ResidualBlock(
|
|
204
|
-
n_cur_filters,
|
|
205
|
-
n_cur_filters,
|
|
206
|
-
dilation=cur_dilation,
|
|
207
|
-
),
|
|
208
|
-
)
|
|
209
|
-
for i_layer in range(1, self.n_layers_per_block):
|
|
210
|
-
self.add_module(
|
|
211
|
-
"res_{:d}_{:d}".format(i_block, i_layer),
|
|
212
|
-
_ResidualBlock(n_cur_filters, n_cur_filters, dilation=cur_dilation),
|
|
213
|
-
)
|
|
214
|
-
|
|
215
|
-
i_block += 1
|
|
216
|
-
cur_dilation[0] *= 2
|
|
217
|
-
self.add_module(
|
|
218
|
-
"res_{:d}_{:d}".format(i_block, 0),
|
|
219
|
-
_ResidualBlock(
|
|
220
|
-
n_cur_filters,
|
|
221
|
-
n_cur_filters,
|
|
222
|
-
dilation=cur_dilation,
|
|
223
|
-
),
|
|
224
|
-
)
|
|
225
|
-
for i_layer in range(1, self.n_layers_per_block):
|
|
226
|
-
self.add_module(
|
|
227
|
-
"res_{:d}_{:d}".format(i_block, i_layer),
|
|
228
|
-
_ResidualBlock(n_cur_filters, n_cur_filters, dilation=cur_dilation),
|
|
229
|
-
)
|
|
230
|
-
i_block += 1
|
|
231
|
-
cur_dilation[0] *= 2
|
|
232
|
-
self.add_module(
|
|
233
|
-
"res_{:d}_{:d}".format(i_block, 0),
|
|
234
|
-
_ResidualBlock(
|
|
235
|
-
n_cur_filters,
|
|
236
|
-
n_cur_filters,
|
|
237
|
-
dilation=cur_dilation,
|
|
238
|
-
),
|
|
239
|
-
)
|
|
240
|
-
for i_layer in range(1, self.n_layers_per_block):
|
|
241
|
-
self.add_module(
|
|
242
|
-
"res_{:d}_{:d}".format(i_block, i_layer),
|
|
243
|
-
_ResidualBlock(n_cur_filters, n_cur_filters, dilation=cur_dilation),
|
|
244
|
-
)
|
|
245
|
-
|
|
246
|
-
self.eval()
|
|
247
|
-
if self.final_pool_length == "auto":
|
|
248
|
-
self.add_module("mean_pool", nn.AdaptiveAvgPool2d((1, 1)))
|
|
249
|
-
else:
|
|
250
|
-
pool_dilation = int(cur_dilation[0]), int(cur_dilation[1])
|
|
251
|
-
self.add_module(
|
|
252
|
-
"mean_pool",
|
|
253
|
-
AvgPool2dWithConv(
|
|
254
|
-
(self.final_pool_length, 1), (1, 1), dilation=pool_dilation
|
|
255
|
-
),
|
|
256
|
-
)
|
|
257
|
-
|
|
258
|
-
# Incorporating classification module and subsequent ones in one final layer
|
|
259
|
-
module = nn.Sequential()
|
|
260
|
-
|
|
261
|
-
module.add_module(
|
|
262
|
-
"conv_classifier",
|
|
263
|
-
nn.Conv2d(
|
|
264
|
-
n_cur_filters,
|
|
265
|
-
self.n_outputs,
|
|
266
|
-
(1, 1),
|
|
267
|
-
bias=True,
|
|
268
|
-
),
|
|
269
|
-
)
|
|
270
|
-
|
|
271
|
-
module.add_module("squeeze", SqueezeFinalOutput())
|
|
272
|
-
|
|
273
|
-
self.add_module("final_layer", module)
|
|
274
|
-
|
|
275
|
-
# Initialize all weights
|
|
276
|
-
self.apply(lambda module: self._weights_init(module, self.conv_weight_init_fn))
|
|
277
|
-
|
|
278
|
-
# Start in train mode
|
|
279
|
-
self.train()
|
|
280
|
-
|
|
281
|
-
@staticmethod
|
|
282
|
-
def _weights_init(module, conv_weight_init_fn):
|
|
283
|
-
"""
|
|
284
|
-
initialize weights
|
|
285
|
-
"""
|
|
286
|
-
classname = module.__class__.__name__
|
|
287
|
-
if "Conv" in classname and classname != "AvgPool2dWithConv":
|
|
288
|
-
conv_weight_init_fn(module.weight)
|
|
289
|
-
if module.bias is not None:
|
|
290
|
-
init.constant_(module.bias, 0)
|
|
291
|
-
elif "BatchNorm" in classname:
|
|
292
|
-
init.constant_(module.weight, 1)
|
|
293
|
-
init.constant_(module.bias, 0)
|
|
294
|
-
|
|
295
|
-
|
|
296
|
-
class _ResidualBlock(nn.Module):
|
|
297
|
-
"""
|
|
298
|
-
create a residual learning building block with two stacked 3x3 convlayers as in paper
|
|
299
|
-
"""
|
|
300
|
-
|
|
301
|
-
def __init__(
|
|
302
|
-
self,
|
|
303
|
-
in_filters,
|
|
304
|
-
out_num_filters,
|
|
305
|
-
dilation,
|
|
306
|
-
filter_time_length=3,
|
|
307
|
-
nonlinearity: nn.Module = nn.ELU,
|
|
308
|
-
batch_norm_alpha=0.1,
|
|
309
|
-
batch_norm_epsilon=1e-4,
|
|
310
|
-
):
|
|
311
|
-
super(_ResidualBlock, self).__init__()
|
|
312
|
-
time_padding = int((filter_time_length - 1) * dilation[0])
|
|
313
|
-
assert time_padding % 2 == 0
|
|
314
|
-
time_padding = int(time_padding // 2)
|
|
315
|
-
dilation = (int(dilation[0]), int(dilation[1]))
|
|
316
|
-
assert (out_num_filters - in_filters) % 2 == 0, (
|
|
317
|
-
"Need even number of extra channels in order to be able to pad correctly"
|
|
318
|
-
)
|
|
319
|
-
self.n_pad_chans = out_num_filters - in_filters
|
|
320
|
-
|
|
321
|
-
self.conv_1 = nn.Conv2d(
|
|
322
|
-
in_filters,
|
|
323
|
-
out_num_filters,
|
|
324
|
-
(filter_time_length, 1),
|
|
325
|
-
stride=(1, 1),
|
|
326
|
-
dilation=dilation,
|
|
327
|
-
padding=(time_padding, 0),
|
|
328
|
-
)
|
|
329
|
-
self.bn1 = nn.BatchNorm2d(
|
|
330
|
-
out_num_filters,
|
|
331
|
-
momentum=batch_norm_alpha,
|
|
332
|
-
affine=True,
|
|
333
|
-
eps=batch_norm_epsilon,
|
|
334
|
-
)
|
|
335
|
-
self.conv_2 = nn.Conv2d(
|
|
336
|
-
out_num_filters,
|
|
337
|
-
out_num_filters,
|
|
338
|
-
(filter_time_length, 1),
|
|
339
|
-
stride=(1, 1),
|
|
340
|
-
dilation=dilation,
|
|
341
|
-
padding=(time_padding, 0),
|
|
342
|
-
)
|
|
343
|
-
self.bn2 = nn.BatchNorm2d(
|
|
344
|
-
out_num_filters,
|
|
345
|
-
momentum=batch_norm_alpha,
|
|
346
|
-
affine=True,
|
|
347
|
-
eps=batch_norm_epsilon,
|
|
348
|
-
)
|
|
349
|
-
# also see https://mail.google.com/mail/u/0/#search/ilya+joos/1576137dd34c3127
|
|
350
|
-
# for resnet options as ilya used them
|
|
351
|
-
self.nonlinearity = nonlinearity()
|
|
352
|
-
|
|
353
|
-
def forward(self, x):
|
|
354
|
-
stack_1 = self.nonlinearity(self.bn1(self.conv_1(x)))
|
|
355
|
-
stack_2 = self.bn2(self.conv_2(stack_1)) # next nonlin after sum
|
|
356
|
-
if self.n_pad_chans != 0:
|
|
357
|
-
zeros_for_padding = x.new_zeros(
|
|
358
|
-
(x.shape[0], self.n_pad_chans // 2, x.shape[2], x.shape[3])
|
|
359
|
-
)
|
|
360
|
-
x = torch.cat((zeros_for_padding, x, zeros_for_padding), dim=1)
|
|
361
|
-
out = self.nonlinearity(x + stack_2)
|
|
362
|
-
return out
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{braindecode-1.2.0.dev168908090.dist-info → braindecode-1.2.0.dev175267687.dist-info}/top_level.txt
RENAMED
|
File without changes
|