braindecode 1.3.0.dev177069446__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (124) hide show
  1. braindecode/__init__.py +9 -0
  2. braindecode/augmentation/__init__.py +52 -0
  3. braindecode/augmentation/base.py +225 -0
  4. braindecode/augmentation/functional.py +1300 -0
  5. braindecode/augmentation/transforms.py +1356 -0
  6. braindecode/classifier.py +258 -0
  7. braindecode/datasets/__init__.py +44 -0
  8. braindecode/datasets/base.py +823 -0
  9. braindecode/datasets/bbci.py +693 -0
  10. braindecode/datasets/bcicomp.py +193 -0
  11. braindecode/datasets/bids/__init__.py +54 -0
  12. braindecode/datasets/bids/datasets.py +239 -0
  13. braindecode/datasets/bids/format.py +717 -0
  14. braindecode/datasets/bids/hub.py +987 -0
  15. braindecode/datasets/bids/hub_format.py +717 -0
  16. braindecode/datasets/bids/hub_io.py +197 -0
  17. braindecode/datasets/bids/hub_validation.py +114 -0
  18. braindecode/datasets/bids/iterable.py +220 -0
  19. braindecode/datasets/chb_mit.py +163 -0
  20. braindecode/datasets/mne.py +170 -0
  21. braindecode/datasets/moabb.py +219 -0
  22. braindecode/datasets/nmt.py +313 -0
  23. braindecode/datasets/registry.py +120 -0
  24. braindecode/datasets/siena.py +162 -0
  25. braindecode/datasets/sleep_physio_challe_18.py +411 -0
  26. braindecode/datasets/sleep_physionet.py +125 -0
  27. braindecode/datasets/tuh.py +591 -0
  28. braindecode/datasets/utils.py +67 -0
  29. braindecode/datasets/xy.py +96 -0
  30. braindecode/datautil/__init__.py +62 -0
  31. braindecode/datautil/channel_utils.py +114 -0
  32. braindecode/datautil/hub_formats.py +180 -0
  33. braindecode/datautil/serialization.py +359 -0
  34. braindecode/datautil/util.py +154 -0
  35. braindecode/eegneuralnet.py +372 -0
  36. braindecode/functional/__init__.py +22 -0
  37. braindecode/functional/functions.py +251 -0
  38. braindecode/functional/initialization.py +47 -0
  39. braindecode/models/__init__.py +117 -0
  40. braindecode/models/atcnet.py +830 -0
  41. braindecode/models/attentionbasenet.py +727 -0
  42. braindecode/models/attn_sleep.py +549 -0
  43. braindecode/models/base.py +574 -0
  44. braindecode/models/bendr.py +493 -0
  45. braindecode/models/biot.py +537 -0
  46. braindecode/models/brainmodule.py +845 -0
  47. braindecode/models/config.py +233 -0
  48. braindecode/models/contrawr.py +319 -0
  49. braindecode/models/ctnet.py +541 -0
  50. braindecode/models/deep4.py +376 -0
  51. braindecode/models/deepsleepnet.py +417 -0
  52. braindecode/models/eegconformer.py +475 -0
  53. braindecode/models/eeginception_erp.py +379 -0
  54. braindecode/models/eeginception_mi.py +379 -0
  55. braindecode/models/eegitnet.py +302 -0
  56. braindecode/models/eegminer.py +256 -0
  57. braindecode/models/eegnet.py +359 -0
  58. braindecode/models/eegnex.py +354 -0
  59. braindecode/models/eegsimpleconv.py +201 -0
  60. braindecode/models/eegsym.py +917 -0
  61. braindecode/models/eegtcnet.py +337 -0
  62. braindecode/models/fbcnet.py +225 -0
  63. braindecode/models/fblightconvnet.py +315 -0
  64. braindecode/models/fbmsnet.py +338 -0
  65. braindecode/models/hybrid.py +126 -0
  66. braindecode/models/ifnet.py +443 -0
  67. braindecode/models/labram.py +1316 -0
  68. braindecode/models/luna.py +891 -0
  69. braindecode/models/medformer.py +760 -0
  70. braindecode/models/msvtnet.py +377 -0
  71. braindecode/models/patchedtransformer.py +640 -0
  72. braindecode/models/reve.py +843 -0
  73. braindecode/models/sccnet.py +280 -0
  74. braindecode/models/shallow_fbcsp.py +212 -0
  75. braindecode/models/signal_jepa.py +1122 -0
  76. braindecode/models/sinc_shallow.py +339 -0
  77. braindecode/models/sleep_stager_blanco_2020.py +169 -0
  78. braindecode/models/sleep_stager_chambon_2018.py +159 -0
  79. braindecode/models/sparcnet.py +426 -0
  80. braindecode/models/sstdpn.py +869 -0
  81. braindecode/models/summary.csv +47 -0
  82. braindecode/models/syncnet.py +234 -0
  83. braindecode/models/tcn.py +275 -0
  84. braindecode/models/tidnet.py +397 -0
  85. braindecode/models/tsinception.py +295 -0
  86. braindecode/models/usleep.py +439 -0
  87. braindecode/models/util.py +369 -0
  88. braindecode/modules/__init__.py +92 -0
  89. braindecode/modules/activation.py +86 -0
  90. braindecode/modules/attention.py +883 -0
  91. braindecode/modules/blocks.py +160 -0
  92. braindecode/modules/convolution.py +330 -0
  93. braindecode/modules/filter.py +654 -0
  94. braindecode/modules/layers.py +216 -0
  95. braindecode/modules/linear.py +70 -0
  96. braindecode/modules/parametrization.py +38 -0
  97. braindecode/modules/stats.py +87 -0
  98. braindecode/modules/util.py +85 -0
  99. braindecode/modules/wrapper.py +90 -0
  100. braindecode/preprocessing/__init__.py +271 -0
  101. braindecode/preprocessing/eegprep_preprocess.py +1317 -0
  102. braindecode/preprocessing/mne_preprocess.py +240 -0
  103. braindecode/preprocessing/preprocess.py +579 -0
  104. braindecode/preprocessing/util.py +177 -0
  105. braindecode/preprocessing/windowers.py +1037 -0
  106. braindecode/regressor.py +234 -0
  107. braindecode/samplers/__init__.py +18 -0
  108. braindecode/samplers/base.py +399 -0
  109. braindecode/samplers/ssl.py +263 -0
  110. braindecode/training/__init__.py +23 -0
  111. braindecode/training/callbacks.py +23 -0
  112. braindecode/training/losses.py +105 -0
  113. braindecode/training/scoring.py +477 -0
  114. braindecode/util.py +419 -0
  115. braindecode/version.py +1 -0
  116. braindecode/visualization/__init__.py +8 -0
  117. braindecode/visualization/confusion_matrices.py +289 -0
  118. braindecode/visualization/gradients.py +62 -0
  119. braindecode-1.3.0.dev177069446.dist-info/METADATA +230 -0
  120. braindecode-1.3.0.dev177069446.dist-info/RECORD +124 -0
  121. braindecode-1.3.0.dev177069446.dist-info/WHEEL +5 -0
  122. braindecode-1.3.0.dev177069446.dist-info/licenses/LICENSE.txt +31 -0
  123. braindecode-1.3.0.dev177069446.dist-info/licenses/NOTICE.txt +20 -0
  124. braindecode-1.3.0.dev177069446.dist-info/top_level.txt +1 -0
@@ -0,0 +1,47 @@
1
+ Model,Application,Type,Sampling Frequency (Hz),Hyperparameters,#Parameters,get_#Parameters,Categorization
2
+ ATCNet,General,Classification,250,"n_chans, n_outputs, n_times",113732,"ATCNet(n_chans=22, n_outputs=4, n_times=1000)","Convolution,Recurrent,Attention/Transformer"
3
+ AttentionBaseNet,Motor Imagery,Classification,250,"n_chans, n_outputs, n_times",3692,"AttentionBaseNet(n_chans=22, n_outputs=4, n_times=1000)","Convolution,Attention/Transformer"
4
+ BDTCN,Normal Abnormal,Classification,100,"n_chans, n_outputs, n_times",456502,"BDTCN(n_chans=21, n_outputs=2, n_times=6000, n_blocks=5, n_filters=55, kernel_size=16)","Convolution,Recurrent"
5
+ BIOT,"Sleep Staging, Epilepsy",Classification,200,"n_chans, n_outputs",3183879,"BIOT(n_chans=2, n_outputs=5, n_times=6000)","Foundation Model"
6
+ ContraWR,Sleep Staging,"Classification, Embedding",125,"n_chans, n_outputs, sfreq",1160165,"ContraWR(n_chans=2, n_outputs=5, n_times=3750, emb_size=256, sfreq=125)",Convolution
7
+ CTNet,Motor Imagery,Classification,250,"n_chans, n_outputs, n_times",26900,"CTNet(n_chans=22, n_outputs=4, n_times=1000, n_filters_time=8, kernel_size=16, num_heads=2, embed_dim=16)","Convolution,Attention/Transformer"
8
+ Deep4Net,General,Classification,250,"n_chans, n_outputs, n_times",282879,"Deep4Net(n_chans=22, n_outputs=4, n_times=1000)","Convolution"
9
+ DeepSleepNet,Sleep Staging,Classification,256,"n_chans, n_outputs",24744837,"DeepSleepNet(n_chans=1, n_outputs=5, n_times=7680, sfreq=256)","Convolution,Recurrent"
10
+ EEGConformer,General,Classification,250,"n_chans, n_outputs, n_times",789572,"EEGConformer(n_chans=22, n_outputs=4, n_times=1000)","Convolution,Attention/Transformer"
11
+ EEGInceptionERP,"ERP, SSVEP",Classification,128,"n_chans, n_outputs",14926,"EEGInceptionERP(n_chans=8, n_outputs=2, n_times=128, sfreq=128)","Convolution"
12
+ EEGInceptionMI,Motor Imagery,Classification,250,"n_chans, n_outputs, n_times",558028,"EEGInceptionMI(n_chans=22, n_outputs=4, n_times=1000, n_convs=5, n_filters=12)","Convolution"
13
+ EEGITNet,Motor Imagery,Classification,125,"n_chans, n_outputs, n_times",5212,"EEGITNet(n_chans=22, n_outputs=4, n_times=500)","Convolution,Recurrent"
14
+ EEGNet,General,Classification,128,"n_chans, n_outputs, n_times",2484,"EEGNet(n_chans=22, n_outputs=4, n_times=512)","Convolution"
15
+ EEGNeX,Motor Imagery,Classification,125,"n_chans, n_outputs, n_times",55940,"EEGNeX(n_chans=22, n_outputs=4, n_times=500)","Convolution"
16
+ EEGSym,Motor Imagery,Classification,250,"n_chans, n_outputs, n_times, sfreq",299218,"EEGSym(n_chans=22, n_outputs=4, n_times=1000, sfreq=250)","Convolution,Channel"
17
+ EEGMiner,Emotion Recognition,Classification,128,"n_chans, n_outputs, n_times, sfreq",7572,"EEGMiner(n_chans=62, n_outputs=2, n_times=2560, sfreq=128)","Convolution,Interpretability"
18
+ EEGSimpleConv,Motor Imagery,Classification,80,"n_chans, n_outputs, sfreq",730404,"EEGSimpleConv(n_chans=22, n_outputs=4, n_times=320, sfreq=80)","Convolution"
19
+ EEGTCNet,Motor Imagery,Classification,250,"n_chans, n_outputs",4516,"EEGTCNet(n_chans=22, n_outputs=4, n_times=1000, kern_length=32)","Convolution,Recurrent"
20
+ Labram,General,"Classification, Embedding",200,"n_chans, n_outputs, n_times",5866180,"Labram(n_chans=22, n_outputs=4, n_times=1000, sfreq=250)","Convolution,Foundation Model"
21
+ MSVTNet,Motor Imagery,Classification,250,"n_chans, n_outputs, n_times",75494," MSVTNet(n_chans=22, n_outputs=4, n_times=1000)","Convolution,Recurrent,Attention/Transformer"
22
+ SCCNet,Motor Imagery,Classification,125,"n_chans, n_outputs, n_times, sfreq",12070,"SCCNet(n_chans=22, n_outputs=4, n_times=1000, sfreq=125)","Convolution"
23
+ SignalJEPA,"Motor Imagery, ERP, SSVEP",Embedding,128,"n_times, chs_info",3456882,"SignalJEPA(n_times=512, chs_info=Lee2019_MI().get_data(subjects=[1])[1]['0']['1train'].info[""chs""][:62])","Convolution,Channel,Foundation Model"
24
+ SignalJEPA_Contextual,"Motor Imagery, ERP, SSVEP",Classification,128,"n_outputs, n_times, chs_info",3459184,"SignalJEPA_Contextual(n_outputs=2, input_window_seconds=4.19, sfreq=128, chs_info=Lee2019_MI().get_data(subjects=[1])[1]['0']['1train'].info[""chs""][:62])","Convolution,Channel,Foundation Model"
25
+ SignalJEPA_PostLocal,"Motor Imagery, ERP, SSVEP",Classification,128,"n_chans, n_outputs, n_times",16142,"SignalJEPA_PostLocal(n_chans=62, n_outputs=2, input_window_seconds=4.19, sfreq=128)","Convolution,Channel,Foundation Model"
26
+ SignalJEPA_PreLocal,"Motor Imagery, ERP, SSVEP",Classification,128,"n_outputs, n_times, chs_info",16142,"SignalJEPA_PreLocal(n_chans=62, n_outputs=2, input_window_seconds=4.19, sfreq=128)","Convolution,Channel,Foundation Model"
27
+ SincShallowNet,Motor Imagery,Classification,250,"n_chans, n_outputs, n_times, sfreq",21892,"SincShallowNet(n_chans=22, n_outputs=4, n_times=1000, sfreq=250)","Convolution,Interpretability"
28
+ ShallowFBCSPNet,General,Classification,250,"n_chans, n_outputs, n_times",46084,"ShallowFBCSPNet(n_chans=22, n_outputs=4, n_times=1000, sfreq=250)","Convolution"
29
+ SleepStagerBlanco2020,Sleep Staging,Classification,100,"n_chans, n_outputs, n_times",2845,"SleepStagerBlanco2020(n_chans=2, n_outputs=5, n_times=3000, sfreq=100)","Convolution"
30
+ SleepStagerChambon2018,Sleep Staging,Classification,128,"n_chans, n_outputs, n_times, sfreq",5835,"SleepStagerChambon2018(n_chans=2, n_outputs=5, n_times=3840, sfreq=128)","Convolution"
31
+ AttnSleep,Sleep Staging,Classification,100,"n_chans, n_outputs, n_times, sfreq",719925,"AttnSleep(n_chans=2, n_outputs=5, n_times=3000, sfreq=100)","Convolution, Attention/Transformer"
32
+ SPARCNet,Epilepsy,Classification,200,"n_chans, n_outputs, n_times",1141921,"SPARCNet(n_chans=16, n_outputs=6, n_times=2000, sfreq=200)","Convolution"
33
+ SyncNet,"Emotion Recognition, Alcoholism",Classification,256,"n_chans, n_outputs, n_times",554,"SyncNet(n_chans=62, n_outputs=3, n_times=5120, sfreq=256)","Interpretability"
34
+ TSception,Emotion Recognition,Classification,256,"n_chans, n_outputs, n_times, sfreq",2187206,"TSception(n_chans=62, n_outputs=3, n_times=5120, sfreq=256)","Convolution"
35
+ TIDNet,General,Classification,250,"n_chans, n_outputs, n_times",240404,"TIDNet(n_chans=22, n_outputs=4, n_times=1000)","Convolution"
36
+ USleep,Sleep Staging,Classification,128,"n_chans, n_outputs, n_times, sfreq",2482011,"USleep(n_chans=2, n_outputs=5, n_times=3000, sfreq=100)","Convolution"
37
+ FBCNet,Motor Imagery,Classification,250,"n_chans, n_outputs, n_times, sfreq",11812,"FCNet(n_chans=22, n_outputs=4, n_times=1000, sfreq=250)","Convolution,FilterBank"
38
+ FBMSNet,Motor Imagery,Classification,250,"n_chans, n_outputs, n_times, sfreq",16231,"FBMSNet(n_chans=22, n_outputs=4, n_times=1000, sfreq=250)","Convolution,FilterBank"
39
+ FBLightConvNet,Motor Imagery,Classification,250,"n_chans, n_outputs, n_times, sfreq",6596,"FBLightConvNet(n_chans=22, n_outputs=4, n_times=1000, sfreq=250)","Convolution,FilterBank"
40
+ IFNet,Motor Imagery,Classification,250,"n_chans, n_outputs, n_times, sfreq",9860,"IFNet(n_chans=22, n_outputs=4, n_times=1000, sfreq=250)","Convolution,FilterBank"
41
+ BrainModule,Speech Decoding,Classification,250,"n_chans, n_outputs, n_times, sfreq",6186909,"BrainModule(n_chans=64, n_outputs=29, n_times=160, sfreq=1000)","Convolution"
42
+ PBT,General,Classification,250,"n_chans, n_outputs, n_times",818948,"PBT(n_chans=22, n_outputs=4, n_times=1000, sfreq=250)","Foundation Model"
43
+ SSTDPN,Motor Imagery,Classification,250,"n_chans, n_outputs, n_times",19502,"SSTDPN(n_chans=22, n_outputs=4, n_times=1000)","Convolution,Attention/Transformer"
44
+ BENDR,General,"Classification,Embedding",250,"n_chans, n_times, n_outputs",157141049,"BENDR(n_chans=22, n_outputs=4, n_times=1000)","Foundation Model,Convolution"
45
+ LUNA,General,"Classification,Embedding",128,"n_chans, n_times, sfreq, chs_info",7100731,"LUNA(n_chans=22, n_times=512, sfreq=128)","Convolution,Channel,Foundation Model"
46
+ MEDFormer,General,Classification,250,"n_chans, n_outputs, n_times",5313924,"MEDFormer(n_chans=22, n_outputs=4, n_times=1000)","Foundation Model,Convolution"
47
+ REVE,General,Classification,200,"n_outputs, n_times, n_chans",69481476,"REVE(n_times=1000, n_outputs=4, n_chans=19)","Foundation Model,Attention/Transformer"
@@ -0,0 +1,234 @@
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from einops.layers.torch import Rearrange
5
+ from numpy import arange, ceil
6
+
7
+ from braindecode.models.base import EEGModuleMixin
8
+
9
+
10
+ class SyncNet(EEGModuleMixin, nn.Module):
11
+ r"""Synchronization Network (SyncNet) from Li, Y et al (2017) [Li2017]_.
12
+
13
+ :bdg-warning:`Interpretability`
14
+
15
+ .. figure:: https://braindecode.org/dev/_static/model/SyncNet.png
16
+ :align: center
17
+ :alt: SyncNet Architecture
18
+
19
+ SyncNet uses parameterized 1-dimensional convolutional filters inspired by
20
+ the Morlet wavelet to extract features from EEG signals. The filters are
21
+ dynamically generated based on learnable parameters that control the
22
+ oscillation and decay characteristics.
23
+
24
+ The filter for channel ``c`` and filter ``k`` is defined as:
25
+
26
+ .. math::
27
+
28
+ f_c^{(k)}(\\tau) = amplitude_c^{(k)} \\cos(\\omega^{(k)} \\tau + \\phi_c^{(k)}) \\exp(-\\beta^{(k)} \\tau^2)
29
+
30
+ where:
31
+ - :math:`amplitude_c^{(k)}` is the amplitude parameter (channel-specific).
32
+ - :math:`\\omega^{(k)}` is the frequency parameter (shared across channels).
33
+ - :math:`\\phi_c^{(k)}` is the phase shift (channel-specific).
34
+ - :math:`\\beta^{(k)}` is the decay parameter (shared across channels).
35
+ - :math:`\\tau` is the time index.
36
+
37
+ Parameters
38
+ ----------
39
+ num_filters : int, optional
40
+ Number of filters in the convolutional layer. Default is 1.
41
+ filter_width : int, optional
42
+ Width of the convolutional filters. Default is 40.
43
+ pool_size : int, optional
44
+ Size of the pooling window. Default is 40.
45
+ activation : nn.Module, optional
46
+ Activation function to apply after pooling. Default is ``nn.ReLU``.
47
+ ampli_init_values : tuple of float, optional
48
+ The initialization range for amplitude parameter using uniform
49
+ distribution. Default is (-0.05, 0.05).
50
+ omega_init_values : tuple of float, optional
51
+ The initialization range for omega parameters using uniform
52
+ distribution. Default is (0, 1).
53
+ beta_init_values : tuple of float, optional
54
+ The initialization range for beta parameters using uniform
55
+ distribution. Default is (0, 1). Default is (0, 0.05).
56
+ phase_init_values : tuple of float, optional
57
+ The initialization range for phase parameters using `normal`
58
+ distribution. Default is (0, 1). Default is (0, 0.05).
59
+
60
+
61
+ Notes
62
+ -----
63
+ This implementation is not guaranteed to be correct! it has not been checked
64
+ by original authors. The modifications are based on derivated code from
65
+ [CodeICASSP2025]_.
66
+
67
+
68
+ References
69
+ ----------
70
+ .. [Li2017] Li, Y., Dzirasa, K., Carin, L., & Carlson, D. E. (2017).
71
+ Targeting EEG/LFP synchrony with neural nets. Advances in neural
72
+ information processing systems, 30.
73
+ .. [CodeICASSP2025] Code from Baselines for EEG-Music Emotion Recognition
74
+ Grand Challenge at ICASSP 2025.
75
+ https://github.com/SalvoCalcagno/eeg-music-challenge-icassp-2025-baselines
76
+
77
+ """
78
+
79
+ def __init__(
80
+ self,
81
+ # braindecode convention
82
+ n_chans=None,
83
+ n_times=None,
84
+ n_outputs=None,
85
+ chs_info=None,
86
+ input_window_seconds=None,
87
+ sfreq=None,
88
+ # model parameters
89
+ num_filters=1,
90
+ filter_width=40,
91
+ pool_size=40,
92
+ activation: type[nn.Module] = nn.ReLU,
93
+ ampli_init_values: tuple[float, float] = (-0.05, 0.05),
94
+ omega_init_values: tuple[float, float] = (0.0, 1.0),
95
+ beta_init_values: tuple[float, float] = (0.0, 0.05),
96
+ phase_init_values: tuple[float, float] = (0.0, 0.05),
97
+ ):
98
+ super().__init__(
99
+ n_chans=n_chans,
100
+ n_times=n_times,
101
+ n_outputs=n_outputs,
102
+ chs_info=chs_info,
103
+ input_window_seconds=input_window_seconds,
104
+ sfreq=sfreq,
105
+ )
106
+ del n_outputs, n_chans, chs_info, n_times, input_window_seconds, sfreq
107
+
108
+ self.num_filters = num_filters
109
+ self.filter_width = filter_width
110
+ self.pool_size = pool_size
111
+ self.activation = activation()
112
+ self.ampli_init_values = ampli_init_values
113
+ self.omega_init_values = omega_init_values
114
+ self.beta_init_values = beta_init_values
115
+ self.phase_init_values = phase_init_values
116
+
117
+ # Initialize parameters
118
+ self.amplitude = nn.Parameter(
119
+ torch.FloatTensor(1, 1, self.n_chans, self.num_filters).uniform_(
120
+ self.ampli_init_values[0], self.ampli_init_values[1]
121
+ )
122
+ )
123
+ self.omega = nn.Parameter(
124
+ torch.FloatTensor(1, 1, 1, self.num_filters).uniform_(
125
+ self.omega_init_values[0], self.omega_init_values[1]
126
+ )
127
+ )
128
+
129
+ self.bias = nn.Parameter(torch.zeros(self.num_filters))
130
+
131
+ # Calculate the output size after pooling
132
+ self.classifier_input_size = int(
133
+ ceil(float(self.n_times) / float(self.pool_size)) * self.num_filters
134
+ )
135
+
136
+ # Create time vector t
137
+ if self.filter_width % 2 == 0:
138
+ t_range = arange(-int(self.filter_width / 2), int(self.filter_width / 2))
139
+ else:
140
+ t_range = arange(
141
+ -int((self.filter_width - 1) / 2), int((self.filter_width - 1) / 2) + 1
142
+ )
143
+
144
+ t_np = t_range.reshape(1, self.filter_width, 1, 1)
145
+ self.t = nn.Parameter(torch.FloatTensor(t_np))
146
+ # Phase Shift
147
+ self.phi_ini = nn.Parameter(
148
+ torch.FloatTensor(1, 1, self.n_chans, self.num_filters).normal_(
149
+ self.beta_init_values[0], self.beta_init_values[1]
150
+ )
151
+ )
152
+ self.beta = nn.Parameter(
153
+ torch.FloatTensor(1, 1, 1, self.num_filters).uniform_(
154
+ self.phase_init_values[0], self.phase_init_values[1]
155
+ )
156
+ )
157
+
158
+ self.padding = self._compute_padding(filter_width=self.filter_width)
159
+ self.pad_input = nn.ConstantPad1d(self.padding, 0.0)
160
+ self.pad_res = nn.ConstantPad1d(self.padding, 0.0)
161
+
162
+ # Define pooling and classifier layers
163
+ self.pool = nn.MaxPool2d((1, self.pool_size), stride=(1, self.pool_size))
164
+
165
+ self.ensuredim = Rearrange("batch ch time -> batch ch 1 time")
166
+
167
+ self.final_layer = nn.Linear(self.classifier_input_size, self.n_outputs)
168
+
169
+ def forward(self, x):
170
+ """Forward pass of the SyncNet model.
171
+
172
+ Parameters
173
+ ----------
174
+ x : torch.Tensor
175
+ Input tensor of shape (batch_size, n_chans, n_times)
176
+
177
+ Returns
178
+ -------
179
+ out : torch.Tensor
180
+ Output tensor of shape (batch_size, n_outputs).
181
+
182
+ """
183
+ # Ensure input tensor has shape (batch_size, n_chans, 1, n_times)
184
+ x = self.ensuredim(x)
185
+ # Output: (batch_size, n_chans, 1, n_times)
186
+
187
+ # Compute the oscillatory component
188
+ W_osc = self.amplitude * torch.cos(self.t * self.omega + self.phi_ini)
189
+ # W_osc is (1, filter_width, n_chans, 1)
190
+
191
+ # Compute the decay component
192
+ t_squared = torch.pow(self.t, 2) # Shape: (filter_width,)
193
+ t_squared_beta = t_squared * self.beta # Shape: (filter_width, num_filters)
194
+ W_decay = torch.exp(-t_squared_beta)
195
+ # W_osc is (1, filter_width, 1, 1)
196
+
197
+ # Combine oscillatory and decay components
198
+ # W shape: (1, n_chans, num_filters, filter_width)
199
+ W = W_osc * W_decay
200
+ # W shape will be: (1, filter_width, n_chans, 1)
201
+
202
+ W = W.view(self.num_filters, self.n_chans, 1, self.filter_width)
203
+
204
+ # Apply convolution
205
+ x_padded = self.pad_input(x.float())
206
+
207
+ res = F.conv2d(x_padded, W.float(), bias=self.bias, stride=1)
208
+
209
+ # Apply padding to the convolution result
210
+ res_padded = self.pad_res(res)
211
+ res_pooled = self.pool(res_padded)
212
+
213
+ # Flatten the result
214
+ res_flat = res_pooled.view(-1, self.classifier_input_size)
215
+
216
+ # Ensure beta remains non-negative
217
+ self.beta.data.clamp_(min=0)
218
+
219
+ # Apply activation
220
+ out = self.activation(res_flat)
221
+ # Apply classifier
222
+ out = self.final_layer(out)
223
+
224
+ return out
225
+
226
+ @staticmethod
227
+ def _compute_padding(filter_width):
228
+ # Compute padding
229
+ P = filter_width - 2
230
+ if P % 2 == 0:
231
+ padding = (P // 2, P // 2 + 1)
232
+ else:
233
+ padding = (P // 2, P // 2)
234
+ return padding
@@ -0,0 +1,275 @@
1
+ # Authors: Patryk Chrabaszcz
2
+ # Lukas Gemein <l.gemein@gmail.com>
3
+ #
4
+ # License: BSD-3
5
+ import torch
6
+ from torch import nn
7
+ from torch.nn import init
8
+ from torch.nn.utils.parametrizations import weight_norm
9
+
10
+ from braindecode.models.base import EEGModuleMixin
11
+ from braindecode.modules import Chomp1d, Ensure4d, SqueezeFinalOutput
12
+
13
+
14
+ class BDTCN(EEGModuleMixin, nn.Module):
15
+ r"""Braindecode TCN from Gemein, L et al (2020) [gemein2020]_.
16
+
17
+ :bdg-success:`Convolution` :bdg-secondary:`Recurrent`
18
+
19
+ .. figure:: https://ars.els-cdn.com/content/image/1-s2.0-S1053811920305073-gr3_lrg.jpg
20
+ :align: center
21
+ :alt: Braindecode TCN Architecture
22
+
23
+ See [gemein2020]_ for details.
24
+
25
+ Parameters
26
+ ----------
27
+ n_filters: int
28
+ number of output filters of each convolution
29
+ n_blocks: int
30
+ number of temporal blocks in the network
31
+ kernel_size: int
32
+ kernel size of the convolutions
33
+ drop_prob: float
34
+ dropout probability
35
+ activation: nn.Module, default=nn.ReLU
36
+ Activation function class to apply. Should be a PyTorch activation
37
+ module class like ``nn.ReLU`` or ``nn.ELU``. Default is ``nn.ReLU``.
38
+
39
+ References
40
+ ----------
41
+ .. [gemein2020] Gemein, L. A., Schirrmeister, R. T., Chrabąszcz, P., Wilson, D.,
42
+ Boedecker, J., Schulze-Bonhage, A., ... & Ball, T. (2020). Machine-learning-based
43
+ diagnostics of EEG pathology. NeuroImage, 220, 117021.
44
+ """
45
+
46
+ def __init__(
47
+ self,
48
+ # Braindecode parameters
49
+ n_chans=None,
50
+ n_outputs=None,
51
+ chs_info=None,
52
+ n_times=None,
53
+ sfreq=None,
54
+ input_window_seconds=None,
55
+ # Model's parameters
56
+ n_blocks=3,
57
+ n_filters=30,
58
+ kernel_size=5,
59
+ drop_prob=0.5,
60
+ activation: type[nn.Module] = nn.ReLU,
61
+ ):
62
+ super().__init__(
63
+ n_outputs=n_outputs,
64
+ n_chans=n_chans,
65
+ chs_info=chs_info,
66
+ n_times=n_times,
67
+ input_window_seconds=input_window_seconds,
68
+ sfreq=sfreq,
69
+ )
70
+ del n_outputs, n_chans, chs_info, n_times, sfreq, input_window_seconds
71
+
72
+ self.base_tcn = TCN(
73
+ n_chans=self.n_chans,
74
+ n_outputs=self.n_outputs,
75
+ n_blocks=n_blocks,
76
+ n_filters=n_filters,
77
+ kernel_size=kernel_size,
78
+ drop_prob=drop_prob,
79
+ activation=activation,
80
+ )
81
+
82
+ self.final_layer = torch.nn.Sequential(
83
+ torch.nn.AdaptiveAvgPool1d(1), torch.nn.Flatten()
84
+ )
85
+
86
+ def forward(self, x):
87
+ x = self.base_tcn(x)
88
+ x = self.final_layer(x)
89
+ return x
90
+
91
+
92
+ class TCN(nn.Module):
93
+ r"""Temporal Convolutional Network (TCN) from Bai et al. 2018 [Bai2018]_.
94
+
95
+ See [Bai2018]_ for details.
96
+
97
+ Code adapted from https://github.com/locuslab/TCN/blob/master/TCN/tcn.py
98
+
99
+ Parameters
100
+ ----------
101
+ n_filters: int
102
+ number of output filters of each convolution
103
+ n_blocks: int
104
+ number of temporal blocks in the network
105
+ kernel_size: int
106
+ kernel size of the convolutions
107
+ drop_prob: float
108
+ dropout probability
109
+ activation: nn.Module, default=nn.ReLU
110
+ Activation function class to apply. Should be a PyTorch activation
111
+ module class like ``nn.ReLU`` or ``nn.ELU``. Default is ``nn.ReLU``.
112
+
113
+ References
114
+ ----------
115
+ .. [Bai2018] Bai, S., Kolter, J. Z., & Koltun, V. (2018).
116
+ An empirical evaluation of generic convolutional and recurrent networks
117
+ for sequence modeling.
118
+ arXiv preprint arXiv:1803.01271.
119
+ """
120
+
121
+ def __init__(
122
+ self,
123
+ n_chans=None,
124
+ n_outputs=None,
125
+ n_blocks=3,
126
+ n_filters=30,
127
+ kernel_size=5,
128
+ drop_prob=0.5,
129
+ activation: type[nn.Module] = nn.ReLU,
130
+ ):
131
+ super().__init__()
132
+ self.mapping = {
133
+ "fc.weight": "final_layer.fc.weight",
134
+ "fc.bias": "final_layer.fc.bias",
135
+ }
136
+ self.ensuredims = Ensure4d()
137
+ t_blocks = nn.Sequential()
138
+ for i in range(n_blocks):
139
+ n_inputs = n_chans if i == 0 else n_filters
140
+ dilation_size = 2**i
141
+ t_blocks.add_module(
142
+ "temporal_block_{:d}".format(i),
143
+ _TemporalBlock(
144
+ n_inputs=n_inputs,
145
+ n_outputs=n_filters,
146
+ kernel_size=kernel_size,
147
+ stride=1,
148
+ dilation=dilation_size,
149
+ padding=(kernel_size - 1) * dilation_size,
150
+ drop_prob=drop_prob,
151
+ activation=activation,
152
+ ),
153
+ )
154
+ self.temporal_blocks = t_blocks
155
+
156
+ self.final_layer = _FinalLayer(
157
+ in_features=n_filters,
158
+ out_features=n_outputs,
159
+ )
160
+ self.min_len = 1
161
+ for i in range(n_blocks):
162
+ dilation = 2**i
163
+ self.min_len += 2 * (kernel_size - 1) * dilation
164
+
165
+ # start in eval mode
166
+ self.train()
167
+
168
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
169
+ """Forward pass.
170
+
171
+ Parameters
172
+ ----------
173
+ x: torch.Tensor
174
+ Batch of EEG windows of shape (batch_size, n_channels, n_times).
175
+ """
176
+ x = self.ensuredims(x)
177
+ # x is in format: B x C x T x 1
178
+ (batch_size, _, time_size, _) = x.size()
179
+ assert time_size >= self.min_len
180
+ # remove empty trailing dimension
181
+ x = x.squeeze(3)
182
+ x = self.temporal_blocks(x)
183
+ # Convert to: B x T x C
184
+ x = x.transpose(1, 2).contiguous()
185
+
186
+ out = self.final_layer(x, batch_size, time_size, self.min_len)
187
+
188
+ return out
189
+
190
+
191
+ class _FinalLayer(nn.Module):
192
+ def __init__(self, in_features, out_features):
193
+ super().__init__()
194
+
195
+ self.fc = nn.Linear(in_features=in_features, out_features=out_features)
196
+
197
+ self.out_fun = nn.Identity()
198
+
199
+ self.squeeze = SqueezeFinalOutput()
200
+
201
+ def forward(
202
+ self, x: torch.Tensor, batch_size: int, time_size: int, min_len: int
203
+ ) -> torch.Tensor:
204
+ fc_out = self.fc(x.view(batch_size * time_size, x.size(2)))
205
+ fc_out = self.out_fun(fc_out)
206
+ fc_out = fc_out.view(batch_size, time_size, fc_out.size(1))
207
+
208
+ out_size = 1 + max(0, time_size - min_len)
209
+ out = fc_out[:, -out_size:, :].transpose(1, 2)
210
+ # re-add 4th dimension for compatibility with braindecode
211
+ return self.squeeze(out[:, :, :, None])
212
+
213
+
214
+ class _TemporalBlock(nn.Module):
215
+ def __init__(
216
+ self,
217
+ n_inputs,
218
+ n_outputs,
219
+ kernel_size,
220
+ stride,
221
+ dilation,
222
+ padding,
223
+ drop_prob,
224
+ activation: type[nn.Module] = nn.ReLU,
225
+ ):
226
+ super().__init__()
227
+ self.conv1 = weight_norm(
228
+ nn.Conv1d(
229
+ n_inputs,
230
+ n_outputs,
231
+ kernel_size,
232
+ stride=stride,
233
+ padding=padding,
234
+ dilation=dilation,
235
+ )
236
+ )
237
+ self.chomp1 = Chomp1d(padding)
238
+ self.relu1 = activation()
239
+ self.dropout1 = nn.Dropout2d(drop_prob)
240
+
241
+ self.conv2 = weight_norm(
242
+ nn.Conv1d(
243
+ n_outputs,
244
+ n_outputs,
245
+ kernel_size,
246
+ stride=stride,
247
+ padding=padding,
248
+ dilation=dilation,
249
+ )
250
+ )
251
+ self.chomp2 = Chomp1d(padding)
252
+ self.relu2 = activation()
253
+ self.dropout2 = nn.Dropout2d(drop_prob)
254
+
255
+ self.downsample = (
256
+ nn.Conv1d(n_inputs, n_outputs, 1) if n_inputs != n_outputs else None
257
+ )
258
+ self.relu = activation()
259
+
260
+ init.normal_(self.conv1.weight, 0, 0.01)
261
+ init.normal_(self.conv2.weight, 0, 0.01)
262
+ if self.downsample is not None:
263
+ init.normal_(self.downsample.weight, 0, 0.01)
264
+
265
+ def forward(self, x):
266
+ out = self.conv1(x)
267
+ out = self.chomp1(out)
268
+ out = self.relu1(out)
269
+ out = self.dropout1(out)
270
+ out = self.conv2(out)
271
+ out = self.chomp2(out)
272
+ out = self.relu2(out)
273
+ out = self.dropout2(out)
274
+ res = x if self.downsample is None else self.downsample(x)
275
+ return self.relu(out + res)