braindecode 1.3.0.dev177069446__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (124) hide show
  1. braindecode/__init__.py +9 -0
  2. braindecode/augmentation/__init__.py +52 -0
  3. braindecode/augmentation/base.py +225 -0
  4. braindecode/augmentation/functional.py +1300 -0
  5. braindecode/augmentation/transforms.py +1356 -0
  6. braindecode/classifier.py +258 -0
  7. braindecode/datasets/__init__.py +44 -0
  8. braindecode/datasets/base.py +823 -0
  9. braindecode/datasets/bbci.py +693 -0
  10. braindecode/datasets/bcicomp.py +193 -0
  11. braindecode/datasets/bids/__init__.py +54 -0
  12. braindecode/datasets/bids/datasets.py +239 -0
  13. braindecode/datasets/bids/format.py +717 -0
  14. braindecode/datasets/bids/hub.py +987 -0
  15. braindecode/datasets/bids/hub_format.py +717 -0
  16. braindecode/datasets/bids/hub_io.py +197 -0
  17. braindecode/datasets/bids/hub_validation.py +114 -0
  18. braindecode/datasets/bids/iterable.py +220 -0
  19. braindecode/datasets/chb_mit.py +163 -0
  20. braindecode/datasets/mne.py +170 -0
  21. braindecode/datasets/moabb.py +219 -0
  22. braindecode/datasets/nmt.py +313 -0
  23. braindecode/datasets/registry.py +120 -0
  24. braindecode/datasets/siena.py +162 -0
  25. braindecode/datasets/sleep_physio_challe_18.py +411 -0
  26. braindecode/datasets/sleep_physionet.py +125 -0
  27. braindecode/datasets/tuh.py +591 -0
  28. braindecode/datasets/utils.py +67 -0
  29. braindecode/datasets/xy.py +96 -0
  30. braindecode/datautil/__init__.py +62 -0
  31. braindecode/datautil/channel_utils.py +114 -0
  32. braindecode/datautil/hub_formats.py +180 -0
  33. braindecode/datautil/serialization.py +359 -0
  34. braindecode/datautil/util.py +154 -0
  35. braindecode/eegneuralnet.py +372 -0
  36. braindecode/functional/__init__.py +22 -0
  37. braindecode/functional/functions.py +251 -0
  38. braindecode/functional/initialization.py +47 -0
  39. braindecode/models/__init__.py +117 -0
  40. braindecode/models/atcnet.py +830 -0
  41. braindecode/models/attentionbasenet.py +727 -0
  42. braindecode/models/attn_sleep.py +549 -0
  43. braindecode/models/base.py +574 -0
  44. braindecode/models/bendr.py +493 -0
  45. braindecode/models/biot.py +537 -0
  46. braindecode/models/brainmodule.py +845 -0
  47. braindecode/models/config.py +233 -0
  48. braindecode/models/contrawr.py +319 -0
  49. braindecode/models/ctnet.py +541 -0
  50. braindecode/models/deep4.py +376 -0
  51. braindecode/models/deepsleepnet.py +417 -0
  52. braindecode/models/eegconformer.py +475 -0
  53. braindecode/models/eeginception_erp.py +379 -0
  54. braindecode/models/eeginception_mi.py +379 -0
  55. braindecode/models/eegitnet.py +302 -0
  56. braindecode/models/eegminer.py +256 -0
  57. braindecode/models/eegnet.py +359 -0
  58. braindecode/models/eegnex.py +354 -0
  59. braindecode/models/eegsimpleconv.py +201 -0
  60. braindecode/models/eegsym.py +917 -0
  61. braindecode/models/eegtcnet.py +337 -0
  62. braindecode/models/fbcnet.py +225 -0
  63. braindecode/models/fblightconvnet.py +315 -0
  64. braindecode/models/fbmsnet.py +338 -0
  65. braindecode/models/hybrid.py +126 -0
  66. braindecode/models/ifnet.py +443 -0
  67. braindecode/models/labram.py +1316 -0
  68. braindecode/models/luna.py +891 -0
  69. braindecode/models/medformer.py +760 -0
  70. braindecode/models/msvtnet.py +377 -0
  71. braindecode/models/patchedtransformer.py +640 -0
  72. braindecode/models/reve.py +843 -0
  73. braindecode/models/sccnet.py +280 -0
  74. braindecode/models/shallow_fbcsp.py +212 -0
  75. braindecode/models/signal_jepa.py +1122 -0
  76. braindecode/models/sinc_shallow.py +339 -0
  77. braindecode/models/sleep_stager_blanco_2020.py +169 -0
  78. braindecode/models/sleep_stager_chambon_2018.py +159 -0
  79. braindecode/models/sparcnet.py +426 -0
  80. braindecode/models/sstdpn.py +869 -0
  81. braindecode/models/summary.csv +47 -0
  82. braindecode/models/syncnet.py +234 -0
  83. braindecode/models/tcn.py +275 -0
  84. braindecode/models/tidnet.py +397 -0
  85. braindecode/models/tsinception.py +295 -0
  86. braindecode/models/usleep.py +439 -0
  87. braindecode/models/util.py +369 -0
  88. braindecode/modules/__init__.py +92 -0
  89. braindecode/modules/activation.py +86 -0
  90. braindecode/modules/attention.py +883 -0
  91. braindecode/modules/blocks.py +160 -0
  92. braindecode/modules/convolution.py +330 -0
  93. braindecode/modules/filter.py +654 -0
  94. braindecode/modules/layers.py +216 -0
  95. braindecode/modules/linear.py +70 -0
  96. braindecode/modules/parametrization.py +38 -0
  97. braindecode/modules/stats.py +87 -0
  98. braindecode/modules/util.py +85 -0
  99. braindecode/modules/wrapper.py +90 -0
  100. braindecode/preprocessing/__init__.py +271 -0
  101. braindecode/preprocessing/eegprep_preprocess.py +1317 -0
  102. braindecode/preprocessing/mne_preprocess.py +240 -0
  103. braindecode/preprocessing/preprocess.py +579 -0
  104. braindecode/preprocessing/util.py +177 -0
  105. braindecode/preprocessing/windowers.py +1037 -0
  106. braindecode/regressor.py +234 -0
  107. braindecode/samplers/__init__.py +18 -0
  108. braindecode/samplers/base.py +399 -0
  109. braindecode/samplers/ssl.py +263 -0
  110. braindecode/training/__init__.py +23 -0
  111. braindecode/training/callbacks.py +23 -0
  112. braindecode/training/losses.py +105 -0
  113. braindecode/training/scoring.py +477 -0
  114. braindecode/util.py +419 -0
  115. braindecode/version.py +1 -0
  116. braindecode/visualization/__init__.py +8 -0
  117. braindecode/visualization/confusion_matrices.py +289 -0
  118. braindecode/visualization/gradients.py +62 -0
  119. braindecode-1.3.0.dev177069446.dist-info/METADATA +230 -0
  120. braindecode-1.3.0.dev177069446.dist-info/RECORD +124 -0
  121. braindecode-1.3.0.dev177069446.dist-info/WHEEL +5 -0
  122. braindecode-1.3.0.dev177069446.dist-info/licenses/LICENSE.txt +31 -0
  123. braindecode-1.3.0.dev177069446.dist-info/licenses/NOTICE.txt +20 -0
  124. braindecode-1.3.0.dev177069446.dist-info/top_level.txt +1 -0
@@ -0,0 +1,315 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import Optional
4
+
5
+ import torch
6
+ import torch.nn.functional as F
7
+ from einops.layers.torch import Rearrange
8
+ from mne.utils import warn
9
+ from torch import nn
10
+
11
+ from braindecode.models.base import EEGModuleMixin
12
+ from braindecode.modules import (
13
+ FilterBankLayer,
14
+ LogVarLayer,
15
+ )
16
+
17
+
18
+ class FBLightConvNet(EEGModuleMixin, nn.Module):
19
+ r"""LightConvNet from Ma, X et al (2023) [lightconvnet]_.
20
+
21
+ :bdg-success:`Convolution` :bdg-primary:`Filterbank`
22
+
23
+ .. figure:: https://raw.githubusercontent.com/Ma-Xinzhi/LightConvNet/refs/heads/main/network_architecture.png
24
+ :align: center
25
+ :alt: LightConvNet Neural Network
26
+
27
+ A lightweight convolutional neural network incorporating temporal
28
+ dependency learning and attention mechanisms. The architecture is
29
+ designed to efficiently capture spatial and temporal features through
30
+ specialized convolutional layers and **multi-head attention**.
31
+
32
+ The network architecture consists of four main modules:
33
+
34
+ 1. **Spatial and Spectral Information Learning**:
35
+ Applies filterbank and spatial convolutions.
36
+ This module is followed by batch normalization and
37
+ an activation function to enhance feature representation.
38
+
39
+ 2. **Temporal Segmentation and Feature Extraction**:
40
+ Divides the processed data into non-overlapping temporal windows.
41
+ Within each window, a variance-based layer extracts discriminative features,
42
+ which are then log-transformed to stabilize variance before being
43
+ passed to the attention module.
44
+
45
+ 3. **Temporal Attention Module**: Utilizes a multi-head attention
46
+ mechanism with depthwise separable convolutions to capture dependencies
47
+ across different temporal segments. The attention weights are normalized
48
+ using softmax and aggregated to form a comprehensive temporal
49
+ representation.
50
+
51
+ 4. **Final Layer**: Flattens the aggregated features and passes them
52
+ through a linear layer to with kernel sizes matching the input
53
+ dimensions to integrate features across different channels generate the
54
+ final output predictions.
55
+
56
+ Notes
57
+ -----
58
+ This implementation is not guaranteed to be correct and has not been checked
59
+ by the original authors; it is a braindecode adaptation from the Pytorch
60
+ source-code [lightconvnetcode]_.
61
+
62
+ Parameters
63
+ ----------
64
+ n_bands : int or None or list of tuple of int, default=8
65
+ Number of frequency bands or a list of frequency band tuples. If a list of tuples is provided,
66
+ each tuple defines the lower and upper bounds of a frequency band.
67
+ n_filters_spat : int, default=32
68
+ Number of spatial filters in the depthwise convolutional layer.
69
+ n_dim : int, default=3
70
+ Number of dimensions for the temporal reduction layer.
71
+ stride_factor : int, default=4
72
+ Stride factor used for reshaping the temporal dimension.
73
+ activation : nn.Module, default=nn.ELU
74
+ Activation function class to apply after convolutional layers.
75
+ verbose : bool, default=False
76
+ If True, enables verbose output during filter creation using mne.
77
+ filter_parameters : dict, default={}
78
+ Additional parameters for the FilterBankLayer.
79
+ heads : int, default=8
80
+ Number of attention heads in the multi-head attention mechanism.
81
+ weight_softmax : bool, default=True
82
+ If True, applies softmax to the attention weights.
83
+ bias : bool, default=False
84
+ If True, includes a bias term in the convolutional layers.
85
+
86
+ References
87
+ ----------
88
+ .. [lightconvnet] Ma, X., Chen, W., Pei, Z., Liu, J., Huang, B., & Chen, J.
89
+ (2023). A temporal dependency learning CNN with attention mechanism
90
+ for MI-EEG decoding. IEEE Transactions on Neural Systems and
91
+ Rehabilitation Engineering.
92
+ .. [lightconvnetcode] Link to source-code:
93
+ https://github.com/Ma-Xinzhi/LightConvNet
94
+ """
95
+
96
+ def __init__(
97
+ self,
98
+ # Braindecode parameters
99
+ n_chans=None,
100
+ n_outputs=None,
101
+ chs_info=None,
102
+ n_times=None,
103
+ input_window_seconds=None,
104
+ sfreq=None,
105
+ # models parameters
106
+ n_bands=9,
107
+ n_filters_spat: int = 32,
108
+ n_dim: int = 3,
109
+ stride_factor: int = 4,
110
+ win_len: int = 250,
111
+ heads: int = 8,
112
+ weight_softmax: bool = True,
113
+ bias: bool = False,
114
+ activation: type[nn.Module] = nn.ELU,
115
+ verbose: bool = False,
116
+ filter_parameters: Optional[dict] = None,
117
+ ):
118
+ super().__init__(
119
+ n_chans=n_chans,
120
+ n_outputs=n_outputs,
121
+ chs_info=chs_info,
122
+ n_times=n_times,
123
+ input_window_seconds=input_window_seconds,
124
+ sfreq=sfreq,
125
+ )
126
+ del n_outputs, n_chans, chs_info, n_times, input_window_seconds, sfreq
127
+
128
+ # Parameters
129
+ self.n_bands = n_bands
130
+ self.n_filters_spat = n_filters_spat
131
+ self.n_dim = n_dim
132
+ self.stride_factor = stride_factor
133
+ self.win_len = win_len
134
+ self.activation = activation
135
+ self.heads = heads
136
+ self.weight_softmax = weight_softmax
137
+ self.bias = bias
138
+ self.filter_parameters = filter_parameters or {}
139
+
140
+ # Checkers
141
+ self.n_times_truncated = self.n_times
142
+ if self.n_times % self.win_len != 0:
143
+ warn(
144
+ f"Time dimension ({self.n_times}) is not divisible by"
145
+ f" win_len ({self.win_len}). Input will be "
146
+ f"truncated in {self.n_times % self.win_len} temporal points ",
147
+ UserWarning,
148
+ )
149
+ self.n_times_truncated = self.n_times - (self.n_times % self.win_len)
150
+
151
+ # Layers
152
+ # Following paper nomeclature
153
+ self.spectral_filtering = FilterBankLayer(
154
+ n_chans=self.n_chans,
155
+ sfreq=self.sfreq,
156
+ band_filters=self.n_bands,
157
+ verbose=verbose,
158
+ **self.filter_parameters,
159
+ )
160
+ # As we have an internal process to create the bands,
161
+ # we get the values from the filterbank
162
+ self.n_bands = self.spectral_filtering.n_bands
163
+
164
+ # The convolution here is different.
165
+ self.spatial_conv = nn.Sequential(
166
+ nn.Conv2d(
167
+ in_channels=self.n_bands,
168
+ out_channels=self.n_filters_spat,
169
+ kernel_size=(self.n_chans, 1),
170
+ ),
171
+ nn.BatchNorm2d(self.n_filters_spat),
172
+ self.activation(),
173
+ )
174
+
175
+ # Temporal aggregator
176
+ self.temporal_layer = LogVarLayer(self.n_dim, False)
177
+
178
+ self.flatten_layer = Rearrange("batch ... -> batch (...)")
179
+
180
+ # LightWeightConv1D
181
+ self.attn_conv = _LightweightConv1d(
182
+ self.n_filters_spat,
183
+ (self.n_times // self.win_len),
184
+ heads=self.heads,
185
+ weight_softmax=weight_softmax,
186
+ bias=bias,
187
+ )
188
+
189
+ self.final_layer = nn.Linear(
190
+ in_features=self.n_filters_spat,
191
+ out_features=self.n_outputs,
192
+ )
193
+
194
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
195
+ """
196
+ Forward pass of the FBLightConvNet model.
197
+ Parameters
198
+ ----------
199
+ x : torch.Tensor
200
+ Input tensor with shape (batch_size, n_chans, n_times).
201
+ Returns
202
+ -------
203
+ torch.Tensor
204
+ Output tensor with shape (batch_size, n_outputs).
205
+ """
206
+ batch_size, _, _ = x.shape
207
+ # x.shape: batch, n_chans, n_times
208
+
209
+ x = self.spectral_filtering(x)
210
+ # x.shape: batch, nbands, n_chans, n_times
211
+
212
+ x = self.spatial_conv(x)
213
+ # x.shape: batch, n_filters_spat, n_times
214
+
215
+ x = x[:, :, :, : self.n_times_truncated]
216
+ # batch, n_filters_spat, n_times_trucated
217
+
218
+ x = x.reshape([batch_size, self.n_filters_spat, -1, self.win_len])
219
+ # batch, n_filters_spat, n_windows, win_len
220
+ # where the n_windows = n_times_truncated / win_len
221
+ # and win_len = 250 by default
222
+
223
+ x = self.temporal_layer(x)
224
+ # x.shape : batch, n_filters_spat, n_windows
225
+
226
+ x = self.attn_conv(x)
227
+ # x.shape : batch, n_filters_spat, 1
228
+
229
+ x = self.flatten_layer(x)
230
+ # x.shape : batch, n_filters_spat
231
+
232
+ x = self.final_layer(x)
233
+ # x.shape : batch, n_outputs
234
+ return x
235
+
236
+
237
+ class _LightweightConv1d(nn.Module):
238
+ r"""Lightweight 1D Convolution Module.
239
+
240
+ Applies a convolution operation with multiple heads, allowing for
241
+ parallel filter applications. Optionally applies a softmax normalization
242
+ to the convolution weights.
243
+
244
+ Parameters
245
+ ----------
246
+ input_size : int
247
+ Number of channels of the input and output.
248
+ kernel_size : int, optional
249
+ Size of the convolution kernel. Default is `1`.
250
+ padding : int, optional
251
+ Amount of zero-padding added to both sides of the input. Default is `0`.
252
+ heads : int, optional
253
+ Number of attention heads used. The weight has shape `(heads, 1, kernel_size)`.
254
+ Default is `1`.
255
+ weight_softmax : bool, optional
256
+ If `True`, normalizes the convolution weights with softmax before applying the convolution.
257
+ Default is `False`.
258
+ bias : bool, optional
259
+ If `True`, adds a learnable bias to the output. Default is `False`.
260
+ """
261
+
262
+ def __init__(
263
+ self,
264
+ input_size: int,
265
+ kernel_size: int = 1,
266
+ padding: int = 0,
267
+ heads: int = 1,
268
+ weight_softmax: bool = False,
269
+ bias: bool = False,
270
+ ):
271
+ super().__init__()
272
+ self.input_size = input_size
273
+ self.kernel_size = kernel_size
274
+ self.heads = heads
275
+ self.padding = padding
276
+ self.weight_softmax = weight_softmax
277
+ self.weight = nn.Parameter(torch.Tensor(heads, 1, kernel_size))
278
+
279
+ if bias:
280
+ self.bias = nn.Parameter(torch.Tensor(input_size))
281
+ else:
282
+ self.bias = None
283
+
284
+ self._init_parameters()
285
+
286
+ def _init_parameters(self):
287
+ nn.init.xavier_uniform_(self.weight)
288
+ if self.bias is not None:
289
+ nn.init.constant_(self.bias, 0.0)
290
+
291
+ def forward(self, input):
292
+ # batch, n_filters_spat, n_windows
293
+ B, C, T = input.size()
294
+
295
+ H = self.heads
296
+
297
+ weight = self.weight
298
+ if self.weight_softmax:
299
+ weight = F.softmax(weight, dim=-1)
300
+ # shape: (heads, 1, kernel_size)
301
+
302
+ # reshape input so each head is its own “batch”
303
+ # original C = H * (C/H), so view to (B * (C/H), H, T) then transpose
304
+ # but since C/H == 1 here per head-channel grouping, .view(-1, H, T) works
305
+ # new shape: (B * channels_per_head, H, T)
306
+ input = input.view(-1, H, T)
307
+ output = F.conv1d(input, weight, padding=self.padding, groups=self.heads)
308
+ # 4, 8, 1
309
+ output = output.view(B, C, -1)
310
+ # 1, 32, 1
311
+ if self.bias is not None:
312
+ # Add bias if it exists
313
+ output = output + self.bias.view(1, -1, 1)
314
+ # final shape: batch, n_filters_spat
315
+ return output
@@ -0,0 +1,338 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import Any, Sequence
4
+
5
+ import torch
6
+ from einops.layers.torch import Rearrange
7
+ from mne.utils import warn
8
+ from torch import nn
9
+
10
+ from braindecode.models.base import EEGModuleMixin
11
+ from braindecode.models.fbcnet import _valid_layers
12
+ from braindecode.modules import (
13
+ Conv2dWithConstraint,
14
+ FilterBankLayer,
15
+ LinearWithConstraint,
16
+ )
17
+
18
+
19
+ class FBMSNet(EEGModuleMixin, nn.Module):
20
+ r"""FBMSNet from Liu et al (2022) [fbmsnet]_.
21
+
22
+ :bdg-success:`Convolution` :bdg-primary:`Filterbank`
23
+
24
+ .. figure:: https://raw.githubusercontent.com/Want2Vanish/FBMSNet/refs/heads/main/FBMSNet.png
25
+ :align: center
26
+ :alt: FBMSNet Architecture
27
+
28
+ 0. **FilterBank Layer**: Applying filterbank to transform the input.
29
+
30
+ 1. **Temporal Convolution Block**: Utilizes mixed depthwise convolution
31
+ (MixConv) to extract multiscale temporal features from multiview EEG
32
+ representations. The input is split into groups corresponding to different
33
+ views each convolved with kernels of varying sizes.
34
+ Kernel sizes are set relative to the EEG
35
+ sampling rate, with ratio coefficients [0.5, 0.25, 0.125, 0.0625],
36
+ dividing the input into four groups.
37
+
38
+ 2. **Spatial Convolution Block**: Applies depthwise convolution with a kernel
39
+ size of (n_chans, 1) to span all EEG channels, effectively learning spatial
40
+ filters. This is followed by batch normalization and the Swish activation
41
+ function. A maximum norm constraint of 2 is imposed on the convolution
42
+ weights to regularize the model.
43
+
44
+ 3. **Temporal Log-Variance Block**: Computes the log-variance.
45
+
46
+ 4. **Classification Layer**: A fully connected with weight constraint.
47
+
48
+ Notes
49
+ -----
50
+ This implementation is not guaranteed to be correct and has not been checked
51
+ by the original authors; it has only been reimplemented from the paper
52
+ description and source code [fbmsnetcode]_. There is an extra layer here to
53
+ compute the filterbank during bash time and not on data time. This avoids
54
+ data-leak, and allows the model to follow the braindecode convention.
55
+
56
+ Parameters
57
+ ----------
58
+ n_bands : int, default=9
59
+ Number of input channels (e.g., number of frequency bands).
60
+ n_filters_spat : int, default=36
61
+ Number of output channels from the MixedConv2d layer.
62
+ temporal_layer : str, default='LogVarLayer'
63
+ Temporal aggregation layer to use.
64
+ n_dim: int, default=3
65
+ Dimension of the temporal reduction layer.
66
+ stride_factor : int, default=4
67
+ Stride factor for temporal segmentation.
68
+ dilatability : int, default=8
69
+ Expansion factor for the spatial convolution block.
70
+ activation : nn.Module, default=nn.SiLU
71
+ Activation function class to apply.
72
+ kernels_weights : Sequence[int], default=(15, 31, 63, 125)
73
+ Kernel sizes for the MixedConv2d layer.
74
+ cnn_max_norm : float, default=2
75
+ Maximum norm constraint for the convolutional layers.
76
+ linear_max_norm : float, default=0.5
77
+ Maximum norm constraint for the linear layers.
78
+ filter_parameters : dict, default=None
79
+ Dictionary of parameters to use for the FilterBankLayer.
80
+ If None, a default Chebyshev Type II filter with transition bandwidth of
81
+ 2 Hz and stop-band ripple of 30 dB will be used.
82
+ verbose: bool, default False
83
+ Verbose parameter to create the filter using mne.
84
+
85
+ References
86
+ ----------
87
+ .. [fbmsnet] Liu, K., Yang, M., Yu, Z., Wang, G., & Wu, W. (2022).
88
+ FBMSNet: A filter-bank multi-scale convolutional neural network for
89
+ EEG-based motor imagery decoding. IEEE Transactions on Biomedical
90
+ Engineering, 70(2), 436-445.
91
+ .. [fbmsnetcode] Liu, K., Yang, M., Yu, Z., Wang, G., & Wu, W. (2022).
92
+ FBMSNet: A filter-bank multi-scale convolutional neural network for
93
+ EEG-based motor imagery decoding.
94
+ https://github.com/Want2Vanish/FBMSNet
95
+ """
96
+
97
+ def __init__(
98
+ self,
99
+ # Braindecode parameters
100
+ n_chans=None,
101
+ n_outputs=None,
102
+ chs_info=None,
103
+ n_times=None,
104
+ input_window_seconds=None,
105
+ sfreq=None,
106
+ # models parameters
107
+ n_bands: int = 9,
108
+ n_filters_spat: int = 36,
109
+ temporal_layer: str = "LogVarLayer",
110
+ n_dim: int = 3,
111
+ stride_factor: int = 4,
112
+ dilatability: int = 8,
113
+ activation: type[nn.Module] = nn.SiLU,
114
+ kernels_weights: Sequence[int] = (15, 31, 63, 125),
115
+ cnn_max_norm: float = 2,
116
+ linear_max_norm: float = 0.5,
117
+ verbose: bool = False,
118
+ filter_parameters: dict[Any, Any] | None = None,
119
+ ):
120
+ super().__init__(
121
+ n_chans=n_chans,
122
+ n_outputs=n_outputs,
123
+ chs_info=chs_info,
124
+ n_times=n_times,
125
+ input_window_seconds=input_window_seconds,
126
+ sfreq=sfreq,
127
+ )
128
+ del n_outputs, n_chans, chs_info, n_times, input_window_seconds, sfreq
129
+
130
+ # Parameters
131
+ self.n_bands = n_bands
132
+ self.n_filters_spat = n_filters_spat
133
+ self.n_dim = n_dim
134
+ self.stride_factor = stride_factor
135
+ self.activation = activation
136
+ self.dilatability = dilatability
137
+ self.kernels_weights = kernels_weights
138
+ self.filter_parameters = filter_parameters or {}
139
+ self.out_channels_spatial = self.n_filters_spat * self.dilatability
140
+
141
+ # Checkers
142
+ if temporal_layer not in _valid_layers:
143
+ raise NotImplementedError(
144
+ f"Temporal layer '{temporal_layer}' is not implemented."
145
+ )
146
+
147
+ if self.n_times % self.stride_factor != 0:
148
+ warn(
149
+ f"Time dimension ({self.n_times}) is not divisible by"
150
+ f" stride_factor ({self.stride_factor}). Input will be padded.",
151
+ UserWarning,
152
+ )
153
+
154
+ # Layers
155
+ # Following paper nomeclature
156
+ self.spectral_filtering = FilterBankLayer(
157
+ n_chans=self.n_chans,
158
+ sfreq=self.sfreq,
159
+ band_filters=self.n_bands,
160
+ verbose=verbose,
161
+ **self.filter_parameters,
162
+ )
163
+ # As we have an internal process to create the bands,
164
+ # we get the values from the filterbank
165
+ self.n_bands = self.spectral_filtering.n_bands
166
+
167
+ # MixedConv2d Layer
168
+ self.mix_conv = nn.Sequential(
169
+ _MixedConv2d(
170
+ in_channels=self.n_bands,
171
+ out_channels=self.n_filters_spat,
172
+ stride=1,
173
+ dilation=1,
174
+ depthwise=False,
175
+ kernels_weights=kernels_weights,
176
+ ),
177
+ nn.BatchNorm2d(self.n_filters_spat),
178
+ )
179
+
180
+ # Spatial Convolution Block (SCB)
181
+ self.spatial_conv = nn.Sequential(
182
+ Conv2dWithConstraint(
183
+ in_channels=self.n_filters_spat,
184
+ out_channels=self.out_channels_spatial,
185
+ kernel_size=(self.n_chans, 1),
186
+ groups=self.n_filters_spat,
187
+ max_norm=cnn_max_norm,
188
+ padding=0,
189
+ ),
190
+ nn.BatchNorm2d(self.out_channels_spatial),
191
+ self.activation(),
192
+ )
193
+
194
+ # Padding layer
195
+ if self.n_times % self.stride_factor != 0:
196
+ self.padding_size = stride_factor - (self.n_times % stride_factor)
197
+ self.n_times_padded = self.n_times + self.padding_size
198
+ self.padding_layer = nn.ConstantPad1d((0, self.padding_size), 0.0)
199
+ else:
200
+ self.padding_layer = nn.Identity()
201
+ self.n_times_padded = self.n_times
202
+
203
+ # Temporal Aggregation Layer
204
+ self.temporal_layer = _valid_layers[temporal_layer](dim=self.n_dim) # type: ignore
205
+
206
+ self.flatten_layer = Rearrange("batch ... -> batch (...)")
207
+
208
+ # Final fully connected layer
209
+ self.final_layer = LinearWithConstraint(
210
+ in_features=self.out_channels_spatial * self.stride_factor,
211
+ out_features=self.n_outputs,
212
+ max_norm=linear_max_norm,
213
+ )
214
+
215
+ def forward(self, x):
216
+ """
217
+ Forward pass of the FBMSNet model.
218
+
219
+ Parameters
220
+ ----------
221
+ x : torch.Tensor
222
+ Input tensor with shape (batch_size, n_chans, n_times).
223
+
224
+ Returns
225
+ -------
226
+ torch.Tensor
227
+ Output tensor with shape (batch_size, n_outputs).
228
+ """
229
+ batch, _, _ = x.shape
230
+
231
+ # shape: (batch, n_chans, n_times)
232
+ x = self.spectral_filtering(x)
233
+ # shape: (batch, n_bands, n_chans, n_times)
234
+
235
+ # Mixed convolution
236
+ x = self.mix_conv(x)
237
+ # shape: (batch, self.n_filters_spat, n_chans, n_times)
238
+
239
+ # Spatial convolution block
240
+ x = self.spatial_conv(x)
241
+ # shape: (batch, self.out_channels_spatial, 1, n_times)
242
+
243
+ # Apply some padding to the input to make it divisible by the stride factor
244
+ x = self.padding_layer(x)
245
+ # shape: (batch, self.out_channels_spatial, 1, n_times_padded)
246
+
247
+ # Reshape for temporal layer
248
+ x = x.view(batch, self.out_channels_spatial, self.stride_factor, -1)
249
+ # shape: (batch, self.out_channels_spatial, self.stride_factor, n_times/self.stride_factor)
250
+
251
+ # Temporal aggregation
252
+ x = self.temporal_layer(x)
253
+ # shape: (batch, self.out_channels_spatial, self.stride_factor, 1)
254
+
255
+ # Flatten and classify
256
+ x = self.flatten_layer(x)
257
+ # shape: (batch, self.out_channels_spatial*self.stride_factor)
258
+
259
+ x = self.final_layer(x)
260
+ # shape: (batch, n_outputs)
261
+ return x
262
+
263
+
264
+ class _MixedConv2d(nn.Module):
265
+ r"""Mixed Grouped Convolution for multiscale feature extraction."""
266
+
267
+ def __init__(
268
+ self,
269
+ in_channels,
270
+ out_channels,
271
+ kernels_weights=(15, 31, 63, 125),
272
+ stride=1,
273
+ dilation=1,
274
+ depthwise=False,
275
+ ):
276
+ super().__init__()
277
+
278
+ num_groups = len(kernels_weights)
279
+ in_splits = self._split_channels(in_channels, num_groups)
280
+ out_splits = self._split_channels(out_channels, num_groups)
281
+ self.splits = in_splits
282
+
283
+ self.convs = nn.ModuleList()
284
+ # Create a convolutional layer for each kernel size
285
+ for k, in_ch, out_ch in zip(kernels_weights, in_splits, out_splits):
286
+ conv_groups = out_ch if depthwise else 1
287
+ conv = nn.Conv2d(
288
+ in_channels=in_ch,
289
+ out_channels=out_ch,
290
+ kernel_size=(1, k),
291
+ stride=stride,
292
+ padding="same",
293
+ dilation=dilation,
294
+ groups=conv_groups,
295
+ bias=False,
296
+ )
297
+ self.convs.append(conv)
298
+
299
+ @staticmethod
300
+ def _split_channels(num_chan, num_groups):
301
+ """
302
+ Splits the total number of channels into a specified
303
+ number of groups as evenly as possible.
304
+
305
+ Parameters
306
+ ----------
307
+ num_chan : int
308
+ The total number of channels to split.
309
+ num_groups : int
310
+ The number of groups to split the channels into.
311
+
312
+ Returns
313
+ -------
314
+ list of int
315
+ A list containing the number of channels in each group.
316
+ The first group may have more channels if the division is not even.
317
+ """
318
+ split = [num_chan // num_groups for _ in range(num_groups)]
319
+ split[0] += num_chan - sum(split)
320
+ return split
321
+
322
+ def forward(self, x):
323
+ # Split the input tensor `x` along the channel dimension (dim=1) into groups.
324
+ # The size of each group is defined by `self.splits`, which is calculated
325
+ # based on the number of input channels and the number of kernel sizes.
326
+ x_split = torch.split(x, self.splits, 1)
327
+
328
+ # For each split group, apply the corresponding convolutional layer.
329
+ # `self.values()` returns the convolutional layers in the order they were added.
330
+ # The result is a list of output tensors, one for each group.
331
+ x_out = [conv(x_split[i]) for i, conv in enumerate(self.convs)]
332
+
333
+ # Concatenate the outputs from all groups along the channel dimension (dim=1)
334
+ # to form a single output tensor.
335
+ x = torch.cat(x_out, 1)
336
+
337
+ # Return the concatenated tensor as the output of the mixed convolution.
338
+ return x