braindecode 1.3.2.dev168310904__tar.gz → 1.3.2.dev168517820__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {braindecode-1.3.2.dev168310904/braindecode.egg-info → braindecode-1.3.2.dev168517820}/PKG-INFO +1 -1
- {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/braindecode/augmentation/transforms.py +6 -4
- {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/braindecode/models/base.py +153 -20
- {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/braindecode/models/bendr.py +17 -6
- {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/braindecode/models/biot.py +23 -5
- {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/braindecode/models/cbramod.py +19 -11
- braindecode-1.3.2.dev168517820/braindecode/models/deepsleepnet.py +477 -0
- {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/braindecode/models/eeginception_mi.py +2 -1
- {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/braindecode/models/eegitnet.py +4 -3
- {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/braindecode/models/eegpt.py +15 -2
- {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/braindecode/models/labram.py +15 -1
- {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/braindecode/models/luna.py +5 -16
- {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/braindecode/models/reve.py +36 -20
- {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/braindecode/models/signal_jepa.py +63 -13
- {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/braindecode/models/util.py +1 -1
- braindecode-1.3.2.dev168517820/braindecode/version.py +1 -0
- {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820/braindecode.egg-info}/PKG-INFO +1 -1
- {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/docs/whats_new.rst +44 -0
- braindecode-1.3.2.dev168310904/braindecode/models/deepsleepnet.py +0 -417
- braindecode-1.3.2.dev168310904/braindecode/version.py +0 -1
- {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/LICENSE.txt +0 -0
- {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/MANIFEST.in +0 -0
- {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/NOTICE.txt +0 -0
- {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/README.rst +0 -0
- {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/braindecode/__init__.py +0 -0
- {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/braindecode/augmentation/__init__.py +0 -0
- {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/braindecode/augmentation/base.py +0 -0
- {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/braindecode/augmentation/functional.py +0 -0
- {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/braindecode/classifier.py +0 -0
- {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/braindecode/datasets/__init__.py +0 -0
- {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/braindecode/datasets/base.py +0 -0
- {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/braindecode/datasets/bbci.py +0 -0
- {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/braindecode/datasets/bcicomp.py +0 -0
- {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/braindecode/datasets/bids/__init__.py +0 -0
- {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/braindecode/datasets/bids/datasets.py +0 -0
- {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/braindecode/datasets/bids/format.py +0 -0
- {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/braindecode/datasets/bids/hub.py +0 -0
- {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/braindecode/datasets/bids/hub_format.py +0 -0
- {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/braindecode/datasets/bids/hub_io.py +0 -0
- {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/braindecode/datasets/bids/hub_validation.py +0 -0
- {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/braindecode/datasets/bids/iterable.py +0 -0
- {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/braindecode/datasets/chb_mit.py +0 -0
- {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/braindecode/datasets/mne.py +0 -0
- {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/braindecode/datasets/moabb.py +0 -0
- {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/braindecode/datasets/nmt.py +0 -0
- {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/braindecode/datasets/registry.py +0 -0
- {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/braindecode/datasets/siena.py +0 -0
- {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/braindecode/datasets/sleep_physio_challe_18.py +0 -0
- {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/braindecode/datasets/sleep_physionet.py +0 -0
- {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/braindecode/datasets/tuh.py +0 -0
- {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/braindecode/datasets/utils.py +0 -0
- {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/braindecode/datasets/xy.py +0 -0
- {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/braindecode/datautil/__init__.py +0 -0
- {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/braindecode/datautil/channel_utils.py +0 -0
- {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/braindecode/datautil/hub_formats.py +0 -0
- {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/braindecode/datautil/serialization.py +0 -0
- {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/braindecode/datautil/util.py +0 -0
- {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/braindecode/eegneuralnet.py +0 -0
- {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/braindecode/functional/__init__.py +0 -0
- {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/braindecode/functional/functions.py +0 -0
- {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/braindecode/functional/initialization.py +0 -0
- {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/braindecode/models/__init__.py +0 -0
- {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/braindecode/models/atcnet.py +0 -0
- {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/braindecode/models/attentionbasenet.py +0 -0
- {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/braindecode/models/attn_sleep.py +0 -0
- {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/braindecode/models/brainmodule.py +0 -0
- {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/braindecode/models/config.py +0 -0
- {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/braindecode/models/contrawr.py +0 -0
- {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/braindecode/models/ctnet.py +0 -0
- {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/braindecode/models/deep4.py +0 -0
- {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/braindecode/models/dgcnn.py +0 -0
- {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/braindecode/models/eegconformer.py +0 -0
- {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/braindecode/models/eeginception_erp.py +0 -0
- {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/braindecode/models/eegminer.py +0 -0
- {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/braindecode/models/eegnet.py +0 -0
- {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/braindecode/models/eegnex.py +0 -0
- {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/braindecode/models/eegsimpleconv.py +0 -0
- {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/braindecode/models/eegsym.py +0 -0
- {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/braindecode/models/eegtcnet.py +0 -0
- {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/braindecode/models/fbcnet.py +0 -0
- {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/braindecode/models/fblightconvnet.py +0 -0
- {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/braindecode/models/fbmsnet.py +0 -0
- {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/braindecode/models/hybrid.py +0 -0
- {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/braindecode/models/ifnet.py +0 -0
- {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/braindecode/models/medformer.py +0 -0
- {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/braindecode/models/msvtnet.py +0 -0
- {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/braindecode/models/patchedtransformer.py +0 -0
- {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/braindecode/models/sccnet.py +0 -0
- {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/braindecode/models/shallow_fbcsp.py +0 -0
- {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/braindecode/models/sinc_shallow.py +0 -0
- {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/braindecode/models/sleep_stager_blanco_2020.py +0 -0
- {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/braindecode/models/sleep_stager_chambon_2018.py +0 -0
- {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/braindecode/models/sparcnet.py +0 -0
- {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/braindecode/models/sstdpn.py +0 -0
- {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/braindecode/models/summary.csv +0 -0
- {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/braindecode/models/syncnet.py +0 -0
- {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/braindecode/models/tcn.py +0 -0
- {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/braindecode/models/tidnet.py +0 -0
- {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/braindecode/models/tsinception.py +0 -0
- {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/braindecode/models/usleep.py +0 -0
- {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/braindecode/modules/__init__.py +0 -0
- {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/braindecode/modules/activation.py +0 -0
- {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/braindecode/modules/attention.py +0 -0
- {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/braindecode/modules/blocks.py +0 -0
- {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/braindecode/modules/convolution.py +0 -0
- {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/braindecode/modules/filter.py +0 -0
- {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/braindecode/modules/layers.py +0 -0
- {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/braindecode/modules/linear.py +0 -0
- {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/braindecode/modules/parametrization.py +0 -0
- {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/braindecode/modules/stats.py +0 -0
- {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/braindecode/modules/util.py +0 -0
- {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/braindecode/modules/wrapper.py +0 -0
- {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/braindecode/preprocessing/__init__.py +0 -0
- {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/braindecode/preprocessing/eegprep_preprocess.py +0 -0
- {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/braindecode/preprocessing/mne_preprocess.py +0 -0
- {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/braindecode/preprocessing/preprocess.py +0 -0
- {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/braindecode/preprocessing/util.py +0 -0
- {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/braindecode/preprocessing/windowers.py +0 -0
- {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/braindecode/regressor.py +0 -0
- {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/braindecode/samplers/__init__.py +0 -0
- {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/braindecode/samplers/base.py +0 -0
- {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/braindecode/samplers/ssl.py +0 -0
- {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/braindecode/training/__init__.py +0 -0
- {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/braindecode/training/callbacks.py +0 -0
- {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/braindecode/training/losses.py +0 -0
- {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/braindecode/training/scoring.py +0 -0
- {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/braindecode/util.py +0 -0
- {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/braindecode/visualization/__init__.py +0 -0
- {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/braindecode/visualization/confusion_matrices.py +0 -0
- {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/braindecode/visualization/gradients.py +0 -0
- {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/braindecode.egg-info/SOURCES.txt +0 -0
- {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/braindecode.egg-info/dependency_links.txt +0 -0
- {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/braindecode.egg-info/requires.txt +0 -0
- {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/braindecode.egg-info/top_level.txt +0 -0
- {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/docs/Makefile +0 -0
- {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/docs/_templates/autosummary/class.rst +0 -0
- {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/docs/_templates/autosummary/class_in_subdir.rst +0 -0
- {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/docs/_templates/autosummary/function.rst +0 -0
- {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/docs/_templates/autosummary/function_in_subdir.rst +0 -0
- {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/docs/api.rst +0 -0
- {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/docs/cite.rst +0 -0
- {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/docs/conf.py +0 -0
- {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/docs/help.rst +0 -0
- {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/docs/index.rst +0 -0
- {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/docs/install/install.rst +0 -0
- {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/docs/install/install_pip.rst +0 -0
- {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/docs/install/install_source.rst +0 -0
- {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/docs/models/categorization/attention.rst +0 -0
- {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/docs/models/categorization/channel.rst +0 -0
- {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/docs/models/categorization/convolution.rst +0 -0
- {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/docs/models/categorization/filterbank.rst +0 -0
- {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/docs/models/categorization/gnn.rst +0 -0
- {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/docs/models/categorization/interpretable.rst +0 -0
- {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/docs/models/categorization/lbm.rst +0 -0
- {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/docs/models/categorization/recurrent.rst +0 -0
- {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/docs/models/categorization/spd.rst +0 -0
- {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/docs/models/models.rst +0 -0
- {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/docs/models/models_categorization.rst +0 -0
- {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/docs/models/models_table.rst +0 -0
- {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/docs/models/models_visualization.rst +0 -0
- {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/docs/sg_execution_times.rst +0 -0
- {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/pyproject.toml +0 -0
- {braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/setup.cfg +0 -0
{braindecode-1.3.2.dev168310904/braindecode.egg-info → braindecode-1.3.2.dev168517820}/PKG-INFO
RENAMED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: braindecode
|
|
3
|
-
Version: 1.3.2.
|
|
3
|
+
Version: 1.3.2.dev168517820
|
|
4
4
|
Summary: Deep learning software to decode EEG, ECG or MEG signals
|
|
5
5
|
Author-email: Robin Tibor Schirrmeister <robintibor@gmail.com>, Bruno Aristimunha Pinto <b.aristimunha@gmail.com>, Alexandre Gramfort <agramfort@meta.com>
|
|
6
6
|
Maintainer-email: Alexandre Gramfort <agramfort@meta.com>, Bruno Aristimunha Pinto <b.aristimunha@gmail.com>, Robin Tibor Schirrmeister <robintibor@gmail.com>
|
|
@@ -2,6 +2,7 @@
|
|
|
2
2
|
# Alexandre Gramfort <alexandre.gramfort@inria.fr>
|
|
3
3
|
# Gustavo Rodrigues <gustavenrique01@gmail.com>
|
|
4
4
|
# Bruna Lopes <brunajaflopes@gmail.com>
|
|
5
|
+
# Sarthak Tayal <sarthaktayal2@gmail.com>
|
|
5
6
|
#
|
|
6
7
|
# License: BSD (3-clause)
|
|
7
8
|
|
|
@@ -557,8 +558,9 @@ class BandstopFilter(Transform):
|
|
|
557
558
|
f" Nyquist frequency ({nyq} Hz)."
|
|
558
559
|
f" Falling back to max_freq = {nyq}."
|
|
559
560
|
)
|
|
560
|
-
assert bandwidth < max_freq, (
|
|
561
|
-
f"`bandwidth` needs to be smaller than max_freq={max_freq}"
|
|
561
|
+
assert bandwidth < max_freq - 2, (
|
|
562
|
+
f"`bandwidth` needs to be smaller than max_freq - 2={max_freq - 2} "
|
|
563
|
+
f"to allow valid notch frequency sampling with 1 Hz transition bands."
|
|
562
564
|
)
|
|
563
565
|
|
|
564
566
|
# override bandwidth value when a magnitude is passed
|
|
@@ -600,8 +602,8 @@ class BandstopFilter(Transform):
|
|
|
600
602
|
|
|
601
603
|
# Prevents transitions from going below 0 and above max_freq
|
|
602
604
|
notched_freqs = self.rng.uniform(
|
|
603
|
-
low=1 +
|
|
604
|
-
high=self.max_freq - 1 -
|
|
605
|
+
low=1 + self.bandwidth / 2,
|
|
606
|
+
high=self.max_freq - 1 - self.bandwidth / 2,
|
|
605
607
|
size=X.shape[0],
|
|
606
608
|
)
|
|
607
609
|
return {
|
{braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/braindecode/models/base.py
RENAMED
|
@@ -59,7 +59,25 @@ def deprecated_args(obj, *old_new_args):
|
|
|
59
59
|
return out_args
|
|
60
60
|
|
|
61
61
|
|
|
62
|
-
class
|
|
62
|
+
class _BraindecodeDocstringMeta(NumpyDocstringInheritanceInitMeta):
|
|
63
|
+
"""Defer ``__init__`` wrapping until after docstring inheritance.
|
|
64
|
+
|
|
65
|
+
``NumpyDocstringInheritanceInitMeta`` uses ``inspect.unwrap()``
|
|
66
|
+
internally, which bypasses ``@wraps`` wrappers. By wrapping
|
|
67
|
+
``__init__`` *after* docstring processing, the metaclass sees the
|
|
68
|
+
unwrapped function and correctly inherits ``cls.__doc__``.
|
|
69
|
+
"""
|
|
70
|
+
|
|
71
|
+
def __init__(cls, class_name, class_bases, class_dict):
|
|
72
|
+
super().__init__(class_name, class_bases, class_dict)
|
|
73
|
+
# Only wrap subclass __init__s, not EEGModuleMixin itself.
|
|
74
|
+
# Wrapping the mixin would cause super().__init__() calls to
|
|
75
|
+
# overwrite _braindecode_init_kwargs captured by the subclass.
|
|
76
|
+
if any(isinstance(b, _BraindecodeDocstringMeta) for b in class_bases):
|
|
77
|
+
track_model_init_kwargs(cls)
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
class EEGModuleMixin(_BaseHubMixin, metaclass=_BraindecodeDocstringMeta):
|
|
63
81
|
"""
|
|
64
82
|
Mixin class for all EEG models in braindecode.
|
|
65
83
|
|
|
@@ -132,20 +150,46 @@ class EEGModuleMixin(_BaseHubMixin, metaclass=NumpyDocstringInheritanceInitMeta)
|
|
|
132
150
|
# Load pretrained model
|
|
133
151
|
model = {name}.from_pretrained("username/my-{name_lower}-model")
|
|
134
152
|
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
153
|
+
# Load with a different number of outputs (head is rebuilt automatically)
|
|
154
|
+
model = {name}.from_pretrained("username/my-{name_lower}-model", n_outputs=4)
|
|
155
|
+
|
|
156
|
+
**Extracting features and replacing the head:**
|
|
157
|
+
|
|
158
|
+
.. code-block::
|
|
159
|
+
|
|
160
|
+
import torch
|
|
161
|
+
|
|
162
|
+
x = torch.randn(1, model.n_chans, model.n_times)
|
|
163
|
+
# Extract encoder features (consistent dict across all models)
|
|
164
|
+
out = model(x, return_features=True)
|
|
165
|
+
features = out["features"]
|
|
166
|
+
|
|
167
|
+
# Replace the classification head
|
|
168
|
+
model.reset_head(n_outputs=10)
|
|
169
|
+
|
|
170
|
+
**Saving and restoring full configuration:**
|
|
171
|
+
|
|
172
|
+
.. code-block::
|
|
173
|
+
|
|
174
|
+
import json
|
|
175
|
+
|
|
176
|
+
config = model.get_config() # all __init__ params
|
|
177
|
+
with open("config.json", "w") as f:
|
|
178
|
+
json.dump(config, f)
|
|
179
|
+
|
|
180
|
+
model2 = {name}.from_config(config) # reconstruct (no weights)
|
|
139
181
|
|
|
140
182
|
All model parameters (both EEG-specific and model-specific such as
|
|
141
183
|
dropout rates, activation functions, number of filters) are automatically
|
|
142
184
|
saved to the Hub and restored when loading.
|
|
185
|
+
|
|
186
|
+
See :ref:`load-pretrained-models` for a complete tutorial.
|
|
143
187
|
"""
|
|
144
188
|
|
|
145
189
|
def __init_subclass__(cls, **kwargs):
|
|
146
190
|
# Append model-specific Hub integration notes to the docstring.
|
|
147
|
-
# This runs
|
|
148
|
-
#
|
|
191
|
+
# This runs before the metaclass __init__, so the Hub notes will
|
|
192
|
+
# be included in the docstring that the metaclass processes.
|
|
149
193
|
if cls.__doc__ is not None:
|
|
150
194
|
hub_notes = cls._HUB_NOTES_TEMPLATE.format(
|
|
151
195
|
name=cls.__name__,
|
|
@@ -155,7 +199,6 @@ class EEGModuleMixin(_BaseHubMixin, metaclass=NumpyDocstringInheritanceInitMeta)
|
|
|
155
199
|
|
|
156
200
|
if not HAS_HF_HUB:
|
|
157
201
|
super().__init_subclass__(**kwargs)
|
|
158
|
-
track_model_init_kwargs(cls)
|
|
159
202
|
return
|
|
160
203
|
|
|
161
204
|
base_tags = ["braindecode", cls.__name__]
|
|
@@ -195,7 +238,6 @@ class EEGModuleMixin(_BaseHubMixin, metaclass=NumpyDocstringInheritanceInitMeta)
|
|
|
195
238
|
coders=coders,
|
|
196
239
|
**kwargs,
|
|
197
240
|
)
|
|
198
|
-
track_model_init_kwargs(cls)
|
|
199
241
|
|
|
200
242
|
def __init__(
|
|
201
243
|
self,
|
|
@@ -426,6 +468,34 @@ class EEGModuleMixin(_BaseHubMixin, metaclass=NumpyDocstringInheritanceInitMeta)
|
|
|
426
468
|
resolve_type_kwargs(cls, config)
|
|
427
469
|
return cls(**config)
|
|
428
470
|
|
|
471
|
+
def reset_head(self, n_outputs):
|
|
472
|
+
"""Replace the classification head for a new number of outputs.
|
|
473
|
+
|
|
474
|
+
This is called automatically by :meth:`from_pretrained` when the
|
|
475
|
+
user passes an ``n_outputs`` that differs from the saved config.
|
|
476
|
+
Override in subclasses that need a model-specific head structure.
|
|
477
|
+
|
|
478
|
+
Parameters
|
|
479
|
+
----------
|
|
480
|
+
n_outputs : int
|
|
481
|
+
New number of output classes.
|
|
482
|
+
|
|
483
|
+
Examples
|
|
484
|
+
--------
|
|
485
|
+
>>> from braindecode.models import BENDR
|
|
486
|
+
>>> model = BENDR(n_chans=22, n_times=1000, n_outputs=4)
|
|
487
|
+
>>> model.reset_head(10)
|
|
488
|
+
>>> model.n_outputs
|
|
489
|
+
10
|
|
490
|
+
|
|
491
|
+
.. versionadded:: 1.4
|
|
492
|
+
"""
|
|
493
|
+
raise NotImplementedError(
|
|
494
|
+
f"{type(self).__name__} does not implement reset_head(). "
|
|
495
|
+
"Override this method to support changing n_outputs after "
|
|
496
|
+
"loading pretrained weights."
|
|
497
|
+
)
|
|
498
|
+
|
|
429
499
|
mapping: Optional[Dict[str, str]] = None
|
|
430
500
|
|
|
431
501
|
def load_state_dict(self, state_dict, *args, **kwargs):
|
|
@@ -625,15 +695,78 @@ class EEGModuleMixin(_BaseHubMixin, metaclass=NumpyDocstringInheritanceInitMeta)
|
|
|
625
695
|
**model_kwargs,
|
|
626
696
|
):
|
|
627
697
|
model_kwargs.pop("braindecode_version", None)
|
|
698
|
+
filename = model_kwargs.pop("filename", None)
|
|
628
699
|
resolve_type_kwargs(cls, model_kwargs)
|
|
629
|
-
|
|
630
|
-
|
|
631
|
-
|
|
632
|
-
|
|
633
|
-
|
|
634
|
-
|
|
635
|
-
|
|
636
|
-
|
|
637
|
-
|
|
638
|
-
|
|
639
|
-
|
|
700
|
+
|
|
701
|
+
# Read saved n_outputs from config.json to detect when the
|
|
702
|
+
# user wants a different number of outputs. Works for both
|
|
703
|
+
# local directories and Hub repo IDs.
|
|
704
|
+
saved_n_outputs = None
|
|
705
|
+
try:
|
|
706
|
+
if Path(model_id).is_dir():
|
|
707
|
+
config_file = Path(model_id) / "config.json"
|
|
708
|
+
else:
|
|
709
|
+
config_file = huggingface_hub.hf_hub_download(
|
|
710
|
+
repo_id=model_id,
|
|
711
|
+
filename="config.json",
|
|
712
|
+
revision=revision,
|
|
713
|
+
cache_dir=cache_dir,
|
|
714
|
+
force_download=force_download,
|
|
715
|
+
token=token,
|
|
716
|
+
local_files_only=local_files_only,
|
|
717
|
+
)
|
|
718
|
+
with open(config_file, "r") as f:
|
|
719
|
+
saved_n_outputs = json.load(f).get("n_outputs")
|
|
720
|
+
except (OSError, json.JSONDecodeError, KeyError):
|
|
721
|
+
pass # config unavailable; skip reset_head logic
|
|
722
|
+
|
|
723
|
+
requested_n_outputs = model_kwargs.get("n_outputs")
|
|
724
|
+
|
|
725
|
+
# If the user requests different n_outputs, load with the
|
|
726
|
+
# saved value first (so weights match), then swap the head.
|
|
727
|
+
if (
|
|
728
|
+
saved_n_outputs is not None
|
|
729
|
+
and requested_n_outputs is not None
|
|
730
|
+
and requested_n_outputs != saved_n_outputs
|
|
731
|
+
):
|
|
732
|
+
model_kwargs["n_outputs"] = saved_n_outputs
|
|
733
|
+
|
|
734
|
+
# If a custom filename is provided, temporarily override the
|
|
735
|
+
# HuggingFace constant so the parent class downloads the
|
|
736
|
+
# correct file (e.g. "LUNA_base.safetensors" instead of
|
|
737
|
+
# "model.safetensors").
|
|
738
|
+
hf_constants = huggingface_hub.constants
|
|
739
|
+
_orig_safetensors = hf_constants.SAFETENSORS_SINGLE_FILE
|
|
740
|
+
if filename is not None:
|
|
741
|
+
hf_constants.SAFETENSORS_SINGLE_FILE = filename
|
|
742
|
+
try:
|
|
743
|
+
model = super()._from_pretrained( # type: ignore
|
|
744
|
+
model_id=model_id,
|
|
745
|
+
revision=revision,
|
|
746
|
+
cache_dir=cache_dir,
|
|
747
|
+
force_download=force_download,
|
|
748
|
+
local_files_only=local_files_only,
|
|
749
|
+
token=token,
|
|
750
|
+
map_location=map_location,
|
|
751
|
+
strict=strict,
|
|
752
|
+
**model_kwargs,
|
|
753
|
+
)
|
|
754
|
+
finally:
|
|
755
|
+
hf_constants.SAFETENSORS_SINGLE_FILE = _orig_safetensors
|
|
756
|
+
|
|
757
|
+
if (
|
|
758
|
+
saved_n_outputs is not None
|
|
759
|
+
and requested_n_outputs is not None
|
|
760
|
+
and requested_n_outputs != saved_n_outputs
|
|
761
|
+
):
|
|
762
|
+
try:
|
|
763
|
+
model.reset_head(requested_n_outputs)
|
|
764
|
+
except NotImplementedError:
|
|
765
|
+
raise ValueError(
|
|
766
|
+
f"{type(model).__name__} does not support changing "
|
|
767
|
+
f"n_outputs after loading. Saved model has "
|
|
768
|
+
f"n_outputs={saved_n_outputs}, but "
|
|
769
|
+
f"n_outputs={requested_n_outputs} was requested."
|
|
770
|
+
) from None
|
|
771
|
+
|
|
772
|
+
return model
|
{braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/braindecode/models/bendr.py
RENAMED
|
@@ -291,15 +291,23 @@ class BENDR(EEGModuleMixin, nn.Module):
|
|
|
291
291
|
else:
|
|
292
292
|
in_features = encoder_h
|
|
293
293
|
|
|
294
|
+
self._head_in_features = in_features
|
|
294
295
|
self.final_layer = None
|
|
295
296
|
if self.include_final_layer:
|
|
296
|
-
|
|
297
|
-
linear = nn.Linear(in_features=in_features, out_features=self.n_outputs)
|
|
298
|
-
self.final_layer = nn.utils.parametrizations.weight_norm(
|
|
299
|
-
linear, name="weight", dim=1
|
|
300
|
-
)
|
|
297
|
+
self._build_head(self.n_outputs)
|
|
301
298
|
|
|
302
|
-
def
|
|
299
|
+
def _build_head(self, n_outputs):
|
|
300
|
+
linear = nn.Linear(in_features=self._head_in_features, out_features=n_outputs)
|
|
301
|
+
self.final_layer = nn.utils.parametrizations.weight_norm(
|
|
302
|
+
linear, name="weight", dim=1
|
|
303
|
+
)
|
|
304
|
+
|
|
305
|
+
def reset_head(self, n_outputs):
|
|
306
|
+
self._n_outputs = n_outputs
|
|
307
|
+
self.include_final_layer = True
|
|
308
|
+
self._build_head(n_outputs)
|
|
309
|
+
|
|
310
|
+
def forward(self, x, return_features=False):
|
|
303
311
|
if self.channel_projection is not None:
|
|
304
312
|
x = self.channel_projection(x)
|
|
305
313
|
encoded = self.encoder(x)
|
|
@@ -328,6 +336,9 @@ class BENDR(EEGModuleMixin, nn.Module):
|
|
|
328
336
|
feature = context[:, :, 0]
|
|
329
337
|
# feature: [batch_size, encoder_h]
|
|
330
338
|
|
|
339
|
+
if return_features:
|
|
340
|
+
return {"features": feature, "cls_token": None}
|
|
341
|
+
|
|
331
342
|
if self.final_layer is not None:
|
|
332
343
|
feature = self.final_layer(feature)
|
|
333
344
|
# feature: [batch_size, n_outputs]
|
{braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/braindecode/models/biot.py
RENAMED
|
@@ -183,13 +183,22 @@ class BIOT(EEGModuleMixin, nn.Module):
|
|
|
183
183
|
attn_layer_dropout=att_layer_drop_prob,
|
|
184
184
|
)
|
|
185
185
|
|
|
186
|
+
self._head_activation = activation
|
|
186
187
|
self.final_layer = _ClassificationHead(
|
|
187
188
|
emb_size=self.embed_dim,
|
|
188
189
|
n_outputs=self.n_outputs,
|
|
189
190
|
activation=activation,
|
|
190
191
|
)
|
|
191
192
|
|
|
192
|
-
def
|
|
193
|
+
def reset_head(self, n_outputs):
|
|
194
|
+
self._n_outputs = n_outputs
|
|
195
|
+
self.final_layer = _ClassificationHead(
|
|
196
|
+
emb_size=self.embed_dim,
|
|
197
|
+
n_outputs=n_outputs,
|
|
198
|
+
activation=self._head_activation,
|
|
199
|
+
)
|
|
200
|
+
|
|
201
|
+
def forward(self, x, return_features=False):
|
|
193
202
|
"""
|
|
194
203
|
Pass the input through the BIOT encoder, and then through the
|
|
195
204
|
classification head.
|
|
@@ -198,15 +207,24 @@ class BIOT(EEGModuleMixin, nn.Module):
|
|
|
198
207
|
----------
|
|
199
208
|
x: Tensor
|
|
200
209
|
(batch_size, n_channels, n_times)
|
|
210
|
+
return_features : bool
|
|
211
|
+
If True, return a dict with ``"features"`` and ``"cls_token"``
|
|
212
|
+
instead of the classification output.
|
|
201
213
|
|
|
202
214
|
Returns
|
|
203
215
|
-------
|
|
204
|
-
|
|
205
|
-
(batch_size, n_outputs)
|
|
206
|
-
|
|
207
|
-
(batch_size,
|
|
216
|
+
torch.Tensor or tuple or dict
|
|
217
|
+
Default: ``torch.Tensor`` of shape ``(batch_size, n_outputs)``.
|
|
218
|
+
If ``return_features=True``: ``dict`` with ``"features"``
|
|
219
|
+
``(batch_size, emb_size)`` and ``"cls_token"`` (``None``).
|
|
220
|
+
If legacy ``return_feature=True`` (init param):
|
|
221
|
+
``(out, emb)`` tuple (ignored when ``return_features=True``).
|
|
208
222
|
"""
|
|
209
223
|
emb = self.encoder(x)
|
|
224
|
+
|
|
225
|
+
if return_features:
|
|
226
|
+
return {"features": emb, "cls_token": None}
|
|
227
|
+
|
|
210
228
|
x = self.final_layer(emb)
|
|
211
229
|
|
|
212
230
|
if self.return_feature:
|
{braindecode-1.3.2.dev168310904 → braindecode-1.3.2.dev168517820}/braindecode/models/cbramod.py
RENAMED
|
@@ -205,20 +205,15 @@ class CBraMod(EEGModuleMixin, nn.Module):
|
|
|
205
205
|
self.encoder = _TransformerEncoder(encoder_layer, num_layers=n_layer)
|
|
206
206
|
self.proj_out = nn.Sequential(nn.Linear(d_model, emb_dim))
|
|
207
207
|
|
|
208
|
+
self._emb_dim = emb_dim
|
|
209
|
+
self._patch_size = patch_size
|
|
208
210
|
self._weights_init()
|
|
209
211
|
|
|
210
|
-
try:
|
|
211
|
-
n_times = self.n_times
|
|
212
|
-
n_chans = self.n_chans
|
|
213
|
-
except ValueError:
|
|
214
|
-
n_times = None
|
|
215
|
-
n_chans = None
|
|
216
|
-
|
|
217
212
|
if return_encoder_output:
|
|
218
213
|
self.final_layer = nn.Identity()
|
|
219
|
-
elif
|
|
220
|
-
n_patch =
|
|
221
|
-
flat_dim =
|
|
214
|
+
elif self._n_times is not None and self._n_chans is not None:
|
|
215
|
+
n_patch = self._n_times // patch_size
|
|
216
|
+
flat_dim = self._n_chans * n_patch * emb_dim
|
|
222
217
|
self.final_layer = nn.Sequential(
|
|
223
218
|
nn.Flatten(), nn.Linear(flat_dim, self.n_outputs)
|
|
224
219
|
)
|
|
@@ -227,6 +222,17 @@ class CBraMod(EEGModuleMixin, nn.Module):
|
|
|
227
222
|
nn.Flatten(), nn.LazyLinear(self.n_outputs)
|
|
228
223
|
)
|
|
229
224
|
|
|
225
|
+
def reset_head(self, n_outputs):
|
|
226
|
+
self._n_outputs = n_outputs
|
|
227
|
+
if self._n_times is not None and self._n_chans is not None:
|
|
228
|
+
n_patch = self._n_times // self._patch_size
|
|
229
|
+
flat_dim = self._n_chans * n_patch * self._emb_dim
|
|
230
|
+
self.final_layer = nn.Sequential(
|
|
231
|
+
nn.Flatten(), nn.Linear(flat_dim, n_outputs)
|
|
232
|
+
)
|
|
233
|
+
else:
|
|
234
|
+
self.final_layer = nn.Sequential(nn.Flatten(), nn.LazyLinear(n_outputs))
|
|
235
|
+
|
|
230
236
|
def _weights_init(self):
|
|
231
237
|
for m in self.modules():
|
|
232
238
|
if isinstance(m, nn.Linear):
|
|
@@ -237,11 +243,13 @@ class CBraMod(EEGModuleMixin, nn.Module):
|
|
|
237
243
|
nn.init.constant_(m.weight, 1)
|
|
238
244
|
nn.init.constant_(m.bias, 0)
|
|
239
245
|
|
|
240
|
-
def forward(self, x, mask=None):
|
|
246
|
+
def forward(self, x, mask=None, return_features=False):
|
|
241
247
|
x = self.rearrange(x)
|
|
242
248
|
patch_emb = self.patch_embedding(x, mask)
|
|
243
249
|
feats = self.encoder(patch_emb)
|
|
244
250
|
out = self.proj_out(feats)
|
|
251
|
+
if return_features:
|
|
252
|
+
return {"features": out, "cls_token": None}
|
|
245
253
|
return self.final_layer(out)
|
|
246
254
|
|
|
247
255
|
|