braindecode 1.3.0.dev177069446__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (124) hide show
  1. braindecode/__init__.py +9 -0
  2. braindecode/augmentation/__init__.py +52 -0
  3. braindecode/augmentation/base.py +225 -0
  4. braindecode/augmentation/functional.py +1300 -0
  5. braindecode/augmentation/transforms.py +1356 -0
  6. braindecode/classifier.py +258 -0
  7. braindecode/datasets/__init__.py +44 -0
  8. braindecode/datasets/base.py +823 -0
  9. braindecode/datasets/bbci.py +693 -0
  10. braindecode/datasets/bcicomp.py +193 -0
  11. braindecode/datasets/bids/__init__.py +54 -0
  12. braindecode/datasets/bids/datasets.py +239 -0
  13. braindecode/datasets/bids/format.py +717 -0
  14. braindecode/datasets/bids/hub.py +987 -0
  15. braindecode/datasets/bids/hub_format.py +717 -0
  16. braindecode/datasets/bids/hub_io.py +197 -0
  17. braindecode/datasets/bids/hub_validation.py +114 -0
  18. braindecode/datasets/bids/iterable.py +220 -0
  19. braindecode/datasets/chb_mit.py +163 -0
  20. braindecode/datasets/mne.py +170 -0
  21. braindecode/datasets/moabb.py +219 -0
  22. braindecode/datasets/nmt.py +313 -0
  23. braindecode/datasets/registry.py +120 -0
  24. braindecode/datasets/siena.py +162 -0
  25. braindecode/datasets/sleep_physio_challe_18.py +411 -0
  26. braindecode/datasets/sleep_physionet.py +125 -0
  27. braindecode/datasets/tuh.py +591 -0
  28. braindecode/datasets/utils.py +67 -0
  29. braindecode/datasets/xy.py +96 -0
  30. braindecode/datautil/__init__.py +62 -0
  31. braindecode/datautil/channel_utils.py +114 -0
  32. braindecode/datautil/hub_formats.py +180 -0
  33. braindecode/datautil/serialization.py +359 -0
  34. braindecode/datautil/util.py +154 -0
  35. braindecode/eegneuralnet.py +372 -0
  36. braindecode/functional/__init__.py +22 -0
  37. braindecode/functional/functions.py +251 -0
  38. braindecode/functional/initialization.py +47 -0
  39. braindecode/models/__init__.py +117 -0
  40. braindecode/models/atcnet.py +830 -0
  41. braindecode/models/attentionbasenet.py +727 -0
  42. braindecode/models/attn_sleep.py +549 -0
  43. braindecode/models/base.py +574 -0
  44. braindecode/models/bendr.py +493 -0
  45. braindecode/models/biot.py +537 -0
  46. braindecode/models/brainmodule.py +845 -0
  47. braindecode/models/config.py +233 -0
  48. braindecode/models/contrawr.py +319 -0
  49. braindecode/models/ctnet.py +541 -0
  50. braindecode/models/deep4.py +376 -0
  51. braindecode/models/deepsleepnet.py +417 -0
  52. braindecode/models/eegconformer.py +475 -0
  53. braindecode/models/eeginception_erp.py +379 -0
  54. braindecode/models/eeginception_mi.py +379 -0
  55. braindecode/models/eegitnet.py +302 -0
  56. braindecode/models/eegminer.py +256 -0
  57. braindecode/models/eegnet.py +359 -0
  58. braindecode/models/eegnex.py +354 -0
  59. braindecode/models/eegsimpleconv.py +201 -0
  60. braindecode/models/eegsym.py +917 -0
  61. braindecode/models/eegtcnet.py +337 -0
  62. braindecode/models/fbcnet.py +225 -0
  63. braindecode/models/fblightconvnet.py +315 -0
  64. braindecode/models/fbmsnet.py +338 -0
  65. braindecode/models/hybrid.py +126 -0
  66. braindecode/models/ifnet.py +443 -0
  67. braindecode/models/labram.py +1316 -0
  68. braindecode/models/luna.py +891 -0
  69. braindecode/models/medformer.py +760 -0
  70. braindecode/models/msvtnet.py +377 -0
  71. braindecode/models/patchedtransformer.py +640 -0
  72. braindecode/models/reve.py +843 -0
  73. braindecode/models/sccnet.py +280 -0
  74. braindecode/models/shallow_fbcsp.py +212 -0
  75. braindecode/models/signal_jepa.py +1122 -0
  76. braindecode/models/sinc_shallow.py +339 -0
  77. braindecode/models/sleep_stager_blanco_2020.py +169 -0
  78. braindecode/models/sleep_stager_chambon_2018.py +159 -0
  79. braindecode/models/sparcnet.py +426 -0
  80. braindecode/models/sstdpn.py +869 -0
  81. braindecode/models/summary.csv +47 -0
  82. braindecode/models/syncnet.py +234 -0
  83. braindecode/models/tcn.py +275 -0
  84. braindecode/models/tidnet.py +397 -0
  85. braindecode/models/tsinception.py +295 -0
  86. braindecode/models/usleep.py +439 -0
  87. braindecode/models/util.py +369 -0
  88. braindecode/modules/__init__.py +92 -0
  89. braindecode/modules/activation.py +86 -0
  90. braindecode/modules/attention.py +883 -0
  91. braindecode/modules/blocks.py +160 -0
  92. braindecode/modules/convolution.py +330 -0
  93. braindecode/modules/filter.py +654 -0
  94. braindecode/modules/layers.py +216 -0
  95. braindecode/modules/linear.py +70 -0
  96. braindecode/modules/parametrization.py +38 -0
  97. braindecode/modules/stats.py +87 -0
  98. braindecode/modules/util.py +85 -0
  99. braindecode/modules/wrapper.py +90 -0
  100. braindecode/preprocessing/__init__.py +271 -0
  101. braindecode/preprocessing/eegprep_preprocess.py +1317 -0
  102. braindecode/preprocessing/mne_preprocess.py +240 -0
  103. braindecode/preprocessing/preprocess.py +579 -0
  104. braindecode/preprocessing/util.py +177 -0
  105. braindecode/preprocessing/windowers.py +1037 -0
  106. braindecode/regressor.py +234 -0
  107. braindecode/samplers/__init__.py +18 -0
  108. braindecode/samplers/base.py +399 -0
  109. braindecode/samplers/ssl.py +263 -0
  110. braindecode/training/__init__.py +23 -0
  111. braindecode/training/callbacks.py +23 -0
  112. braindecode/training/losses.py +105 -0
  113. braindecode/training/scoring.py +477 -0
  114. braindecode/util.py +419 -0
  115. braindecode/version.py +1 -0
  116. braindecode/visualization/__init__.py +8 -0
  117. braindecode/visualization/confusion_matrices.py +289 -0
  118. braindecode/visualization/gradients.py +62 -0
  119. braindecode-1.3.0.dev177069446.dist-info/METADATA +230 -0
  120. braindecode-1.3.0.dev177069446.dist-info/RECORD +124 -0
  121. braindecode-1.3.0.dev177069446.dist-info/WHEEL +5 -0
  122. braindecode-1.3.0.dev177069446.dist-info/licenses/LICENSE.txt +31 -0
  123. braindecode-1.3.0.dev177069446.dist-info/licenses/NOTICE.txt +20 -0
  124. braindecode-1.3.0.dev177069446.dist-info/top_level.txt +1 -0
@@ -0,0 +1,337 @@
1
+ # Authors: Bruno Aristimunha <b.aristimunha@gmail.com>
2
+ #
3
+ # License: BSD (3-clause)
4
+
5
+ from __future__ import annotations
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ from einops.layers.torch import Rearrange
10
+
11
+ from braindecode.models.base import EEGModuleMixin
12
+ from braindecode.modules import Chomp1d, MaxNormLinear
13
+
14
+
15
+ class EEGTCNet(EEGModuleMixin, nn.Module):
16
+ r"""EEGTCNet model from Ingolfsson et al (2020) [ingolfsson2020]_.
17
+
18
+ :bdg-success:`Convolution` :bdg-secondary:`Recurrent`
19
+
20
+ .. figure:: https://braindecode.org/dev/_static/model/eegtcnet.jpg
21
+ :align: center
22
+ :alt: EEGTCNet Architecture
23
+
24
+ Combining EEGNet and TCN blocks.
25
+
26
+ Parameters
27
+ ----------
28
+ activation : nn.Module, optional
29
+ Activation function to use. Default is `nn.ELU()`.
30
+ depth_multiplier : int, optional
31
+ Depth multiplier for the depthwise convolution. Default is 2.
32
+ filter_1 : int, optional
33
+ Number of temporal filters in the first convolutional layer. Default is 8.
34
+ kern_length : int, optional
35
+ Length of the temporal kernel in the first convolutional layer. Default is 64.
36
+ dropout : float, optional
37
+ Dropout rate. Default is 0.5.
38
+ depth : int, optional
39
+ Number of residual blocks in the TCN. Default is 2.
40
+ kernel_size : int, optional
41
+ Size of the temporal convolutional kernel in the TCN. Default is 4.
42
+ filters : int, optional
43
+ Number of filters in the TCN convolutional layers. Default is 12.
44
+ max_norm_const : float
45
+ Maximum L2-norm constraint imposed on weights of the last
46
+ fully-connected layer. Defaults to 0.25.
47
+
48
+ References
49
+ ----------
50
+ .. [ingolfsson2020] Ingolfsson, T. M., Hersche, M., Wang, X., Kobayashi, N.,
51
+ Cavigelli, L., & Benini, L. (2020). EEG-TCNet: An accurate temporal
52
+ convolutional network for embedded motor-imagery brain–machine interfaces.
53
+ https://doi.org/10.48550/arXiv.2006.00622
54
+ """
55
+
56
+ def __init__(
57
+ self,
58
+ # Signal related parameters
59
+ n_chans=None,
60
+ n_outputs=None,
61
+ n_times=None,
62
+ chs_info=None,
63
+ input_window_seconds=None,
64
+ sfreq=None,
65
+ # Model parameters
66
+ activation: type[nn.Module] = nn.ELU,
67
+ depth_multiplier: int = 2,
68
+ filter_1: int = 8,
69
+ kern_length: int = 64,
70
+ drop_prob: float = 0.5,
71
+ depth: int = 2,
72
+ kernel_size: int = 4,
73
+ filters: int = 12,
74
+ max_norm_const: float = 0.25,
75
+ ):
76
+ super().__init__(
77
+ n_outputs=n_outputs,
78
+ n_chans=n_chans,
79
+ chs_info=chs_info,
80
+ n_times=n_times,
81
+ input_window_seconds=input_window_seconds,
82
+ sfreq=sfreq,
83
+ )
84
+ del n_outputs, n_chans, chs_info, n_times, input_window_seconds, sfreq
85
+
86
+ self.activation = activation
87
+ self.drop_prob = drop_prob
88
+ self.depth_multiplier = depth_multiplier
89
+ self.filter_1 = filter_1
90
+ self.kern_length = kern_length
91
+ self.depth = depth
92
+ self.kernel_size = kernel_size
93
+ self.filters = filters
94
+ self.max_norm_const = max_norm_const
95
+ self.filter_2 = self.filter_1 * self.depth_multiplier
96
+
97
+ self.arrange_dim_input = Rearrange(
98
+ "batch nchans ntimes -> batch 1 ntimes nchans"
99
+ )
100
+ # EEGNet_TC Block
101
+ self.eegnet_tc = _EEGNetTC(
102
+ n_chans=self.n_chans,
103
+ filter_1=self.filter_1,
104
+ kern_length=self.kern_length,
105
+ depth_multiplier=self.depth_multiplier,
106
+ drop_prob=self.drop_prob,
107
+ activation=self.activation,
108
+ )
109
+ self.arrange_dim_eegnet = Rearrange(
110
+ "batch filter2 rtimes 1 -> batch rtimes filter2"
111
+ )
112
+
113
+ # TCN Block
114
+ self.tcn_block = _TCNBlock(
115
+ input_dimension=self.filter_2,
116
+ depth=self.depth,
117
+ kernel_size=self.kernel_size,
118
+ filters=self.filters,
119
+ drop_prob=self.drop_prob,
120
+ activation=self.activation,
121
+ )
122
+
123
+ # Classification Block
124
+ self.final_layer = MaxNormLinear(
125
+ in_features=self.filters,
126
+ out_features=self.n_outputs,
127
+ max_norm_val=self.max_norm_const,
128
+ )
129
+
130
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
131
+ """
132
+ Forward pass of the EEGTCNet model.
133
+
134
+ Parameters
135
+ ----------
136
+ x : torch.Tensor
137
+ Input tensor of shape (batch_size, n_chans, n_times).
138
+
139
+ Returns
140
+ -------
141
+ torch.Tensor
142
+ Output tensor of shape (batch_size, n_outputs).
143
+ """
144
+ # x shape: (batch_size, n_chans, n_times)
145
+ x = self.arrange_dim_input(x) # (batch_size, 1, n_times, n_chans)
146
+ x = self.eegnet_tc(x) # (batch_size, filter, reduced_time, 1)
147
+
148
+ x = self.arrange_dim_eegnet(x) # (batch_size, reduced_time, F2)
149
+ x = self.tcn_block(x) # (batch_size, time_steps, filters)
150
+
151
+ # Select the last time step
152
+ x = x[:, -1, :] # (batch_size, filters)
153
+
154
+ x = self.final_layer(x) # (batch_size, n_outputs)
155
+
156
+ return x
157
+
158
+
159
+ class _EEGNetTC(nn.Module):
160
+ r"""EEGNet Temporal Convolutional Network (TCN) block.
161
+
162
+ The main difference from our :class:`EEGNet` (braindecode) implementation is the
163
+ kernel and dimensional order. Because of this, we decided to keep this
164
+ implementation in a future issue; we will re-evaluate if it is necessary
165
+ to maintain this separate implementation.
166
+
167
+ Parameters
168
+ ----------
169
+ n_chans : int
170
+ Number of EEG channels.
171
+ filter_1 : int
172
+ Number of temporal filters in the first convolutional layer.
173
+ kern_length : int
174
+ Length of the temporal kernel in the first convolutional layer.
175
+ depth_multiplier : int
176
+ Depth multiplier for the depthwise convolution.
177
+ drop_prob : float
178
+ Dropout rate.
179
+ activation : nn.Module
180
+ Activation function.
181
+ """
182
+
183
+ def __init__(
184
+ self,
185
+ n_chans: int,
186
+ filter_1: int = 8,
187
+ kern_length: int = 64,
188
+ depth_multiplier: int = 2,
189
+ drop_prob: float = 0.5,
190
+ activation: type[nn.Module] = nn.ELU,
191
+ ):
192
+ super().__init__()
193
+ self.activation = activation()
194
+ self.drop_prob = drop_prob
195
+ self.n_chans = n_chans
196
+ self.filter_1 = filter_1
197
+ self.filter_2 = self.filter_1 * depth_multiplier
198
+
199
+ # First Conv2D Layer
200
+ self.conv1 = nn.Conv2d(
201
+ in_channels=1,
202
+ out_channels=self.filter_1,
203
+ kernel_size=(kern_length, 1),
204
+ padding=(kern_length // 2, 0),
205
+ bias=False,
206
+ )
207
+ self.bn1 = nn.BatchNorm2d(self.filter_1)
208
+
209
+ # Depthwise Convolution
210
+ self.depthwise_conv = nn.Conv2d(
211
+ in_channels=self.filter_1,
212
+ out_channels=self.filter_2,
213
+ kernel_size=(1, n_chans),
214
+ groups=self.filter_1,
215
+ bias=False,
216
+ )
217
+ self.bn2 = nn.BatchNorm2d(self.filter_2)
218
+ self.pool1 = nn.AvgPool2d(kernel_size=(8, 1))
219
+ self.drop1 = nn.Dropout(p=drop_prob)
220
+
221
+ # Separable Convolution (Depthwise + Pointwise)
222
+ self.separable_conv_depthwise = nn.Conv2d(
223
+ in_channels=self.filter_2,
224
+ out_channels=self.filter_2,
225
+ kernel_size=(self.filter_2, 1),
226
+ groups=self.filter_2,
227
+ padding=(self.filter_2 // 2, 0),
228
+ bias=False,
229
+ )
230
+ self.separable_conv_pointwise = nn.Conv2d(
231
+ in_channels=self.filter_2,
232
+ out_channels=self.filter_2,
233
+ kernel_size=(1, 1),
234
+ bias=False,
235
+ )
236
+ self.bn3 = nn.BatchNorm2d(self.filter_2)
237
+ self.pool2 = nn.AvgPool2d(kernel_size=(self.filter_1, 1))
238
+ self.drop2 = nn.Dropout(p=drop_prob)
239
+
240
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
241
+ # x shape: (batch_size, 1, n_times, n_chans)
242
+ x = self.conv1(x)
243
+ x = self.bn1(x)
244
+ x = self.activation(x)
245
+
246
+ x = self.depthwise_conv(x)
247
+ x = self.bn2(x)
248
+ x = self.activation(x)
249
+ x = self.pool1(x)
250
+ x = self.drop1(x)
251
+
252
+ x = self.separable_conv_depthwise(x)
253
+ x = self.separable_conv_pointwise(x)
254
+ x = self.bn3(x)
255
+ x = self.activation(x)
256
+ x = self.pool2(x)
257
+ x = self.drop2(x)
258
+
259
+ return x # Shape: (batch_size, F2, reduced_time, 1)
260
+
261
+
262
+ class _TCNBlock(nn.Module):
263
+ r"""
264
+ Many differences from our Temporal Block (braindecode) implementation.
265
+ Because of this, we decided to keep this implementation in a future issue;
266
+ we will re-evaluate if it is necessary to maintain this separate
267
+ implementation.
268
+
269
+
270
+ """
271
+
272
+ def __init__(
273
+ self,
274
+ input_dimension: int,
275
+ depth: int,
276
+ kernel_size: int,
277
+ filters: int,
278
+ drop_prob: float,
279
+ activation: type[nn.Module] = nn.ELU,
280
+ ):
281
+ super().__init__()
282
+ self.activation = activation()
283
+ self.drop_prob = drop_prob
284
+ self.depth = depth
285
+ self.filters = filters
286
+ self.kernel_size = kernel_size
287
+
288
+ self.layers = nn.ModuleList()
289
+ self.downsample = (
290
+ nn.Conv1d(input_dimension, filters, kernel_size=1, bias=False)
291
+ if input_dimension != filters
292
+ else None
293
+ )
294
+
295
+ for i in range(depth):
296
+ dilation = 2**i
297
+ padding = (kernel_size - 1) * dilation
298
+ conv_block = nn.Sequential(
299
+ nn.Conv1d(
300
+ in_channels=input_dimension if i == 0 else filters,
301
+ out_channels=filters,
302
+ kernel_size=kernel_size,
303
+ dilation=dilation,
304
+ padding=padding,
305
+ bias=False,
306
+ ),
307
+ Chomp1d(padding),
308
+ self.activation,
309
+ nn.Dropout(self.drop_prob),
310
+ nn.Conv1d(
311
+ in_channels=filters,
312
+ out_channels=filters,
313
+ kernel_size=kernel_size,
314
+ dilation=dilation,
315
+ padding=padding,
316
+ bias=False,
317
+ ),
318
+ Chomp1d(padding),
319
+ self.activation,
320
+ nn.Dropout(self.drop_prob),
321
+ )
322
+ self.layers.append(conv_block)
323
+
324
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
325
+ # x shape: (batch_size, time_steps, input_dimension)
326
+ x = x.permute(0, 2, 1) # (batch_size, input_dimension, time_steps)
327
+
328
+ res = x if self.downsample is None else self.downsample(x)
329
+ for layer in self.layers:
330
+ out = layer(x)
331
+ out = out + res
332
+ out = self.activation(out)
333
+ res = out # Update residual
334
+ x = out # Update input for next layer
335
+
336
+ out = out.permute(0, 2, 1) # (batch_size, time_steps, filters)
337
+ return out
@@ -0,0 +1,225 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import Any
4
+
5
+ import torch
6
+ from einops.layers.torch import Rearrange
7
+ from mne.utils import warn
8
+ from torch import nn
9
+
10
+ from braindecode.models.base import EEGModuleMixin
11
+ from braindecode.modules import (
12
+ Conv2dWithConstraint,
13
+ FilterBankLayer,
14
+ LinearWithConstraint,
15
+ LogVarLayer,
16
+ MaxLayer,
17
+ MeanLayer,
18
+ StdLayer,
19
+ VarLayer,
20
+ )
21
+
22
+ _valid_layers = {
23
+ "VarLayer": VarLayer,
24
+ "StdLayer": StdLayer,
25
+ "LogVarLayer": LogVarLayer,
26
+ "MeanLayer": MeanLayer,
27
+ "MaxLayer": MaxLayer,
28
+ }
29
+
30
+
31
+ class FBCNet(EEGModuleMixin, nn.Module):
32
+ r"""FBCNet from Mane, R et al (2021) [fbcnet2021]_.
33
+
34
+ :bdg-success:`Convolution` :bdg-primary:`Filterbank`
35
+
36
+ .. figure:: https://raw.githubusercontent.com/ravikiran-mane/FBCNet/refs/heads/master/FBCNet-V2.png
37
+ :align: center
38
+ :alt: FBCNet Architecture
39
+
40
+ The FBCNet model applies spatial convolution and variance calculation along
41
+ the time axis, inspired by the Filter Bank Common Spatial Pattern (FBCSP)
42
+ algorithm.
43
+
44
+ Notes
45
+ -----
46
+ This implementation is not guaranteed to be correct and has not been checked
47
+ by the original authors; it has only been reimplemented from the paper
48
+ description and source code [fbcnetcode2021]_. There is a difference in the
49
+ activation function; in the paper, the ELU is used as the activation function,
50
+ but in the original code, SiLU is used. We followed the code.
51
+
52
+ Parameters
53
+ ----------
54
+ n_bands : int or None or list[tuple[int, int]]], default=9
55
+ Number of frequency bands. Could
56
+ n_filters_spat : int, default=32
57
+ Number of spatial filters for the first convolution.
58
+ n_dim: int, default=3
59
+ Number of dimensions for the temporal reductor
60
+ temporal_layer : str, default='LogVarLayer'
61
+ Type of temporal aggregator layer. Options: 'VarLayer', 'StdLayer',
62
+ 'LogVarLayer', 'MeanLayer', 'MaxLayer'.
63
+ stride_factor : int, default=4
64
+ Stride factor for reshaping.
65
+ activation : nn.Module, default=nn.SiLU
66
+ Activation function class to apply in Spatial Convolution Block.
67
+ cnn_max_norm : float, default=2.0
68
+ Maximum norm for the spatial convolution layer.
69
+ linear_max_norm : float, default=0.5
70
+ Maximum norm for the final linear layer.
71
+ filter_parameters: dict, default None
72
+ Dictionary of parameters to use for the FilterBankLayer.
73
+ If None, a default Chebyshev Type II filter with transition bandwidth of
74
+ 2 Hz and stop-band ripple of 30 dB will be used.
75
+
76
+ References
77
+ ----------
78
+ .. [fbcnet2021] Mane, R., Chew, E., Chua, K., Ang, K. K., Robinson, N.,
79
+ Vinod, A. P., ... & Guan, C. (2021). FBCNet: A multi-view convolutional
80
+ neural network for brain-computer interface. preprint arXiv:2104.01233.
81
+ .. [fbcnetcode2021] Link to source-code:
82
+ https://github.com/ravikiran-mane/FBCNet
83
+ """
84
+
85
+ def __init__(
86
+ self,
87
+ # Braindecode parameters
88
+ n_chans=None,
89
+ n_outputs=None,
90
+ chs_info=None,
91
+ n_times=None,
92
+ input_window_seconds=None,
93
+ sfreq=None,
94
+ # models parameters
95
+ n_bands=9,
96
+ n_filters_spat: int = 32,
97
+ temporal_layer: str = "LogVarLayer",
98
+ n_dim: int = 3,
99
+ stride_factor: int = 4,
100
+ activation: type[nn.Module] = nn.SiLU,
101
+ linear_max_norm: float = 0.5,
102
+ cnn_max_norm: float = 2.0,
103
+ filter_parameters: dict[Any, Any] | None = None,
104
+ ):
105
+ super().__init__(
106
+ n_chans=n_chans,
107
+ n_outputs=n_outputs,
108
+ chs_info=chs_info,
109
+ n_times=n_times,
110
+ input_window_seconds=input_window_seconds,
111
+ sfreq=sfreq,
112
+ )
113
+ del n_outputs, n_chans, chs_info, n_times, input_window_seconds, sfreq
114
+
115
+ # Parameters
116
+ self.n_bands = n_bands
117
+ self.n_filters_spat = n_filters_spat
118
+ self.n_dim = n_dim
119
+ self.stride_factor = stride_factor
120
+ self.activation = activation
121
+ self.filter_parameters = filter_parameters or {}
122
+
123
+ # Checkers
124
+ if temporal_layer not in _valid_layers:
125
+ raise NotImplementedError(
126
+ f"Temporal layer '{temporal_layer}' is not implemented."
127
+ )
128
+
129
+ if self.n_times % self.stride_factor != 0:
130
+ warn(
131
+ f"Time dimension ({self.n_times}) is not divisible by"
132
+ f" stride_factor ({self.stride_factor}). Input will be padded.",
133
+ UserWarning,
134
+ )
135
+
136
+ # Layers
137
+ # Following paper nomenclature
138
+ self.spectral_filtering = FilterBankLayer(
139
+ n_chans=self.n_chans,
140
+ sfreq=self.sfreq,
141
+ band_filters=self.n_bands,
142
+ verbose=False,
143
+ **self.filter_parameters,
144
+ )
145
+ # As we have an internal process to create the bands,
146
+ # we get the values from the filterbank
147
+ self.n_bands = self.spectral_filtering.n_bands
148
+
149
+ # Spatial Convolution Block (SCB)
150
+ self.spatial_conv = nn.Sequential(
151
+ Conv2dWithConstraint(
152
+ in_channels=self.n_bands,
153
+ out_channels=self.n_filters_spat * self.n_bands,
154
+ kernel_size=(self.n_chans, 1),
155
+ groups=self.n_bands,
156
+ max_norm=cnn_max_norm,
157
+ padding=0,
158
+ ),
159
+ nn.BatchNorm2d(self.n_filters_spat * self.n_bands),
160
+ self.activation(),
161
+ )
162
+
163
+ # Padding layer
164
+ if self.n_times % self.stride_factor != 0:
165
+ self.padding_size = stride_factor - (self.n_times % stride_factor)
166
+ self.n_times_padded = self.n_times + self.padding_size
167
+ self.padding_layer = nn.ConstantPad1d((0, self.padding_size), 0.0)
168
+ else:
169
+ self.padding_layer = nn.Identity()
170
+ self.n_times_padded = self.n_times
171
+
172
+ # Temporal aggregator
173
+ self.temporal_layer = _valid_layers[temporal_layer](dim=self.n_dim) # type: ignore
174
+
175
+ # Flatten layer
176
+ self.flatten_layer = Rearrange("batch ... -> batch (...)")
177
+
178
+ # Final fully connected layer
179
+ self.final_layer = LinearWithConstraint(
180
+ in_features=self.n_filters_spat * self.n_bands * self.stride_factor,
181
+ out_features=self.n_outputs,
182
+ max_norm=linear_max_norm,
183
+ )
184
+
185
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
186
+ """
187
+ Forward pass of the FBCNet model.
188
+
189
+ Parameters
190
+ ----------
191
+ x : torch.Tensor
192
+ Input tensor with shape (batch_size, n_chans, n_times).
193
+
194
+ Returns
195
+ -------
196
+ torch.Tensor
197
+ Output tensor with shape (batch_size, n_outputs).
198
+ """
199
+ # output: (batch_size, n_chans, n_times)
200
+ x = self.spectral_filtering(x)
201
+
202
+ # output: (batch_size, n_bands, n_chans, n_times)
203
+ x = self.spatial_conv(x)
204
+ batch_size, channels, _, _ = x.shape
205
+
206
+ # shape: (batch_size, n_filters_spat * n_bands, 1, n_times)
207
+ x = self.padding_layer(x)
208
+
209
+ # shape: (batch_size, n_filters_spat * n_bands, 1, n_times_padded)
210
+ x = x.view(
211
+ batch_size,
212
+ channels,
213
+ self.stride_factor,
214
+ self.n_times_padded // self.stride_factor,
215
+ )
216
+ # shape: batch_size, n_filters_spat * n_bands, stride, n_times_padded/stride
217
+ x = self.temporal_layer(x) # type: ignore[operator]
218
+
219
+ # shape: batch_size, n_filters_spat * n_bands, stride, 1
220
+ x = self.flatten_layer(x)
221
+
222
+ # shape: batch_size, n_filters_spat * n_bands * stride
223
+ x = self.final_layer(x)
224
+ # shape: batch_size, n_outputs
225
+ return x