braindecode 0.8__py3-none-any.whl → 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of braindecode might be problematic. Click here for more details.

Files changed (102) hide show
  1. braindecode/__init__.py +1 -2
  2. braindecode/augmentation/__init__.py +50 -0
  3. braindecode/augmentation/base.py +222 -0
  4. braindecode/augmentation/functional.py +1096 -0
  5. braindecode/augmentation/transforms.py +1274 -0
  6. braindecode/classifier.py +26 -24
  7. braindecode/datasets/__init__.py +34 -0
  8. braindecode/datasets/base.py +840 -0
  9. braindecode/datasets/bbci.py +694 -0
  10. braindecode/datasets/bcicomp.py +194 -0
  11. braindecode/datasets/bids.py +245 -0
  12. braindecode/datasets/mne.py +172 -0
  13. braindecode/datasets/moabb.py +209 -0
  14. braindecode/datasets/nmt.py +311 -0
  15. braindecode/datasets/sleep_physio_challe_18.py +412 -0
  16. braindecode/datasets/sleep_physionet.py +125 -0
  17. braindecode/datasets/tuh.py +588 -0
  18. braindecode/datasets/xy.py +95 -0
  19. braindecode/datautil/__init__.py +49 -0
  20. braindecode/datautil/serialization.py +342 -0
  21. braindecode/datautil/util.py +41 -0
  22. braindecode/eegneuralnet.py +63 -47
  23. braindecode/functional/__init__.py +10 -0
  24. braindecode/functional/functions.py +251 -0
  25. braindecode/functional/initialization.py +47 -0
  26. braindecode/models/__init__.py +52 -0
  27. braindecode/models/atcnet.py +652 -0
  28. braindecode/models/attentionbasenet.py +550 -0
  29. braindecode/models/base.py +296 -0
  30. braindecode/models/biot.py +483 -0
  31. braindecode/models/contrawr.py +296 -0
  32. braindecode/models/ctnet.py +450 -0
  33. braindecode/models/deep4.py +322 -0
  34. braindecode/models/deepsleepnet.py +295 -0
  35. braindecode/models/eegconformer.py +372 -0
  36. braindecode/models/eeginception_erp.py +304 -0
  37. braindecode/models/eeginception_mi.py +371 -0
  38. braindecode/models/eegitnet.py +301 -0
  39. braindecode/models/eegminer.py +255 -0
  40. braindecode/models/eegnet.py +473 -0
  41. braindecode/models/eegnex.py +247 -0
  42. braindecode/models/eegresnet.py +362 -0
  43. braindecode/models/eegsimpleconv.py +199 -0
  44. braindecode/models/eegtcnet.py +335 -0
  45. braindecode/models/fbcnet.py +221 -0
  46. braindecode/models/fblightconvnet.py +313 -0
  47. braindecode/models/fbmsnet.py +325 -0
  48. braindecode/models/hybrid.py +126 -0
  49. braindecode/models/ifnet.py +441 -0
  50. braindecode/models/labram.py +1166 -0
  51. braindecode/models/msvtnet.py +375 -0
  52. braindecode/models/sccnet.py +182 -0
  53. braindecode/models/shallow_fbcsp.py +208 -0
  54. braindecode/models/signal_jepa.py +1012 -0
  55. braindecode/models/sinc_shallow.py +337 -0
  56. braindecode/models/sleep_stager_blanco_2020.py +167 -0
  57. braindecode/models/sleep_stager_chambon_2018.py +157 -0
  58. braindecode/models/sleep_stager_eldele_2021.py +536 -0
  59. braindecode/models/sparcnet.py +378 -0
  60. braindecode/models/summary.csv +41 -0
  61. braindecode/models/syncnet.py +232 -0
  62. braindecode/models/tcn.py +273 -0
  63. braindecode/models/tidnet.py +395 -0
  64. braindecode/models/tsinception.py +258 -0
  65. braindecode/models/usleep.py +340 -0
  66. braindecode/models/util.py +133 -0
  67. braindecode/modules/__init__.py +38 -0
  68. braindecode/modules/activation.py +60 -0
  69. braindecode/modules/attention.py +757 -0
  70. braindecode/modules/blocks.py +108 -0
  71. braindecode/modules/convolution.py +274 -0
  72. braindecode/modules/filter.py +632 -0
  73. braindecode/modules/layers.py +133 -0
  74. braindecode/modules/linear.py +50 -0
  75. braindecode/modules/parametrization.py +38 -0
  76. braindecode/modules/stats.py +77 -0
  77. braindecode/modules/util.py +77 -0
  78. braindecode/modules/wrapper.py +75 -0
  79. braindecode/preprocessing/__init__.py +37 -0
  80. braindecode/preprocessing/mne_preprocess.py +77 -0
  81. braindecode/preprocessing/preprocess.py +478 -0
  82. braindecode/preprocessing/windowers.py +1031 -0
  83. braindecode/regressor.py +23 -12
  84. braindecode/samplers/__init__.py +18 -0
  85. braindecode/samplers/base.py +401 -0
  86. braindecode/samplers/ssl.py +263 -0
  87. braindecode/training/__init__.py +23 -0
  88. braindecode/training/callbacks.py +23 -0
  89. braindecode/training/losses.py +105 -0
  90. braindecode/training/scoring.py +483 -0
  91. braindecode/util.py +55 -59
  92. braindecode/version.py +1 -1
  93. braindecode/visualization/__init__.py +8 -0
  94. braindecode/visualization/confusion_matrices.py +289 -0
  95. braindecode/visualization/gradients.py +57 -0
  96. {braindecode-0.8.dist-info → braindecode-1.0.0.dist-info}/METADATA +39 -55
  97. braindecode-1.0.0.dist-info/RECORD +101 -0
  98. {braindecode-0.8.dist-info → braindecode-1.0.0.dist-info}/WHEEL +1 -1
  99. {braindecode-0.8.dist-info → braindecode-1.0.0.dist-info/licenses}/LICENSE.txt +1 -1
  100. braindecode-1.0.0.dist-info/licenses/NOTICE.txt +20 -0
  101. braindecode-0.8.dist-info/RECORD +0 -11
  102. {braindecode-0.8.dist-info → braindecode-1.0.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,378 @@
1
+ from __future__ import annotations
2
+
3
+ from collections import OrderedDict
4
+ from math import floor, log2
5
+ from typing import Any
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+
11
+ from braindecode.models.base import EEGModuleMixin
12
+
13
+
14
+ class SPARCNet(EEGModuleMixin, nn.Module):
15
+ """Seizures, Periodic and Rhythmic pattern Continuum Neural Network (SPaRCNet) from Jing et al. (2023) [jing2023]_.
16
+
17
+ This is a temporal CNN model for biosignal classification based on the DenseNet
18
+ architecture.
19
+
20
+ The model is based on the unofficial implementation [Code2023]_.
21
+
22
+ .. versionadded:: 0.9
23
+
24
+ Notes
25
+ -----
26
+ This implementation is not guaranteed to be correct, has not been checked
27
+ by original authors.
28
+
29
+ Parameters
30
+ ----------
31
+ block_layers : int, optional
32
+ Number of layers per dense block. Default is 4.
33
+ growth_rate : int, optional
34
+ Growth rate of the DenseNet. Default is 16.
35
+ bn_size : int, optional
36
+ Bottleneck size. Default is 16.
37
+ drop_prob : float, optional
38
+ Dropout rate. Default is 0.5.
39
+ conv_bias : bool, optional
40
+ Whether to use bias in convolutional layers. Default is True.
41
+ batch_norm : bool, optional
42
+ Whether to use batch normalization. Default is True.
43
+ activation: nn.Module, default=nn.ELU
44
+ Activation function class to apply. Should be a PyTorch activation
45
+ module class like ``nn.ReLU`` or ``nn.ELU``. Default is ``nn.ELU``.
46
+
47
+ References
48
+ ----------
49
+ .. [jing2023] Jing, J., Ge, W., Hong, S., Fernandes, M. B., Lin, Z.,
50
+ Yang, C., ... & Westover, M. B. (2023). Development of expert-level
51
+ classification of seizures and rhythmic and periodic
52
+ patterns during eeg interpretation. Neurology, 100(17), e1750-e1762.
53
+ .. [Code2023] Yang, C., Westover, M.B. and Sun, J., 2023. BIOT
54
+ Biosignal Transformer for Cross-data Learning in the Wild.
55
+ GitHub https://github.com/ycq091044/BIOT (accessed 2024-02-13)
56
+
57
+ """
58
+
59
+ def __init__(
60
+ self,
61
+ n_chans=None,
62
+ n_times=None,
63
+ n_outputs=None,
64
+ # Neural network parameters
65
+ block_layers: int = 4,
66
+ growth_rate: int = 16,
67
+ bottleneck_size: int = 16,
68
+ drop_prob: float = 0.5,
69
+ conv_bias: bool = True,
70
+ batch_norm: bool = True,
71
+ activation: nn.Module = nn.ELU,
72
+ # EEGModuleMixin parameters
73
+ # (another way to present the same parameters)
74
+ chs_info=None,
75
+ input_window_seconds=None,
76
+ sfreq=None,
77
+ ):
78
+ super().__init__(
79
+ n_outputs=n_outputs,
80
+ n_chans=n_chans,
81
+ chs_info=chs_info,
82
+ n_times=n_times,
83
+ input_window_seconds=input_window_seconds,
84
+ sfreq=sfreq,
85
+ )
86
+ del n_outputs, n_chans, chs_info, n_times, sfreq, input_window_seconds
87
+
88
+ # add initial convolutional layer
89
+ # the number of output channels is the smallest power of 2
90
+ # that is greater than the number of input channels
91
+ out_channels = 2 ** (floor(log2(self.n_chans)) + 1)
92
+ first_conv = OrderedDict(
93
+ [
94
+ (
95
+ "conv0",
96
+ nn.Conv1d(
97
+ in_channels=self.n_chans,
98
+ out_channels=out_channels,
99
+ kernel_size=7,
100
+ stride=2,
101
+ padding=3,
102
+ bias=conv_bias,
103
+ ),
104
+ )
105
+ ]
106
+ )
107
+ first_conv["norm0"] = nn.BatchNorm1d(out_channels)
108
+ first_conv["act_layer"] = activation()
109
+ first_conv["pool0"] = nn.MaxPool1d(kernel_size=3, stride=2, padding=1)
110
+
111
+ self.encoder = nn.Sequential(first_conv)
112
+
113
+ n_channels = out_channels
114
+
115
+ # Adding dense blocks
116
+ for n_layer in range(floor(log2(self.n_times // 4))):
117
+ block = _DenseBlock(
118
+ num_layers=block_layers,
119
+ in_channels=n_channels,
120
+ growth_rate=growth_rate,
121
+ bottleneck_size=bottleneck_size,
122
+ drop_prob=drop_prob,
123
+ conv_bias=conv_bias,
124
+ batch_norm=batch_norm,
125
+ activation=activation,
126
+ )
127
+ self.encoder.add_module("denseblock%d" % (n_layer + 1), block)
128
+ # update the number of channels after each dense block
129
+ n_channels = n_channels + block_layers * growth_rate
130
+
131
+ trans = _TransitionLayer(
132
+ in_channels=n_channels,
133
+ out_channels=n_channels // 2,
134
+ conv_bias=conv_bias,
135
+ batch_norm=batch_norm,
136
+ activation=activation,
137
+ )
138
+ self.encoder.add_module("transition%d" % (n_layer + 1), trans)
139
+ # update the number of channels after each transition layer
140
+ n_channels = n_channels // 2
141
+
142
+ # add final convolutional layer
143
+ self.final_layer = nn.Sequential(
144
+ activation(),
145
+ nn.Linear(n_channels, self.n_outputs),
146
+ )
147
+
148
+ self._init_weights()
149
+
150
+ def _init_weights(self):
151
+ """
152
+ Initialize the weights of the model.
153
+
154
+ Official init from torch repo, using kaiming_normal for conv layers
155
+ and normal for linear layers.
156
+
157
+ """
158
+ for m in self.modules():
159
+ if isinstance(m, nn.Conv1d):
160
+ nn.init.kaiming_normal_(m.weight.data)
161
+ elif isinstance(m, nn.BatchNorm1d):
162
+ m.weight.data.fill_(1)
163
+ m.bias.data.zero_()
164
+ elif isinstance(m, nn.Linear):
165
+ m.bias.data.zero_()
166
+
167
+ def forward(self, X: torch.Tensor):
168
+ """
169
+ Forward pass of the model.
170
+
171
+ Parameters
172
+ ----------
173
+ X: torch.Tensor
174
+ The input tensor of the model with shape (batch_size, n_channels, n_times)
175
+
176
+ Returns
177
+ -------
178
+ torch.Tensor
179
+ The output tensor of the model with shape (batch_size, n_outputs)
180
+ """
181
+ emb = self.encoder(X).squeeze(-1)
182
+ out = self.final_layer(emb)
183
+ return out
184
+
185
+
186
+ class _DenseLayer(nn.Sequential):
187
+ """
188
+ A densely connected layer with batch normalization and dropout.
189
+
190
+ Parameters
191
+ ----------
192
+ in_channels : int
193
+ Number of input channels.
194
+ growth_rate : int
195
+ Rate of growth of channels in this layer.
196
+ bottleneck_size : int
197
+ Multiplicative factor for the bottleneck layer (does not affect the output size).
198
+ drop_prob : float, optional
199
+ Dropout rate. Default is 0.5.
200
+ conv_bias : bool, optional
201
+ Whether to use bias in convolutional layers. Default is True.
202
+ batch_norm : bool, optional
203
+ Whether to use batch normalization. Default is True.
204
+ activation: nn.Module, default=nn.ELU
205
+ Activation function class to apply. Should be a PyTorch activation
206
+ module class like ``nn.ReLU`` or ``nn.ELU``. Default is ``nn.ELU``.
207
+
208
+ Examples
209
+ --------
210
+ >>> x = torch.randn(128, 5, 1000)
211
+ >>> batch, channels, length = x.shape
212
+ >>> model = _DenseLayer(channels, 5, 2)
213
+ >>> y = model(x)
214
+ >>> y.shape
215
+ torch.Size([128, 10, 1000])
216
+ """
217
+
218
+ def __init__(
219
+ self,
220
+ in_channels: int,
221
+ growth_rate: int,
222
+ bottleneck_size: int,
223
+ drop_prob: float = 0.5,
224
+ conv_bias: bool = True,
225
+ batch_norm: bool = True,
226
+ activation: nn.Module = nn.ELU,
227
+ ):
228
+ super().__init__()
229
+ if batch_norm:
230
+ self.add_module("norm1", nn.BatchNorm1d(in_channels))
231
+
232
+ self.add_module("elu1", activation())
233
+ self.add_module(
234
+ "conv1",
235
+ nn.Conv1d(
236
+ in_channels=in_channels,
237
+ out_channels=bottleneck_size * growth_rate,
238
+ kernel_size=1,
239
+ stride=1,
240
+ bias=conv_bias,
241
+ ),
242
+ )
243
+ if batch_norm:
244
+ self.add_module("norm2", nn.BatchNorm1d(bottleneck_size * growth_rate))
245
+ self.add_module("elu2", activation())
246
+ self.add_module(
247
+ "conv2",
248
+ nn.Conv1d(
249
+ in_channels=bottleneck_size * growth_rate,
250
+ out_channels=growth_rate,
251
+ kernel_size=3,
252
+ stride=1,
253
+ padding=1,
254
+ bias=conv_bias,
255
+ ),
256
+ )
257
+ self.drop_prob = drop_prob
258
+
259
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
260
+ # Manually pass through each submodule
261
+ out = x
262
+ for layer in self:
263
+ out = layer(out)
264
+ # apply dropout using the functional API
265
+ out = F.dropout(out, p=self.drop_prob, training=self.training)
266
+ # concatenate input and new features
267
+ return torch.cat([x, out], dim=1)
268
+
269
+
270
+ class _DenseBlock(nn.Sequential):
271
+ """
272
+ A densely connected block that uses DenseLayers.
273
+
274
+ Parameters
275
+ ----------
276
+ num_layers : int
277
+ Number of layers in this block.
278
+ in_channels : int
279
+ Number of input channels.
280
+ growth_rate : int
281
+ Rate of growth of channels in this layer.
282
+ bottleneck_size : int
283
+ Multiplicative factor for the bottleneck layer (does not affect the output size).
284
+ drop_prob : float, optional
285
+ Dropout rate. Default is 0.5.
286
+ conv_bias : bool, optional
287
+ Whether to use bias in convolutional layers. Default is True.
288
+ batch_norm : bool, optional
289
+ Whether to use batch normalization. Default is True.
290
+ activation: nn.Module, default=nn.ELU
291
+ Activation function class to apply. Should be a PyTorch activation
292
+ module class like ``nn.ReLU`` or ``nn.ELU``. Default is ``nn.ELU``.
293
+
294
+ Examples
295
+ --------
296
+ >>> x = torch.randn(128, 5, 1000)
297
+ >>> batch, channels, length = x.shape
298
+ >>> model = _DenseBlock(3, channels, 5, 2)
299
+ >>> y = model(x)
300
+ >>> y.shape
301
+ torch.Size([128, 20, 1000])
302
+ """
303
+
304
+ def __init__(
305
+ self,
306
+ num_layers,
307
+ in_channels,
308
+ growth_rate,
309
+ bottleneck_size,
310
+ drop_prob=0.5,
311
+ conv_bias=True,
312
+ batch_norm=True,
313
+ activation: nn.Module = nn.ELU,
314
+ ):
315
+ super(_DenseBlock, self).__init__()
316
+ for idx_layer in range(num_layers):
317
+ layer = _DenseLayer(
318
+ in_channels=in_channels + idx_layer * growth_rate,
319
+ growth_rate=growth_rate,
320
+ bottleneck_size=bottleneck_size,
321
+ drop_prob=drop_prob,
322
+ conv_bias=conv_bias,
323
+ batch_norm=batch_norm,
324
+ activation=activation,
325
+ )
326
+ self.add_module(f"denselayer{idx_layer + 1}", layer)
327
+
328
+
329
+ class _TransitionLayer(nn.Sequential):
330
+ """
331
+ A pooling transition layer.
332
+
333
+ Parameters
334
+ ----------
335
+ in_channels : int
336
+ Number of input channels.
337
+ out_channels : int
338
+ Number of output channels.
339
+ conv_bias : bool, optional
340
+ Whether to use bias in convolutional layers. Default is True.
341
+ batch_norm : bool, optional
342
+ Whether to use batch normalization. Default is True.
343
+ activation: nn.Module, default=nn.ELU
344
+ Activation function class to apply. Should be a PyTorch activation
345
+ module class like ``nn.ReLU`` or ``nn.ELU``. Default is ``nn.ELU``.
346
+
347
+ Examples
348
+ --------
349
+ >>> x = torch.randn(128, 5, 1000)
350
+ >>> model = _TransitionLayer(5, 18)
351
+ >>> y = model(x)
352
+ >>> y.shape
353
+ torch.Size([128, 18, 500])
354
+ """
355
+
356
+ def __init__(
357
+ self,
358
+ in_channels,
359
+ out_channels,
360
+ conv_bias=True,
361
+ batch_norm=True,
362
+ activation: nn.Module = nn.ELU,
363
+ ):
364
+ super(_TransitionLayer, self).__init__()
365
+ if batch_norm:
366
+ self.add_module("norm", nn.BatchNorm1d(in_channels))
367
+ self.add_module("elu", activation())
368
+ self.add_module(
369
+ "conv",
370
+ nn.Conv1d(
371
+ in_channels=in_channels,
372
+ out_channels=out_channels,
373
+ kernel_size=1,
374
+ stride=1,
375
+ bias=conv_bias,
376
+ ),
377
+ )
378
+ self.add_module("pool", nn.AvgPool1d(kernel_size=2, stride=2))
@@ -0,0 +1,41 @@
1
+ Model,Paradigm,Type,Freq(Hz),Hyperparameters,#Parameters,get_#Parameters
2
+ ATCNet,General,Classification,250,"n_chans, n_outputs, n_times",113732,"ATCNet(n_chans=22, n_outputs=4, n_times=1000)"
3
+ AttentionBaseNet,Motor Imagery,Classification,250,"n_chans, n_outputs, n_times",3692,"AttentionBaseNet(n_chans=22, n_outputs=4, n_times=1000)"
4
+ BDTCN,Normal/Abnormal,Classification,100,"n_chans, n_outputs, n_times",456502,"BDTCN(n_chans=21, n_outputs=2, n_times=6000, n_blocks=5, n_filters=55, kernel_size=16)"
5
+ BIOT,"Sleep Staging, Epilepsy",Classification,200,"n_chans, n_outputs",3183879,"BIOT(n_chans=2, n_outputs=5, n_times=6000)"
6
+ ContraWR,Sleep Staging,"Classification, Embedding",125,"n_chans, n_outputs, sfreq",1160165,"ContraWR(n_chans=2, n_outputs=5, n_times=3750, emb_size=256, sfreq=125)"
7
+ CTNet,Motor Imagery,Classification,250,"n_chans, n_outputs, n_times",26900,"CTNet(n_chans=22, n_outputs=4, n_times=1000, n_filters_time=8, kernel_size=16, heads=2, emb_size=16)"
8
+ Deep4Net,Motor Imagery,Classification,250,"n_chans, n_outputs, n_times",282879,"Deep4Net(n_chans=22, n_outputs=4, n_times=1000)"
9
+ DeepSleepNet,Sleep Staging,Classification,256,"n_chans, n_outputs",24744837,"DeepSleepNet(n_chans=1, n_outputs=5, n_times=7680, sfreq=256)"
10
+ EEGConformer,General,Classification,250,"n_chans, n_outputs, n_times",789572,"EEGConformer(n_chans=22, n_outputs=4, n_times=1000)."
11
+ EEGInceptionERP,"ERP, SSVEP",Classification,128,"n_chans, n_outputs",14926,"EEGInceptionERP(n_chans=8, n_outputs=2, n_times=128, sfreq=128)"
12
+ EEGInceptionMI,Motor Imagery,Classification,250,"n_chans, n_outputs, n_times",558028,"EEGInceptionMI(n_chans=22, n_outputs=4, n_times=1000, n_convs=5, n_filters=12)"
13
+ EEGITNet,Motor Imagery,Classification,125,"n_chans, n_outputs, n_times",5212,"EEGITNet(n_chans=22, n_outputs=4, n_times=500)"
14
+ EEGNetv1,General,Classification,128,"n_chans, n_outputs, n_times",3052,"EEGNetv1(n_chans=22, n_outputs=4, n_times=512)"
15
+ EEGNetv4,General,Classification,128,"n_chans, n_outputs, n_times",2484,"EEGNetv4(n_chans=22, n_outputs=4, n_times=512)"
16
+ EEGNeX,Motor Imagery,Classification,125,"n_chans, n_outputs, n_times",55940,"EEGNeX(n_chans=22, n_outputs=4, n_times=500)"
17
+ EEGMiner,Emotion Recognition,Classification,128,"n_chans, n_outputs, n_times, sfreq",7572,"EEGMiner(n_chans=62, n_outputs=2, n_times=2560, sfreq=128)"
18
+ EEGResNet,General,Classification,250,"n_chans, n_outputs, n_times",247484,"EEGResNet(n_chans=22, n_outputs=4, n_times=1000)"
19
+ EEGSimpleConv,Motor Imagery,Classification,80,"n_chans, n_outputs, sfreq",730404,"EEGSimpleConv(n_chans=22, n_outputs=4, n_times=320, sfreq=80)"
20
+ EEGTCNet,Motor Imagery,Classification,250,"n_chans, n_outputs",4516,"EEGTCNet(n_chans=22, n_outputs=4, n_times=1000, kern_length=32)"
21
+ Labram,General,"Classification, Embedding",200,"n_chans, n_outputs, n_times",5866180,"Labram(n_chans=22, n_outputs=4, n_times=1000, sfreq=250)"
22
+ MSVTNet,Motor Imagery,Classification,250,"n_chans, n_outputs, n_times",75494," MSVTNet(n_chans=22, n_outputs=4, n_times=1000)"
23
+ SCCNet,Motor Imagery,Classification,125,"n_chans, n_outputs, n_times, sfreq",12070,"SCCNet(n_chans=22, n_outputs=4, n_times=1000, sfreq=125)"
24
+ SignalJEPA,"Motor Imagery, ERP, SSVEP",Embedding,128,"n_times, chs_info",3456882,"SignalJEPA(n_times=512, chs_info=Lee2019_MI().get_data(subjects=[1])[1]['0']['1train'].info[""chs""][:62])"
25
+ SignalJEPA_Contextual,"Motor Imagery, ERP, SSVEP",Classification,128,"n_outputs, n_times, chs_info",3459184,"SignalJEPA_Contextual(n_outputs=2, input_window_seconds=4.19, sfreq=128, chs_info=Lee2019_MI().get_data(subjects=[1])[1]['0']['1train'].info[""chs""][:62])"
26
+ SignalJEPA_PostLocal,"Motor Imagery, ERP, SSVEP",Classification,128,"n_chans, n_outputs, n_times",16142,"SignalJEPA_PostLocal(n_chans=62, n_outputs=2, input_window_seconds=4.19, sfreq=128)"
27
+ SignalJEPA_PreLocal,"Motor Imagery, ERP, SSVEP",Classification,128,"n_outputs, n_times, chs_info",16142,"SignalJEPA_PreLocal(n_chans=62, n_outputs=2, input_window_seconds=4.19, sfreq=128)"
28
+ SincShallowNet,Motor Imagery,Classification,250,"n_chans, n_outputs, n_times, sfreq",21892,"SincShallowNet(n_chans=22, n_outputs=4, n_times=1000, sfreq=250)"
29
+ ShallowFBCSPNet,Motor Imagery,Classification,250,"n_chans, n_outputs, n_times",46084,"ShallowFBCSPNet(n_chans=22, n_outputs=4, n_times=1000, sfreq=250)"
30
+ SleepStagerBlanco2020,Sleep Staging,Classification,100,"n_chans, n_outputs, n_times",2845,"SleepStagerBlanco2020(n_chans=2, n_outputs=5, n_times=3000, sfreq=100)"
31
+ SleepStagerChambon2018,Sleep Staging,Classification,128,"n_chans, n_outputs, n_times, sfreq",5835,"SleepStagerChambon2018(n_chans=2, n_outputs=5, n_times=3840, sfreq=128)"
32
+ SleepStagerEldele2021,Sleep Staging,Classification,100,"n_chans, n_outputs, n_times, sfreq",719925,"SleepStagerEldele2021(n_chans=2, n_outputs=5, n_times=3000, sfreq=100)"
33
+ SPARCNet,Epilepsy,Classification,200,"n_chans, n_outputs, n_times",1141921,"SPARCNet(n_chans=16, n_outputs=6, n_times=2000, sfreq=200)"
34
+ SyncNet,"Emotion Recognition, Alcoholism",Classification,256,"n_chans, n_outputs, n_times",554,"SyncNet(n_chans=62, n_outputs=3, n_times=5120, sfreq=256)"
35
+ TSceptionV1,Emotion Recognition,Classification,256,"n_chans, n_outputs, n_times, sfreq",2187206,"TSceptionV1(n_chans=62, n_outputs=3, n_times=5120, sfreq=256)"
36
+ TIDNet,General,Classification,250,"n_chans, n_outputs, n_times",240404,"TIDNet(n_chans=22, n_outputs=4, n_times=1000)"
37
+ USleep,Sleep Staging,Classification,128,"n_chans, n_outputs, n_times, sfreq",2482011,"USleep(n_chans=2, n_outputs=5, n_times=3000, sfreq=100)"
38
+ FBCNet,Motor Imagery,Classification,250,"n_chans, n_outputs, n_times, sfreq",11812,"FCNet(n_chans=22, n_outputs=4, n_times=1000, sfreq=250)"
39
+ FBMSNet,Motor Imagery,Classification,250,"n_chans, n_outputs, n_times, sfreq",16231,"FBMSNet(n_chans=22, n_outputs=4, n_times=1000, sfreq=250)"
40
+ FBLightConvNet,Motor Imagery,Classification,250,"n_chans, n_outputs, n_times, sfreq",6596,"FBLightConvNet(n_chans=22, n_outputs=4, n_times=1000, sfreq=250)"
41
+ IFNet,Motor Imagery,Classification,250,"n_chans, n_outputs, n_times, sfreq",9860,"IFNet(n_chans=22, n_outputs=4, n_times=1000, sfreq=250)"
@@ -0,0 +1,232 @@
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from einops.layers.torch import Rearrange
5
+ from numpy import arange, ceil
6
+
7
+ from braindecode.models.base import EEGModuleMixin
8
+
9
+
10
+ class SyncNet(EEGModuleMixin, nn.Module):
11
+ """Synchronization Network (SyncNet) from Li, Y et al (2017) [Li2017]_.
12
+
13
+ .. figure:: https://braindecode.org/dev/_static/model/SyncNet.png
14
+ :align: center
15
+ :alt: SyncNet Architecture
16
+
17
+ SyncNet uses parameterized 1-dimensional convolutional filters inspired by
18
+ the Morlet wavelet to extract features from EEG signals. The filters are
19
+ dynamically generated based on learnable parameters that control the
20
+ oscillation and decay characteristics.
21
+
22
+ The filter for channel ``c`` and filter ``k`` is defined as:
23
+
24
+ .. math::
25
+
26
+ f_c^{(k)}(\\tau) = amplitude_c^{(k)} \\cos(\\omega^{(k)} \\tau + \\phi_c^{(k)}) \\exp(-\\beta^{(k)} \\tau^2)
27
+
28
+ where:
29
+ - :math:`amplitude_c^{(k)}` is the amplitude parameter (channel-specific).
30
+ - :math:`\\omega^{(k)}` is the frequency parameter (shared across channels).
31
+ - :math:`\\phi_c^{(k)}` is the phase shift (channel-specific).
32
+ - :math:`\\beta^{(k)}` is the decay parameter (shared across channels).
33
+ - :math:`\\tau` is the time index.
34
+
35
+ Parameters
36
+ ----------
37
+ num_filters : int, optional
38
+ Number of filters in the convolutional layer. Default is 1.
39
+ filter_width : int, optional
40
+ Width of the convolutional filters. Default is 40.
41
+ pool_size : int, optional
42
+ Size of the pooling window. Default is 40.
43
+ activation : nn.Module, optional
44
+ Activation function to apply after pooling. Default is ``nn.ReLU``.
45
+ ampli_init_values : tuple of float, optional
46
+ The initialization range for amplitude parameter using uniform
47
+ distribution. Default is (-0.05, 0.05).
48
+ omega_init_values : tuple of float, optional
49
+ The initialization range for omega parameters using uniform
50
+ distribution. Default is (0, 1).
51
+ beta_init_values : tuple of float, optional
52
+ The initialization range for beta parameters using uniform
53
+ distribution. Default is (0, 1). Default is (0, 0.05).
54
+ phase_init_values : tuple of float, optional
55
+ The initialization range for phase parameters using `normal`
56
+ distribution. Default is (0, 1). Default is (0, 0.05).
57
+
58
+
59
+ Notes
60
+ -----
61
+ This implementation is not guaranteed to be correct! it has not been checked
62
+ by original authors. The modifications are based on derivated code from
63
+ [CodeICASSP2025]_.
64
+
65
+
66
+ References
67
+ ----------
68
+ .. [Li2017] Li, Y., Dzirasa, K., Carin, L., & Carlson, D. E. (2017).
69
+ Targeting EEG/LFP synchrony with neural nets. Advances in neural
70
+ information processing systems, 30.
71
+ .. [CodeICASSP2025] Code from Baselines for EEG-Music Emotion Recognition
72
+ Grand Challenge at ICASSP 2025.
73
+ https://github.com/SalvoCalcagno/eeg-music-challenge-icassp-2025-baselines
74
+
75
+ """
76
+
77
+ def __init__(
78
+ self,
79
+ # braindecode convention
80
+ n_chans=None,
81
+ n_times=None,
82
+ n_outputs=None,
83
+ chs_info=None,
84
+ input_window_seconds=None,
85
+ sfreq=None,
86
+ # model parameters
87
+ num_filters=1,
88
+ filter_width=40,
89
+ pool_size=40,
90
+ activation: nn.Module = nn.ReLU,
91
+ ampli_init_values: tuple[float, float] = (-0.05, 0.05),
92
+ omega_init_values: tuple[float, float] = (0.0, 1.0),
93
+ beta_init_values: tuple[float, float] = (0.0, 0.05),
94
+ phase_init_values: tuple[float, float] = (0.0, 0.05),
95
+ ):
96
+ super().__init__(
97
+ n_chans=n_chans,
98
+ n_times=n_times,
99
+ n_outputs=n_outputs,
100
+ chs_info=chs_info,
101
+ input_window_seconds=input_window_seconds,
102
+ sfreq=sfreq,
103
+ )
104
+ del n_outputs, n_chans, chs_info, n_times, input_window_seconds, sfreq
105
+
106
+ self.num_filters = num_filters
107
+ self.filter_width = filter_width
108
+ self.pool_size = pool_size
109
+ self.activation = activation()
110
+ self.ampli_init_values = ampli_init_values
111
+ self.omega_init_values = omega_init_values
112
+ self.beta_init_values = beta_init_values
113
+ self.phase_init_values = phase_init_values
114
+
115
+ # Initialize parameters
116
+ self.amplitude = nn.Parameter(
117
+ torch.FloatTensor(1, 1, self.n_chans, self.num_filters).uniform_(
118
+ self.ampli_init_values[0], self.ampli_init_values[1]
119
+ )
120
+ )
121
+ self.omega = nn.Parameter(
122
+ torch.FloatTensor(1, 1, 1, self.num_filters).uniform_(
123
+ self.omega_init_values[0], self.omega_init_values[1]
124
+ )
125
+ )
126
+
127
+ self.bias = nn.Parameter(torch.zeros(self.num_filters))
128
+
129
+ # Calculate the output size after pooling
130
+ self.classifier_input_size = int(
131
+ ceil(float(self.n_times) / float(self.pool_size)) * self.num_filters
132
+ )
133
+
134
+ # Create time vector t
135
+ if self.filter_width % 2 == 0:
136
+ t_range = arange(-int(self.filter_width / 2), int(self.filter_width / 2))
137
+ else:
138
+ t_range = arange(
139
+ -int((self.filter_width - 1) / 2), int((self.filter_width - 1) / 2) + 1
140
+ )
141
+
142
+ t_np = t_range.reshape(1, self.filter_width, 1, 1)
143
+ self.t = nn.Parameter(torch.FloatTensor(t_np))
144
+ # Phase Shift
145
+ self.phi_ini = nn.Parameter(
146
+ torch.FloatTensor(1, 1, self.n_chans, self.num_filters).normal_(
147
+ self.beta_init_values[0], self.beta_init_values[1]
148
+ )
149
+ )
150
+ self.beta = nn.Parameter(
151
+ torch.FloatTensor(1, 1, 1, self.num_filters).uniform_(
152
+ self.phase_init_values[0], self.phase_init_values[1]
153
+ )
154
+ )
155
+
156
+ self.padding = self._compute_padding(filter_width=self.filter_width)
157
+ self.pad_input = nn.ConstantPad1d(self.padding, 0.0)
158
+ self.pad_res = nn.ConstantPad1d(self.padding, 0.0)
159
+
160
+ # Define pooling and classifier layers
161
+ self.pool = nn.MaxPool2d((1, self.pool_size), stride=(1, self.pool_size))
162
+
163
+ self.ensuredim = Rearrange("batch ch time -> batch ch 1 time")
164
+
165
+ self.final_layer = nn.Linear(self.classifier_input_size, self.n_outputs)
166
+
167
+ def forward(self, x):
168
+ """Forward pass of the SyncNet model.
169
+
170
+ Parameters
171
+ ----------
172
+ x : torch.Tensor
173
+ Input tensor of shape (batch_size, n_chans, n_times)
174
+
175
+ Returns
176
+ -------
177
+ out : torch.Tensor
178
+ Output tensor of shape (batch_size, n_outputs).
179
+
180
+ """
181
+ # Ensure input tensor has shape (batch_size, n_chans, 1, n_times)
182
+ x = self.ensuredim(x)
183
+ # Output: (batch_size, n_chans, 1, n_times)
184
+
185
+ # Compute the oscillatory component
186
+ W_osc = self.amplitude * torch.cos(self.t * self.omega + self.phi_ini)
187
+ # W_osc is (1, filter_width, n_chans, 1)
188
+
189
+ # Compute the decay component
190
+ t_squared = torch.pow(self.t, 2) # Shape: (filter_width,)
191
+ t_squared_beta = t_squared * self.beta # Shape: (filter_width, num_filters)
192
+ W_decay = torch.exp(-t_squared_beta)
193
+ # W_osc is (1, filter_width, 1, 1)
194
+
195
+ # Combine oscillatory and decay components
196
+ # W shape: (1, n_chans, num_filters, filter_width)
197
+ W = W_osc * W_decay
198
+ # W shape will be: (1, filter_width, n_chans, 1)
199
+
200
+ W = W.view(self.num_filters, self.n_chans, 1, self.filter_width)
201
+
202
+ # Apply convolution
203
+ x_padded = self.pad_input(x.float())
204
+
205
+ res = F.conv2d(x_padded, W.float(), bias=self.bias, stride=1)
206
+
207
+ # Apply padding to the convolution result
208
+ res_padded = self.pad_res(res)
209
+ res_pooled = self.pool(res_padded)
210
+
211
+ # Flatten the result
212
+ res_flat = res_pooled.view(-1, self.classifier_input_size)
213
+
214
+ # Ensure beta remains non-negative
215
+ self.beta.data.clamp_(min=0)
216
+
217
+ # Apply activation
218
+ out = self.activation(res_flat)
219
+ # Apply classifier
220
+ out = self.final_layer(out)
221
+
222
+ return out
223
+
224
+ @staticmethod
225
+ def _compute_padding(filter_width):
226
+ # Compute padding
227
+ P = filter_width - 2
228
+ if P % 2 == 0:
229
+ padding = (P // 2, P // 2 + 1)
230
+ else:
231
+ padding = (P // 2, P // 2)
232
+ return padding