braindecode 1.3.0.dev175435903__tar.gz → 1.3.0.dev177628147__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {braindecode-1.3.0.dev175435903/braindecode.egg-info → braindecode-1.3.0.dev177628147}/PKG-INFO +1 -5
- {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/braindecode/datautil/serialization.py +1 -1
- {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/braindecode/models/atcnet.py +11 -11
- {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/braindecode/models/attentionbasenet.py +4 -4
- {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/braindecode/models/attn_sleep.py +7 -9
- {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/braindecode/models/bendr.py +1 -1
- {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/braindecode/models/biot.py +23 -25
- {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/braindecode/models/contrawr.py +2 -2
- {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/braindecode/models/ctnet.py +33 -33
- {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/braindecode/models/deep4.py +4 -4
- {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/braindecode/models/deepsleepnet.py +4 -4
- {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/braindecode/models/eegconformer.py +22 -27
- {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/braindecode/models/eeginception_erp.py +1 -1
- {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/braindecode/models/eeginception_mi.py +3 -3
- {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/braindecode/models/eegitnet.py +3 -3
- {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/braindecode/models/eegnet.py +1 -1
- {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/braindecode/models/eegnex.py +1 -1
- {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/braindecode/models/eegsimpleconv.py +1 -1
- {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/braindecode/models/eegtcnet.py +3 -3
- {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/braindecode/models/fbcnet.py +1 -1
- {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/braindecode/models/fblightconvnet.py +1 -1
- {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/braindecode/models/fbmsnet.py +1 -1
- {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/braindecode/models/hybrid.py +1 -1
- {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/braindecode/models/ifnet.py +2 -2
- {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/braindecode/models/labram.py +39 -39
- {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/braindecode/models/medformer.py +14 -14
- {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/braindecode/models/msvtnet.py +9 -9
- {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/braindecode/models/patchedtransformer.py +46 -46
- {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/braindecode/models/sccnet.py +1 -1
- {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/braindecode/models/shallow_fbcsp.py +2 -4
- {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/braindecode/models/sinc_shallow.py +2 -1
- {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/braindecode/models/sleep_stager_blanco_2020.py +1 -1
- {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/braindecode/models/sleep_stager_chambon_2018.py +1 -1
- {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/braindecode/models/sparcnet.py +4 -4
- {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/braindecode/models/sstdpn.py +1 -1
- {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/braindecode/models/summary.csv +1 -1
- {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/braindecode/models/syncnet.py +1 -1
- {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/braindecode/models/tcn.py +3 -3
- {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/braindecode/models/tidnet.py +1 -1
- {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/braindecode/models/tsinception.py +1 -1
- {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/braindecode/models/usleep.py +3 -3
- {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/braindecode/models/util.py +14 -154
- {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/braindecode/modules/blocks.py +1 -3
- braindecode-1.3.0.dev177628147/braindecode/version.py +1 -0
- {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147/braindecode.egg-info}/PKG-INFO +1 -5
- {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/braindecode.egg-info/SOURCES.txt +0 -1
- {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/braindecode.egg-info/requires.txt +0 -5
- {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/docs/api.rst +0 -4
- {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/docs/whats_new.rst +0 -5
- {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/pyproject.toml +0 -5
- braindecode-1.3.0.dev175435903/braindecode/models/config.py +0 -230
- braindecode-1.3.0.dev175435903/braindecode/version.py +0 -1
- {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/LICENSE.txt +0 -0
- {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/MANIFEST.in +0 -0
- {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/NOTICE.txt +0 -0
- {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/README.rst +0 -0
- {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/braindecode/__init__.py +0 -0
- {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/braindecode/augmentation/__init__.py +0 -0
- {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/braindecode/augmentation/base.py +0 -0
- {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/braindecode/augmentation/functional.py +0 -0
- {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/braindecode/augmentation/transforms.py +0 -0
- {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/braindecode/classifier.py +0 -0
- {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/braindecode/datasets/__init__.py +0 -0
- {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/braindecode/datasets/base.py +0 -0
- {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/braindecode/datasets/bbci.py +0 -0
- {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/braindecode/datasets/bcicomp.py +0 -0
- {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/braindecode/datasets/bids.py +0 -0
- {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/braindecode/datasets/experimental.py +0 -0
- {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/braindecode/datasets/hub.py +0 -0
- {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/braindecode/datasets/hub_validation.py +0 -0
- {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/braindecode/datasets/mne.py +0 -0
- {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/braindecode/datasets/moabb.py +0 -0
- {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/braindecode/datasets/nmt.py +0 -0
- {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/braindecode/datasets/registry.py +0 -0
- {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/braindecode/datasets/sleep_physio_challe_18.py +0 -0
- {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/braindecode/datasets/sleep_physionet.py +0 -0
- {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/braindecode/datasets/tuh.py +0 -0
- {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/braindecode/datasets/xy.py +0 -0
- {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/braindecode/datautil/__init__.py +0 -0
- {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/braindecode/datautil/channel_utils.py +0 -0
- {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/braindecode/datautil/hub_formats.py +0 -0
- {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/braindecode/datautil/util.py +0 -0
- {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/braindecode/eegneuralnet.py +0 -0
- {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/braindecode/functional/__init__.py +0 -0
- {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/braindecode/functional/functions.py +0 -0
- {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/braindecode/functional/initialization.py +0 -0
- {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/braindecode/models/__init__.py +0 -0
- {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/braindecode/models/base.py +0 -0
- {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/braindecode/models/eegminer.py +0 -0
- {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/braindecode/models/eegsym.py +0 -0
- {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/braindecode/models/luna.py +0 -0
- {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/braindecode/models/signal_jepa.py +0 -0
- {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/braindecode/modules/__init__.py +0 -0
- {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/braindecode/modules/activation.py +0 -0
- {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/braindecode/modules/attention.py +0 -0
- {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/braindecode/modules/convolution.py +0 -0
- {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/braindecode/modules/filter.py +0 -0
- {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/braindecode/modules/layers.py +0 -0
- {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/braindecode/modules/linear.py +0 -0
- {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/braindecode/modules/parametrization.py +0 -0
- {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/braindecode/modules/stats.py +0 -0
- {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/braindecode/modules/util.py +0 -0
- {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/braindecode/modules/wrapper.py +0 -0
- {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/braindecode/preprocessing/__init__.py +0 -0
- {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/braindecode/preprocessing/eegprep_preprocess.py +0 -0
- {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/braindecode/preprocessing/mne_preprocess.py +0 -0
- {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/braindecode/preprocessing/preprocess.py +0 -0
- {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/braindecode/preprocessing/util.py +0 -0
- {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/braindecode/preprocessing/windowers.py +0 -0
- {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/braindecode/regressor.py +0 -0
- {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/braindecode/samplers/__init__.py +0 -0
- {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/braindecode/samplers/base.py +0 -0
- {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/braindecode/samplers/ssl.py +0 -0
- {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/braindecode/training/__init__.py +0 -0
- {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/braindecode/training/callbacks.py +0 -0
- {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/braindecode/training/losses.py +0 -0
- {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/braindecode/training/scoring.py +0 -0
- {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/braindecode/util.py +0 -0
- {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/braindecode/visualization/__init__.py +0 -0
- {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/braindecode/visualization/confusion_matrices.py +0 -0
- {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/braindecode/visualization/gradients.py +0 -0
- {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/braindecode.egg-info/dependency_links.txt +0 -0
- {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/braindecode.egg-info/top_level.txt +0 -0
- {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/docs/Makefile +0 -0
- {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/docs/_templates/autosummary/class.rst +0 -0
- {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/docs/_templates/autosummary/function.rst +0 -0
- {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/docs/cite.rst +0 -0
- {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/docs/conf.py +0 -0
- {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/docs/help.rst +0 -0
- {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/docs/index.rst +0 -0
- {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/docs/install/install.rst +0 -0
- {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/docs/install/install_pip.rst +0 -0
- {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/docs/install/install_source.rst +0 -0
- {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/docs/models/categorization/attention.rst +0 -0
- {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/docs/models/categorization/channel.rst +0 -0
- {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/docs/models/categorization/convolution.rst +0 -0
- {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/docs/models/categorization/filterbank.rst +0 -0
- {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/docs/models/categorization/gnn.rst +0 -0
- {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/docs/models/categorization/interpretable.rst +0 -0
- {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/docs/models/categorization/lbm.rst +0 -0
- {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/docs/models/categorization/recurrent.rst +0 -0
- {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/docs/models/categorization/spd.rst +0 -0
- {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/docs/models/models.rst +0 -0
- {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/docs/models/models_categorization.rst +0 -0
- {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/docs/models/models_table.rst +0 -0
- {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/docs/models/models_visualization.rst +0 -0
- {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/docs/sg_execution_times.rst +0 -0
- {braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/setup.cfg +0 -0
{braindecode-1.3.0.dev175435903/braindecode.egg-info → braindecode-1.3.0.dev177628147}/PKG-INFO
RENAMED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: braindecode
|
|
3
|
-
Version: 1.3.0.
|
|
3
|
+
Version: 1.3.0.dev177628147
|
|
4
4
|
Summary: Deep learning software to decode EEG, ECG or MEG signals
|
|
5
5
|
Author-email: Robin Tibor Schirrmeister <robintibor@gmail.com>, Bruno Aristimunha Pinto <b.aristimunha@gmail.com>, Alexandre Gramfort <agramfort@meta.com>
|
|
6
6
|
Maintainer-email: Alexandre Gramfort <agramfort@meta.com>, Bruno Aristimunha Pinto <b.aristimunha@gmail.com>, Robin Tibor Schirrmeister <robintibor@gmail.com>
|
|
@@ -52,9 +52,6 @@ Requires-Dist: pytest-cov; extra == "tests"
|
|
|
52
52
|
Requires-Dist: codecov; extra == "tests"
|
|
53
53
|
Requires-Dist: pytest_cases; extra == "tests"
|
|
54
54
|
Requires-Dist: mypy; extra == "tests"
|
|
55
|
-
Provides-Extra: typing
|
|
56
|
-
Requires-Dist: pydantic<3.0,>=2.0; extra == "typing"
|
|
57
|
-
Requires-Dist: numpydantic>=1.7; extra == "typing"
|
|
58
55
|
Provides-Extra: docs
|
|
59
56
|
Requires-Dist: sphinx_gallery; extra == "docs"
|
|
60
57
|
Requires-Dist: sphinx_rtd_theme; extra == "docs"
|
|
@@ -79,7 +76,6 @@ Requires-Dist: braindecode[hub]; extra == "all"
|
|
|
79
76
|
Requires-Dist: braindecode[tests]; extra == "all"
|
|
80
77
|
Requires-Dist: braindecode[docs]; extra == "all"
|
|
81
78
|
Requires-Dist: braindecode[eegprep]; extra == "all"
|
|
82
|
-
Requires-Dist: braindecode[typing]; extra == "all"
|
|
83
79
|
Dynamic: license-file
|
|
84
80
|
|
|
85
81
|
.. image:: https://badges.gitter.im/braindecodechat/community.svg
|
|
@@ -138,7 +138,7 @@ def _load_signals(fif_file, preload, is_raw):
|
|
|
138
138
|
with open(pkl_file, "rb") as f:
|
|
139
139
|
signals = pickle.load(f)
|
|
140
140
|
|
|
141
|
-
if all(
|
|
141
|
+
if all(f.exists() for f in signals.filenames):
|
|
142
142
|
if preload:
|
|
143
143
|
signals.load_data()
|
|
144
144
|
return signals
|
{braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/braindecode/models/atcnet.py
RENAMED
|
@@ -141,7 +141,7 @@ class ATCNet(EEGModuleMixin, nn.Module):
|
|
|
141
141
|
- Pool sizes ``P1,P2`` trade temporal resolution for stability/compute; they set
|
|
142
142
|
``T_c = T/(P1·P2)`` and thus window width ``T_w``.
|
|
143
143
|
- ``n_windows`` controls the ensemble over shifts (compute ∝ windows).
|
|
144
|
-
- ``
|
|
144
|
+
- ``att_num_heads``, ``att_head_dim`` set attention capacity; keep ``H·d_h ≈ F2``.
|
|
145
145
|
- ``tcn_depth``, ``tcn_kernel_size`` govern receptive field; larger values demand
|
|
146
146
|
longer inputs (see minimum length above). The implementation warns and *rescales*
|
|
147
147
|
kernels/pools/windows if inputs are too short.
|
|
@@ -194,10 +194,10 @@ class ATCNet(EEGModuleMixin, nn.Module):
|
|
|
194
194
|
table 1 of the paper [1]_. Defaults to 0.3 as in [1]_.
|
|
195
195
|
n_windows : int
|
|
196
196
|
Number of sliding windows, denoted n in [1]_. Defaults to 5 as in [1]_.
|
|
197
|
-
|
|
197
|
+
att_head_dim : int
|
|
198
198
|
Embedding dimension used in each self-attention head, denoted dh in
|
|
199
199
|
table 1 of the paper [1]_. Defaults to 8 as in [1]_.
|
|
200
|
-
|
|
200
|
+
att_num_heads : int
|
|
201
201
|
Number of attention heads, denoted H in table 1 of the paper [1]_.
|
|
202
202
|
Defaults to 2 as in [1]_.
|
|
203
203
|
att_dropout : float
|
|
@@ -248,13 +248,13 @@ class ATCNet(EEGModuleMixin, nn.Module):
|
|
|
248
248
|
conv_block_depth_mult=2,
|
|
249
249
|
conv_block_dropout=0.3,
|
|
250
250
|
n_windows=5,
|
|
251
|
-
|
|
252
|
-
|
|
251
|
+
att_head_dim=8,
|
|
252
|
+
att_num_heads=2,
|
|
253
253
|
att_drop_prob=0.5,
|
|
254
254
|
tcn_depth=2,
|
|
255
255
|
tcn_kernel_size=4,
|
|
256
256
|
tcn_drop_prob=0.3,
|
|
257
|
-
tcn_activation:
|
|
257
|
+
tcn_activation: nn.Module = nn.ELU,
|
|
258
258
|
concat=False,
|
|
259
259
|
max_norm_const=0.25,
|
|
260
260
|
chs_info=None,
|
|
@@ -316,8 +316,8 @@ class ATCNet(EEGModuleMixin, nn.Module):
|
|
|
316
316
|
self.conv_block_depth_mult = conv_block_depth_mult
|
|
317
317
|
self.conv_block_dropout = conv_block_dropout
|
|
318
318
|
self.n_windows = n_windows
|
|
319
|
-
self.
|
|
320
|
-
self.
|
|
319
|
+
self.att_head_dim = att_head_dim
|
|
320
|
+
self.att_num_heads = att_num_heads
|
|
321
321
|
self.att_dropout = att_drop_prob
|
|
322
322
|
self.tcn_depth = tcn_depth
|
|
323
323
|
self.tcn_kernel_size = tcn_kernel_size
|
|
@@ -356,8 +356,8 @@ class ATCNet(EEGModuleMixin, nn.Module):
|
|
|
356
356
|
[
|
|
357
357
|
_AttentionBlock(
|
|
358
358
|
in_shape=self.F2,
|
|
359
|
-
head_dim=self.
|
|
360
|
-
num_heads=
|
|
359
|
+
head_dim=self.att_head_dim,
|
|
360
|
+
num_heads=att_num_heads,
|
|
361
361
|
dropout=att_drop_prob,
|
|
362
362
|
)
|
|
363
363
|
for _ in range(self.n_windows)
|
|
@@ -656,7 +656,7 @@ class _TCNResidualBlock(nn.Module):
|
|
|
656
656
|
kernel_size=4,
|
|
657
657
|
n_filters=32,
|
|
658
658
|
dropout=0.3,
|
|
659
|
-
activation:
|
|
659
|
+
activation: nn.Module = nn.ELU,
|
|
660
660
|
dilation=1,
|
|
661
661
|
):
|
|
662
662
|
super().__init__()
|
|
@@ -235,7 +235,7 @@ class AttentionBaseNet(EEGModuleMixin, nn.Module):
|
|
|
235
235
|
kernel_size : int, default=9
|
|
236
236
|
The kernel size used in certain types of attention mechanisms for convolution
|
|
237
237
|
operations.
|
|
238
|
-
activation:
|
|
238
|
+
activation: nn.Module, default=nn.ELU
|
|
239
239
|
Activation function class to apply. Should be a PyTorch activation
|
|
240
240
|
module class like ``nn.ReLU`` or ``nn.ELU``. Default is ``nn.ELU``.
|
|
241
241
|
extra_params : bool, default=False
|
|
@@ -277,7 +277,7 @@ class AttentionBaseNet(EEGModuleMixin, nn.Module):
|
|
|
277
277
|
freq_idx: int = 0,
|
|
278
278
|
n_codewords: int = 4,
|
|
279
279
|
kernel_size: int = 9,
|
|
280
|
-
activation:
|
|
280
|
+
activation: nn.Module = nn.ELU,
|
|
281
281
|
extra_params: bool = False,
|
|
282
282
|
):
|
|
283
283
|
super(AttentionBaseNet, self).__init__()
|
|
@@ -453,7 +453,7 @@ class _FeatureExtractor(nn.Module):
|
|
|
453
453
|
pool_length: int = 75,
|
|
454
454
|
pool_stride: int = 15,
|
|
455
455
|
drop_prob: float = 0.5,
|
|
456
|
-
activation:
|
|
456
|
+
activation: nn.Module = nn.ELU,
|
|
457
457
|
):
|
|
458
458
|
super().__init__()
|
|
459
459
|
|
|
@@ -592,7 +592,7 @@ class _ChannelAttentionBlock(nn.Module):
|
|
|
592
592
|
n_codewords: int = 4,
|
|
593
593
|
kernel_size: int = 9,
|
|
594
594
|
extra_params: bool = False,
|
|
595
|
-
activation:
|
|
595
|
+
activation: nn.Module = nn.ELU,
|
|
596
596
|
):
|
|
597
597
|
super().__init__()
|
|
598
598
|
self.conv = nn.Sequential(
|
{braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/braindecode/models/attn_sleep.py
RENAMED
|
@@ -90,8 +90,8 @@ class AttnSleep(EEGModuleMixin, nn.Module):
|
|
|
90
90
|
d_ff=120,
|
|
91
91
|
n_attn_heads=5,
|
|
92
92
|
drop_prob=0.1,
|
|
93
|
-
activation_mrcnn:
|
|
94
|
-
activation:
|
|
93
|
+
activation_mrcnn: nn.Module = nn.GELU,
|
|
94
|
+
activation: nn.Module = nn.ReLU,
|
|
95
95
|
input_window_seconds=None,
|
|
96
96
|
n_outputs=None,
|
|
97
97
|
after_reduced_cnn_size=30,
|
|
@@ -230,7 +230,7 @@ class _SEBasicBlock(nn.Module):
|
|
|
230
230
|
planes,
|
|
231
231
|
stride=1,
|
|
232
232
|
downsample=None,
|
|
233
|
-
activation:
|
|
233
|
+
activation: nn.Module = nn.ReLU,
|
|
234
234
|
*,
|
|
235
235
|
reduction=16,
|
|
236
236
|
):
|
|
@@ -278,8 +278,8 @@ class _MRCNN(nn.Module):
|
|
|
278
278
|
self,
|
|
279
279
|
after_reduced_cnn_size,
|
|
280
280
|
kernel_size=7,
|
|
281
|
-
activation:
|
|
282
|
-
activation_se:
|
|
281
|
+
activation: nn.Module = nn.GELU,
|
|
282
|
+
activation_se: nn.Module = nn.ReLU,
|
|
283
283
|
):
|
|
284
284
|
super(_MRCNN, self).__init__()
|
|
285
285
|
drate = 0.5
|
|
@@ -325,7 +325,7 @@ class _MRCNN(nn.Module):
|
|
|
325
325
|
)
|
|
326
326
|
|
|
327
327
|
def _make_layer(
|
|
328
|
-
self, block, planes, blocks, stride=1, activate:
|
|
328
|
+
self, block, planes, blocks, stride=1, activate: nn.Module = nn.ReLU
|
|
329
329
|
): # makes residual SE block
|
|
330
330
|
downsample = None
|
|
331
331
|
if stride != 1 or self.inplanes != planes * block.expansion:
|
|
@@ -526,9 +526,7 @@ class _EncoderLayer(nn.Module):
|
|
|
526
526
|
class _PositionwiseFeedForward(nn.Module):
|
|
527
527
|
"""Positionwise feed-forward network."""
|
|
528
528
|
|
|
529
|
-
def __init__(
|
|
530
|
-
self, d_model, d_ff, dropout=0.1, activation: type[nn.Module] = nn.ReLU
|
|
531
|
-
):
|
|
529
|
+
def __init__(self, d_model, d_ff, dropout=0.1, activation: nn.Module = nn.ReLU):
|
|
532
530
|
super().__init__()
|
|
533
531
|
self.w_1 = nn.Linear(d_model, d_ff)
|
|
534
532
|
self.w_2 = nn.Linear(d_ff, d_model)
|
{braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/braindecode/models/bendr.py
RENAMED
|
@@ -176,7 +176,7 @@ class BENDR(EEGModuleMixin, nn.Module):
|
|
|
176
176
|
projection_head=False, # Whether encoder should project back to input feature size (unused in original fine-tuning)
|
|
177
177
|
drop_prob=0.1, # General dropout probability (paper: 0.15 for pretraining, 0.0 for fine-tuning)
|
|
178
178
|
layer_drop=0.0, # Probability of dropping transformer layers during training (paper: 0.01 for pretraining)
|
|
179
|
-
activation
|
|
179
|
+
activation=nn.GELU, # Activation function
|
|
180
180
|
# Transformer specific parameters
|
|
181
181
|
transformer_layers=8,
|
|
182
182
|
transformer_heads=8,
|
{braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/braindecode/models/biot.py
RENAMED
|
@@ -45,11 +45,11 @@ class BIOT(EEGModuleMixin, nn.Module):
|
|
|
45
45
|
|
|
46
46
|
Parameters
|
|
47
47
|
----------
|
|
48
|
-
|
|
48
|
+
emb_size : int, optional
|
|
49
49
|
The size of the embedding layer, by default 256
|
|
50
|
-
|
|
50
|
+
att_num_heads : int, optional
|
|
51
51
|
The number of attention heads, by default 8
|
|
52
|
-
|
|
52
|
+
n_layers : int, optional
|
|
53
53
|
The number of transformer layers, by default 4
|
|
54
54
|
activation: nn.Module, default=nn.ELU
|
|
55
55
|
Activation function class to apply. Should be a PyTorch activation
|
|
@@ -76,9 +76,9 @@ class BIOT(EEGModuleMixin, nn.Module):
|
|
|
76
76
|
|
|
77
77
|
def __init__(
|
|
78
78
|
self,
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
79
|
+
emb_size=256,
|
|
80
|
+
att_num_heads=8,
|
|
81
|
+
n_layers=4,
|
|
82
82
|
sfreq=200,
|
|
83
83
|
hop_length=100,
|
|
84
84
|
return_feature=False,
|
|
@@ -87,12 +87,12 @@ class BIOT(EEGModuleMixin, nn.Module):
|
|
|
87
87
|
chs_info=None,
|
|
88
88
|
n_times=None,
|
|
89
89
|
input_window_seconds=None,
|
|
90
|
-
activation:
|
|
90
|
+
activation: nn.Module = nn.ELU,
|
|
91
91
|
drop_prob: float = 0.5,
|
|
92
92
|
# Parameters for the encoder
|
|
93
93
|
max_seq_len: int = 1024,
|
|
94
|
-
|
|
95
|
-
|
|
94
|
+
attn_dropout=0.2,
|
|
95
|
+
attn_layer_dropout=0.2,
|
|
96
96
|
):
|
|
97
97
|
super().__init__(
|
|
98
98
|
n_outputs=n_outputs,
|
|
@@ -103,10 +103,10 @@ class BIOT(EEGModuleMixin, nn.Module):
|
|
|
103
103
|
sfreq=sfreq,
|
|
104
104
|
)
|
|
105
105
|
del n_outputs, n_chans, chs_info, n_times, sfreq
|
|
106
|
-
self.
|
|
106
|
+
self.emb_size = emb_size
|
|
107
107
|
self.hop_length = hop_length
|
|
108
|
-
self.
|
|
109
|
-
self.
|
|
108
|
+
self.att_num_heads = att_num_heads
|
|
109
|
+
self.n_layers = n_layers
|
|
110
110
|
self.return_feature = return_feature
|
|
111
111
|
if (self.sfreq != 200) & (self.sfreq is not None):
|
|
112
112
|
warn(
|
|
@@ -114,7 +114,7 @@ class BIOT(EEGModuleMixin, nn.Module):
|
|
|
114
114
|
+ "no guarantee to generalize well with the default parameters",
|
|
115
115
|
UserWarning,
|
|
116
116
|
)
|
|
117
|
-
if self.n_chans >
|
|
117
|
+
if self.n_chans > emb_size:
|
|
118
118
|
warn(
|
|
119
119
|
"The number of channels is larger than the embedding size. "
|
|
120
120
|
+ "This may cause overfitting. Consider using a larger "
|
|
@@ -142,20 +142,20 @@ class BIOT(EEGModuleMixin, nn.Module):
|
|
|
142
142
|
self.n_fft = int(self.sfreq)
|
|
143
143
|
|
|
144
144
|
self.encoder = _BIOTEncoder(
|
|
145
|
-
emb_size=
|
|
146
|
-
|
|
147
|
-
n_layers=
|
|
145
|
+
emb_size=emb_size,
|
|
146
|
+
att_num_heads=att_num_heads,
|
|
147
|
+
n_layers=n_layers,
|
|
148
148
|
n_chans=self.n_chans,
|
|
149
149
|
n_fft=self.n_fft,
|
|
150
150
|
hop_length=hop_length,
|
|
151
151
|
drop_prob=drop_prob,
|
|
152
152
|
max_seq_len=max_seq_len,
|
|
153
|
-
attn_dropout=
|
|
154
|
-
attn_layer_dropout=
|
|
153
|
+
attn_dropout=attn_dropout,
|
|
154
|
+
attn_layer_dropout=attn_layer_dropout,
|
|
155
155
|
)
|
|
156
156
|
|
|
157
157
|
self.final_layer = _ClassificationHead(
|
|
158
|
-
emb_size=
|
|
158
|
+
emb_size=emb_size,
|
|
159
159
|
n_outputs=self.n_outputs,
|
|
160
160
|
activation=activation,
|
|
161
161
|
)
|
|
@@ -250,9 +250,7 @@ class _ClassificationHead(nn.Sequential):
|
|
|
250
250
|
(batch, n_outputs)
|
|
251
251
|
"""
|
|
252
252
|
|
|
253
|
-
def __init__(
|
|
254
|
-
self, emb_size: int, n_outputs: int, activation: type[nn.Module] = nn.ELU
|
|
255
|
-
):
|
|
253
|
+
def __init__(self, emb_size: int, n_outputs: int, activation: nn.Module = nn.ELU):
|
|
256
254
|
super().__init__()
|
|
257
255
|
self.activation_layer = activation()
|
|
258
256
|
self.classification_head = nn.Linear(emb_size, n_outputs)
|
|
@@ -347,7 +345,7 @@ class _BIOTEncoder(nn.Module):
|
|
|
347
345
|
The number of channels
|
|
348
346
|
emb_size: int
|
|
349
347
|
The size of the embedding layer
|
|
350
|
-
|
|
348
|
+
att_num_heads: int
|
|
351
349
|
The number of attention heads
|
|
352
350
|
n_layers: int
|
|
353
351
|
The number of transformer layers
|
|
@@ -360,7 +358,7 @@ class _BIOTEncoder(nn.Module):
|
|
|
360
358
|
def __init__(
|
|
361
359
|
self,
|
|
362
360
|
emb_size=256, # The size of the embedding layer
|
|
363
|
-
|
|
361
|
+
att_num_heads=8, # The number of attention heads
|
|
364
362
|
n_chans=16, # The number of channels
|
|
365
363
|
n_layers=4, # The number of transformer layers
|
|
366
364
|
n_fft=200, # Related with the frequency resolution
|
|
@@ -380,7 +378,7 @@ class _BIOTEncoder(nn.Module):
|
|
|
380
378
|
)
|
|
381
379
|
self.transformer = LinearAttentionTransformer(
|
|
382
380
|
dim=emb_size,
|
|
383
|
-
heads=
|
|
381
|
+
heads=att_num_heads,
|
|
384
382
|
depth=n_layers,
|
|
385
383
|
max_seq_len=max_seq_len,
|
|
386
384
|
attn_layer_dropout=attn_layer_dropout,
|
{braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/braindecode/models/contrawr.py
RENAMED
|
@@ -58,7 +58,7 @@ class ContraWR(EEGModuleMixin, nn.Module):
|
|
|
58
58
|
emb_size: int = 256,
|
|
59
59
|
res_channels: list[int] = [32, 64, 128],
|
|
60
60
|
steps=20,
|
|
61
|
-
activation:
|
|
61
|
+
activation: nn.Module = nn.ELU,
|
|
62
62
|
drop_prob: float = 0.5,
|
|
63
63
|
stride_res: int = 2,
|
|
64
64
|
kernel_size_res: int = 3,
|
|
@@ -195,7 +195,7 @@ class _ResBlock(nn.Module):
|
|
|
195
195
|
kernel_size=3,
|
|
196
196
|
padding=1,
|
|
197
197
|
drop_prob=0.5,
|
|
198
|
-
activation:
|
|
198
|
+
activation: nn.Module = nn.ReLU,
|
|
199
199
|
):
|
|
200
200
|
super().__init__()
|
|
201
201
|
self.conv1 = nn.Conv2d(
|
{braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/braindecode/models/ctnet.py
RENAMED
|
@@ -61,11 +61,11 @@ class CTNet(EEGModuleMixin, nn.Module):
|
|
|
61
61
|
----------
|
|
62
62
|
activation : nn.Module, default=nn.GELU
|
|
63
63
|
Activation function to use in the network.
|
|
64
|
-
|
|
64
|
+
heads : int, default=4
|
|
65
65
|
Number of attention heads in the Transformer encoder.
|
|
66
|
-
|
|
66
|
+
emb_size : int or None, default=None
|
|
67
67
|
Embedding size (dimensionality) for the Transformer encoder.
|
|
68
|
-
|
|
68
|
+
depth : int, default=6
|
|
69
69
|
Number of encoder layers in the Transformer.
|
|
70
70
|
n_filters_time : int, default=20
|
|
71
71
|
Number of temporal filters in the first convolutional layer.
|
|
@@ -77,11 +77,11 @@ class CTNet(EEGModuleMixin, nn.Module):
|
|
|
77
77
|
Pooling size for the first average pooling layer.
|
|
78
78
|
pool_size_2 : int, default=8
|
|
79
79
|
Pooling size for the second average pooling layer.
|
|
80
|
-
|
|
80
|
+
drop_prob_cnn : float, default=0.3
|
|
81
81
|
Dropout probability after convolutional layers.
|
|
82
|
-
|
|
82
|
+
drop_prob_posi : float, default=0.1
|
|
83
83
|
Dropout probability for the positional encoding in the Transformer.
|
|
84
|
-
|
|
84
|
+
drop_prob_final : float, default=0.5
|
|
85
85
|
Dropout probability before the final classification layer.
|
|
86
86
|
|
|
87
87
|
Notes
|
|
@@ -109,15 +109,15 @@ class CTNet(EEGModuleMixin, nn.Module):
|
|
|
109
109
|
n_times=None,
|
|
110
110
|
input_window_seconds=None,
|
|
111
111
|
# Model specific arguments
|
|
112
|
-
activation_patch:
|
|
113
|
-
activation_transformer:
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
112
|
+
activation_patch: nn.Module = nn.ELU,
|
|
113
|
+
activation_transformer: nn.Module = nn.GELU,
|
|
114
|
+
drop_prob_cnn: float = 0.3,
|
|
115
|
+
drop_prob_posi: float = 0.1,
|
|
116
|
+
drop_prob_final: float = 0.5,
|
|
117
117
|
# other parameters
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
118
|
+
heads: int = 4,
|
|
119
|
+
emb_size: Optional[int] = 40,
|
|
120
|
+
depth: int = 6,
|
|
121
121
|
n_filters_time: Optional[int] = None,
|
|
122
122
|
kernel_size: int = 64,
|
|
123
123
|
depth_multiplier: Optional[int] = 2,
|
|
@@ -136,14 +136,14 @@ class CTNet(EEGModuleMixin, nn.Module):
|
|
|
136
136
|
|
|
137
137
|
self.activation_patch = activation_patch
|
|
138
138
|
self.activation_transformer = activation_transformer
|
|
139
|
-
self.
|
|
139
|
+
self.drop_prob_cnn = drop_prob_cnn
|
|
140
140
|
self.pool_size_1 = pool_size_1
|
|
141
141
|
self.pool_size_2 = pool_size_2
|
|
142
142
|
self.kernel_size = kernel_size
|
|
143
|
-
self.
|
|
144
|
-
self.
|
|
145
|
-
self.
|
|
146
|
-
self.
|
|
143
|
+
self.drop_prob_posi = drop_prob_posi
|
|
144
|
+
self.drop_prob_final = drop_prob_final
|
|
145
|
+
self.heads = heads
|
|
146
|
+
self.depth = depth
|
|
147
147
|
# n_times - pool_size_1 / p
|
|
148
148
|
self.sequence_length = math.floor(
|
|
149
149
|
(
|
|
@@ -154,8 +154,8 @@ class CTNet(EEGModuleMixin, nn.Module):
|
|
|
154
154
|
+ 1
|
|
155
155
|
)
|
|
156
156
|
|
|
157
|
-
self.depth_multiplier, self.n_filters_time, self.
|
|
158
|
-
depth_multiplier, n_filters_time,
|
|
157
|
+
self.depth_multiplier, self.n_filters_time, self.emb_size = self._resolve_dims(
|
|
158
|
+
depth_multiplier, n_filters_time, emb_size
|
|
159
159
|
)
|
|
160
160
|
|
|
161
161
|
# Layers
|
|
@@ -168,32 +168,32 @@ class CTNet(EEGModuleMixin, nn.Module):
|
|
|
168
168
|
depth_multiplier=self.depth_multiplier,
|
|
169
169
|
pool_size_1=self.pool_size_1,
|
|
170
170
|
pool_size_2=self.pool_size_2,
|
|
171
|
-
drop_prob=self.
|
|
171
|
+
drop_prob=self.drop_prob_cnn,
|
|
172
172
|
n_chans=self.n_chans,
|
|
173
173
|
activation=self.activation_patch,
|
|
174
174
|
)
|
|
175
175
|
|
|
176
176
|
self.position = _PositionalEncoding(
|
|
177
|
-
emb_size=self.
|
|
178
|
-
drop_prob=self.
|
|
177
|
+
emb_size=self.emb_size,
|
|
178
|
+
drop_prob=self.drop_prob_posi,
|
|
179
179
|
n_times=self.n_times,
|
|
180
180
|
pool_size=self.pool_size_1,
|
|
181
181
|
)
|
|
182
182
|
|
|
183
183
|
self.trans = _TransformerEncoder(
|
|
184
|
-
self.
|
|
185
|
-
self.
|
|
186
|
-
self.
|
|
184
|
+
self.heads,
|
|
185
|
+
self.depth,
|
|
186
|
+
self.emb_size,
|
|
187
187
|
activation=self.activation_transformer,
|
|
188
188
|
)
|
|
189
189
|
|
|
190
190
|
self.flatten_drop_layer = nn.Sequential(
|
|
191
191
|
nn.Flatten(),
|
|
192
|
-
nn.Dropout(p=self.
|
|
192
|
+
nn.Dropout(p=self.drop_prob_final),
|
|
193
193
|
)
|
|
194
194
|
|
|
195
195
|
self.final_layer = nn.Linear(
|
|
196
|
-
in_features=int(self.
|
|
196
|
+
in_features=int(self.emb_size * self.sequence_length),
|
|
197
197
|
out_features=self.n_outputs,
|
|
198
198
|
)
|
|
199
199
|
|
|
@@ -213,7 +213,7 @@ class CTNet(EEGModuleMixin, nn.Module):
|
|
|
213
213
|
"""
|
|
214
214
|
x = self.ensuredim(x)
|
|
215
215
|
cnn = self.cnn(x)
|
|
216
|
-
cnn = cnn * math.sqrt(self.
|
|
216
|
+
cnn = cnn * math.sqrt(self.emb_size)
|
|
217
217
|
cnn = self.position(cnn)
|
|
218
218
|
trans = self.trans(cnn)
|
|
219
219
|
features = cnn + trans
|
|
@@ -312,7 +312,7 @@ class _PatchEmbeddingEEGNet(nn.Module):
|
|
|
312
312
|
pool_size_2: int = 8,
|
|
313
313
|
drop_prob: float = 0.3,
|
|
314
314
|
n_chans: int = 22,
|
|
315
|
-
activation:
|
|
315
|
+
activation: nn.Module = nn.ELU,
|
|
316
316
|
):
|
|
317
317
|
super().__init__()
|
|
318
318
|
n_filters_out = depth_multiplier * n_filters_time
|
|
@@ -416,7 +416,7 @@ class _TransformerEncoderBlock(nn.Module):
|
|
|
416
416
|
drop_prob: float = 0.5,
|
|
417
417
|
forward_expansion: int = 4,
|
|
418
418
|
forward_drop_p: float = 0.5,
|
|
419
|
-
activation:
|
|
419
|
+
activation: nn.Module = nn.GELU,
|
|
420
420
|
):
|
|
421
421
|
super().__init__()
|
|
422
422
|
self.attention = _ResidualAdd(
|
|
@@ -466,7 +466,7 @@ class _TransformerEncoder(nn.Module):
|
|
|
466
466
|
nheads: int,
|
|
467
467
|
depth: int,
|
|
468
468
|
dim_feedforward: int,
|
|
469
|
-
activation:
|
|
469
|
+
activation: nn.Module = nn.GELU,
|
|
470
470
|
):
|
|
471
471
|
super().__init__()
|
|
472
472
|
self.layers = nn.Sequential(
|
{braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/braindecode/models/deep4.py
RENAMED
|
@@ -109,12 +109,12 @@ class Deep4Net(EEGModuleMixin, nn.Sequential):
|
|
|
109
109
|
filter_length_3=10,
|
|
110
110
|
n_filters_4=200,
|
|
111
111
|
filter_length_4=10,
|
|
112
|
-
activation_first_conv_nonlin:
|
|
112
|
+
activation_first_conv_nonlin: nn.Module = nn.ELU,
|
|
113
113
|
first_pool_mode="max",
|
|
114
|
-
first_pool_nonlin:
|
|
115
|
-
activation_later_conv_nonlin:
|
|
114
|
+
first_pool_nonlin: nn.Module = nn.Identity,
|
|
115
|
+
activation_later_conv_nonlin: nn.Module = nn.ELU,
|
|
116
116
|
later_pool_mode="max",
|
|
117
|
-
later_pool_nonlin:
|
|
117
|
+
later_pool_nonlin: nn.Module = nn.Identity,
|
|
118
118
|
drop_prob=0.5,
|
|
119
119
|
split_first_layer=True,
|
|
120
120
|
batch_norm=True,
|
{braindecode-1.3.0.dev175435903 → braindecode-1.3.0.dev177628147}/braindecode/models/deepsleepnet.py
RENAMED
|
@@ -172,8 +172,8 @@ class DeepSleepNet(EEGModuleMixin, nn.Module):
|
|
|
172
172
|
n_times=None,
|
|
173
173
|
input_window_seconds=None,
|
|
174
174
|
sfreq=None,
|
|
175
|
-
activation_large:
|
|
176
|
-
activation_small:
|
|
175
|
+
activation_large: nn.Module = nn.ELU,
|
|
176
|
+
activation_small: nn.Module = nn.ReLU,
|
|
177
177
|
drop_prob: float = 0.5,
|
|
178
178
|
):
|
|
179
179
|
super().__init__(
|
|
@@ -252,7 +252,7 @@ class _SmallCNN(nn.Module):
|
|
|
252
252
|
The dropout rate for regularization. Values should be between 0 and 1.
|
|
253
253
|
"""
|
|
254
254
|
|
|
255
|
-
def __init__(self, activation:
|
|
255
|
+
def __init__(self, activation: nn.Module = nn.ReLU, drop_prob: float = 0.5):
|
|
256
256
|
super().__init__()
|
|
257
257
|
self.conv1 = nn.Sequential(
|
|
258
258
|
nn.Conv2d(
|
|
@@ -328,7 +328,7 @@ class _LargeCNN(nn.Module):
|
|
|
328
328
|
|
|
329
329
|
"""
|
|
330
330
|
|
|
331
|
-
def __init__(self, activation:
|
|
331
|
+
def __init__(self, activation: nn.Module = nn.ELU, drop_prob: float = 0.5):
|
|
332
332
|
super().__init__()
|
|
333
333
|
|
|
334
334
|
self.conv1 = nn.Sequential(
|