braindecode 0.8.1__py3-none-any.whl → 1.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of braindecode might be problematic. Click here for more details.

Files changed (108) hide show
  1. braindecode/__init__.py +1 -2
  2. braindecode/augmentation/__init__.py +39 -19
  3. braindecode/augmentation/base.py +25 -28
  4. braindecode/augmentation/functional.py +237 -100
  5. braindecode/augmentation/transforms.py +325 -158
  6. braindecode/classifier.py +26 -24
  7. braindecode/datasets/__init__.py +28 -10
  8. braindecode/datasets/base.py +220 -134
  9. braindecode/datasets/bbci.py +43 -52
  10. braindecode/datasets/bcicomp.py +47 -32
  11. braindecode/datasets/bids.py +245 -0
  12. braindecode/datasets/mne.py +45 -24
  13. braindecode/datasets/moabb.py +87 -27
  14. braindecode/datasets/nmt.py +311 -0
  15. braindecode/datasets/sleep_physio_challe_18.py +412 -0
  16. braindecode/datasets/sleep_physionet.py +43 -26
  17. braindecode/datasets/tuh.py +324 -140
  18. braindecode/datasets/xy.py +27 -12
  19. braindecode/datautil/__init__.py +37 -18
  20. braindecode/datautil/serialization.py +110 -72
  21. braindecode/eegneuralnet.py +63 -47
  22. braindecode/functional/__init__.py +22 -0
  23. braindecode/functional/functions.py +250 -0
  24. braindecode/functional/initialization.py +47 -0
  25. braindecode/models/__init__.py +84 -14
  26. braindecode/models/atcnet.py +193 -164
  27. braindecode/models/attentionbasenet.py +599 -0
  28. braindecode/models/base.py +86 -102
  29. braindecode/models/biot.py +504 -0
  30. braindecode/models/contrawr.py +317 -0
  31. braindecode/models/ctnet.py +536 -0
  32. braindecode/models/deep4.py +116 -77
  33. braindecode/models/deepsleepnet.py +149 -119
  34. braindecode/models/eegconformer.py +112 -173
  35. braindecode/models/eeginception_erp.py +109 -118
  36. braindecode/models/eeginception_mi.py +161 -97
  37. braindecode/models/eegitnet.py +215 -152
  38. braindecode/models/eegminer.py +254 -0
  39. braindecode/models/eegnet.py +228 -161
  40. braindecode/models/eegnex.py +247 -0
  41. braindecode/models/eegresnet.py +234 -152
  42. braindecode/models/eegsimpleconv.py +199 -0
  43. braindecode/models/eegtcnet.py +335 -0
  44. braindecode/models/fbcnet.py +221 -0
  45. braindecode/models/fblightconvnet.py +313 -0
  46. braindecode/models/fbmsnet.py +324 -0
  47. braindecode/models/hybrid.py +52 -71
  48. braindecode/models/ifnet.py +441 -0
  49. braindecode/models/labram.py +1186 -0
  50. braindecode/models/msvtnet.py +375 -0
  51. braindecode/models/sccnet.py +207 -0
  52. braindecode/models/shallow_fbcsp.py +50 -56
  53. braindecode/models/signal_jepa.py +1011 -0
  54. braindecode/models/sinc_shallow.py +337 -0
  55. braindecode/models/sleep_stager_blanco_2020.py +55 -46
  56. braindecode/models/sleep_stager_chambon_2018.py +54 -53
  57. braindecode/models/sleep_stager_eldele_2021.py +247 -141
  58. braindecode/models/sparcnet.py +424 -0
  59. braindecode/models/summary.csv +41 -0
  60. braindecode/models/syncnet.py +232 -0
  61. braindecode/models/tcn.py +158 -88
  62. braindecode/models/tidnet.py +280 -167
  63. braindecode/models/tsinception.py +283 -0
  64. braindecode/models/usleep.py +190 -177
  65. braindecode/models/util.py +109 -145
  66. braindecode/modules/__init__.py +84 -0
  67. braindecode/modules/activation.py +60 -0
  68. braindecode/modules/attention.py +757 -0
  69. braindecode/modules/blocks.py +108 -0
  70. braindecode/modules/convolution.py +274 -0
  71. braindecode/modules/filter.py +628 -0
  72. braindecode/modules/layers.py +131 -0
  73. braindecode/modules/linear.py +49 -0
  74. braindecode/modules/parametrization.py +38 -0
  75. braindecode/modules/stats.py +77 -0
  76. braindecode/modules/util.py +76 -0
  77. braindecode/modules/wrapper.py +73 -0
  78. braindecode/preprocessing/__init__.py +36 -11
  79. braindecode/preprocessing/mne_preprocess.py +13 -7
  80. braindecode/preprocessing/preprocess.py +139 -75
  81. braindecode/preprocessing/windowers.py +576 -187
  82. braindecode/regressor.py +23 -12
  83. braindecode/samplers/__init__.py +16 -8
  84. braindecode/samplers/base.py +146 -32
  85. braindecode/samplers/ssl.py +162 -17
  86. braindecode/training/__init__.py +18 -10
  87. braindecode/training/callbacks.py +2 -4
  88. braindecode/training/losses.py +3 -8
  89. braindecode/training/scoring.py +76 -68
  90. braindecode/util.py +55 -59
  91. braindecode/version.py +1 -1
  92. braindecode/visualization/__init__.py +2 -3
  93. braindecode/visualization/confusion_matrices.py +117 -73
  94. braindecode/visualization/gradients.py +14 -10
  95. {braindecode-0.8.1.dist-info → braindecode-1.1.0.dist-info}/METADATA +42 -58
  96. braindecode-1.1.0.dist-info/RECORD +101 -0
  97. {braindecode-0.8.1.dist-info → braindecode-1.1.0.dist-info}/WHEEL +1 -1
  98. {braindecode-0.8.1.dist-info → braindecode-1.1.0.dist-info/licenses}/LICENSE.txt +1 -1
  99. braindecode-1.1.0.dist-info/licenses/NOTICE.txt +20 -0
  100. braindecode/datautil/mne.py +0 -9
  101. braindecode/datautil/preprocess.py +0 -12
  102. braindecode/datautil/windowers.py +0 -6
  103. braindecode/datautil/xy.py +0 -9
  104. braindecode/models/eeginception.py +0 -317
  105. braindecode/models/functions.py +0 -47
  106. braindecode/models/modules.py +0 -358
  107. braindecode-0.8.1.dist-info/RECORD +0 -68
  108. {braindecode-0.8.1.dist-info → braindecode-1.1.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,375 @@
1
+ # Authors: Tao Yang <sheeptao@outlook.com>
2
+ # Bruno Aristimunha <b.aristimunha@gmail.com> (braindecode adaptation)
3
+ #
4
+ from typing import Type, Union
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ from einops.layers.torch import Rearrange
9
+
10
+ from braindecode.models.base import EEGModuleMixin
11
+
12
+
13
+ class MSVTNet(EEGModuleMixin, nn.Module):
14
+ """MSVTNet model from Liu K et al (2024) from [msvt2024]_.
15
+
16
+ This model implements a multi-scale convolutional transformer network
17
+ for EEG signal classification, as described in [msvt2024]_.
18
+
19
+ .. figure:: https://raw.githubusercontent.com/SheepTAO/MSVTNet/refs/heads/main/MSVTNet_Arch.png
20
+ :align: center
21
+ :alt: MSVTNet Architecture
22
+
23
+ Parameters
24
+ ----------
25
+ n_filters_list : list[int], optional
26
+ List of filter numbers for each TSConv block, by default (9, 9, 9, 9).
27
+ conv1_kernels_size : list[int], optional
28
+ List of kernel sizes for the first convolution in each TSConv block,
29
+ by default (15, 31, 63, 125).
30
+ conv2_kernel_size : int, optional
31
+ Kernel size for the second convolution in TSConv blocks, by default 15.
32
+ depth_multiplier : int, optional
33
+ Depth multiplier for depthwise convolution, by default 2.
34
+ pool1_size : int, optional
35
+ Pooling size for the first pooling layer in TSConv blocks, by default 8.
36
+ pool2_size : int, optional
37
+ Pooling size for the second pooling layer in TSConv blocks, by default 7.
38
+ drop_prob : float, optional
39
+ Dropout probability for convolutional layers, by default 0.3.
40
+ num_heads : int, optional
41
+ Number of attention heads in the transformer encoder, by default 8.
42
+ feedforward_ratio : float, optional
43
+ Ratio to compute feedforward dimension in the transformer, by default 1.
44
+ drop_prob_trans : float, optional
45
+ Dropout probability for the transformer, by default 0.5.
46
+ num_layers : int, optional
47
+ Number of transformer encoder layers, by default 2.
48
+ activation : Type[nn.Module], optional
49
+ Activation function class to use, by default nn.ELU.
50
+ return_features : bool, optional
51
+ Whether to return predictions from branch classifiers, by default False.
52
+
53
+ Notes
54
+ -----
55
+ This implementation is not guaranteed to be correct, has not been checked
56
+ by original authors, only reimplemented based on the original code [msvt2024code]_.
57
+
58
+ References
59
+ ----------
60
+ .. [msvt2024] Liu, K., et al. (2024). MSVTNet: Multi-Scale Vision
61
+ Transformer Neural Network for EEG-Based Motor Imagery Decoding.
62
+ IEEE Journal of Biomedical an Health Informatics.
63
+ .. [msvt2024code] Liu, K., et al. (2024). MSVTNet: Multi-Scale Vision
64
+ Transformer Neural Network for EEG-Based Motor Imagery Decoding.
65
+ Source Code: https://github.com/SheepTAO/MSVTNet
66
+ """
67
+
68
+ def __init__(
69
+ self,
70
+ # braindecode parameters
71
+ n_chans=None,
72
+ n_outputs=None,
73
+ n_times=None,
74
+ input_window_seconds=None,
75
+ sfreq=None,
76
+ chs_info=None,
77
+ # Model's parameters
78
+ n_filters_list: tuple[int, ...] = (9, 9, 9, 9),
79
+ conv1_kernels_size: tuple[int, ...] = (15, 31, 63, 125),
80
+ conv2_kernel_size: int = 15,
81
+ depth_multiplier: int = 2,
82
+ pool1_size: int = 8,
83
+ pool2_size: int = 7,
84
+ drop_prob: float = 0.3,
85
+ num_heads: int = 8,
86
+ feedforward_ratio: float = 1,
87
+ drop_prob_trans: float = 0.5,
88
+ num_layers: int = 2,
89
+ activation: Type[nn.Module] = nn.ELU,
90
+ return_features: bool = False,
91
+ ):
92
+ super().__init__(
93
+ n_outputs=n_outputs,
94
+ n_chans=n_chans,
95
+ chs_info=chs_info,
96
+ n_times=n_times,
97
+ input_window_seconds=input_window_seconds,
98
+ sfreq=sfreq,
99
+ )
100
+ del n_outputs, n_chans, chs_info, n_times, input_window_seconds, sfreq
101
+
102
+ self.return_features = return_features
103
+ assert len(n_filters_list) == len(conv1_kernels_size), (
104
+ "The length of n_filters_list and conv1_kernel_sizes should be equal."
105
+ )
106
+
107
+ self.ensure_dim = Rearrange("batch chans time -> batch 1 chans time")
108
+ self.mstsconv = nn.ModuleList(
109
+ [
110
+ nn.Sequential(
111
+ _TSConv(
112
+ self.n_chans,
113
+ n_filters_list[b],
114
+ conv1_kernels_size[b],
115
+ conv2_kernel_size,
116
+ depth_multiplier,
117
+ pool1_size,
118
+ pool2_size,
119
+ drop_prob,
120
+ activation,
121
+ ),
122
+ Rearrange("batch channels 1 time -> batch time channels"),
123
+ )
124
+ for b in range(len(n_filters_list))
125
+ ]
126
+ )
127
+ branch_linear_in = self._forward_flatten(cat=False)
128
+ self.branch_head = nn.ModuleList(
129
+ [
130
+ _DenseLayers(branch_linear_in[b].shape[1], self.n_outputs)
131
+ for b in range(len(n_filters_list))
132
+ ]
133
+ )
134
+
135
+ seq_len, d_model = self._forward_mstsconv().shape[1:3] # type: ignore
136
+ self.transformer = _Transformer(
137
+ seq_len,
138
+ d_model,
139
+ num_heads,
140
+ feedforward_ratio,
141
+ drop_prob_trans,
142
+ num_layers,
143
+ )
144
+
145
+ linear_in = self._forward_flatten().shape[1] # type: ignore
146
+ self.flatten_layer = nn.Flatten()
147
+ self.final_layer = nn.Linear(linear_in, self.n_outputs)
148
+
149
+ def _forward_mstsconv(
150
+ self, cat: bool = True
151
+ ) -> Union[torch.Tensor, list[torch.Tensor]]:
152
+ x = torch.randn(1, 1, self.n_chans, self.n_times)
153
+ x = [tsconv(x) for tsconv in self.mstsconv]
154
+ if cat:
155
+ x = torch.cat(x, dim=2)
156
+ return x
157
+
158
+ def _forward_flatten(
159
+ self, cat: bool = True
160
+ ) -> Union[torch.Tensor, list[torch.Tensor]]:
161
+ x = self._forward_mstsconv(cat)
162
+ if cat:
163
+ x = self.transformer(x)
164
+ x = x.flatten(start_dim=1, end_dim=-1)
165
+ else:
166
+ x = [_.flatten(start_dim=1, end_dim=-1) for _ in x]
167
+ return x
168
+
169
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
170
+ # x with shape: (batch, n_chans, n_times)
171
+ x = self.ensure_dim(x)
172
+ # x with shape: (batch, 1, n_chans, n_times)
173
+ x_list = [tsconv(x) for tsconv in self.mstsconv]
174
+ # x_list contains 4 tensors, each of shape: [batch_size, seq_len, embed_dim]
175
+ branch_preds = [
176
+ branch(x_list[idx]) for idx, branch in enumerate(self.branch_head)
177
+ ]
178
+ # branch_preds contains 4 tensors, each of shape: [batch_size, num_classes]
179
+ x = torch.stack(x_list, dim=2)
180
+ x = x.view(x.size(0), x.size(1), -1)
181
+ # x shape after concatenation: [batch_size, seq_len, total_embed_dim]
182
+ x = self.transformer(x)
183
+ # x shape after transformer: [batch_size, embed_dim]
184
+
185
+ x = self.final_layer(x)
186
+ if self.return_features:
187
+ # x shape after final layer: [batch_size, num_classes]
188
+ # branch_preds shape: [batch_size, num_classes]
189
+ return torch.stack(branch_preds)
190
+ return x
191
+
192
+
193
+ class _TSConv(nn.Sequential):
194
+ """
195
+ Time-Distributed Separable Convolution block.
196
+
197
+ The architecture consists of:
198
+ - **Temporal Convolution**
199
+ - **Batch Normalization**
200
+ - **Depthwise Spatial Convolution**
201
+ - **Batch Normalization**
202
+ - **Activation Function**
203
+ - **First Pooling Layer**
204
+ - **Dropout**
205
+ - **Depthwise Temporal Convolution**
206
+ - **Batch Normalization**
207
+ - **Activation Function**
208
+ - **Second Pooling Layer**
209
+ - **Dropout**
210
+
211
+ Parameters
212
+ ----------
213
+ n_channels : int
214
+ Number of input channels (EEG channels).
215
+ n_filters : int
216
+ Number of filters for the convolution layers.
217
+ conv1_kernel_size : int
218
+ Kernel size for the first convolution layer.
219
+ conv2_kernel_size : int
220
+ Kernel size for the second convolution layer.
221
+ depth_multiplier : int
222
+ Depth multiplier for depthwise convolution.
223
+ pool1_size : int
224
+ Kernel size for the first pooling layer.
225
+ pool2_size : int
226
+ Kernel size for the second pooling layer.
227
+ drop_prob : float
228
+ Dropout probability.
229
+ activation : Type[nn.Module], optional
230
+ Activation function class to use, by default nn.ELU.
231
+ """
232
+
233
+ def __init__(
234
+ self,
235
+ n_channels: int,
236
+ n_filters: int,
237
+ conv1_kernel_size: int,
238
+ conv2_kernel_size: int,
239
+ depth_multiplier: int,
240
+ pool1_size: int,
241
+ pool2_size: int,
242
+ drop_prob: float,
243
+ activation: Type[nn.Module] = nn.ELU,
244
+ ):
245
+ super().__init__(
246
+ nn.Conv2d(
247
+ in_channels=1,
248
+ out_channels=n_filters,
249
+ kernel_size=(1, conv1_kernel_size),
250
+ padding="same",
251
+ bias=False,
252
+ ),
253
+ nn.BatchNorm2d(n_filters),
254
+ nn.Conv2d(
255
+ in_channels=n_filters,
256
+ out_channels=n_filters * depth_multiplier,
257
+ kernel_size=(n_channels, 1),
258
+ groups=n_filters,
259
+ bias=False,
260
+ ),
261
+ nn.BatchNorm2d(n_filters * depth_multiplier),
262
+ activation(),
263
+ nn.AvgPool2d(kernel_size=(1, pool1_size)),
264
+ nn.Dropout(drop_prob),
265
+ nn.Conv2d(
266
+ in_channels=n_filters * depth_multiplier,
267
+ out_channels=n_filters * depth_multiplier,
268
+ kernel_size=(1, conv2_kernel_size),
269
+ padding="same",
270
+ groups=n_filters * depth_multiplier,
271
+ bias=False,
272
+ ),
273
+ nn.BatchNorm2d(n_filters * depth_multiplier),
274
+ activation(),
275
+ nn.AvgPool2d(kernel_size=(1, pool2_size)),
276
+ nn.Dropout(drop_prob),
277
+ )
278
+
279
+
280
+ class _PositionalEncoding(nn.Module):
281
+ """
282
+ Positional encoding module that adds learnable positional embeddings.
283
+
284
+ Parameters
285
+ ----------
286
+ seq_length : int
287
+ Sequence length.
288
+ d_model : int
289
+ Dimensionality of the model.
290
+ """
291
+
292
+ def __init__(self, seq_length: int, d_model: int) -> None:
293
+ super().__init__()
294
+ self.seq_length = seq_length
295
+ self.d_model = d_model
296
+ self.pe = nn.Parameter(torch.zeros(1, seq_length, d_model))
297
+
298
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
299
+ x = x + self.pe
300
+ return x
301
+
302
+
303
+ class _Transformer(nn.Module):
304
+ """
305
+ Transformer encoder module with learnable class token and positional encoding.
306
+
307
+ Parameters
308
+ ----------
309
+ seq_length : int
310
+ Sequence length of the input.
311
+ d_model : int
312
+ Dimensionality of the model.
313
+ num_heads : int
314
+ Number of heads in the multihead attention.
315
+ feedforward_ratio : float
316
+ Ratio to compute the dimension of the feedforward network.
317
+ drop_prob : float, optional
318
+ Dropout probability, by default 0.5.
319
+ num_layers : int, optional
320
+ Number of transformer encoder layers, by default 4.
321
+ """
322
+
323
+ def __init__(
324
+ self,
325
+ seq_length: int,
326
+ d_model: int,
327
+ num_heads: int,
328
+ feedforward_ratio: float,
329
+ drop_prob: float = 0.5,
330
+ num_layers: int = 4,
331
+ ) -> None:
332
+ super().__init__()
333
+ self.cls_embedding = nn.Parameter(torch.zeros(1, 1, d_model))
334
+ self.pos_embedding = _PositionalEncoding(seq_length + 1, d_model)
335
+
336
+ dim_ff = int(d_model * feedforward_ratio)
337
+ self.dropout = nn.Dropout(drop_prob)
338
+ self.trans = nn.TransformerEncoder(
339
+ nn.TransformerEncoderLayer(
340
+ d_model,
341
+ num_heads,
342
+ dim_ff,
343
+ drop_prob,
344
+ batch_first=True,
345
+ norm_first=True,
346
+ ),
347
+ num_layers,
348
+ norm=nn.LayerNorm(d_model),
349
+ )
350
+
351
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
352
+ batch_size = x.shape[0]
353
+ x = torch.cat((self.cls_embedding.expand(batch_size, -1, -1), x), dim=1)
354
+ x = self.pos_embedding(x)
355
+ x = self.dropout(x)
356
+ return self.trans(x)[:, 0]
357
+
358
+
359
+ class _DenseLayers(nn.Sequential):
360
+ """
361
+ Final classification layers.
362
+
363
+ Parameters
364
+ ----------
365
+ linear_in : int
366
+ Input dimension to the linear layer.
367
+ n_classes : int
368
+ Number of output classes.
369
+ """
370
+
371
+ def __init__(self, linear_in: int, n_classes: int):
372
+ super().__init__(
373
+ nn.Flatten(),
374
+ nn.Linear(linear_in, n_classes),
375
+ )
@@ -0,0 +1,207 @@
1
+ # Authors: Chun-Shu Wei
2
+ # Bruno Aristimunha <b.aristimunha@gmail.com> (braindecode adaptation)
3
+ #
4
+ # License: BSD (3-clause)
5
+
6
+ import math
7
+ from warnings import warn
8
+
9
+ import torch
10
+ from einops.layers.torch import Rearrange
11
+ from torch import nn
12
+
13
+ from braindecode.models.base import EEGModuleMixin
14
+ from braindecode.modules import LogActivation
15
+
16
+
17
+ class SCCNet(EEGModuleMixin, nn.Module):
18
+ """SCCNet from Wei, C S (2019) [sccnet]_.
19
+
20
+ Spatial component-wise convolutional network (SCCNet) for motor-imagery EEG
21
+ classification.
22
+
23
+ .. figure:: https://dt5vp8kor0orz.cloudfront.net/6e3ec5d729cd51fe8acc5a978db27d02a5df9e05/2-Figure1-1.png
24
+ :align: center
25
+ :alt: Spatial component-wise convolutional network
26
+
27
+
28
+ 1. **Spatial Component Analysis**: Performs convolution spatial filtering
29
+ across all EEG channels to extract spatial components, effectively
30
+ reducing the channel dimension.
31
+ 2. **Spatio-Temporal Filtering**: Applies convolution across the spatial
32
+ components and temporal domain to capture spatio-temporal patterns.
33
+ 3. **Temporal Smoothing (Pooling)**: Uses average pooling over time to smooth the
34
+ features and reduce the temporal dimension, focusing on longer-term patterns.
35
+ 4. **Classification**: Flattens the features and applies a fully connected
36
+ layer.
37
+
38
+
39
+ Parameters
40
+ ----------
41
+ n_spatial_filters : int, optional
42
+ Number of spatial filters in the first convolutional layer. Default is 22.
43
+ n_spatial_filters_smooth : int, optional
44
+ Number of spatial filters used as filter in the second convolutional
45
+ layer. Default is 20.
46
+ drop_prob : float, optional
47
+ Dropout probability. Default is 0.5.
48
+ activation : nn.Module, optional
49
+ Activation function after the second convolutional layer. Default is
50
+ logarithm activation.
51
+
52
+ Notes
53
+ -----
54
+ This implementation is not guaranteed to be correct, has not been checked
55
+ by original authors, only reimplemented from the paper description and
56
+ the source that have not been tested [sccnetcode]_.
57
+
58
+
59
+ References
60
+ ----------
61
+ .. [sccnet] Wei, C. S., Koike-Akino, T., & Wang, Y. (2019, March). Spatial
62
+ component-wise convolutional network (SCCNet) for motor-imagery EEG
63
+ classification. In 2019 9th International IEEE/EMBS Conference on
64
+ Neural Engineering (NER) (pp. 328-331). IEEE.
65
+ .. [sccnetcode] Hsieh, C. Y., Chou, J. L., Chang, Y. H., & Wei, C. S.
66
+ XBrainLab: An Open-Source Software for Explainable Artificial
67
+ Intelligence-Based EEG Analysis. In NeurIPS 2023 AI for
68
+ Science Workshop.
69
+
70
+ """
71
+
72
+ def __init__(
73
+ self,
74
+ # Signal related parameters
75
+ n_chans=None,
76
+ n_outputs=None,
77
+ n_times=None,
78
+ chs_info=None,
79
+ input_window_seconds=None,
80
+ sfreq=None,
81
+ # Model related parameters
82
+ n_spatial_filters: int = 22,
83
+ n_spatial_filters_smooth: int = 20,
84
+ drop_prob: float = 0.5,
85
+ activation: nn.Module = LogActivation,
86
+ batch_norm_momentum: float = 0.1,
87
+ ):
88
+ super().__init__(
89
+ n_outputs=n_outputs,
90
+ n_chans=n_chans,
91
+ chs_info=chs_info,
92
+ n_times=n_times,
93
+ input_window_seconds=input_window_seconds,
94
+ sfreq=sfreq,
95
+ )
96
+ del n_outputs, n_chans, chs_info, n_times, input_window_seconds, sfreq
97
+ # Parameters
98
+ self.n_spatial_filters = n_spatial_filters
99
+ self.n_spatial_filters_smooth = n_spatial_filters_smooth
100
+ self.drop_prob = drop_prob
101
+
102
+ # Original logical for SCCNet
103
+ conv_kernel_time = 0.1 # 100ms
104
+ pool_kernel_time = 0.5 # 500ms
105
+
106
+ # Calculate sample-based sizes from time durations
107
+ conv_kernel_samples = int(math.floor(self.sfreq * conv_kernel_time))
108
+ pool_kernel_samples = int(math.floor(self.sfreq * pool_kernel_time))
109
+
110
+ # If the input window is too short for the default kernel sizes,
111
+ # scale them down proportionally.
112
+ total_kernel_samples = conv_kernel_samples + pool_kernel_samples
113
+
114
+ if self.n_times < total_kernel_samples:
115
+ warning_msg = (
116
+ f"Input window seconds ({self.input_window_seconds:.2f}s) is smaller than the "
117
+ f"model's combined kernel sizes ({(total_kernel_samples / self.sfreq):.2f}s). "
118
+ "Scaling temporal parameters down proportionally."
119
+ )
120
+ warn(warning_msg, UserWarning, stacklevel=2)
121
+
122
+ scaling_factor = self.n_times / total_kernel_samples
123
+ conv_kernel_samples = int(math.floor(conv_kernel_samples * scaling_factor))
124
+ pool_kernel_samples = int(math.floor(pool_kernel_samples * scaling_factor))
125
+
126
+ # Ensure kernels are at least 1 sample wide
127
+ self.samples_100ms = max(1, conv_kernel_samples)
128
+ self.kernel_size_pool = max(1, pool_kernel_samples)
129
+
130
+ num_features = self._calc_num_features()
131
+
132
+ # Layers
133
+ self.ensure_dim = Rearrange("batch nchan times -> batch 1 nchan times")
134
+
135
+ self.activation = LogActivation() if activation is None else activation()
136
+
137
+ self.spatial_conv = nn.Conv2d(
138
+ in_channels=1,
139
+ out_channels=self.n_spatial_filters,
140
+ kernel_size=(self.n_chans, 1),
141
+ )
142
+
143
+ self.spatial_batch_norm = nn.BatchNorm2d(
144
+ self.n_spatial_filters, momentum=batch_norm_momentum
145
+ )
146
+
147
+ self.permute = Rearrange(
148
+ "batch filspat nchans time -> batch nchans filspat time"
149
+ )
150
+
151
+ self.spatial_filt_conv = nn.Conv2d(
152
+ in_channels=1,
153
+ out_channels=self.n_spatial_filters_smooth,
154
+ kernel_size=(self.n_spatial_filters, self.samples_100ms),
155
+ bias=False,
156
+ )
157
+ self.batch_norm = nn.BatchNorm2d(
158
+ self.n_spatial_filters_smooth, momentum=batch_norm_momentum
159
+ )
160
+
161
+ self.dropout = nn.Dropout(self.drop_prob)
162
+ self.temporal_smoothing = nn.AvgPool2d(
163
+ kernel_size=(1, self.kernel_size_pool),
164
+ stride=(1, self.samples_100ms),
165
+ )
166
+
167
+ self.final_layer = nn.Linear(num_features, self.n_outputs)
168
+
169
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
170
+ # Shape: (batch_size, n_chans, n_times)
171
+ x = self.ensure_dim(x)
172
+ # Shape: (batch_size, 1, n_chans, n_times)
173
+ x = self.spatial_conv(x)
174
+ # Shape: (batch_size, n_filters, 1, n_times)
175
+ x = self.spatial_batch_norm(x)
176
+ # Shape: (batch_size, n_filters, 1, n_times)
177
+ x = self.permute(x)
178
+ # Shape: (batch_size, 1, n_filters, n_times)
179
+ x = self.spatial_filt_conv(x)
180
+ # Shape: (batch_size, n_filters_filt, 1, n_times_reduced)
181
+ x = self.batch_norm(x)
182
+ # Shape: (batch_size, n_filters_filt, 1, n_times_reduced)
183
+ x = torch.pow(x, 2)
184
+ # Shape: (batch_size, n_filters_filt, 1, n_times_reduced)
185
+ x = self.dropout(x)
186
+ # Shape: (batch_size, n_filters_filt, 1, n_times_reduced)
187
+ x = self.temporal_smoothing(x)
188
+ # Shape: (batch_size, n_filters_filt, 1, n_times_reduced_avg_pool)
189
+ x = self.activation(x)
190
+ # Shape: (batch_size, n_filters_filt, 1, n_times_reduced_avg_pool)
191
+ x = x.view(x.size(0), -1)
192
+ # Shape: (batch_size, n_filters_filt*n_times_reduced_avg_pool)
193
+ x = self.final_layer(x)
194
+ # Shape: (batch_size, n_outputs)
195
+ return x
196
+
197
+ def _calc_num_features(self) -> int:
198
+ # Compute the number of features for the final linear layer
199
+ w_out_conv2 = (
200
+ self.n_times - self.samples_100ms + 1 # After second conv layer
201
+ )
202
+ w_out_pool = (
203
+ (w_out_conv2 - self.kernel_size_pool) // self.samples_100ms + 1
204
+ # After pooling layer
205
+ )
206
+ num_features = self.n_spatial_filters_smooth * w_out_pool
207
+ return num_features