braindecode 1.5.0.dev985__tar.gz → 1.5.0.dev172791986__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {braindecode-1.5.0.dev985/braindecode.egg-info → braindecode-1.5.0.dev172791986}/PKG-INFO +1 -1
- {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/braindecode/models/__init__.py +10 -3
- {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/braindecode/models/bendr.py +92 -28
- {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/braindecode/models/biot.py +61 -1
- braindecode-1.5.0.dev172791986/braindecode/models/interpolated.py +182 -0
- {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/braindecode/models/labram.py +208 -266
- {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/braindecode/models/reve.py +1 -1
- {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/braindecode/models/signal_jepa.py +17 -0
- {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/braindecode/models/summary.csv +4 -0
- {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/braindecode/models/util.py +121 -20
- {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/braindecode/modules/__init__.py +2 -0
- braindecode-1.5.0.dev172791986/braindecode/modules/interpolation.py +201 -0
- braindecode-1.5.0.dev172791986/braindecode/version.py +1 -0
- {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986/braindecode.egg-info}/PKG-INFO +1 -1
- {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/braindecode.egg-info/SOURCES.txt +2 -0
- {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/docs/api.rst +6 -0
- {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/docs/whats_new.rst +40 -0
- braindecode-1.5.0.dev985/braindecode/version.py +0 -1
- {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/LICENSE.txt +0 -0
- {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/MANIFEST.in +0 -0
- {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/NOTICE.txt +0 -0
- {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/README.rst +0 -0
- {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/braindecode/__init__.py +0 -0
- {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/braindecode/augmentation/__init__.py +0 -0
- {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/braindecode/augmentation/base.py +0 -0
- {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/braindecode/augmentation/functional.py +0 -0
- {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/braindecode/augmentation/transforms.py +0 -0
- {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/braindecode/classifier.py +0 -0
- {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/braindecode/datasets/__init__.py +0 -0
- {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/braindecode/datasets/base.py +0 -0
- {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/braindecode/datasets/bbci.py +0 -0
- {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/braindecode/datasets/bcicomp.py +0 -0
- {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/braindecode/datasets/bids/__init__.py +0 -0
- {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/braindecode/datasets/bids/datasets.py +0 -0
- {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/braindecode/datasets/bids/format.py +0 -0
- {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/braindecode/datasets/bids/hub.py +0 -0
- {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/braindecode/datasets/bids/hub_format.py +0 -0
- {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/braindecode/datasets/bids/hub_io.py +0 -0
- {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/braindecode/datasets/bids/hub_validation.py +0 -0
- {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/braindecode/datasets/bids/iterable.py +0 -0
- {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/braindecode/datasets/chb_mit.py +0 -0
- {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/braindecode/datasets/mne.py +0 -0
- {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/braindecode/datasets/moabb.py +0 -0
- {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/braindecode/datasets/nmt.py +0 -0
- {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/braindecode/datasets/registry.py +0 -0
- {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/braindecode/datasets/siena.py +0 -0
- {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/braindecode/datasets/sleep_physio_challe_18.py +0 -0
- {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/braindecode/datasets/sleep_physionet.py +0 -0
- {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/braindecode/datasets/tuh.py +0 -0
- {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/braindecode/datasets/utils.py +0 -0
- {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/braindecode/datasets/xy.py +0 -0
- {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/braindecode/datautil/__init__.py +0 -0
- {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/braindecode/datautil/channel_utils.py +0 -0
- {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/braindecode/datautil/hub_formats.py +0 -0
- {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/braindecode/datautil/serialization.py +0 -0
- {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/braindecode/datautil/util.py +0 -0
- {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/braindecode/eegneuralnet.py +0 -0
- {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/braindecode/functional/__init__.py +0 -0
- {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/braindecode/functional/functions.py +0 -0
- {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/braindecode/functional/initialization.py +0 -0
- {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/braindecode/models/atcnet.py +0 -0
- {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/braindecode/models/attentionbasenet.py +0 -0
- {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/braindecode/models/attn_sleep.py +0 -0
- {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/braindecode/models/base.py +0 -0
- {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/braindecode/models/brainmodule.py +0 -0
- {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/braindecode/models/cbramod.py +0 -0
- {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/braindecode/models/codebrain.py +0 -0
- {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/braindecode/models/config.py +0 -0
- {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/braindecode/models/contrawr.py +0 -0
- {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/braindecode/models/ctnet.py +0 -0
- {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/braindecode/models/deep4.py +0 -0
- {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/braindecode/models/deepsleepnet.py +0 -0
- {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/braindecode/models/dgcnn.py +0 -0
- {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/braindecode/models/eegconformer.py +0 -0
- {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/braindecode/models/eeginception_erp.py +0 -0
- {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/braindecode/models/eeginception_mi.py +0 -0
- {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/braindecode/models/eegitnet.py +0 -0
- {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/braindecode/models/eegminer.py +0 -0
- {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/braindecode/models/eegnet.py +0 -0
- {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/braindecode/models/eegnex.py +0 -0
- {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/braindecode/models/eegpt.py +0 -0
- {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/braindecode/models/eegsimpleconv.py +0 -0
- {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/braindecode/models/eegsym.py +0 -0
- {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/braindecode/models/eegtcnet.py +0 -0
- {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/braindecode/models/fbcnet.py +0 -0
- {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/braindecode/models/fblightconvnet.py +0 -0
- {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/braindecode/models/fbmsnet.py +0 -0
- {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/braindecode/models/hybrid.py +0 -0
- {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/braindecode/models/ifnet.py +0 -0
- {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/braindecode/models/luna.py +0 -0
- {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/braindecode/models/medformer.py +0 -0
- {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/braindecode/models/msvtnet.py +0 -0
- {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/braindecode/models/patchedtransformer.py +0 -0
- {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/braindecode/models/sccnet.py +0 -0
- {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/braindecode/models/shallow_fbcsp.py +0 -0
- {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/braindecode/models/sinc_shallow.py +0 -0
- {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/braindecode/models/sleep_stager_blanco_2020.py +0 -0
- {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/braindecode/models/sleep_stager_chambon_2018.py +0 -0
- {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/braindecode/models/sparcnet.py +0 -0
- {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/braindecode/models/sstdpn.py +0 -0
- {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/braindecode/models/syncnet.py +0 -0
- {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/braindecode/models/tcn.py +0 -0
- {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/braindecode/models/tidnet.py +0 -0
- {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/braindecode/models/tsinception.py +0 -0
- {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/braindecode/models/usleep.py +0 -0
- {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/braindecode/modules/activation.py +0 -0
- {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/braindecode/modules/attention.py +0 -0
- {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/braindecode/modules/blocks.py +0 -0
- {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/braindecode/modules/convolution.py +0 -0
- {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/braindecode/modules/filter.py +0 -0
- {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/braindecode/modules/layers.py +0 -0
- {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/braindecode/modules/linear.py +0 -0
- {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/braindecode/modules/parametrization.py +0 -0
- {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/braindecode/modules/stats.py +0 -0
- {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/braindecode/modules/util.py +0 -0
- {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/braindecode/modules/wrapper.py +0 -0
- {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/braindecode/preprocessing/__init__.py +0 -0
- {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/braindecode/preprocessing/eegprep_preprocess.py +0 -0
- {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/braindecode/preprocessing/mne_preprocess.py +0 -0
- {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/braindecode/preprocessing/preprocess.py +0 -0
- {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/braindecode/preprocessing/util.py +0 -0
- {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/braindecode/preprocessing/windowers.py +0 -0
- {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/braindecode/regressor.py +0 -0
- {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/braindecode/samplers/__init__.py +0 -0
- {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/braindecode/samplers/base.py +0 -0
- {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/braindecode/samplers/ssl.py +0 -0
- {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/braindecode/training/__init__.py +0 -0
- {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/braindecode/training/callbacks.py +0 -0
- {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/braindecode/training/losses.py +0 -0
- {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/braindecode/training/scoring.py +0 -0
- {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/braindecode/util.py +0 -0
- {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/braindecode/visualization/__init__.py +0 -0
- {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/braindecode/visualization/confusion_matrices.py +0 -0
- {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/braindecode/visualization/gradients.py +0 -0
- {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/braindecode.egg-info/dependency_links.txt +0 -0
- {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/braindecode.egg-info/requires.txt +0 -0
- {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/braindecode.egg-info/top_level.txt +0 -0
- {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/docs/Makefile +0 -0
- {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/docs/_templates/autosummary/class.rst +0 -0
- {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/docs/_templates/autosummary/class_in_subdir.rst +0 -0
- {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/docs/_templates/autosummary/function.rst +0 -0
- {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/docs/_templates/autosummary/function_in_subdir.rst +0 -0
- {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/docs/cite.rst +0 -0
- {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/docs/conf.py +0 -0
- {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/docs/help.rst +0 -0
- {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/docs/index.rst +0 -0
- {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/docs/install/install.rst +0 -0
- {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/docs/install/install_pip.rst +0 -0
- {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/docs/install/install_source.rst +0 -0
- {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/docs/models/categorization/attention.rst +0 -0
- {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/docs/models/categorization/channel.rst +0 -0
- {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/docs/models/categorization/convolution.rst +0 -0
- {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/docs/models/categorization/filterbank.rst +0 -0
- {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/docs/models/categorization/gnn.rst +0 -0
- {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/docs/models/categorization/interpretable.rst +0 -0
- {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/docs/models/categorization/lbm.rst +0 -0
- {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/docs/models/categorization/recurrent.rst +0 -0
- {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/docs/models/categorization/spd.rst +0 -0
- {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/docs/models/models.rst +0 -0
- {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/docs/models/models_categorization.rst +0 -0
- {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/docs/models/models_table.rst +0 -0
- {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/docs/models/models_visualization.rst +0 -0
- {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/docs/sg_execution_times.rst +0 -0
- {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/pyproject.toml +0 -0
- {braindecode-1.5.0.dev985 → braindecode-1.5.0.dev172791986}/setup.cfg +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: braindecode
|
|
3
|
-
Version: 1.5.0.
|
|
3
|
+
Version: 1.5.0.dev172791986
|
|
4
4
|
Summary: Deep learning software to decode EEG, ECG or MEG signals
|
|
5
5
|
Author-email: Robin Tibor Schirrmeister <robintibor@gmail.com>, Bruno Aristimunha Pinto <b.aristimunha@gmail.com>, Alexandre Gramfort <agramfort@meta.com>
|
|
6
6
|
Maintainer-email: Alexandre Gramfort <agramfort@meta.com>, Bruno Aristimunha Pinto <b.aristimunha@gmail.com>, Robin Tibor Schirrmeister <robintibor@gmail.com>
|
|
@@ -4,8 +4,8 @@ from .atcnet import ATCNet
|
|
|
4
4
|
from .attentionbasenet import AttentionBaseNet
|
|
5
5
|
from .attn_sleep import AttnSleep
|
|
6
6
|
from .base import EEGModuleMixin
|
|
7
|
-
from .bendr import BENDR
|
|
8
|
-
from .biot import BIOT
|
|
7
|
+
from .bendr import BENDR, InterpolatedBENDR
|
|
8
|
+
from .biot import BIOT, InterpolatedBIOT
|
|
9
9
|
from .brainmodule import BrainModule
|
|
10
10
|
from .cbramod import CBraMod
|
|
11
11
|
from .codebrain import CodeBrain
|
|
@@ -30,7 +30,8 @@ from .fblightconvnet import FBLightConvNet
|
|
|
30
30
|
from .fbmsnet import FBMSNet
|
|
31
31
|
from .hybrid import HybridNet
|
|
32
32
|
from .ifnet import IFNet
|
|
33
|
-
from .
|
|
33
|
+
from .interpolated import InterpolatedModel
|
|
34
|
+
from .labram import InterpolatedLaBraM, Labram
|
|
34
35
|
from .luna import LUNA
|
|
35
36
|
from .medformer import MEDFormer
|
|
36
37
|
from .msvtnet import MSVTNet
|
|
@@ -39,6 +40,7 @@ from .reve import REVE
|
|
|
39
40
|
from .sccnet import SCCNet
|
|
40
41
|
from .shallow_fbcsp import ShallowFBCSPNet
|
|
41
42
|
from .signal_jepa import (
|
|
43
|
+
InterpolatedSignalJEPA,
|
|
42
44
|
SignalJEPA,
|
|
43
45
|
SignalJEPA_Contextual,
|
|
44
46
|
SignalJEPA_PostLocal,
|
|
@@ -97,6 +99,11 @@ __all__ = [
|
|
|
97
99
|
"FBMSNet",
|
|
98
100
|
"HybridNet",
|
|
99
101
|
"IFNet",
|
|
102
|
+
"InterpolatedBENDR",
|
|
103
|
+
"InterpolatedBIOT",
|
|
104
|
+
"InterpolatedLaBraM",
|
|
105
|
+
"InterpolatedModel",
|
|
106
|
+
"InterpolatedSignalJEPA",
|
|
100
107
|
"Labram",
|
|
101
108
|
"LUNA",
|
|
102
109
|
"extract_channel_locations_from_chs_info",
|
|
@@ -1,12 +1,60 @@
|
|
|
1
1
|
import copy
|
|
2
2
|
|
|
3
|
+
import numpy as np
|
|
3
4
|
import torch
|
|
4
5
|
from einops.layers.torch import Rearrange
|
|
5
6
|
from torch import nn
|
|
6
|
-
from torch.nn.utils.parametrize import register_parametrization
|
|
7
7
|
|
|
8
8
|
from braindecode.models.base import EEGModuleMixin
|
|
9
|
-
|
|
9
|
+
|
|
10
|
+
# The 20 channels used to pre-train BENDR, in the order expected by the
|
|
11
|
+
# `braindecode/braindecode-bendr` checkpoint. The first 19 entries are the
|
|
12
|
+
# EEG channels taken verbatim from `dn3.transforms.instance.To1020.EEG_20_div`
|
|
13
|
+
# (https://github.com/SPOClab-ca/dn3/blob/master/dn3/transforms/instance.py).
|
|
14
|
+
# Their positions come from MNE's ``standard_1005`` montage (T5/T6 are
|
|
15
|
+
# legacy names that share positions with P7/P8 there).
|
|
16
|
+
#
|
|
17
|
+
# The 20th entry is ``SCALE``, a relative-amplitude statistic (not an
|
|
18
|
+
# electrode) appended by ``To1020(include_scale_ch=True)`` during
|
|
19
|
+
# pre-training. Since it has no physical position, the ``loc`` below is
|
|
20
|
+
# the centroid of the 19 EEG positions — purely a placeholder so that
|
|
21
|
+
# :class:`~braindecode.modules.ChannelInterpolationLayer` (used by
|
|
22
|
+
# :class:`InterpolatedBENDR`) can build a valid spline interpolation
|
|
23
|
+
# matrix. It is NOT the SCALE the pre-training pipeline computes
|
|
24
|
+
# (which is an RMS-like amplitude via ``dn3.MappingDeep1010``); users
|
|
25
|
+
# who need a faithful SCALE must compute it themselves and feed 20
|
|
26
|
+
# channels to :class:`BENDR` directly.
|
|
27
|
+
_BENDR_TARGET_CHS_TUPLES: list[tuple[str, tuple[float, float, float]]] = [
|
|
28
|
+
("FP1", (-0.0294367, +0.0839171, -0.0069900)), # standard_1005
|
|
29
|
+
("FP2", (+0.0298723, +0.0848959, -0.0070800)), # standard_1005
|
|
30
|
+
("F7", (-0.0702629, +0.0424743, -0.0114200)), # standard_1005
|
|
31
|
+
("F3", (-0.0502438, +0.0531112, +0.0421920)), # standard_1005
|
|
32
|
+
("FZ", (+0.0003122, +0.0585120, +0.0664620)), # standard_1005
|
|
33
|
+
("F4", (+0.0518362, +0.0543048, +0.0408140)), # standard_1005
|
|
34
|
+
("F8", (+0.0730431, +0.0444217, -0.0120000)), # standard_1005
|
|
35
|
+
("T7", (-0.0841611, -0.0160187, -0.0093460)), # standard_1005
|
|
36
|
+
("C3", (-0.0653581, -0.0116317, +0.0643580)), # standard_1005
|
|
37
|
+
("CZ", (+0.0004009, -0.0091670, +0.1002440)), # standard_1005
|
|
38
|
+
("C4", (+0.0671179, -0.0109003, +0.0635800)), # standard_1005
|
|
39
|
+
("T8", (+0.0850799, -0.0150203, -0.0094900)), # standard_1005
|
|
40
|
+
("T5", (-0.0724343, -0.0734527, -0.0024870)), # standard_1005 (= P7)
|
|
41
|
+
("P3", (-0.0530073, -0.0787878, +0.0559400)), # standard_1005
|
|
42
|
+
("PZ", (+0.0003247, -0.0811150, +0.0826150)), # standard_1005
|
|
43
|
+
("P4", (+0.0556667, -0.0785602, +0.0565610)), # standard_1005
|
|
44
|
+
("T6", (+0.0730557, -0.0730683, -0.0025400)), # standard_1005 (= P8)
|
|
45
|
+
("O1", (-0.0294134, -0.1124490, +0.0088390)), # standard_1005
|
|
46
|
+
("O2", (+0.0298426, -0.1121560, +0.0088000)), # standard_1005
|
|
47
|
+
(
|
|
48
|
+
"SCALE",
|
|
49
|
+
(+0.0006439, -0.0131942, +0.0278448),
|
|
50
|
+
), # centroid of the 19 EEG positions (placeholder; see comment above)
|
|
51
|
+
]
|
|
52
|
+
|
|
53
|
+
_BENDR_TARGET_CHS_INFO: list[dict] = [
|
|
54
|
+
{"ch_name": ch, "kind": "eeg", "loc": np.asarray(loc, dtype=float)}
|
|
55
|
+
for ch, loc in _BENDR_TARGET_CHS_TUPLES
|
|
56
|
+
]
|
|
57
|
+
BENDR_CHANNEL_ORDER: list[str] = [ch for ch, _ in _BENDR_TARGET_CHS_TUPLES]
|
|
10
58
|
|
|
11
59
|
|
|
12
60
|
class BENDR(EEGModuleMixin, nn.Module):
|
|
@@ -195,15 +243,6 @@ class BENDR(EEGModuleMixin, nn.Module):
|
|
|
195
243
|
The contextualizer is still created (to allow loading pretrained weights) but is not
|
|
196
244
|
used in the forward pass. Requires input length of at least
|
|
197
245
|
``4 * product(enc_downsample)`` samples (384 with default downsampling of 96x).
|
|
198
|
-
n_chans_pretrained : int or None, default=None
|
|
199
|
-
Number of input channels the pretrained weights expect (20 for the official BENDR
|
|
200
|
-
checkpoint). When set and ``n_chans != n_chans_pretrained``, a 1x1 Conv1d with
|
|
201
|
-
max-norm constraint projects from ``n_chans`` to ``n_chans_pretrained`` before the
|
|
202
|
-
encoder. This allows fine-tuning pretrained BENDR on datasets with arbitrary channel
|
|
203
|
-
counts. When using ``from_pretrained``, pass ``strict=False`` since the checkpoint
|
|
204
|
-
will not contain ``channel_projection`` weights.
|
|
205
|
-
chan_proj_max_norm : float, default=1.0
|
|
206
|
-
Max-norm constraint value for the channel projection weights.
|
|
207
246
|
"""
|
|
208
247
|
|
|
209
248
|
def __init__(
|
|
@@ -231,8 +270,6 @@ class BENDR(EEGModuleMixin, nn.Module):
|
|
|
231
270
|
start_token=-5, # Value for start token embedding
|
|
232
271
|
final_layer=True, # Whether to include the final linear layer
|
|
233
272
|
encoder_only=False, # If True, bypass contextualizer and use 4-chunk pooling
|
|
234
|
-
n_chans_pretrained=None, # Expected input channels of pretrained weights
|
|
235
|
-
chan_proj_max_norm=1.0, # Max-norm for channel projection weights
|
|
236
273
|
):
|
|
237
274
|
super().__init__(
|
|
238
275
|
n_outputs=n_outputs,
|
|
@@ -246,25 +283,34 @@ class BENDR(EEGModuleMixin, nn.Module):
|
|
|
246
283
|
# Keep these parameters if needed later, otherwise they are captured by the mixin
|
|
247
284
|
del n_outputs, n_chans, chs_info, n_times, input_window_seconds, sfreq
|
|
248
285
|
|
|
286
|
+
# If the user supplies chs_info, require it to match BENDR_CHANNEL_ORDER
|
|
287
|
+
# exactly (case-insensitive). Arbitrary channel sets should go through
|
|
288
|
+
# InterpolatedBENDR — same pattern as Labram / InterpolatedLaBraM.
|
|
289
|
+
# When chs_info is absent (the usual n_chans=20 path, incl.
|
|
290
|
+
# from_pretrained), no check is performed.
|
|
291
|
+
try:
|
|
292
|
+
_chs_info = self.chs_info
|
|
293
|
+
except ValueError:
|
|
294
|
+
_chs_info = None
|
|
295
|
+
if _chs_info is not None:
|
|
296
|
+
user_names = [ch["ch_name"] for ch in _chs_info] # type: ignore[index]
|
|
297
|
+
canonical = BENDR_CHANNEL_ORDER
|
|
298
|
+
if [n.lower() for n in user_names] != [n.lower() for n in canonical]:
|
|
299
|
+
raise ValueError(
|
|
300
|
+
f"BENDR requires chs_info to match BENDR_CHANNEL_ORDER exactly "
|
|
301
|
+
f"({len(canonical)} channels, specific order; last is 'SCALE'). "
|
|
302
|
+
f"Got {len(user_names)} channel(s). For arbitrary channel sets, "
|
|
303
|
+
f"use InterpolatedBENDR "
|
|
304
|
+
f"(from braindecode.models import InterpolatedBENDR)."
|
|
305
|
+
)
|
|
306
|
+
|
|
249
307
|
self.encoder_h = encoder_h
|
|
250
308
|
self.contextualizer_hidden = contextualizer_hidden
|
|
251
309
|
self.include_final_layer = final_layer
|
|
252
310
|
self.encoder_only = encoder_only
|
|
253
311
|
|
|
254
|
-
# Channel projection for pretrained weight compatibility
|
|
255
|
-
encoder_n_chans = self.n_chans
|
|
256
|
-
if n_chans_pretrained is not None and self.n_chans != n_chans_pretrained:
|
|
257
|
-
conv = nn.Conv1d(self.n_chans, n_chans_pretrained, 1, bias=False)
|
|
258
|
-
register_parametrization(
|
|
259
|
-
conv, "weight", MaxNormParametrize(chan_proj_max_norm)
|
|
260
|
-
)
|
|
261
|
-
self.channel_projection = conv
|
|
262
|
-
encoder_n_chans = n_chans_pretrained
|
|
263
|
-
else:
|
|
264
|
-
self.channel_projection = None
|
|
265
|
-
|
|
266
312
|
self.encoder = _ConvEncoderBENDR(
|
|
267
|
-
in_features=
|
|
313
|
+
in_features=self.n_chans,
|
|
268
314
|
encoder_h=encoder_h,
|
|
269
315
|
dropout=drop_prob,
|
|
270
316
|
projection_head=projection_head,
|
|
@@ -308,8 +354,6 @@ class BENDR(EEGModuleMixin, nn.Module):
|
|
|
308
354
|
self._build_head(n_outputs)
|
|
309
355
|
|
|
310
356
|
def forward(self, x, return_features=False):
|
|
311
|
-
if self.channel_projection is not None:
|
|
312
|
-
x = self.channel_projection(x)
|
|
313
357
|
encoded = self.encoder(x)
|
|
314
358
|
# encoded: [batch_size, encoder_h, n_encoded_times]
|
|
315
359
|
|
|
@@ -552,3 +596,23 @@ class _BENDRContextualizer(nn.Module):
|
|
|
552
596
|
# x: [batch_size, in_features, seq_len + 1]
|
|
553
597
|
|
|
554
598
|
return x
|
|
599
|
+
|
|
600
|
+
|
|
601
|
+
# -----------------------------------------------------------------------------
|
|
602
|
+
# InterpolatedBENDR — experimental channel-interpolation variant of BENDR
|
|
603
|
+
# -----------------------------------------------------------------------------
|
|
604
|
+
# Wraps :class:`BENDR` with an MNE-backed channel-interpolation layer that
|
|
605
|
+
# projects arbitrary user ``chs_info`` to the canonical 20-channel BENDR
|
|
606
|
+
# input (:data:`_BENDR_TARGET_CHS_INFO` — the 19 pre-training EEG channels
|
|
607
|
+
# plus a ``SCALE`` placeholder at the centroid of those 19 positions).
|
|
608
|
+
# Frozen by default; set ``trainable=True`` to fine-tune the projection.
|
|
609
|
+
#
|
|
610
|
+
# NOTE: the ``SCALE`` target has no physical position, so the row of the
|
|
611
|
+
# interpolation matrix that produces it is a spatial spline of the user's
|
|
612
|
+
# EEG channels — *not* the dn3 ``MappingDeep1010`` RMS statistic the
|
|
613
|
+
# checkpoint saw during pre-training. Expect degraded zero-shot transfer
|
|
614
|
+
# from the SCALE channel; downstream fine-tuning should still work.
|
|
615
|
+
|
|
616
|
+
from braindecode.models.interpolated import InterpolatedModel # noqa: E402
|
|
617
|
+
|
|
618
|
+
InterpolatedBENDR = InterpolatedModel(BENDR, _BENDR_TARGET_CHS_INFO)
|
|
@@ -1,12 +1,57 @@
|
|
|
1
1
|
import math
|
|
2
2
|
from warnings import warn
|
|
3
3
|
|
|
4
|
+
import numpy as np
|
|
4
5
|
import torch
|
|
5
6
|
import torch.nn as nn
|
|
6
7
|
from linear_attention_transformer import LinearAttentionTransformer
|
|
7
8
|
|
|
8
9
|
from braindecode.models.base import EEGModuleMixin
|
|
9
10
|
|
|
11
|
+
# -----------------------------------------------------------------------------
|
|
12
|
+
# Canonical channel order for InterpolatedBIOT — the 18-channel TCP bipolar
|
|
13
|
+
# montage used by BIOT's shhs-prest and six-datasets pretrained checkpoints.
|
|
14
|
+
# Source: https://github.com/ycq091044/BIOT (README + datasets/TUAB/process.py
|
|
15
|
+
# + datasets/SHHS/process.py). Indices 0-15 are the TCP 16-channel bipolar
|
|
16
|
+
# derivations; indices 16-17 are SHHS differential channels.
|
|
17
|
+
#
|
|
18
|
+
# The `loc` values are only used to build an MNE interpolation matrix for
|
|
19
|
+
# InterpolatedBIOT. All entries are bipolar / differential derivations.
|
|
20
|
+
# TODO: positions are stored as the midpoint of the two constituent
|
|
21
|
+
# electrodes. This is a simplification — a bipolar signal V(A)-V(B) cannot
|
|
22
|
+
# be faithfully recovered by spatial interpolation at the midpoint. Revisit
|
|
23
|
+
# in a follow-up PR (e.g. a dedicated BipolarDerivationLayer).
|
|
24
|
+
# -----------------------------------------------------------------------------
|
|
25
|
+
|
|
26
|
+
# fmt: off
|
|
27
|
+
_BIOT_TARGET_CHS_TUPLES: list[tuple[str, tuple[float, float, float]]] = [
|
|
28
|
+
("FP1-F7", (-0.04984980, 0.06319570, -0.00920500)),
|
|
29
|
+
("F7-T7", (-0.07721200, 0.01322780, -0.01038300)),
|
|
30
|
+
("T7-P7", (-0.07829770, -0.04473570, -0.00591650)),
|
|
31
|
+
("P7-O1", (-0.05092385, -0.09295085, 0.00317600)),
|
|
32
|
+
("FP2-F8", (0.05145770, 0.06465880, -0.00954000)),
|
|
33
|
+
("F8-T8", (0.07906150, 0.01470070, -0.01074500)),
|
|
34
|
+
("T8-P8", (0.07906780, -0.04404430, -0.00601500)),
|
|
35
|
+
("P8-O2", (0.05144915, -0.09261215, 0.00313000)),
|
|
36
|
+
("FP1-F3", (-0.03984025, 0.06851415, 0.01760100)),
|
|
37
|
+
("F3-C3", (-0.05780095, 0.02073975, 0.05327500)),
|
|
38
|
+
("C3-P3", (-0.05918270, -0.04520975, 0.06014900)),
|
|
39
|
+
("P3-O1", (-0.04121035, -0.09561840, 0.03238950)),
|
|
40
|
+
("FP2-F4", (0.04085425, 0.06960035, 0.01686700)),
|
|
41
|
+
("F4-C4", (0.05947705, 0.02170225, 0.05219700)),
|
|
42
|
+
("C4-P4", (0.06139230, -0.04473025, 0.06007050)),
|
|
43
|
+
("P4-O2", (0.04275465, -0.09535810, 0.03268050)),
|
|
44
|
+
("C3-A2", (0.01021790, -0.01832050, -0.00183650)),
|
|
45
|
+
("C4-A1", (-0.00947910, -0.01794500, -0.00220300)),
|
|
46
|
+
]
|
|
47
|
+
# fmt: on
|
|
48
|
+
|
|
49
|
+
_BIOT_TARGET_CHS_INFO = [
|
|
50
|
+
{"ch_name": ch, "kind": "eeg", "loc": np.asarray(loc, dtype=float)}
|
|
51
|
+
for ch, loc in _BIOT_TARGET_CHS_TUPLES
|
|
52
|
+
]
|
|
53
|
+
BIOT_CHANNEL_ORDER = [ch for ch, _ in _BIOT_TARGET_CHS_TUPLES]
|
|
54
|
+
|
|
10
55
|
|
|
11
56
|
class BIOT(EEGModuleMixin, nn.Module):
|
|
12
57
|
r"""BIOT from Yang et al (2023) [Yang2023]_
|
|
@@ -439,7 +484,9 @@ class _BIOTEncoder(nn.Module):
|
|
|
439
484
|
self.channel_tokens = nn.Embedding(
|
|
440
485
|
num_embeddings=n_chans, embedding_dim=emb_size
|
|
441
486
|
)
|
|
442
|
-
self.register_buffer(
|
|
487
|
+
self.register_buffer(
|
|
488
|
+
"index", torch.arange(n_chans, dtype=torch.long), persistent=False
|
|
489
|
+
)
|
|
443
490
|
|
|
444
491
|
def stft(self, sample):
|
|
445
492
|
"""
|
|
@@ -553,3 +600,16 @@ class _BIOTEncoder(nn.Module):
|
|
|
553
600
|
# (batch_size, emb)
|
|
554
601
|
emb = self.transformer(emb).mean(dim=1)
|
|
555
602
|
return emb
|
|
603
|
+
|
|
604
|
+
|
|
605
|
+
# -----------------------------------------------------------------------------
|
|
606
|
+
# InterpolatedBIOT — experimental channel-interpolation variant of BIOT
|
|
607
|
+
# -----------------------------------------------------------------------------
|
|
608
|
+
# Wraps :class:`BIOT` with an MNE-backed channel-interpolation layer that
|
|
609
|
+
# projects arbitrary user ``chs_info`` to the canonical 18-channel BIOT
|
|
610
|
+
# montage (:data:`_BIOT_TARGET_CHS_INFO`). Frozen by default; set
|
|
611
|
+
# ``trainable=True`` to fine-tune the projection matrix.
|
|
612
|
+
|
|
613
|
+
from braindecode.models.interpolated import InterpolatedModel # noqa: E402
|
|
614
|
+
|
|
615
|
+
InterpolatedBIOT = InterpolatedModel(BIOT, _BIOT_TARGET_CHS_INFO)
|
|
@@ -0,0 +1,182 @@
|
|
|
1
|
+
# Authors: Pierre Guetschel
|
|
2
|
+
#
|
|
3
|
+
# License: BSD (3-clause)
|
|
4
|
+
from __future__ import annotations
|
|
5
|
+
|
|
6
|
+
from typing import Literal, Optional, Type
|
|
7
|
+
|
|
8
|
+
import numpy as np
|
|
9
|
+
import torch.nn as nn
|
|
10
|
+
|
|
11
|
+
from braindecode.modules.interpolation import ChannelInterpolationLayer
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def InterpolatedModel(
|
|
15
|
+
model_cls: Type,
|
|
16
|
+
target_chs_info: list[dict],
|
|
17
|
+
name: Optional[str] = None,
|
|
18
|
+
) -> Type:
|
|
19
|
+
"""Return a subclass of ``model_cls`` that interpolates channels to ``target_chs_info``.
|
|
20
|
+
|
|
21
|
+
.. warning:: Experimental. Public API may change without a deprecation cycle.
|
|
22
|
+
|
|
23
|
+
Parameters
|
|
24
|
+
----------
|
|
25
|
+
model_cls : type
|
|
26
|
+
A braindecode model class (subclass of
|
|
27
|
+
:class:`~braindecode.models.base.EEGModuleMixin`).
|
|
28
|
+
target_chs_info : list of dict
|
|
29
|
+
The canonical channel set the backbone expects internally. Every
|
|
30
|
+
instance of the returned class projects its input channels to
|
|
31
|
+
this set via :class:`~braindecode.modules.ChannelInterpolationLayer`.
|
|
32
|
+
name : str, optional
|
|
33
|
+
``__name__`` to assign to the returned class. Defaults to
|
|
34
|
+
``f"Interpolated{model_cls.__name__}"``.
|
|
35
|
+
|
|
36
|
+
Returns
|
|
37
|
+
-------
|
|
38
|
+
type
|
|
39
|
+
A new subclass of ``model_cls`` whose ``__init__`` accepts
|
|
40
|
+
arbitrary user ``chs_info`` and automatically inserts a frozen
|
|
41
|
+
(by default) channel-interpolation layer before the backbone.
|
|
42
|
+
"""
|
|
43
|
+
_is_sequential = issubclass(model_cls, nn.Sequential)
|
|
44
|
+
|
|
45
|
+
class _Interpolated(model_cls):
|
|
46
|
+
_TARGET_CHS_INFO = target_chs_info
|
|
47
|
+
|
|
48
|
+
def __init__(
|
|
49
|
+
self,
|
|
50
|
+
chs_info,
|
|
51
|
+
n_outputs=None,
|
|
52
|
+
n_times=None,
|
|
53
|
+
input_window_seconds=None,
|
|
54
|
+
sfreq=None,
|
|
55
|
+
n_chans=None,
|
|
56
|
+
interpolation_method: str = "spline",
|
|
57
|
+
interpolation_mode: Literal["always", "name_match"] = "name_match",
|
|
58
|
+
trainable: bool = False,
|
|
59
|
+
**kwargs,
|
|
60
|
+
):
|
|
61
|
+
# Signal-related params are declared EXPLICITLY here so that
|
|
62
|
+
# skorch's ``EEGClassifier._set_signal_args`` (which inspects
|
|
63
|
+
# the ``__init__`` signature via ``inspect.signature``) can
|
|
64
|
+
# auto-forward them from the training dataset. They would not
|
|
65
|
+
# be discoverable if they were collapsed into ``**kwargs``.
|
|
66
|
+
# ``n_chans`` is declared for the same discoverability reason
|
|
67
|
+
# but intentionally ignored: the backbone must see
|
|
68
|
+
# ``len(target_chs_info)`` derived from ``chs_info``.
|
|
69
|
+
del n_chans
|
|
70
|
+
# Backbone init uses the target channels. During this call,
|
|
71
|
+
# some backbones run a dummy forward (e.g. to size the head);
|
|
72
|
+
# ``self.interpolation_layer`` does not exist yet — the
|
|
73
|
+
# ``forward`` override below falls back to pass-through when
|
|
74
|
+
# the attribute is absent. Assigning an ``nn.Identity()``
|
|
75
|
+
# before ``super()`` is impossible: ``nn.Module.__setattr__``
|
|
76
|
+
# requires ``nn.Module.__init__`` to have run, and the chain
|
|
77
|
+
# would wipe ``self._modules`` when it reaches it again.
|
|
78
|
+
super().__init__(
|
|
79
|
+
chs_info=target_chs_info,
|
|
80
|
+
n_outputs=n_outputs,
|
|
81
|
+
n_times=n_times,
|
|
82
|
+
input_window_seconds=input_window_seconds,
|
|
83
|
+
sfreq=sfreq,
|
|
84
|
+
**kwargs,
|
|
85
|
+
)
|
|
86
|
+
|
|
87
|
+
layer = ChannelInterpolationLayer(
|
|
88
|
+
src_chs_info=chs_info,
|
|
89
|
+
tgt_chs_info=target_chs_info,
|
|
90
|
+
mode=interpolation_mode,
|
|
91
|
+
method=interpolation_method,
|
|
92
|
+
trainable=trainable,
|
|
93
|
+
)
|
|
94
|
+
if _is_sequential:
|
|
95
|
+
# For nn.Sequential subclasses, prepend the interpolation
|
|
96
|
+
# layer so that nn.Sequential.forward runs it first.
|
|
97
|
+
# Registering via attribute assignment appends to _modules;
|
|
98
|
+
# instead we rebuild _modules with the layer first.
|
|
99
|
+
old_modules = list(self._modules.items()) # type: ignore[has-type]
|
|
100
|
+
self._modules.clear() # type: ignore[has-type]
|
|
101
|
+
self._modules["interpolation_layer"] = layer # type: ignore[index]
|
|
102
|
+
for k, v in old_modules:
|
|
103
|
+
self._modules[k] = v # type: ignore[index]
|
|
104
|
+
else:
|
|
105
|
+
self.interpolation_layer = layer
|
|
106
|
+
|
|
107
|
+
# Rebind private attrs so the user-facing view (.chs_info,
|
|
108
|
+
# .n_chans, .input_shape, build_model_config) reflects the
|
|
109
|
+
# user's channels. Properties are NOT overridden — we mutate
|
|
110
|
+
# the private attrs the base-class properties read from.
|
|
111
|
+
self._chs_info = chs_info
|
|
112
|
+
self._n_chans = len(chs_info)
|
|
113
|
+
|
|
114
|
+
if not _is_sequential:
|
|
115
|
+
|
|
116
|
+
def forward(self, x, *args, **kwargs):
|
|
117
|
+
# During super().__init__() the interpolation_layer attr
|
|
118
|
+
# does not exist yet; any dummy forward call (e.g. from
|
|
119
|
+
# get_output_shape) must pass through unchanged so the
|
|
120
|
+
# backbone sees its expected target-shape input.
|
|
121
|
+
# Forward *args / **kwargs so backbone-specific flags
|
|
122
|
+
# like ``return_features`` keep working through the
|
|
123
|
+
# wrapper.
|
|
124
|
+
interp = getattr(self, "interpolation_layer", None)
|
|
125
|
+
if interp is not None:
|
|
126
|
+
x = interp(x)
|
|
127
|
+
return super().forward(x, *args, **kwargs)
|
|
128
|
+
|
|
129
|
+
_Interpolated.__name__ = name or f"Interpolated{model_cls.__name__}"
|
|
130
|
+
_Interpolated.__qualname__ = _Interpolated.__name__
|
|
131
|
+
# Propagate the backbone docstring so Sphinx and the categorization tests
|
|
132
|
+
# can read the class badges. Prepend a short header so the class shows
|
|
133
|
+
# up clearly in documentation as distinct from the backbone.
|
|
134
|
+
backbone_doc = model_cls.__doc__ or ""
|
|
135
|
+
_Interpolated.__doc__ = (
|
|
136
|
+
f"Channel-interpolating wrapper around :class:`{model_cls.__name__}`.\n\n"
|
|
137
|
+
":bdg-dark-line:`Channel`\n\n"
|
|
138
|
+
f"Accepts arbitrary user ``chs_info`` and projects them to the\n"
|
|
139
|
+
f"backbone's canonical channel set via\n"
|
|
140
|
+
f":class:`~braindecode.modules.ChannelInterpolationLayer`.\n\n"
|
|
141
|
+
f"For all other parameters and behavior see the backbone\n"
|
|
142
|
+
f"documentation reproduced below.\n\n" + backbone_doc
|
|
143
|
+
)
|
|
144
|
+
return _Interpolated
|
|
145
|
+
|
|
146
|
+
|
|
147
|
+
def _build_chs_info_from_montage(names: list[str], montage: str) -> list[dict]:
|
|
148
|
+
"""Build a ``list[dict]`` ``chs_info`` from channel names + an MNE montage.
|
|
149
|
+
|
|
150
|
+
Each returned dict has ``ch_name``, ``kind="eeg"``, and ``loc`` (shape
|
|
151
|
+
``(3,)``). Used by braindecode's shipped ``Interpolated*`` variants to
|
|
152
|
+
turn a bare list of canonical channel names into the dict form
|
|
153
|
+
``ChannelInterpolationLayer`` expects.
|
|
154
|
+
|
|
155
|
+
Parameters
|
|
156
|
+
----------
|
|
157
|
+
names : list of str
|
|
158
|
+
Channel names in the desired order.
|
|
159
|
+
montage : str
|
|
160
|
+
Name of an MNE standard montage (e.g. ``"standard_1005"``).
|
|
161
|
+
|
|
162
|
+
Returns
|
|
163
|
+
-------
|
|
164
|
+
list of dict
|
|
165
|
+
|
|
166
|
+
Raises
|
|
167
|
+
------
|
|
168
|
+
ValueError
|
|
169
|
+
If a name is not found in the montage.
|
|
170
|
+
"""
|
|
171
|
+
import mne
|
|
172
|
+
|
|
173
|
+
mtg = mne.channels.make_standard_montage(montage)
|
|
174
|
+
ch_pos = mtg.get_positions()["ch_pos"]
|
|
175
|
+
out = []
|
|
176
|
+
for n in names:
|
|
177
|
+
if n not in ch_pos:
|
|
178
|
+
raise ValueError(f"Channel {n!r} not found in montage {montage!r}.")
|
|
179
|
+
out.append(
|
|
180
|
+
{"ch_name": n, "kind": "eeg", "loc": np.asarray(ch_pos[n], dtype=float)}
|
|
181
|
+
)
|
|
182
|
+
return out
|