braindecode 1.3.0.dev177069446__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- braindecode/__init__.py +9 -0
- braindecode/augmentation/__init__.py +52 -0
- braindecode/augmentation/base.py +225 -0
- braindecode/augmentation/functional.py +1300 -0
- braindecode/augmentation/transforms.py +1356 -0
- braindecode/classifier.py +258 -0
- braindecode/datasets/__init__.py +44 -0
- braindecode/datasets/base.py +823 -0
- braindecode/datasets/bbci.py +693 -0
- braindecode/datasets/bcicomp.py +193 -0
- braindecode/datasets/bids/__init__.py +54 -0
- braindecode/datasets/bids/datasets.py +239 -0
- braindecode/datasets/bids/format.py +717 -0
- braindecode/datasets/bids/hub.py +987 -0
- braindecode/datasets/bids/hub_format.py +717 -0
- braindecode/datasets/bids/hub_io.py +197 -0
- braindecode/datasets/bids/hub_validation.py +114 -0
- braindecode/datasets/bids/iterable.py +220 -0
- braindecode/datasets/chb_mit.py +163 -0
- braindecode/datasets/mne.py +170 -0
- braindecode/datasets/moabb.py +219 -0
- braindecode/datasets/nmt.py +313 -0
- braindecode/datasets/registry.py +120 -0
- braindecode/datasets/siena.py +162 -0
- braindecode/datasets/sleep_physio_challe_18.py +411 -0
- braindecode/datasets/sleep_physionet.py +125 -0
- braindecode/datasets/tuh.py +591 -0
- braindecode/datasets/utils.py +67 -0
- braindecode/datasets/xy.py +96 -0
- braindecode/datautil/__init__.py +62 -0
- braindecode/datautil/channel_utils.py +114 -0
- braindecode/datautil/hub_formats.py +180 -0
- braindecode/datautil/serialization.py +359 -0
- braindecode/datautil/util.py +154 -0
- braindecode/eegneuralnet.py +372 -0
- braindecode/functional/__init__.py +22 -0
- braindecode/functional/functions.py +251 -0
- braindecode/functional/initialization.py +47 -0
- braindecode/models/__init__.py +117 -0
- braindecode/models/atcnet.py +830 -0
- braindecode/models/attentionbasenet.py +727 -0
- braindecode/models/attn_sleep.py +549 -0
- braindecode/models/base.py +574 -0
- braindecode/models/bendr.py +493 -0
- braindecode/models/biot.py +537 -0
- braindecode/models/brainmodule.py +845 -0
- braindecode/models/config.py +233 -0
- braindecode/models/contrawr.py +319 -0
- braindecode/models/ctnet.py +541 -0
- braindecode/models/deep4.py +376 -0
- braindecode/models/deepsleepnet.py +417 -0
- braindecode/models/eegconformer.py +475 -0
- braindecode/models/eeginception_erp.py +379 -0
- braindecode/models/eeginception_mi.py +379 -0
- braindecode/models/eegitnet.py +302 -0
- braindecode/models/eegminer.py +256 -0
- braindecode/models/eegnet.py +359 -0
- braindecode/models/eegnex.py +354 -0
- braindecode/models/eegsimpleconv.py +201 -0
- braindecode/models/eegsym.py +917 -0
- braindecode/models/eegtcnet.py +337 -0
- braindecode/models/fbcnet.py +225 -0
- braindecode/models/fblightconvnet.py +315 -0
- braindecode/models/fbmsnet.py +338 -0
- braindecode/models/hybrid.py +126 -0
- braindecode/models/ifnet.py +443 -0
- braindecode/models/labram.py +1316 -0
- braindecode/models/luna.py +891 -0
- braindecode/models/medformer.py +760 -0
- braindecode/models/msvtnet.py +377 -0
- braindecode/models/patchedtransformer.py +640 -0
- braindecode/models/reve.py +843 -0
- braindecode/models/sccnet.py +280 -0
- braindecode/models/shallow_fbcsp.py +212 -0
- braindecode/models/signal_jepa.py +1122 -0
- braindecode/models/sinc_shallow.py +339 -0
- braindecode/models/sleep_stager_blanco_2020.py +169 -0
- braindecode/models/sleep_stager_chambon_2018.py +159 -0
- braindecode/models/sparcnet.py +426 -0
- braindecode/models/sstdpn.py +869 -0
- braindecode/models/summary.csv +47 -0
- braindecode/models/syncnet.py +234 -0
- braindecode/models/tcn.py +275 -0
- braindecode/models/tidnet.py +397 -0
- braindecode/models/tsinception.py +295 -0
- braindecode/models/usleep.py +439 -0
- braindecode/models/util.py +369 -0
- braindecode/modules/__init__.py +92 -0
- braindecode/modules/activation.py +86 -0
- braindecode/modules/attention.py +883 -0
- braindecode/modules/blocks.py +160 -0
- braindecode/modules/convolution.py +330 -0
- braindecode/modules/filter.py +654 -0
- braindecode/modules/layers.py +216 -0
- braindecode/modules/linear.py +70 -0
- braindecode/modules/parametrization.py +38 -0
- braindecode/modules/stats.py +87 -0
- braindecode/modules/util.py +85 -0
- braindecode/modules/wrapper.py +90 -0
- braindecode/preprocessing/__init__.py +271 -0
- braindecode/preprocessing/eegprep_preprocess.py +1317 -0
- braindecode/preprocessing/mne_preprocess.py +240 -0
- braindecode/preprocessing/preprocess.py +579 -0
- braindecode/preprocessing/util.py +177 -0
- braindecode/preprocessing/windowers.py +1037 -0
- braindecode/regressor.py +234 -0
- braindecode/samplers/__init__.py +18 -0
- braindecode/samplers/base.py +399 -0
- braindecode/samplers/ssl.py +263 -0
- braindecode/training/__init__.py +23 -0
- braindecode/training/callbacks.py +23 -0
- braindecode/training/losses.py +105 -0
- braindecode/training/scoring.py +477 -0
- braindecode/util.py +419 -0
- braindecode/version.py +1 -0
- braindecode/visualization/__init__.py +8 -0
- braindecode/visualization/confusion_matrices.py +289 -0
- braindecode/visualization/gradients.py +62 -0
- braindecode-1.3.0.dev177069446.dist-info/METADATA +230 -0
- braindecode-1.3.0.dev177069446.dist-info/RECORD +124 -0
- braindecode-1.3.0.dev177069446.dist-info/WHEEL +5 -0
- braindecode-1.3.0.dev177069446.dist-info/licenses/LICENSE.txt +31 -0
- braindecode-1.3.0.dev177069446.dist-info/licenses/NOTICE.txt +20 -0
- braindecode-1.3.0.dev177069446.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,47 @@
|
|
|
1
|
+
Model,Application,Type,Sampling Frequency (Hz),Hyperparameters,#Parameters,get_#Parameters,Categorization
|
|
2
|
+
ATCNet,General,Classification,250,"n_chans, n_outputs, n_times",113732,"ATCNet(n_chans=22, n_outputs=4, n_times=1000)","Convolution,Recurrent,Attention/Transformer"
|
|
3
|
+
AttentionBaseNet,Motor Imagery,Classification,250,"n_chans, n_outputs, n_times",3692,"AttentionBaseNet(n_chans=22, n_outputs=4, n_times=1000)","Convolution,Attention/Transformer"
|
|
4
|
+
BDTCN,Normal Abnormal,Classification,100,"n_chans, n_outputs, n_times",456502,"BDTCN(n_chans=21, n_outputs=2, n_times=6000, n_blocks=5, n_filters=55, kernel_size=16)","Convolution,Recurrent"
|
|
5
|
+
BIOT,"Sleep Staging, Epilepsy",Classification,200,"n_chans, n_outputs",3183879,"BIOT(n_chans=2, n_outputs=5, n_times=6000)","Foundation Model"
|
|
6
|
+
ContraWR,Sleep Staging,"Classification, Embedding",125,"n_chans, n_outputs, sfreq",1160165,"ContraWR(n_chans=2, n_outputs=5, n_times=3750, emb_size=256, sfreq=125)",Convolution
|
|
7
|
+
CTNet,Motor Imagery,Classification,250,"n_chans, n_outputs, n_times",26900,"CTNet(n_chans=22, n_outputs=4, n_times=1000, n_filters_time=8, kernel_size=16, num_heads=2, embed_dim=16)","Convolution,Attention/Transformer"
|
|
8
|
+
Deep4Net,General,Classification,250,"n_chans, n_outputs, n_times",282879,"Deep4Net(n_chans=22, n_outputs=4, n_times=1000)","Convolution"
|
|
9
|
+
DeepSleepNet,Sleep Staging,Classification,256,"n_chans, n_outputs",24744837,"DeepSleepNet(n_chans=1, n_outputs=5, n_times=7680, sfreq=256)","Convolution,Recurrent"
|
|
10
|
+
EEGConformer,General,Classification,250,"n_chans, n_outputs, n_times",789572,"EEGConformer(n_chans=22, n_outputs=4, n_times=1000)","Convolution,Attention/Transformer"
|
|
11
|
+
EEGInceptionERP,"ERP, SSVEP",Classification,128,"n_chans, n_outputs",14926,"EEGInceptionERP(n_chans=8, n_outputs=2, n_times=128, sfreq=128)","Convolution"
|
|
12
|
+
EEGInceptionMI,Motor Imagery,Classification,250,"n_chans, n_outputs, n_times",558028,"EEGInceptionMI(n_chans=22, n_outputs=4, n_times=1000, n_convs=5, n_filters=12)","Convolution"
|
|
13
|
+
EEGITNet,Motor Imagery,Classification,125,"n_chans, n_outputs, n_times",5212,"EEGITNet(n_chans=22, n_outputs=4, n_times=500)","Convolution,Recurrent"
|
|
14
|
+
EEGNet,General,Classification,128,"n_chans, n_outputs, n_times",2484,"EEGNet(n_chans=22, n_outputs=4, n_times=512)","Convolution"
|
|
15
|
+
EEGNeX,Motor Imagery,Classification,125,"n_chans, n_outputs, n_times",55940,"EEGNeX(n_chans=22, n_outputs=4, n_times=500)","Convolution"
|
|
16
|
+
EEGSym,Motor Imagery,Classification,250,"n_chans, n_outputs, n_times, sfreq",299218,"EEGSym(n_chans=22, n_outputs=4, n_times=1000, sfreq=250)","Convolution,Channel"
|
|
17
|
+
EEGMiner,Emotion Recognition,Classification,128,"n_chans, n_outputs, n_times, sfreq",7572,"EEGMiner(n_chans=62, n_outputs=2, n_times=2560, sfreq=128)","Convolution,Interpretability"
|
|
18
|
+
EEGSimpleConv,Motor Imagery,Classification,80,"n_chans, n_outputs, sfreq",730404,"EEGSimpleConv(n_chans=22, n_outputs=4, n_times=320, sfreq=80)","Convolution"
|
|
19
|
+
EEGTCNet,Motor Imagery,Classification,250,"n_chans, n_outputs",4516,"EEGTCNet(n_chans=22, n_outputs=4, n_times=1000, kern_length=32)","Convolution,Recurrent"
|
|
20
|
+
Labram,General,"Classification, Embedding",200,"n_chans, n_outputs, n_times",5866180,"Labram(n_chans=22, n_outputs=4, n_times=1000, sfreq=250)","Convolution,Foundation Model"
|
|
21
|
+
MSVTNet,Motor Imagery,Classification,250,"n_chans, n_outputs, n_times",75494," MSVTNet(n_chans=22, n_outputs=4, n_times=1000)","Convolution,Recurrent,Attention/Transformer"
|
|
22
|
+
SCCNet,Motor Imagery,Classification,125,"n_chans, n_outputs, n_times, sfreq",12070,"SCCNet(n_chans=22, n_outputs=4, n_times=1000, sfreq=125)","Convolution"
|
|
23
|
+
SignalJEPA,"Motor Imagery, ERP, SSVEP",Embedding,128,"n_times, chs_info",3456882,"SignalJEPA(n_times=512, chs_info=Lee2019_MI().get_data(subjects=[1])[1]['0']['1train'].info[""chs""][:62])","Convolution,Channel,Foundation Model"
|
|
24
|
+
SignalJEPA_Contextual,"Motor Imagery, ERP, SSVEP",Classification,128,"n_outputs, n_times, chs_info",3459184,"SignalJEPA_Contextual(n_outputs=2, input_window_seconds=4.19, sfreq=128, chs_info=Lee2019_MI().get_data(subjects=[1])[1]['0']['1train'].info[""chs""][:62])","Convolution,Channel,Foundation Model"
|
|
25
|
+
SignalJEPA_PostLocal,"Motor Imagery, ERP, SSVEP",Classification,128,"n_chans, n_outputs, n_times",16142,"SignalJEPA_PostLocal(n_chans=62, n_outputs=2, input_window_seconds=4.19, sfreq=128)","Convolution,Channel,Foundation Model"
|
|
26
|
+
SignalJEPA_PreLocal,"Motor Imagery, ERP, SSVEP",Classification,128,"n_outputs, n_times, chs_info",16142,"SignalJEPA_PreLocal(n_chans=62, n_outputs=2, input_window_seconds=4.19, sfreq=128)","Convolution,Channel,Foundation Model"
|
|
27
|
+
SincShallowNet,Motor Imagery,Classification,250,"n_chans, n_outputs, n_times, sfreq",21892,"SincShallowNet(n_chans=22, n_outputs=4, n_times=1000, sfreq=250)","Convolution,Interpretability"
|
|
28
|
+
ShallowFBCSPNet,General,Classification,250,"n_chans, n_outputs, n_times",46084,"ShallowFBCSPNet(n_chans=22, n_outputs=4, n_times=1000, sfreq=250)","Convolution"
|
|
29
|
+
SleepStagerBlanco2020,Sleep Staging,Classification,100,"n_chans, n_outputs, n_times",2845,"SleepStagerBlanco2020(n_chans=2, n_outputs=5, n_times=3000, sfreq=100)","Convolution"
|
|
30
|
+
SleepStagerChambon2018,Sleep Staging,Classification,128,"n_chans, n_outputs, n_times, sfreq",5835,"SleepStagerChambon2018(n_chans=2, n_outputs=5, n_times=3840, sfreq=128)","Convolution"
|
|
31
|
+
AttnSleep,Sleep Staging,Classification,100,"n_chans, n_outputs, n_times, sfreq",719925,"AttnSleep(n_chans=2, n_outputs=5, n_times=3000, sfreq=100)","Convolution, Attention/Transformer"
|
|
32
|
+
SPARCNet,Epilepsy,Classification,200,"n_chans, n_outputs, n_times",1141921,"SPARCNet(n_chans=16, n_outputs=6, n_times=2000, sfreq=200)","Convolution"
|
|
33
|
+
SyncNet,"Emotion Recognition, Alcoholism",Classification,256,"n_chans, n_outputs, n_times",554,"SyncNet(n_chans=62, n_outputs=3, n_times=5120, sfreq=256)","Interpretability"
|
|
34
|
+
TSception,Emotion Recognition,Classification,256,"n_chans, n_outputs, n_times, sfreq",2187206,"TSception(n_chans=62, n_outputs=3, n_times=5120, sfreq=256)","Convolution"
|
|
35
|
+
TIDNet,General,Classification,250,"n_chans, n_outputs, n_times",240404,"TIDNet(n_chans=22, n_outputs=4, n_times=1000)","Convolution"
|
|
36
|
+
USleep,Sleep Staging,Classification,128,"n_chans, n_outputs, n_times, sfreq",2482011,"USleep(n_chans=2, n_outputs=5, n_times=3000, sfreq=100)","Convolution"
|
|
37
|
+
FBCNet,Motor Imagery,Classification,250,"n_chans, n_outputs, n_times, sfreq",11812,"FCNet(n_chans=22, n_outputs=4, n_times=1000, sfreq=250)","Convolution,FilterBank"
|
|
38
|
+
FBMSNet,Motor Imagery,Classification,250,"n_chans, n_outputs, n_times, sfreq",16231,"FBMSNet(n_chans=22, n_outputs=4, n_times=1000, sfreq=250)","Convolution,FilterBank"
|
|
39
|
+
FBLightConvNet,Motor Imagery,Classification,250,"n_chans, n_outputs, n_times, sfreq",6596,"FBLightConvNet(n_chans=22, n_outputs=4, n_times=1000, sfreq=250)","Convolution,FilterBank"
|
|
40
|
+
IFNet,Motor Imagery,Classification,250,"n_chans, n_outputs, n_times, sfreq",9860,"IFNet(n_chans=22, n_outputs=4, n_times=1000, sfreq=250)","Convolution,FilterBank"
|
|
41
|
+
BrainModule,Speech Decoding,Classification,250,"n_chans, n_outputs, n_times, sfreq",6186909,"BrainModule(n_chans=64, n_outputs=29, n_times=160, sfreq=1000)","Convolution"
|
|
42
|
+
PBT,General,Classification,250,"n_chans, n_outputs, n_times",818948,"PBT(n_chans=22, n_outputs=4, n_times=1000, sfreq=250)","Foundation Model"
|
|
43
|
+
SSTDPN,Motor Imagery,Classification,250,"n_chans, n_outputs, n_times",19502,"SSTDPN(n_chans=22, n_outputs=4, n_times=1000)","Convolution,Attention/Transformer"
|
|
44
|
+
BENDR,General,"Classification,Embedding",250,"n_chans, n_times, n_outputs",157141049,"BENDR(n_chans=22, n_outputs=4, n_times=1000)","Foundation Model,Convolution"
|
|
45
|
+
LUNA,General,"Classification,Embedding",128,"n_chans, n_times, sfreq, chs_info",7100731,"LUNA(n_chans=22, n_times=512, sfreq=128)","Convolution,Channel,Foundation Model"
|
|
46
|
+
MEDFormer,General,Classification,250,"n_chans, n_outputs, n_times",5313924,"MEDFormer(n_chans=22, n_outputs=4, n_times=1000)","Foundation Model,Convolution"
|
|
47
|
+
REVE,General,Classification,200,"n_outputs, n_times, n_chans",69481476,"REVE(n_times=1000, n_outputs=4, n_chans=19)","Foundation Model,Attention/Transformer"
|
|
@@ -0,0 +1,234 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.nn as nn
|
|
3
|
+
import torch.nn.functional as F
|
|
4
|
+
from einops.layers.torch import Rearrange
|
|
5
|
+
from numpy import arange, ceil
|
|
6
|
+
|
|
7
|
+
from braindecode.models.base import EEGModuleMixin
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class SyncNet(EEGModuleMixin, nn.Module):
|
|
11
|
+
r"""Synchronization Network (SyncNet) from Li, Y et al (2017) [Li2017]_.
|
|
12
|
+
|
|
13
|
+
:bdg-warning:`Interpretability`
|
|
14
|
+
|
|
15
|
+
.. figure:: https://braindecode.org/dev/_static/model/SyncNet.png
|
|
16
|
+
:align: center
|
|
17
|
+
:alt: SyncNet Architecture
|
|
18
|
+
|
|
19
|
+
SyncNet uses parameterized 1-dimensional convolutional filters inspired by
|
|
20
|
+
the Morlet wavelet to extract features from EEG signals. The filters are
|
|
21
|
+
dynamically generated based on learnable parameters that control the
|
|
22
|
+
oscillation and decay characteristics.
|
|
23
|
+
|
|
24
|
+
The filter for channel ``c`` and filter ``k`` is defined as:
|
|
25
|
+
|
|
26
|
+
.. math::
|
|
27
|
+
|
|
28
|
+
f_c^{(k)}(\\tau) = amplitude_c^{(k)} \\cos(\\omega^{(k)} \\tau + \\phi_c^{(k)}) \\exp(-\\beta^{(k)} \\tau^2)
|
|
29
|
+
|
|
30
|
+
where:
|
|
31
|
+
- :math:`amplitude_c^{(k)}` is the amplitude parameter (channel-specific).
|
|
32
|
+
- :math:`\\omega^{(k)}` is the frequency parameter (shared across channels).
|
|
33
|
+
- :math:`\\phi_c^{(k)}` is the phase shift (channel-specific).
|
|
34
|
+
- :math:`\\beta^{(k)}` is the decay parameter (shared across channels).
|
|
35
|
+
- :math:`\\tau` is the time index.
|
|
36
|
+
|
|
37
|
+
Parameters
|
|
38
|
+
----------
|
|
39
|
+
num_filters : int, optional
|
|
40
|
+
Number of filters in the convolutional layer. Default is 1.
|
|
41
|
+
filter_width : int, optional
|
|
42
|
+
Width of the convolutional filters. Default is 40.
|
|
43
|
+
pool_size : int, optional
|
|
44
|
+
Size of the pooling window. Default is 40.
|
|
45
|
+
activation : nn.Module, optional
|
|
46
|
+
Activation function to apply after pooling. Default is ``nn.ReLU``.
|
|
47
|
+
ampli_init_values : tuple of float, optional
|
|
48
|
+
The initialization range for amplitude parameter using uniform
|
|
49
|
+
distribution. Default is (-0.05, 0.05).
|
|
50
|
+
omega_init_values : tuple of float, optional
|
|
51
|
+
The initialization range for omega parameters using uniform
|
|
52
|
+
distribution. Default is (0, 1).
|
|
53
|
+
beta_init_values : tuple of float, optional
|
|
54
|
+
The initialization range for beta parameters using uniform
|
|
55
|
+
distribution. Default is (0, 1). Default is (0, 0.05).
|
|
56
|
+
phase_init_values : tuple of float, optional
|
|
57
|
+
The initialization range for phase parameters using `normal`
|
|
58
|
+
distribution. Default is (0, 1). Default is (0, 0.05).
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
Notes
|
|
62
|
+
-----
|
|
63
|
+
This implementation is not guaranteed to be correct! it has not been checked
|
|
64
|
+
by original authors. The modifications are based on derivated code from
|
|
65
|
+
[CodeICASSP2025]_.
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
References
|
|
69
|
+
----------
|
|
70
|
+
.. [Li2017] Li, Y., Dzirasa, K., Carin, L., & Carlson, D. E. (2017).
|
|
71
|
+
Targeting EEG/LFP synchrony with neural nets. Advances in neural
|
|
72
|
+
information processing systems, 30.
|
|
73
|
+
.. [CodeICASSP2025] Code from Baselines for EEG-Music Emotion Recognition
|
|
74
|
+
Grand Challenge at ICASSP 2025.
|
|
75
|
+
https://github.com/SalvoCalcagno/eeg-music-challenge-icassp-2025-baselines
|
|
76
|
+
|
|
77
|
+
"""
|
|
78
|
+
|
|
79
|
+
def __init__(
|
|
80
|
+
self,
|
|
81
|
+
# braindecode convention
|
|
82
|
+
n_chans=None,
|
|
83
|
+
n_times=None,
|
|
84
|
+
n_outputs=None,
|
|
85
|
+
chs_info=None,
|
|
86
|
+
input_window_seconds=None,
|
|
87
|
+
sfreq=None,
|
|
88
|
+
# model parameters
|
|
89
|
+
num_filters=1,
|
|
90
|
+
filter_width=40,
|
|
91
|
+
pool_size=40,
|
|
92
|
+
activation: type[nn.Module] = nn.ReLU,
|
|
93
|
+
ampli_init_values: tuple[float, float] = (-0.05, 0.05),
|
|
94
|
+
omega_init_values: tuple[float, float] = (0.0, 1.0),
|
|
95
|
+
beta_init_values: tuple[float, float] = (0.0, 0.05),
|
|
96
|
+
phase_init_values: tuple[float, float] = (0.0, 0.05),
|
|
97
|
+
):
|
|
98
|
+
super().__init__(
|
|
99
|
+
n_chans=n_chans,
|
|
100
|
+
n_times=n_times,
|
|
101
|
+
n_outputs=n_outputs,
|
|
102
|
+
chs_info=chs_info,
|
|
103
|
+
input_window_seconds=input_window_seconds,
|
|
104
|
+
sfreq=sfreq,
|
|
105
|
+
)
|
|
106
|
+
del n_outputs, n_chans, chs_info, n_times, input_window_seconds, sfreq
|
|
107
|
+
|
|
108
|
+
self.num_filters = num_filters
|
|
109
|
+
self.filter_width = filter_width
|
|
110
|
+
self.pool_size = pool_size
|
|
111
|
+
self.activation = activation()
|
|
112
|
+
self.ampli_init_values = ampli_init_values
|
|
113
|
+
self.omega_init_values = omega_init_values
|
|
114
|
+
self.beta_init_values = beta_init_values
|
|
115
|
+
self.phase_init_values = phase_init_values
|
|
116
|
+
|
|
117
|
+
# Initialize parameters
|
|
118
|
+
self.amplitude = nn.Parameter(
|
|
119
|
+
torch.FloatTensor(1, 1, self.n_chans, self.num_filters).uniform_(
|
|
120
|
+
self.ampli_init_values[0], self.ampli_init_values[1]
|
|
121
|
+
)
|
|
122
|
+
)
|
|
123
|
+
self.omega = nn.Parameter(
|
|
124
|
+
torch.FloatTensor(1, 1, 1, self.num_filters).uniform_(
|
|
125
|
+
self.omega_init_values[0], self.omega_init_values[1]
|
|
126
|
+
)
|
|
127
|
+
)
|
|
128
|
+
|
|
129
|
+
self.bias = nn.Parameter(torch.zeros(self.num_filters))
|
|
130
|
+
|
|
131
|
+
# Calculate the output size after pooling
|
|
132
|
+
self.classifier_input_size = int(
|
|
133
|
+
ceil(float(self.n_times) / float(self.pool_size)) * self.num_filters
|
|
134
|
+
)
|
|
135
|
+
|
|
136
|
+
# Create time vector t
|
|
137
|
+
if self.filter_width % 2 == 0:
|
|
138
|
+
t_range = arange(-int(self.filter_width / 2), int(self.filter_width / 2))
|
|
139
|
+
else:
|
|
140
|
+
t_range = arange(
|
|
141
|
+
-int((self.filter_width - 1) / 2), int((self.filter_width - 1) / 2) + 1
|
|
142
|
+
)
|
|
143
|
+
|
|
144
|
+
t_np = t_range.reshape(1, self.filter_width, 1, 1)
|
|
145
|
+
self.t = nn.Parameter(torch.FloatTensor(t_np))
|
|
146
|
+
# Phase Shift
|
|
147
|
+
self.phi_ini = nn.Parameter(
|
|
148
|
+
torch.FloatTensor(1, 1, self.n_chans, self.num_filters).normal_(
|
|
149
|
+
self.beta_init_values[0], self.beta_init_values[1]
|
|
150
|
+
)
|
|
151
|
+
)
|
|
152
|
+
self.beta = nn.Parameter(
|
|
153
|
+
torch.FloatTensor(1, 1, 1, self.num_filters).uniform_(
|
|
154
|
+
self.phase_init_values[0], self.phase_init_values[1]
|
|
155
|
+
)
|
|
156
|
+
)
|
|
157
|
+
|
|
158
|
+
self.padding = self._compute_padding(filter_width=self.filter_width)
|
|
159
|
+
self.pad_input = nn.ConstantPad1d(self.padding, 0.0)
|
|
160
|
+
self.pad_res = nn.ConstantPad1d(self.padding, 0.0)
|
|
161
|
+
|
|
162
|
+
# Define pooling and classifier layers
|
|
163
|
+
self.pool = nn.MaxPool2d((1, self.pool_size), stride=(1, self.pool_size))
|
|
164
|
+
|
|
165
|
+
self.ensuredim = Rearrange("batch ch time -> batch ch 1 time")
|
|
166
|
+
|
|
167
|
+
self.final_layer = nn.Linear(self.classifier_input_size, self.n_outputs)
|
|
168
|
+
|
|
169
|
+
def forward(self, x):
|
|
170
|
+
"""Forward pass of the SyncNet model.
|
|
171
|
+
|
|
172
|
+
Parameters
|
|
173
|
+
----------
|
|
174
|
+
x : torch.Tensor
|
|
175
|
+
Input tensor of shape (batch_size, n_chans, n_times)
|
|
176
|
+
|
|
177
|
+
Returns
|
|
178
|
+
-------
|
|
179
|
+
out : torch.Tensor
|
|
180
|
+
Output tensor of shape (batch_size, n_outputs).
|
|
181
|
+
|
|
182
|
+
"""
|
|
183
|
+
# Ensure input tensor has shape (batch_size, n_chans, 1, n_times)
|
|
184
|
+
x = self.ensuredim(x)
|
|
185
|
+
# Output: (batch_size, n_chans, 1, n_times)
|
|
186
|
+
|
|
187
|
+
# Compute the oscillatory component
|
|
188
|
+
W_osc = self.amplitude * torch.cos(self.t * self.omega + self.phi_ini)
|
|
189
|
+
# W_osc is (1, filter_width, n_chans, 1)
|
|
190
|
+
|
|
191
|
+
# Compute the decay component
|
|
192
|
+
t_squared = torch.pow(self.t, 2) # Shape: (filter_width,)
|
|
193
|
+
t_squared_beta = t_squared * self.beta # Shape: (filter_width, num_filters)
|
|
194
|
+
W_decay = torch.exp(-t_squared_beta)
|
|
195
|
+
# W_osc is (1, filter_width, 1, 1)
|
|
196
|
+
|
|
197
|
+
# Combine oscillatory and decay components
|
|
198
|
+
# W shape: (1, n_chans, num_filters, filter_width)
|
|
199
|
+
W = W_osc * W_decay
|
|
200
|
+
# W shape will be: (1, filter_width, n_chans, 1)
|
|
201
|
+
|
|
202
|
+
W = W.view(self.num_filters, self.n_chans, 1, self.filter_width)
|
|
203
|
+
|
|
204
|
+
# Apply convolution
|
|
205
|
+
x_padded = self.pad_input(x.float())
|
|
206
|
+
|
|
207
|
+
res = F.conv2d(x_padded, W.float(), bias=self.bias, stride=1)
|
|
208
|
+
|
|
209
|
+
# Apply padding to the convolution result
|
|
210
|
+
res_padded = self.pad_res(res)
|
|
211
|
+
res_pooled = self.pool(res_padded)
|
|
212
|
+
|
|
213
|
+
# Flatten the result
|
|
214
|
+
res_flat = res_pooled.view(-1, self.classifier_input_size)
|
|
215
|
+
|
|
216
|
+
# Ensure beta remains non-negative
|
|
217
|
+
self.beta.data.clamp_(min=0)
|
|
218
|
+
|
|
219
|
+
# Apply activation
|
|
220
|
+
out = self.activation(res_flat)
|
|
221
|
+
# Apply classifier
|
|
222
|
+
out = self.final_layer(out)
|
|
223
|
+
|
|
224
|
+
return out
|
|
225
|
+
|
|
226
|
+
@staticmethod
|
|
227
|
+
def _compute_padding(filter_width):
|
|
228
|
+
# Compute padding
|
|
229
|
+
P = filter_width - 2
|
|
230
|
+
if P % 2 == 0:
|
|
231
|
+
padding = (P // 2, P // 2 + 1)
|
|
232
|
+
else:
|
|
233
|
+
padding = (P // 2, P // 2)
|
|
234
|
+
return padding
|
|
@@ -0,0 +1,275 @@
|
|
|
1
|
+
# Authors: Patryk Chrabaszcz
|
|
2
|
+
# Lukas Gemein <l.gemein@gmail.com>
|
|
3
|
+
#
|
|
4
|
+
# License: BSD-3
|
|
5
|
+
import torch
|
|
6
|
+
from torch import nn
|
|
7
|
+
from torch.nn import init
|
|
8
|
+
from torch.nn.utils.parametrizations import weight_norm
|
|
9
|
+
|
|
10
|
+
from braindecode.models.base import EEGModuleMixin
|
|
11
|
+
from braindecode.modules import Chomp1d, Ensure4d, SqueezeFinalOutput
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class BDTCN(EEGModuleMixin, nn.Module):
|
|
15
|
+
r"""Braindecode TCN from Gemein, L et al (2020) [gemein2020]_.
|
|
16
|
+
|
|
17
|
+
:bdg-success:`Convolution` :bdg-secondary:`Recurrent`
|
|
18
|
+
|
|
19
|
+
.. figure:: https://ars.els-cdn.com/content/image/1-s2.0-S1053811920305073-gr3_lrg.jpg
|
|
20
|
+
:align: center
|
|
21
|
+
:alt: Braindecode TCN Architecture
|
|
22
|
+
|
|
23
|
+
See [gemein2020]_ for details.
|
|
24
|
+
|
|
25
|
+
Parameters
|
|
26
|
+
----------
|
|
27
|
+
n_filters: int
|
|
28
|
+
number of output filters of each convolution
|
|
29
|
+
n_blocks: int
|
|
30
|
+
number of temporal blocks in the network
|
|
31
|
+
kernel_size: int
|
|
32
|
+
kernel size of the convolutions
|
|
33
|
+
drop_prob: float
|
|
34
|
+
dropout probability
|
|
35
|
+
activation: nn.Module, default=nn.ReLU
|
|
36
|
+
Activation function class to apply. Should be a PyTorch activation
|
|
37
|
+
module class like ``nn.ReLU`` or ``nn.ELU``. Default is ``nn.ReLU``.
|
|
38
|
+
|
|
39
|
+
References
|
|
40
|
+
----------
|
|
41
|
+
.. [gemein2020] Gemein, L. A., Schirrmeister, R. T., Chrabąszcz, P., Wilson, D.,
|
|
42
|
+
Boedecker, J., Schulze-Bonhage, A., ... & Ball, T. (2020). Machine-learning-based
|
|
43
|
+
diagnostics of EEG pathology. NeuroImage, 220, 117021.
|
|
44
|
+
"""
|
|
45
|
+
|
|
46
|
+
def __init__(
|
|
47
|
+
self,
|
|
48
|
+
# Braindecode parameters
|
|
49
|
+
n_chans=None,
|
|
50
|
+
n_outputs=None,
|
|
51
|
+
chs_info=None,
|
|
52
|
+
n_times=None,
|
|
53
|
+
sfreq=None,
|
|
54
|
+
input_window_seconds=None,
|
|
55
|
+
# Model's parameters
|
|
56
|
+
n_blocks=3,
|
|
57
|
+
n_filters=30,
|
|
58
|
+
kernel_size=5,
|
|
59
|
+
drop_prob=0.5,
|
|
60
|
+
activation: type[nn.Module] = nn.ReLU,
|
|
61
|
+
):
|
|
62
|
+
super().__init__(
|
|
63
|
+
n_outputs=n_outputs,
|
|
64
|
+
n_chans=n_chans,
|
|
65
|
+
chs_info=chs_info,
|
|
66
|
+
n_times=n_times,
|
|
67
|
+
input_window_seconds=input_window_seconds,
|
|
68
|
+
sfreq=sfreq,
|
|
69
|
+
)
|
|
70
|
+
del n_outputs, n_chans, chs_info, n_times, sfreq, input_window_seconds
|
|
71
|
+
|
|
72
|
+
self.base_tcn = TCN(
|
|
73
|
+
n_chans=self.n_chans,
|
|
74
|
+
n_outputs=self.n_outputs,
|
|
75
|
+
n_blocks=n_blocks,
|
|
76
|
+
n_filters=n_filters,
|
|
77
|
+
kernel_size=kernel_size,
|
|
78
|
+
drop_prob=drop_prob,
|
|
79
|
+
activation=activation,
|
|
80
|
+
)
|
|
81
|
+
|
|
82
|
+
self.final_layer = torch.nn.Sequential(
|
|
83
|
+
torch.nn.AdaptiveAvgPool1d(1), torch.nn.Flatten()
|
|
84
|
+
)
|
|
85
|
+
|
|
86
|
+
def forward(self, x):
|
|
87
|
+
x = self.base_tcn(x)
|
|
88
|
+
x = self.final_layer(x)
|
|
89
|
+
return x
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
class TCN(nn.Module):
|
|
93
|
+
r"""Temporal Convolutional Network (TCN) from Bai et al. 2018 [Bai2018]_.
|
|
94
|
+
|
|
95
|
+
See [Bai2018]_ for details.
|
|
96
|
+
|
|
97
|
+
Code adapted from https://github.com/locuslab/TCN/blob/master/TCN/tcn.py
|
|
98
|
+
|
|
99
|
+
Parameters
|
|
100
|
+
----------
|
|
101
|
+
n_filters: int
|
|
102
|
+
number of output filters of each convolution
|
|
103
|
+
n_blocks: int
|
|
104
|
+
number of temporal blocks in the network
|
|
105
|
+
kernel_size: int
|
|
106
|
+
kernel size of the convolutions
|
|
107
|
+
drop_prob: float
|
|
108
|
+
dropout probability
|
|
109
|
+
activation: nn.Module, default=nn.ReLU
|
|
110
|
+
Activation function class to apply. Should be a PyTorch activation
|
|
111
|
+
module class like ``nn.ReLU`` or ``nn.ELU``. Default is ``nn.ReLU``.
|
|
112
|
+
|
|
113
|
+
References
|
|
114
|
+
----------
|
|
115
|
+
.. [Bai2018] Bai, S., Kolter, J. Z., & Koltun, V. (2018).
|
|
116
|
+
An empirical evaluation of generic convolutional and recurrent networks
|
|
117
|
+
for sequence modeling.
|
|
118
|
+
arXiv preprint arXiv:1803.01271.
|
|
119
|
+
"""
|
|
120
|
+
|
|
121
|
+
def __init__(
|
|
122
|
+
self,
|
|
123
|
+
n_chans=None,
|
|
124
|
+
n_outputs=None,
|
|
125
|
+
n_blocks=3,
|
|
126
|
+
n_filters=30,
|
|
127
|
+
kernel_size=5,
|
|
128
|
+
drop_prob=0.5,
|
|
129
|
+
activation: type[nn.Module] = nn.ReLU,
|
|
130
|
+
):
|
|
131
|
+
super().__init__()
|
|
132
|
+
self.mapping = {
|
|
133
|
+
"fc.weight": "final_layer.fc.weight",
|
|
134
|
+
"fc.bias": "final_layer.fc.bias",
|
|
135
|
+
}
|
|
136
|
+
self.ensuredims = Ensure4d()
|
|
137
|
+
t_blocks = nn.Sequential()
|
|
138
|
+
for i in range(n_blocks):
|
|
139
|
+
n_inputs = n_chans if i == 0 else n_filters
|
|
140
|
+
dilation_size = 2**i
|
|
141
|
+
t_blocks.add_module(
|
|
142
|
+
"temporal_block_{:d}".format(i),
|
|
143
|
+
_TemporalBlock(
|
|
144
|
+
n_inputs=n_inputs,
|
|
145
|
+
n_outputs=n_filters,
|
|
146
|
+
kernel_size=kernel_size,
|
|
147
|
+
stride=1,
|
|
148
|
+
dilation=dilation_size,
|
|
149
|
+
padding=(kernel_size - 1) * dilation_size,
|
|
150
|
+
drop_prob=drop_prob,
|
|
151
|
+
activation=activation,
|
|
152
|
+
),
|
|
153
|
+
)
|
|
154
|
+
self.temporal_blocks = t_blocks
|
|
155
|
+
|
|
156
|
+
self.final_layer = _FinalLayer(
|
|
157
|
+
in_features=n_filters,
|
|
158
|
+
out_features=n_outputs,
|
|
159
|
+
)
|
|
160
|
+
self.min_len = 1
|
|
161
|
+
for i in range(n_blocks):
|
|
162
|
+
dilation = 2**i
|
|
163
|
+
self.min_len += 2 * (kernel_size - 1) * dilation
|
|
164
|
+
|
|
165
|
+
# start in eval mode
|
|
166
|
+
self.train()
|
|
167
|
+
|
|
168
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
169
|
+
"""Forward pass.
|
|
170
|
+
|
|
171
|
+
Parameters
|
|
172
|
+
----------
|
|
173
|
+
x: torch.Tensor
|
|
174
|
+
Batch of EEG windows of shape (batch_size, n_channels, n_times).
|
|
175
|
+
"""
|
|
176
|
+
x = self.ensuredims(x)
|
|
177
|
+
# x is in format: B x C x T x 1
|
|
178
|
+
(batch_size, _, time_size, _) = x.size()
|
|
179
|
+
assert time_size >= self.min_len
|
|
180
|
+
# remove empty trailing dimension
|
|
181
|
+
x = x.squeeze(3)
|
|
182
|
+
x = self.temporal_blocks(x)
|
|
183
|
+
# Convert to: B x T x C
|
|
184
|
+
x = x.transpose(1, 2).contiguous()
|
|
185
|
+
|
|
186
|
+
out = self.final_layer(x, batch_size, time_size, self.min_len)
|
|
187
|
+
|
|
188
|
+
return out
|
|
189
|
+
|
|
190
|
+
|
|
191
|
+
class _FinalLayer(nn.Module):
|
|
192
|
+
def __init__(self, in_features, out_features):
|
|
193
|
+
super().__init__()
|
|
194
|
+
|
|
195
|
+
self.fc = nn.Linear(in_features=in_features, out_features=out_features)
|
|
196
|
+
|
|
197
|
+
self.out_fun = nn.Identity()
|
|
198
|
+
|
|
199
|
+
self.squeeze = SqueezeFinalOutput()
|
|
200
|
+
|
|
201
|
+
def forward(
|
|
202
|
+
self, x: torch.Tensor, batch_size: int, time_size: int, min_len: int
|
|
203
|
+
) -> torch.Tensor:
|
|
204
|
+
fc_out = self.fc(x.view(batch_size * time_size, x.size(2)))
|
|
205
|
+
fc_out = self.out_fun(fc_out)
|
|
206
|
+
fc_out = fc_out.view(batch_size, time_size, fc_out.size(1))
|
|
207
|
+
|
|
208
|
+
out_size = 1 + max(0, time_size - min_len)
|
|
209
|
+
out = fc_out[:, -out_size:, :].transpose(1, 2)
|
|
210
|
+
# re-add 4th dimension for compatibility with braindecode
|
|
211
|
+
return self.squeeze(out[:, :, :, None])
|
|
212
|
+
|
|
213
|
+
|
|
214
|
+
class _TemporalBlock(nn.Module):
|
|
215
|
+
def __init__(
|
|
216
|
+
self,
|
|
217
|
+
n_inputs,
|
|
218
|
+
n_outputs,
|
|
219
|
+
kernel_size,
|
|
220
|
+
stride,
|
|
221
|
+
dilation,
|
|
222
|
+
padding,
|
|
223
|
+
drop_prob,
|
|
224
|
+
activation: type[nn.Module] = nn.ReLU,
|
|
225
|
+
):
|
|
226
|
+
super().__init__()
|
|
227
|
+
self.conv1 = weight_norm(
|
|
228
|
+
nn.Conv1d(
|
|
229
|
+
n_inputs,
|
|
230
|
+
n_outputs,
|
|
231
|
+
kernel_size,
|
|
232
|
+
stride=stride,
|
|
233
|
+
padding=padding,
|
|
234
|
+
dilation=dilation,
|
|
235
|
+
)
|
|
236
|
+
)
|
|
237
|
+
self.chomp1 = Chomp1d(padding)
|
|
238
|
+
self.relu1 = activation()
|
|
239
|
+
self.dropout1 = nn.Dropout2d(drop_prob)
|
|
240
|
+
|
|
241
|
+
self.conv2 = weight_norm(
|
|
242
|
+
nn.Conv1d(
|
|
243
|
+
n_outputs,
|
|
244
|
+
n_outputs,
|
|
245
|
+
kernel_size,
|
|
246
|
+
stride=stride,
|
|
247
|
+
padding=padding,
|
|
248
|
+
dilation=dilation,
|
|
249
|
+
)
|
|
250
|
+
)
|
|
251
|
+
self.chomp2 = Chomp1d(padding)
|
|
252
|
+
self.relu2 = activation()
|
|
253
|
+
self.dropout2 = nn.Dropout2d(drop_prob)
|
|
254
|
+
|
|
255
|
+
self.downsample = (
|
|
256
|
+
nn.Conv1d(n_inputs, n_outputs, 1) if n_inputs != n_outputs else None
|
|
257
|
+
)
|
|
258
|
+
self.relu = activation()
|
|
259
|
+
|
|
260
|
+
init.normal_(self.conv1.weight, 0, 0.01)
|
|
261
|
+
init.normal_(self.conv2.weight, 0, 0.01)
|
|
262
|
+
if self.downsample is not None:
|
|
263
|
+
init.normal_(self.downsample.weight, 0, 0.01)
|
|
264
|
+
|
|
265
|
+
def forward(self, x):
|
|
266
|
+
out = self.conv1(x)
|
|
267
|
+
out = self.chomp1(out)
|
|
268
|
+
out = self.relu1(out)
|
|
269
|
+
out = self.dropout1(out)
|
|
270
|
+
out = self.conv2(out)
|
|
271
|
+
out = self.chomp2(out)
|
|
272
|
+
out = self.relu2(out)
|
|
273
|
+
out = self.dropout2(out)
|
|
274
|
+
res = x if self.downsample is None else self.downsample(x)
|
|
275
|
+
return self.relu(out + res)
|