braindecode 0.8.1__py3-none-any.whl → 1.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of braindecode might be problematic. Click here for more details.

Files changed (108) hide show
  1. braindecode/__init__.py +1 -2
  2. braindecode/augmentation/__init__.py +39 -19
  3. braindecode/augmentation/base.py +25 -28
  4. braindecode/augmentation/functional.py +237 -100
  5. braindecode/augmentation/transforms.py +325 -158
  6. braindecode/classifier.py +26 -24
  7. braindecode/datasets/__init__.py +28 -10
  8. braindecode/datasets/base.py +220 -134
  9. braindecode/datasets/bbci.py +43 -52
  10. braindecode/datasets/bcicomp.py +47 -32
  11. braindecode/datasets/bids.py +245 -0
  12. braindecode/datasets/mne.py +45 -24
  13. braindecode/datasets/moabb.py +87 -27
  14. braindecode/datasets/nmt.py +311 -0
  15. braindecode/datasets/sleep_physio_challe_18.py +412 -0
  16. braindecode/datasets/sleep_physionet.py +43 -26
  17. braindecode/datasets/tuh.py +324 -140
  18. braindecode/datasets/xy.py +27 -12
  19. braindecode/datautil/__init__.py +37 -18
  20. braindecode/datautil/serialization.py +110 -72
  21. braindecode/eegneuralnet.py +63 -47
  22. braindecode/functional/__init__.py +22 -0
  23. braindecode/functional/functions.py +250 -0
  24. braindecode/functional/initialization.py +47 -0
  25. braindecode/models/__init__.py +84 -14
  26. braindecode/models/atcnet.py +193 -164
  27. braindecode/models/attentionbasenet.py +599 -0
  28. braindecode/models/base.py +86 -102
  29. braindecode/models/biot.py +504 -0
  30. braindecode/models/contrawr.py +317 -0
  31. braindecode/models/ctnet.py +536 -0
  32. braindecode/models/deep4.py +116 -77
  33. braindecode/models/deepsleepnet.py +149 -119
  34. braindecode/models/eegconformer.py +112 -173
  35. braindecode/models/eeginception_erp.py +109 -118
  36. braindecode/models/eeginception_mi.py +161 -97
  37. braindecode/models/eegitnet.py +215 -152
  38. braindecode/models/eegminer.py +254 -0
  39. braindecode/models/eegnet.py +228 -161
  40. braindecode/models/eegnex.py +247 -0
  41. braindecode/models/eegresnet.py +234 -152
  42. braindecode/models/eegsimpleconv.py +199 -0
  43. braindecode/models/eegtcnet.py +335 -0
  44. braindecode/models/fbcnet.py +221 -0
  45. braindecode/models/fblightconvnet.py +313 -0
  46. braindecode/models/fbmsnet.py +324 -0
  47. braindecode/models/hybrid.py +52 -71
  48. braindecode/models/ifnet.py +441 -0
  49. braindecode/models/labram.py +1186 -0
  50. braindecode/models/msvtnet.py +375 -0
  51. braindecode/models/sccnet.py +207 -0
  52. braindecode/models/shallow_fbcsp.py +50 -56
  53. braindecode/models/signal_jepa.py +1011 -0
  54. braindecode/models/sinc_shallow.py +337 -0
  55. braindecode/models/sleep_stager_blanco_2020.py +55 -46
  56. braindecode/models/sleep_stager_chambon_2018.py +54 -53
  57. braindecode/models/sleep_stager_eldele_2021.py +247 -141
  58. braindecode/models/sparcnet.py +424 -0
  59. braindecode/models/summary.csv +41 -0
  60. braindecode/models/syncnet.py +232 -0
  61. braindecode/models/tcn.py +158 -88
  62. braindecode/models/tidnet.py +280 -167
  63. braindecode/models/tsinception.py +283 -0
  64. braindecode/models/usleep.py +190 -177
  65. braindecode/models/util.py +109 -145
  66. braindecode/modules/__init__.py +84 -0
  67. braindecode/modules/activation.py +60 -0
  68. braindecode/modules/attention.py +757 -0
  69. braindecode/modules/blocks.py +108 -0
  70. braindecode/modules/convolution.py +274 -0
  71. braindecode/modules/filter.py +628 -0
  72. braindecode/modules/layers.py +131 -0
  73. braindecode/modules/linear.py +49 -0
  74. braindecode/modules/parametrization.py +38 -0
  75. braindecode/modules/stats.py +77 -0
  76. braindecode/modules/util.py +76 -0
  77. braindecode/modules/wrapper.py +73 -0
  78. braindecode/preprocessing/__init__.py +36 -11
  79. braindecode/preprocessing/mne_preprocess.py +13 -7
  80. braindecode/preprocessing/preprocess.py +139 -75
  81. braindecode/preprocessing/windowers.py +576 -187
  82. braindecode/regressor.py +23 -12
  83. braindecode/samplers/__init__.py +16 -8
  84. braindecode/samplers/base.py +146 -32
  85. braindecode/samplers/ssl.py +162 -17
  86. braindecode/training/__init__.py +18 -10
  87. braindecode/training/callbacks.py +2 -4
  88. braindecode/training/losses.py +3 -8
  89. braindecode/training/scoring.py +76 -68
  90. braindecode/util.py +55 -59
  91. braindecode/version.py +1 -1
  92. braindecode/visualization/__init__.py +2 -3
  93. braindecode/visualization/confusion_matrices.py +117 -73
  94. braindecode/visualization/gradients.py +14 -10
  95. {braindecode-0.8.1.dist-info → braindecode-1.1.0.dist-info}/METADATA +42 -58
  96. braindecode-1.1.0.dist-info/RECORD +101 -0
  97. {braindecode-0.8.1.dist-info → braindecode-1.1.0.dist-info}/WHEEL +1 -1
  98. {braindecode-0.8.1.dist-info → braindecode-1.1.0.dist-info/licenses}/LICENSE.txt +1 -1
  99. braindecode-1.1.0.dist-info/licenses/NOTICE.txt +20 -0
  100. braindecode/datautil/mne.py +0 -9
  101. braindecode/datautil/preprocess.py +0 -12
  102. braindecode/datautil/windowers.py +0 -6
  103. braindecode/datautil/xy.py +0 -9
  104. braindecode/models/eeginception.py +0 -317
  105. braindecode/models/functions.py +0 -47
  106. braindecode/models/modules.py +0 -358
  107. braindecode-0.8.1.dist-info/RECORD +0 -68
  108. {braindecode-0.8.1.dist-info → braindecode-1.1.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,424 @@
1
+ from __future__ import annotations
2
+
3
+ from collections import OrderedDict
4
+ from math import floor, log2
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+
10
+ from braindecode.models.base import EEGModuleMixin
11
+
12
+
13
+ class SPARCNet(EEGModuleMixin, nn.Module):
14
+ """Seizures, Periodic and Rhythmic pattern Continuum Neural Network (SPaRCNet) from Jing et al. (2023) [jing2023]_.
15
+
16
+ This is a temporal CNN model for biosignal classification based on the DenseNet
17
+ architecture.
18
+
19
+ The model is based on the unofficial implementation [Code2023]_.
20
+
21
+ .. versionadded:: 0.9
22
+
23
+ Notes
24
+ -----
25
+ This implementation is not guaranteed to be correct, has not been checked
26
+ by original authors.
27
+
28
+ Parameters
29
+ ----------
30
+ block_layers : int, optional
31
+ Number of layers per dense block. Default is 4.
32
+ growth_rate : int, optional
33
+ Growth rate of the DenseNet. Default is 16.
34
+ bn_size : int, optional
35
+ Bottleneck size. Default is 16.
36
+ drop_prob : float, optional
37
+ Dropout rate. Default is 0.5.
38
+ conv_bias : bool, optional
39
+ Whether to use bias in convolutional layers. Default is True.
40
+ batch_norm : bool, optional
41
+ Whether to use batch normalization. Default is True.
42
+ activation: nn.Module, default=nn.ELU
43
+ Activation function class to apply. Should be a PyTorch activation
44
+ module class like ``nn.ReLU`` or ``nn.ELU``. Default is ``nn.ELU``.
45
+
46
+ References
47
+ ----------
48
+ .. [jing2023] Jing, J., Ge, W., Hong, S., Fernandes, M. B., Lin, Z.,
49
+ Yang, C., ... & Westover, M. B. (2023). Development of expert-level
50
+ classification of seizures and rhythmic and periodic
51
+ patterns during eeg interpretation. Neurology, 100(17), e1750-e1762.
52
+ .. [Code2023] Yang, C., Westover, M.B. and Sun, J., 2023. BIOT
53
+ Biosignal Transformer for Cross-data Learning in the Wild.
54
+ GitHub https://github.com/ycq091044/BIOT (accessed 2024-02-13)
55
+
56
+ """
57
+
58
+ def __init__(
59
+ self,
60
+ n_chans=None,
61
+ n_times=None,
62
+ n_outputs=None,
63
+ # Neural network parameters
64
+ block_layers: int = 4,
65
+ growth_rate: int = 16,
66
+ bottleneck_size: int = 16,
67
+ drop_prob: float = 0.5,
68
+ conv_bias: bool = True,
69
+ batch_norm: bool = True,
70
+ activation: nn.Module = nn.ELU,
71
+ kernel_size_conv0: int = 7,
72
+ kernel_size_conv1: int = 1,
73
+ kernel_size_conv2: int = 3,
74
+ kernel_size_pool: int = 3,
75
+ stride_pool: int = 2,
76
+ stride_conv0: int = 2,
77
+ stride_conv1: int = 1,
78
+ stride_conv2: int = 1,
79
+ padding_pool: int = 1,
80
+ padding_conv0: int = 3,
81
+ padding_conv2: int = 1,
82
+ kernel_size_trans: int = 2,
83
+ stride_trans: int = 2,
84
+ # EEGModuleMixin parameters
85
+ # (another way to present the same parameters)
86
+ chs_info=None,
87
+ input_window_seconds=None,
88
+ sfreq=None,
89
+ ):
90
+ super().__init__(
91
+ n_outputs=n_outputs,
92
+ n_chans=n_chans,
93
+ chs_info=chs_info,
94
+ n_times=n_times,
95
+ input_window_seconds=input_window_seconds,
96
+ sfreq=sfreq,
97
+ )
98
+ del n_outputs, n_chans, chs_info, n_times, sfreq, input_window_seconds
99
+
100
+ # add initial convolutional layer
101
+ # the number of output channels is the smallest power of 2
102
+ # that is greater than the number of input channels
103
+ out_channels = 2 ** (floor(log2(self.n_chans)) + 1)
104
+ first_conv = OrderedDict(
105
+ [
106
+ (
107
+ "conv0",
108
+ nn.Conv1d(
109
+ in_channels=self.n_chans,
110
+ out_channels=out_channels,
111
+ kernel_size=kernel_size_conv0,
112
+ stride=stride_conv0,
113
+ padding=padding_conv0,
114
+ bias=conv_bias,
115
+ ),
116
+ )
117
+ ]
118
+ )
119
+ first_conv["norm0"] = nn.BatchNorm1d(out_channels)
120
+ first_conv["act_layer"] = activation()
121
+ first_conv["pool0"] = nn.MaxPool1d(
122
+ kernel_size=kernel_size_pool,
123
+ stride=stride_pool,
124
+ padding=padding_pool,
125
+ )
126
+
127
+ self.encoder = nn.Sequential(first_conv)
128
+
129
+ n_channels = out_channels
130
+
131
+ # Adding dense blocks
132
+ for n_layer in range(floor(log2(self.n_times // 4))):
133
+ block = _DenseBlock(
134
+ num_layers=block_layers,
135
+ in_channels=n_channels,
136
+ growth_rate=growth_rate,
137
+ bottleneck_size=bottleneck_size,
138
+ drop_prob=drop_prob,
139
+ conv_bias=conv_bias,
140
+ batch_norm=batch_norm,
141
+ activation=activation,
142
+ kernel_size_conv1=kernel_size_conv1,
143
+ kernel_size_conv2=kernel_size_conv2,
144
+ stride_conv1=stride_conv1,
145
+ stride_conv2=stride_conv2,
146
+ padding_conv2=padding_conv2,
147
+ )
148
+ self.encoder.add_module("denseblock%d" % (n_layer + 1), block)
149
+ # update the number of channels after each dense block
150
+ n_channels = n_channels + block_layers * growth_rate
151
+
152
+ trans = _TransitionLayer(
153
+ in_channels=n_channels,
154
+ out_channels=n_channels // 2,
155
+ conv_bias=conv_bias,
156
+ batch_norm=batch_norm,
157
+ activation=activation,
158
+ kernel_size_trans=kernel_size_trans,
159
+ stride_trans=stride_trans,
160
+ )
161
+ self.encoder.add_module("transition%d" % (n_layer + 1), trans)
162
+ # update the number of channels after each transition layer
163
+ n_channels = n_channels // 2
164
+
165
+ self.adaptative_pool = nn.AdaptiveAvgPool1d(1)
166
+ self.activation_layer = activation()
167
+ self.flatten_layer = nn.Flatten()
168
+
169
+ # add final convolutional layer
170
+ self.final_layer = nn.Linear(n_channels, self.n_outputs)
171
+
172
+ self._init_weights()
173
+
174
+ def _init_weights(self):
175
+ """
176
+ Initialize the weights of the model.
177
+
178
+ Official init from torch repo, using kaiming_normal for conv layers
179
+ and normal for linear layers.
180
+
181
+ """
182
+ for m in self.modules():
183
+ if isinstance(m, nn.Conv1d):
184
+ nn.init.kaiming_normal_(m.weight.data)
185
+ elif isinstance(m, nn.BatchNorm1d):
186
+ m.weight.data.fill_(1)
187
+ m.bias.data.zero_()
188
+ elif isinstance(m, nn.Linear):
189
+ m.bias.data.zero_()
190
+
191
+ def forward(self, X: torch.Tensor):
192
+ """
193
+ Forward pass of the model.
194
+
195
+ Parameters
196
+ ----------
197
+ X: torch.Tensor
198
+ The input tensor of the model with shape (batch_size, n_channels, n_times)
199
+
200
+ Returns
201
+ -------
202
+ torch.Tensor
203
+ The output tensor of the model with shape (batch_size, n_outputs)
204
+ """
205
+ emb = self.encoder(X)
206
+ emb = self.adaptative_pool(emb)
207
+ emb = self.activation_layer(emb)
208
+ emb = self.flatten_layer(emb)
209
+ out = self.final_layer(emb)
210
+ return out
211
+
212
+
213
+ class _DenseLayer(nn.Sequential):
214
+ """
215
+ A densely connected layer with batch normalization and dropout.
216
+
217
+ Parameters
218
+ ----------
219
+ in_channels : int
220
+ Number of input channels.
221
+ growth_rate : int
222
+ Rate of growth of channels in this layer.
223
+ bottleneck_size : int
224
+ Multiplicative factor for the bottleneck layer (does not affect the output size).
225
+ drop_prob : float, optional
226
+ Dropout rate. Default is 0.5.
227
+ conv_bias : bool, optional
228
+ Whether to use bias in convolutional layers. Default is True.
229
+ batch_norm : bool, optional
230
+ Whether to use batch normalization. Default is True.
231
+ activation: nn.Module, default=nn.ELU
232
+ Activation function class to apply. Should be a PyTorch activation
233
+ module class like ``nn.ReLU`` or ``nn.ELU``. Default is ``nn.ELU``.
234
+
235
+ Examples
236
+ --------
237
+ >>> x = torch.randn(128, 5, 1000)
238
+ >>> batch, channels, length = x.shape
239
+ >>> model = _DenseLayer(channels, 5, 2)
240
+ >>> y = model(x)
241
+ >>> y.shape
242
+ torch.Size([128, 10, 1000])
243
+ """
244
+
245
+ def __init__(
246
+ self,
247
+ in_channels: int,
248
+ growth_rate: int,
249
+ bottleneck_size: int,
250
+ drop_prob: float = 0.5,
251
+ conv_bias: bool = True,
252
+ batch_norm: bool = True,
253
+ activation: nn.Module = nn.ELU,
254
+ kernel_size_conv1: int = 1,
255
+ kernel_size_conv2: int = 3,
256
+ stride_conv1: int = 1,
257
+ stride_conv2: int = 1,
258
+ padding_conv2: int = 1,
259
+ ):
260
+ super().__init__()
261
+ if batch_norm:
262
+ self.add_module("norm1", nn.BatchNorm1d(in_channels))
263
+
264
+ self.add_module("elu1", activation())
265
+ self.add_module(
266
+ "conv1",
267
+ nn.Conv1d(
268
+ in_channels=in_channels,
269
+ out_channels=bottleneck_size * growth_rate,
270
+ kernel_size=kernel_size_conv1,
271
+ stride=stride_conv1,
272
+ bias=conv_bias,
273
+ ),
274
+ )
275
+ if batch_norm:
276
+ self.add_module("norm2", nn.BatchNorm1d(bottleneck_size * growth_rate))
277
+ self.add_module("elu2", activation())
278
+ self.add_module(
279
+ "conv2",
280
+ nn.Conv1d(
281
+ in_channels=bottleneck_size * growth_rate,
282
+ out_channels=growth_rate,
283
+ kernel_size=kernel_size_conv2,
284
+ stride=stride_conv2,
285
+ padding=padding_conv2,
286
+ bias=conv_bias,
287
+ ),
288
+ )
289
+ self.drop_prob = drop_prob
290
+
291
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
292
+ # Manually pass through each submodule
293
+ out = x
294
+ for layer in self:
295
+ out = layer(out)
296
+ # apply dropout using the functional API
297
+ out = F.dropout(out, p=self.drop_prob, training=self.training)
298
+ # concatenate input and new features
299
+ return torch.cat([x, out], dim=1)
300
+
301
+
302
+ class _DenseBlock(nn.Sequential):
303
+ """
304
+ A densely connected block that uses DenseLayers.
305
+
306
+ Parameters
307
+ ----------
308
+ num_layers : int
309
+ Number of layers in this block.
310
+ in_channels : int
311
+ Number of input channels.
312
+ growth_rate : int
313
+ Rate of growth of channels in this layer.
314
+ bottleneck_size : int
315
+ Multiplicative factor for the bottleneck layer (does not affect the output size).
316
+ drop_prob : float, optional
317
+ Dropout rate. Default is 0.5.
318
+ conv_bias : bool, optional
319
+ Whether to use bias in convolutional layers. Default is True.
320
+ batch_norm : bool, optional
321
+ Whether to use batch normalization. Default is True.
322
+ activation: nn.Module, default=nn.ELU
323
+ Activation function class to apply. Should be a PyTorch activation
324
+ module class like ``nn.ReLU`` or ``nn.ELU``. Default is ``nn.ELU``.
325
+
326
+ Examples
327
+ --------
328
+ >>> x = torch.randn(128, 5, 1000)
329
+ >>> batch, channels, length = x.shape
330
+ >>> model = _DenseBlock(3, channels, 5, 2)
331
+ >>> y = model(x)
332
+ >>> y.shape
333
+ torch.Size([128, 20, 1000])
334
+ """
335
+
336
+ def __init__(
337
+ self,
338
+ num_layers,
339
+ in_channels,
340
+ growth_rate,
341
+ bottleneck_size,
342
+ drop_prob=0.5,
343
+ conv_bias=True,
344
+ batch_norm=True,
345
+ activation: nn.Module = nn.ELU,
346
+ kernel_size_conv1: int = 1,
347
+ kernel_size_conv2: int = 3,
348
+ stride_conv1: int = 1,
349
+ stride_conv2: int = 1,
350
+ padding_conv2: int = 1,
351
+ ):
352
+ super(_DenseBlock, self).__init__()
353
+ for idx_layer in range(num_layers):
354
+ layer = _DenseLayer(
355
+ in_channels=in_channels + idx_layer * growth_rate,
356
+ growth_rate=growth_rate,
357
+ bottleneck_size=bottleneck_size,
358
+ drop_prob=drop_prob,
359
+ conv_bias=conv_bias,
360
+ batch_norm=batch_norm,
361
+ activation=activation,
362
+ kernel_size_conv1=kernel_size_conv1,
363
+ kernel_size_conv2=kernel_size_conv2,
364
+ stride_conv1=stride_conv1,
365
+ stride_conv2=stride_conv2,
366
+ padding_conv2=padding_conv2,
367
+ )
368
+ self.add_module(f"denselayer{idx_layer + 1}", layer)
369
+
370
+
371
+ class _TransitionLayer(nn.Sequential):
372
+ """
373
+ A pooling transition layer.
374
+
375
+ Parameters
376
+ ----------
377
+ in_channels : int
378
+ Number of input channels.
379
+ out_channels : int
380
+ Number of output channels.
381
+ conv_bias : bool, optional
382
+ Whether to use bias in convolutional layers. Default is True.
383
+ batch_norm : bool, optional
384
+ Whether to use batch normalization. Default is True.
385
+ activation: nn.Module, default=nn.ELU
386
+ Activation function class to apply. Should be a PyTorch activation
387
+ module class like ``nn.ReLU`` or ``nn.ELU``. Default is ``nn.ELU``.
388
+
389
+ Examples
390
+ --------
391
+ >>> x = torch.randn(128, 5, 1000)
392
+ >>> model = _TransitionLayer(5, 18)
393
+ >>> y = model(x)
394
+ >>> y.shape
395
+ torch.Size([128, 18, 500])
396
+ """
397
+
398
+ def __init__(
399
+ self,
400
+ in_channels,
401
+ out_channels,
402
+ conv_bias=True,
403
+ batch_norm=True,
404
+ activation: nn.Module = nn.ELU,
405
+ kernel_size_trans: int = 2,
406
+ stride_trans: int = 2,
407
+ ):
408
+ super(_TransitionLayer, self).__init__()
409
+ if batch_norm:
410
+ self.add_module("norm", nn.BatchNorm1d(in_channels))
411
+ self.add_module("elu", activation())
412
+ self.add_module(
413
+ "conv",
414
+ nn.Conv1d(
415
+ in_channels=in_channels,
416
+ out_channels=out_channels,
417
+ kernel_size=1,
418
+ stride=1,
419
+ bias=conv_bias,
420
+ ),
421
+ )
422
+ self.add_module(
423
+ "pool", nn.AvgPool1d(kernel_size=kernel_size_trans, stride=stride_trans)
424
+ )
@@ -0,0 +1,41 @@
1
+ Model,Paradigm,Type,Freq(Hz),Hyperparameters,#Parameters,get_#Parameters
2
+ ATCNet,General,Classification,250,"n_chans, n_outputs, n_times",113732,"ATCNet(n_chans=22, n_outputs=4, n_times=1000)"
3
+ AttentionBaseNet,Motor Imagery,Classification,250,"n_chans, n_outputs, n_times",3692,"AttentionBaseNet(n_chans=22, n_outputs=4, n_times=1000)"
4
+ BDTCN,Normal/Abnormal,Classification,100,"n_chans, n_outputs, n_times",456502,"BDTCN(n_chans=21, n_outputs=2, n_times=6000, n_blocks=5, n_filters=55, kernel_size=16)"
5
+ BIOT,"Sleep Staging, Epilepsy",Classification,200,"n_chans, n_outputs",3183879,"BIOT(n_chans=2, n_outputs=5, n_times=6000)"
6
+ ContraWR,Sleep Staging,"Classification, Embedding",125,"n_chans, n_outputs, sfreq",1160165,"ContraWR(n_chans=2, n_outputs=5, n_times=3750, emb_size=256, sfreq=125)"
7
+ CTNet,Motor Imagery,Classification,250,"n_chans, n_outputs, n_times",26900,"CTNet(n_chans=22, n_outputs=4, n_times=1000, n_filters_time=8, kernel_size=16, heads=2, emb_size=16)"
8
+ Deep4Net,Motor Imagery,Classification,250,"n_chans, n_outputs, n_times",282879,"Deep4Net(n_chans=22, n_outputs=4, n_times=1000)"
9
+ DeepSleepNet,Sleep Staging,Classification,256,"n_chans, n_outputs",24744837,"DeepSleepNet(n_chans=1, n_outputs=5, n_times=7680, sfreq=256)"
10
+ EEGConformer,General,Classification,250,"n_chans, n_outputs, n_times",789572,"EEGConformer(n_chans=22, n_outputs=4, n_times=1000)."
11
+ EEGInceptionERP,"ERP, SSVEP",Classification,128,"n_chans, n_outputs",14926,"EEGInceptionERP(n_chans=8, n_outputs=2, n_times=128, sfreq=128)"
12
+ EEGInceptionMI,Motor Imagery,Classification,250,"n_chans, n_outputs, n_times",558028,"EEGInceptionMI(n_chans=22, n_outputs=4, n_times=1000, n_convs=5, n_filters=12)"
13
+ EEGITNet,Motor Imagery,Classification,125,"n_chans, n_outputs, n_times",5212,"EEGITNet(n_chans=22, n_outputs=4, n_times=500)"
14
+ EEGNetv1,General,Classification,128,"n_chans, n_outputs, n_times",3052,"EEGNetv1(n_chans=22, n_outputs=4, n_times=512)"
15
+ EEGNetv4,General,Classification,128,"n_chans, n_outputs, n_times",2484,"EEGNetv4(n_chans=22, n_outputs=4, n_times=512)"
16
+ EEGNeX,Motor Imagery,Classification,125,"n_chans, n_outputs, n_times",55940,"EEGNeX(n_chans=22, n_outputs=4, n_times=500)"
17
+ EEGMiner,Emotion Recognition,Classification,128,"n_chans, n_outputs, n_times, sfreq",7572,"EEGMiner(n_chans=62, n_outputs=2, n_times=2560, sfreq=128)"
18
+ EEGResNet,General,Classification,250,"n_chans, n_outputs, n_times",247484,"EEGResNet(n_chans=22, n_outputs=4, n_times=1000)"
19
+ EEGSimpleConv,Motor Imagery,Classification,80,"n_chans, n_outputs, sfreq",730404,"EEGSimpleConv(n_chans=22, n_outputs=4, n_times=320, sfreq=80)"
20
+ EEGTCNet,Motor Imagery,Classification,250,"n_chans, n_outputs",4516,"EEGTCNet(n_chans=22, n_outputs=4, n_times=1000, kern_length=32)"
21
+ Labram,General,"Classification, Embedding",200,"n_chans, n_outputs, n_times",5866180,"Labram(n_chans=22, n_outputs=4, n_times=1000, sfreq=250)"
22
+ MSVTNet,Motor Imagery,Classification,250,"n_chans, n_outputs, n_times",75494," MSVTNet(n_chans=22, n_outputs=4, n_times=1000)"
23
+ SCCNet,Motor Imagery,Classification,125,"n_chans, n_outputs, n_times, sfreq",12070,"SCCNet(n_chans=22, n_outputs=4, n_times=1000, sfreq=125)"
24
+ SignalJEPA,"Motor Imagery, ERP, SSVEP",Embedding,128,"n_times, chs_info",3456882,"SignalJEPA(n_times=512, chs_info=Lee2019_MI().get_data(subjects=[1])[1]['0']['1train'].info[""chs""][:62])"
25
+ SignalJEPA_Contextual,"Motor Imagery, ERP, SSVEP",Classification,128,"n_outputs, n_times, chs_info",3459184,"SignalJEPA_Contextual(n_outputs=2, input_window_seconds=4.19, sfreq=128, chs_info=Lee2019_MI().get_data(subjects=[1])[1]['0']['1train'].info[""chs""][:62])"
26
+ SignalJEPA_PostLocal,"Motor Imagery, ERP, SSVEP",Classification,128,"n_chans, n_outputs, n_times",16142,"SignalJEPA_PostLocal(n_chans=62, n_outputs=2, input_window_seconds=4.19, sfreq=128)"
27
+ SignalJEPA_PreLocal,"Motor Imagery, ERP, SSVEP",Classification,128,"n_outputs, n_times, chs_info",16142,"SignalJEPA_PreLocal(n_chans=62, n_outputs=2, input_window_seconds=4.19, sfreq=128)"
28
+ SincShallowNet,Motor Imagery,Classification,250,"n_chans, n_outputs, n_times, sfreq",21892,"SincShallowNet(n_chans=22, n_outputs=4, n_times=1000, sfreq=250)"
29
+ ShallowFBCSPNet,Motor Imagery,Classification,250,"n_chans, n_outputs, n_times",46084,"ShallowFBCSPNet(n_chans=22, n_outputs=4, n_times=1000, sfreq=250)"
30
+ SleepStagerBlanco2020,Sleep Staging,Classification,100,"n_chans, n_outputs, n_times",2845,"SleepStagerBlanco2020(n_chans=2, n_outputs=5, n_times=3000, sfreq=100)"
31
+ SleepStagerChambon2018,Sleep Staging,Classification,128,"n_chans, n_outputs, n_times, sfreq",5835,"SleepStagerChambon2018(n_chans=2, n_outputs=5, n_times=3840, sfreq=128)"
32
+ SleepStagerEldele2021,Sleep Staging,Classification,100,"n_chans, n_outputs, n_times, sfreq",719925,"SleepStagerEldele2021(n_chans=2, n_outputs=5, n_times=3000, sfreq=100)"
33
+ SPARCNet,Epilepsy,Classification,200,"n_chans, n_outputs, n_times",1141921,"SPARCNet(n_chans=16, n_outputs=6, n_times=2000, sfreq=200)"
34
+ SyncNet,"Emotion Recognition, Alcoholism",Classification,256,"n_chans, n_outputs, n_times",554,"SyncNet(n_chans=62, n_outputs=3, n_times=5120, sfreq=256)"
35
+ TSceptionV1,Emotion Recognition,Classification,256,"n_chans, n_outputs, n_times, sfreq",2187206,"TSceptionV1(n_chans=62, n_outputs=3, n_times=5120, sfreq=256)"
36
+ TIDNet,General,Classification,250,"n_chans, n_outputs, n_times",240404,"TIDNet(n_chans=22, n_outputs=4, n_times=1000)"
37
+ USleep,Sleep Staging,Classification,128,"n_chans, n_outputs, n_times, sfreq",2482011,"USleep(n_chans=2, n_outputs=5, n_times=3000, sfreq=100)"
38
+ FBCNet,Motor Imagery,Classification,250,"n_chans, n_outputs, n_times, sfreq",11812,"FCNet(n_chans=22, n_outputs=4, n_times=1000, sfreq=250)"
39
+ FBMSNet,Motor Imagery,Classification,250,"n_chans, n_outputs, n_times, sfreq",16231,"FBMSNet(n_chans=22, n_outputs=4, n_times=1000, sfreq=250)"
40
+ FBLightConvNet,Motor Imagery,Classification,250,"n_chans, n_outputs, n_times, sfreq",6596,"FBLightConvNet(n_chans=22, n_outputs=4, n_times=1000, sfreq=250)"
41
+ IFNet,Motor Imagery,Classification,250,"n_chans, n_outputs, n_times, sfreq",9860,"IFNet(n_chans=22, n_outputs=4, n_times=1000, sfreq=250)"
@@ -0,0 +1,232 @@
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from einops.layers.torch import Rearrange
5
+ from numpy import arange, ceil
6
+
7
+ from braindecode.models.base import EEGModuleMixin
8
+
9
+
10
+ class SyncNet(EEGModuleMixin, nn.Module):
11
+ """Synchronization Network (SyncNet) from Li, Y et al (2017) [Li2017]_.
12
+
13
+ .. figure:: https://braindecode.org/dev/_static/model/SyncNet.png
14
+ :align: center
15
+ :alt: SyncNet Architecture
16
+
17
+ SyncNet uses parameterized 1-dimensional convolutional filters inspired by
18
+ the Morlet wavelet to extract features from EEG signals. The filters are
19
+ dynamically generated based on learnable parameters that control the
20
+ oscillation and decay characteristics.
21
+
22
+ The filter for channel ``c`` and filter ``k`` is defined as:
23
+
24
+ .. math::
25
+
26
+ f_c^{(k)}(\\tau) = amplitude_c^{(k)} \\cos(\\omega^{(k)} \\tau + \\phi_c^{(k)}) \\exp(-\\beta^{(k)} \\tau^2)
27
+
28
+ where:
29
+ - :math:`amplitude_c^{(k)}` is the amplitude parameter (channel-specific).
30
+ - :math:`\\omega^{(k)}` is the frequency parameter (shared across channels).
31
+ - :math:`\\phi_c^{(k)}` is the phase shift (channel-specific).
32
+ - :math:`\\beta^{(k)}` is the decay parameter (shared across channels).
33
+ - :math:`\\tau` is the time index.
34
+
35
+ Parameters
36
+ ----------
37
+ num_filters : int, optional
38
+ Number of filters in the convolutional layer. Default is 1.
39
+ filter_width : int, optional
40
+ Width of the convolutional filters. Default is 40.
41
+ pool_size : int, optional
42
+ Size of the pooling window. Default is 40.
43
+ activation : nn.Module, optional
44
+ Activation function to apply after pooling. Default is ``nn.ReLU``.
45
+ ampli_init_values : tuple of float, optional
46
+ The initialization range for amplitude parameter using uniform
47
+ distribution. Default is (-0.05, 0.05).
48
+ omega_init_values : tuple of float, optional
49
+ The initialization range for omega parameters using uniform
50
+ distribution. Default is (0, 1).
51
+ beta_init_values : tuple of float, optional
52
+ The initialization range for beta parameters using uniform
53
+ distribution. Default is (0, 1). Default is (0, 0.05).
54
+ phase_init_values : tuple of float, optional
55
+ The initialization range for phase parameters using `normal`
56
+ distribution. Default is (0, 1). Default is (0, 0.05).
57
+
58
+
59
+ Notes
60
+ -----
61
+ This implementation is not guaranteed to be correct! it has not been checked
62
+ by original authors. The modifications are based on derivated code from
63
+ [CodeICASSP2025]_.
64
+
65
+
66
+ References
67
+ ----------
68
+ .. [Li2017] Li, Y., Dzirasa, K., Carin, L., & Carlson, D. E. (2017).
69
+ Targeting EEG/LFP synchrony with neural nets. Advances in neural
70
+ information processing systems, 30.
71
+ .. [CodeICASSP2025] Code from Baselines for EEG-Music Emotion Recognition
72
+ Grand Challenge at ICASSP 2025.
73
+ https://github.com/SalvoCalcagno/eeg-music-challenge-icassp-2025-baselines
74
+
75
+ """
76
+
77
+ def __init__(
78
+ self,
79
+ # braindecode convention
80
+ n_chans=None,
81
+ n_times=None,
82
+ n_outputs=None,
83
+ chs_info=None,
84
+ input_window_seconds=None,
85
+ sfreq=None,
86
+ # model parameters
87
+ num_filters=1,
88
+ filter_width=40,
89
+ pool_size=40,
90
+ activation: nn.Module = nn.ReLU,
91
+ ampli_init_values: tuple[float, float] = (-0.05, 0.05),
92
+ omega_init_values: tuple[float, float] = (0.0, 1.0),
93
+ beta_init_values: tuple[float, float] = (0.0, 0.05),
94
+ phase_init_values: tuple[float, float] = (0.0, 0.05),
95
+ ):
96
+ super().__init__(
97
+ n_chans=n_chans,
98
+ n_times=n_times,
99
+ n_outputs=n_outputs,
100
+ chs_info=chs_info,
101
+ input_window_seconds=input_window_seconds,
102
+ sfreq=sfreq,
103
+ )
104
+ del n_outputs, n_chans, chs_info, n_times, input_window_seconds, sfreq
105
+
106
+ self.num_filters = num_filters
107
+ self.filter_width = filter_width
108
+ self.pool_size = pool_size
109
+ self.activation = activation()
110
+ self.ampli_init_values = ampli_init_values
111
+ self.omega_init_values = omega_init_values
112
+ self.beta_init_values = beta_init_values
113
+ self.phase_init_values = phase_init_values
114
+
115
+ # Initialize parameters
116
+ self.amplitude = nn.Parameter(
117
+ torch.FloatTensor(1, 1, self.n_chans, self.num_filters).uniform_(
118
+ self.ampli_init_values[0], self.ampli_init_values[1]
119
+ )
120
+ )
121
+ self.omega = nn.Parameter(
122
+ torch.FloatTensor(1, 1, 1, self.num_filters).uniform_(
123
+ self.omega_init_values[0], self.omega_init_values[1]
124
+ )
125
+ )
126
+
127
+ self.bias = nn.Parameter(torch.zeros(self.num_filters))
128
+
129
+ # Calculate the output size after pooling
130
+ self.classifier_input_size = int(
131
+ ceil(float(self.n_times) / float(self.pool_size)) * self.num_filters
132
+ )
133
+
134
+ # Create time vector t
135
+ if self.filter_width % 2 == 0:
136
+ t_range = arange(-int(self.filter_width / 2), int(self.filter_width / 2))
137
+ else:
138
+ t_range = arange(
139
+ -int((self.filter_width - 1) / 2), int((self.filter_width - 1) / 2) + 1
140
+ )
141
+
142
+ t_np = t_range.reshape(1, self.filter_width, 1, 1)
143
+ self.t = nn.Parameter(torch.FloatTensor(t_np))
144
+ # Phase Shift
145
+ self.phi_ini = nn.Parameter(
146
+ torch.FloatTensor(1, 1, self.n_chans, self.num_filters).normal_(
147
+ self.beta_init_values[0], self.beta_init_values[1]
148
+ )
149
+ )
150
+ self.beta = nn.Parameter(
151
+ torch.FloatTensor(1, 1, 1, self.num_filters).uniform_(
152
+ self.phase_init_values[0], self.phase_init_values[1]
153
+ )
154
+ )
155
+
156
+ self.padding = self._compute_padding(filter_width=self.filter_width)
157
+ self.pad_input = nn.ConstantPad1d(self.padding, 0.0)
158
+ self.pad_res = nn.ConstantPad1d(self.padding, 0.0)
159
+
160
+ # Define pooling and classifier layers
161
+ self.pool = nn.MaxPool2d((1, self.pool_size), stride=(1, self.pool_size))
162
+
163
+ self.ensuredim = Rearrange("batch ch time -> batch ch 1 time")
164
+
165
+ self.final_layer = nn.Linear(self.classifier_input_size, self.n_outputs)
166
+
167
+ def forward(self, x):
168
+ """Forward pass of the SyncNet model.
169
+
170
+ Parameters
171
+ ----------
172
+ x : torch.Tensor
173
+ Input tensor of shape (batch_size, n_chans, n_times)
174
+
175
+ Returns
176
+ -------
177
+ out : torch.Tensor
178
+ Output tensor of shape (batch_size, n_outputs).
179
+
180
+ """
181
+ # Ensure input tensor has shape (batch_size, n_chans, 1, n_times)
182
+ x = self.ensuredim(x)
183
+ # Output: (batch_size, n_chans, 1, n_times)
184
+
185
+ # Compute the oscillatory component
186
+ W_osc = self.amplitude * torch.cos(self.t * self.omega + self.phi_ini)
187
+ # W_osc is (1, filter_width, n_chans, 1)
188
+
189
+ # Compute the decay component
190
+ t_squared = torch.pow(self.t, 2) # Shape: (filter_width,)
191
+ t_squared_beta = t_squared * self.beta # Shape: (filter_width, num_filters)
192
+ W_decay = torch.exp(-t_squared_beta)
193
+ # W_osc is (1, filter_width, 1, 1)
194
+
195
+ # Combine oscillatory and decay components
196
+ # W shape: (1, n_chans, num_filters, filter_width)
197
+ W = W_osc * W_decay
198
+ # W shape will be: (1, filter_width, n_chans, 1)
199
+
200
+ W = W.view(self.num_filters, self.n_chans, 1, self.filter_width)
201
+
202
+ # Apply convolution
203
+ x_padded = self.pad_input(x.float())
204
+
205
+ res = F.conv2d(x_padded, W.float(), bias=self.bias, stride=1)
206
+
207
+ # Apply padding to the convolution result
208
+ res_padded = self.pad_res(res)
209
+ res_pooled = self.pool(res_padded)
210
+
211
+ # Flatten the result
212
+ res_flat = res_pooled.view(-1, self.classifier_input_size)
213
+
214
+ # Ensure beta remains non-negative
215
+ self.beta.data.clamp_(min=0)
216
+
217
+ # Apply activation
218
+ out = self.activation(res_flat)
219
+ # Apply classifier
220
+ out = self.final_layer(out)
221
+
222
+ return out
223
+
224
+ @staticmethod
225
+ def _compute_padding(filter_width):
226
+ # Compute padding
227
+ P = filter_width - 2
228
+ if P % 2 == 0:
229
+ padding = (P // 2, P // 2 + 1)
230
+ else:
231
+ padding = (P // 2, P // 2)
232
+ return padding