braindecode 1.3.0.dev177069446__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- braindecode/__init__.py +9 -0
- braindecode/augmentation/__init__.py +52 -0
- braindecode/augmentation/base.py +225 -0
- braindecode/augmentation/functional.py +1300 -0
- braindecode/augmentation/transforms.py +1356 -0
- braindecode/classifier.py +258 -0
- braindecode/datasets/__init__.py +44 -0
- braindecode/datasets/base.py +823 -0
- braindecode/datasets/bbci.py +693 -0
- braindecode/datasets/bcicomp.py +193 -0
- braindecode/datasets/bids/__init__.py +54 -0
- braindecode/datasets/bids/datasets.py +239 -0
- braindecode/datasets/bids/format.py +717 -0
- braindecode/datasets/bids/hub.py +987 -0
- braindecode/datasets/bids/hub_format.py +717 -0
- braindecode/datasets/bids/hub_io.py +197 -0
- braindecode/datasets/bids/hub_validation.py +114 -0
- braindecode/datasets/bids/iterable.py +220 -0
- braindecode/datasets/chb_mit.py +163 -0
- braindecode/datasets/mne.py +170 -0
- braindecode/datasets/moabb.py +219 -0
- braindecode/datasets/nmt.py +313 -0
- braindecode/datasets/registry.py +120 -0
- braindecode/datasets/siena.py +162 -0
- braindecode/datasets/sleep_physio_challe_18.py +411 -0
- braindecode/datasets/sleep_physionet.py +125 -0
- braindecode/datasets/tuh.py +591 -0
- braindecode/datasets/utils.py +67 -0
- braindecode/datasets/xy.py +96 -0
- braindecode/datautil/__init__.py +62 -0
- braindecode/datautil/channel_utils.py +114 -0
- braindecode/datautil/hub_formats.py +180 -0
- braindecode/datautil/serialization.py +359 -0
- braindecode/datautil/util.py +154 -0
- braindecode/eegneuralnet.py +372 -0
- braindecode/functional/__init__.py +22 -0
- braindecode/functional/functions.py +251 -0
- braindecode/functional/initialization.py +47 -0
- braindecode/models/__init__.py +117 -0
- braindecode/models/atcnet.py +830 -0
- braindecode/models/attentionbasenet.py +727 -0
- braindecode/models/attn_sleep.py +549 -0
- braindecode/models/base.py +574 -0
- braindecode/models/bendr.py +493 -0
- braindecode/models/biot.py +537 -0
- braindecode/models/brainmodule.py +845 -0
- braindecode/models/config.py +233 -0
- braindecode/models/contrawr.py +319 -0
- braindecode/models/ctnet.py +541 -0
- braindecode/models/deep4.py +376 -0
- braindecode/models/deepsleepnet.py +417 -0
- braindecode/models/eegconformer.py +475 -0
- braindecode/models/eeginception_erp.py +379 -0
- braindecode/models/eeginception_mi.py +379 -0
- braindecode/models/eegitnet.py +302 -0
- braindecode/models/eegminer.py +256 -0
- braindecode/models/eegnet.py +359 -0
- braindecode/models/eegnex.py +354 -0
- braindecode/models/eegsimpleconv.py +201 -0
- braindecode/models/eegsym.py +917 -0
- braindecode/models/eegtcnet.py +337 -0
- braindecode/models/fbcnet.py +225 -0
- braindecode/models/fblightconvnet.py +315 -0
- braindecode/models/fbmsnet.py +338 -0
- braindecode/models/hybrid.py +126 -0
- braindecode/models/ifnet.py +443 -0
- braindecode/models/labram.py +1316 -0
- braindecode/models/luna.py +891 -0
- braindecode/models/medformer.py +760 -0
- braindecode/models/msvtnet.py +377 -0
- braindecode/models/patchedtransformer.py +640 -0
- braindecode/models/reve.py +843 -0
- braindecode/models/sccnet.py +280 -0
- braindecode/models/shallow_fbcsp.py +212 -0
- braindecode/models/signal_jepa.py +1122 -0
- braindecode/models/sinc_shallow.py +339 -0
- braindecode/models/sleep_stager_blanco_2020.py +169 -0
- braindecode/models/sleep_stager_chambon_2018.py +159 -0
- braindecode/models/sparcnet.py +426 -0
- braindecode/models/sstdpn.py +869 -0
- braindecode/models/summary.csv +47 -0
- braindecode/models/syncnet.py +234 -0
- braindecode/models/tcn.py +275 -0
- braindecode/models/tidnet.py +397 -0
- braindecode/models/tsinception.py +295 -0
- braindecode/models/usleep.py +439 -0
- braindecode/models/util.py +369 -0
- braindecode/modules/__init__.py +92 -0
- braindecode/modules/activation.py +86 -0
- braindecode/modules/attention.py +883 -0
- braindecode/modules/blocks.py +160 -0
- braindecode/modules/convolution.py +330 -0
- braindecode/modules/filter.py +654 -0
- braindecode/modules/layers.py +216 -0
- braindecode/modules/linear.py +70 -0
- braindecode/modules/parametrization.py +38 -0
- braindecode/modules/stats.py +87 -0
- braindecode/modules/util.py +85 -0
- braindecode/modules/wrapper.py +90 -0
- braindecode/preprocessing/__init__.py +271 -0
- braindecode/preprocessing/eegprep_preprocess.py +1317 -0
- braindecode/preprocessing/mne_preprocess.py +240 -0
- braindecode/preprocessing/preprocess.py +579 -0
- braindecode/preprocessing/util.py +177 -0
- braindecode/preprocessing/windowers.py +1037 -0
- braindecode/regressor.py +234 -0
- braindecode/samplers/__init__.py +18 -0
- braindecode/samplers/base.py +399 -0
- braindecode/samplers/ssl.py +263 -0
- braindecode/training/__init__.py +23 -0
- braindecode/training/callbacks.py +23 -0
- braindecode/training/losses.py +105 -0
- braindecode/training/scoring.py +477 -0
- braindecode/util.py +419 -0
- braindecode/version.py +1 -0
- braindecode/visualization/__init__.py +8 -0
- braindecode/visualization/confusion_matrices.py +289 -0
- braindecode/visualization/gradients.py +62 -0
- braindecode-1.3.0.dev177069446.dist-info/METADATA +230 -0
- braindecode-1.3.0.dev177069446.dist-info/RECORD +124 -0
- braindecode-1.3.0.dev177069446.dist-info/WHEEL +5 -0
- braindecode-1.3.0.dev177069446.dist-info/licenses/LICENSE.txt +31 -0
- braindecode-1.3.0.dev177069446.dist-info/licenses/NOTICE.txt +20 -0
- braindecode-1.3.0.dev177069446.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,337 @@
|
|
|
1
|
+
# Authors: Bruno Aristimunha <b.aristimunha@gmail.com>
|
|
2
|
+
#
|
|
3
|
+
# License: BSD (3-clause)
|
|
4
|
+
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
|
|
7
|
+
import torch
|
|
8
|
+
import torch.nn as nn
|
|
9
|
+
from einops.layers.torch import Rearrange
|
|
10
|
+
|
|
11
|
+
from braindecode.models.base import EEGModuleMixin
|
|
12
|
+
from braindecode.modules import Chomp1d, MaxNormLinear
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class EEGTCNet(EEGModuleMixin, nn.Module):
|
|
16
|
+
r"""EEGTCNet model from Ingolfsson et al (2020) [ingolfsson2020]_.
|
|
17
|
+
|
|
18
|
+
:bdg-success:`Convolution` :bdg-secondary:`Recurrent`
|
|
19
|
+
|
|
20
|
+
.. figure:: https://braindecode.org/dev/_static/model/eegtcnet.jpg
|
|
21
|
+
:align: center
|
|
22
|
+
:alt: EEGTCNet Architecture
|
|
23
|
+
|
|
24
|
+
Combining EEGNet and TCN blocks.
|
|
25
|
+
|
|
26
|
+
Parameters
|
|
27
|
+
----------
|
|
28
|
+
activation : nn.Module, optional
|
|
29
|
+
Activation function to use. Default is `nn.ELU()`.
|
|
30
|
+
depth_multiplier : int, optional
|
|
31
|
+
Depth multiplier for the depthwise convolution. Default is 2.
|
|
32
|
+
filter_1 : int, optional
|
|
33
|
+
Number of temporal filters in the first convolutional layer. Default is 8.
|
|
34
|
+
kern_length : int, optional
|
|
35
|
+
Length of the temporal kernel in the first convolutional layer. Default is 64.
|
|
36
|
+
dropout : float, optional
|
|
37
|
+
Dropout rate. Default is 0.5.
|
|
38
|
+
depth : int, optional
|
|
39
|
+
Number of residual blocks in the TCN. Default is 2.
|
|
40
|
+
kernel_size : int, optional
|
|
41
|
+
Size of the temporal convolutional kernel in the TCN. Default is 4.
|
|
42
|
+
filters : int, optional
|
|
43
|
+
Number of filters in the TCN convolutional layers. Default is 12.
|
|
44
|
+
max_norm_const : float
|
|
45
|
+
Maximum L2-norm constraint imposed on weights of the last
|
|
46
|
+
fully-connected layer. Defaults to 0.25.
|
|
47
|
+
|
|
48
|
+
References
|
|
49
|
+
----------
|
|
50
|
+
.. [ingolfsson2020] Ingolfsson, T. M., Hersche, M., Wang, X., Kobayashi, N.,
|
|
51
|
+
Cavigelli, L., & Benini, L. (2020). EEG-TCNet: An accurate temporal
|
|
52
|
+
convolutional network for embedded motor-imagery brain–machine interfaces.
|
|
53
|
+
https://doi.org/10.48550/arXiv.2006.00622
|
|
54
|
+
"""
|
|
55
|
+
|
|
56
|
+
def __init__(
|
|
57
|
+
self,
|
|
58
|
+
# Signal related parameters
|
|
59
|
+
n_chans=None,
|
|
60
|
+
n_outputs=None,
|
|
61
|
+
n_times=None,
|
|
62
|
+
chs_info=None,
|
|
63
|
+
input_window_seconds=None,
|
|
64
|
+
sfreq=None,
|
|
65
|
+
# Model parameters
|
|
66
|
+
activation: type[nn.Module] = nn.ELU,
|
|
67
|
+
depth_multiplier: int = 2,
|
|
68
|
+
filter_1: int = 8,
|
|
69
|
+
kern_length: int = 64,
|
|
70
|
+
drop_prob: float = 0.5,
|
|
71
|
+
depth: int = 2,
|
|
72
|
+
kernel_size: int = 4,
|
|
73
|
+
filters: int = 12,
|
|
74
|
+
max_norm_const: float = 0.25,
|
|
75
|
+
):
|
|
76
|
+
super().__init__(
|
|
77
|
+
n_outputs=n_outputs,
|
|
78
|
+
n_chans=n_chans,
|
|
79
|
+
chs_info=chs_info,
|
|
80
|
+
n_times=n_times,
|
|
81
|
+
input_window_seconds=input_window_seconds,
|
|
82
|
+
sfreq=sfreq,
|
|
83
|
+
)
|
|
84
|
+
del n_outputs, n_chans, chs_info, n_times, input_window_seconds, sfreq
|
|
85
|
+
|
|
86
|
+
self.activation = activation
|
|
87
|
+
self.drop_prob = drop_prob
|
|
88
|
+
self.depth_multiplier = depth_multiplier
|
|
89
|
+
self.filter_1 = filter_1
|
|
90
|
+
self.kern_length = kern_length
|
|
91
|
+
self.depth = depth
|
|
92
|
+
self.kernel_size = kernel_size
|
|
93
|
+
self.filters = filters
|
|
94
|
+
self.max_norm_const = max_norm_const
|
|
95
|
+
self.filter_2 = self.filter_1 * self.depth_multiplier
|
|
96
|
+
|
|
97
|
+
self.arrange_dim_input = Rearrange(
|
|
98
|
+
"batch nchans ntimes -> batch 1 ntimes nchans"
|
|
99
|
+
)
|
|
100
|
+
# EEGNet_TC Block
|
|
101
|
+
self.eegnet_tc = _EEGNetTC(
|
|
102
|
+
n_chans=self.n_chans,
|
|
103
|
+
filter_1=self.filter_1,
|
|
104
|
+
kern_length=self.kern_length,
|
|
105
|
+
depth_multiplier=self.depth_multiplier,
|
|
106
|
+
drop_prob=self.drop_prob,
|
|
107
|
+
activation=self.activation,
|
|
108
|
+
)
|
|
109
|
+
self.arrange_dim_eegnet = Rearrange(
|
|
110
|
+
"batch filter2 rtimes 1 -> batch rtimes filter2"
|
|
111
|
+
)
|
|
112
|
+
|
|
113
|
+
# TCN Block
|
|
114
|
+
self.tcn_block = _TCNBlock(
|
|
115
|
+
input_dimension=self.filter_2,
|
|
116
|
+
depth=self.depth,
|
|
117
|
+
kernel_size=self.kernel_size,
|
|
118
|
+
filters=self.filters,
|
|
119
|
+
drop_prob=self.drop_prob,
|
|
120
|
+
activation=self.activation,
|
|
121
|
+
)
|
|
122
|
+
|
|
123
|
+
# Classification Block
|
|
124
|
+
self.final_layer = MaxNormLinear(
|
|
125
|
+
in_features=self.filters,
|
|
126
|
+
out_features=self.n_outputs,
|
|
127
|
+
max_norm_val=self.max_norm_const,
|
|
128
|
+
)
|
|
129
|
+
|
|
130
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
131
|
+
"""
|
|
132
|
+
Forward pass of the EEGTCNet model.
|
|
133
|
+
|
|
134
|
+
Parameters
|
|
135
|
+
----------
|
|
136
|
+
x : torch.Tensor
|
|
137
|
+
Input tensor of shape (batch_size, n_chans, n_times).
|
|
138
|
+
|
|
139
|
+
Returns
|
|
140
|
+
-------
|
|
141
|
+
torch.Tensor
|
|
142
|
+
Output tensor of shape (batch_size, n_outputs).
|
|
143
|
+
"""
|
|
144
|
+
# x shape: (batch_size, n_chans, n_times)
|
|
145
|
+
x = self.arrange_dim_input(x) # (batch_size, 1, n_times, n_chans)
|
|
146
|
+
x = self.eegnet_tc(x) # (batch_size, filter, reduced_time, 1)
|
|
147
|
+
|
|
148
|
+
x = self.arrange_dim_eegnet(x) # (batch_size, reduced_time, F2)
|
|
149
|
+
x = self.tcn_block(x) # (batch_size, time_steps, filters)
|
|
150
|
+
|
|
151
|
+
# Select the last time step
|
|
152
|
+
x = x[:, -1, :] # (batch_size, filters)
|
|
153
|
+
|
|
154
|
+
x = self.final_layer(x) # (batch_size, n_outputs)
|
|
155
|
+
|
|
156
|
+
return x
|
|
157
|
+
|
|
158
|
+
|
|
159
|
+
class _EEGNetTC(nn.Module):
|
|
160
|
+
r"""EEGNet Temporal Convolutional Network (TCN) block.
|
|
161
|
+
|
|
162
|
+
The main difference from our :class:`EEGNet` (braindecode) implementation is the
|
|
163
|
+
kernel and dimensional order. Because of this, we decided to keep this
|
|
164
|
+
implementation in a future issue; we will re-evaluate if it is necessary
|
|
165
|
+
to maintain this separate implementation.
|
|
166
|
+
|
|
167
|
+
Parameters
|
|
168
|
+
----------
|
|
169
|
+
n_chans : int
|
|
170
|
+
Number of EEG channels.
|
|
171
|
+
filter_1 : int
|
|
172
|
+
Number of temporal filters in the first convolutional layer.
|
|
173
|
+
kern_length : int
|
|
174
|
+
Length of the temporal kernel in the first convolutional layer.
|
|
175
|
+
depth_multiplier : int
|
|
176
|
+
Depth multiplier for the depthwise convolution.
|
|
177
|
+
drop_prob : float
|
|
178
|
+
Dropout rate.
|
|
179
|
+
activation : nn.Module
|
|
180
|
+
Activation function.
|
|
181
|
+
"""
|
|
182
|
+
|
|
183
|
+
def __init__(
|
|
184
|
+
self,
|
|
185
|
+
n_chans: int,
|
|
186
|
+
filter_1: int = 8,
|
|
187
|
+
kern_length: int = 64,
|
|
188
|
+
depth_multiplier: int = 2,
|
|
189
|
+
drop_prob: float = 0.5,
|
|
190
|
+
activation: type[nn.Module] = nn.ELU,
|
|
191
|
+
):
|
|
192
|
+
super().__init__()
|
|
193
|
+
self.activation = activation()
|
|
194
|
+
self.drop_prob = drop_prob
|
|
195
|
+
self.n_chans = n_chans
|
|
196
|
+
self.filter_1 = filter_1
|
|
197
|
+
self.filter_2 = self.filter_1 * depth_multiplier
|
|
198
|
+
|
|
199
|
+
# First Conv2D Layer
|
|
200
|
+
self.conv1 = nn.Conv2d(
|
|
201
|
+
in_channels=1,
|
|
202
|
+
out_channels=self.filter_1,
|
|
203
|
+
kernel_size=(kern_length, 1),
|
|
204
|
+
padding=(kern_length // 2, 0),
|
|
205
|
+
bias=False,
|
|
206
|
+
)
|
|
207
|
+
self.bn1 = nn.BatchNorm2d(self.filter_1)
|
|
208
|
+
|
|
209
|
+
# Depthwise Convolution
|
|
210
|
+
self.depthwise_conv = nn.Conv2d(
|
|
211
|
+
in_channels=self.filter_1,
|
|
212
|
+
out_channels=self.filter_2,
|
|
213
|
+
kernel_size=(1, n_chans),
|
|
214
|
+
groups=self.filter_1,
|
|
215
|
+
bias=False,
|
|
216
|
+
)
|
|
217
|
+
self.bn2 = nn.BatchNorm2d(self.filter_2)
|
|
218
|
+
self.pool1 = nn.AvgPool2d(kernel_size=(8, 1))
|
|
219
|
+
self.drop1 = nn.Dropout(p=drop_prob)
|
|
220
|
+
|
|
221
|
+
# Separable Convolution (Depthwise + Pointwise)
|
|
222
|
+
self.separable_conv_depthwise = nn.Conv2d(
|
|
223
|
+
in_channels=self.filter_2,
|
|
224
|
+
out_channels=self.filter_2,
|
|
225
|
+
kernel_size=(self.filter_2, 1),
|
|
226
|
+
groups=self.filter_2,
|
|
227
|
+
padding=(self.filter_2 // 2, 0),
|
|
228
|
+
bias=False,
|
|
229
|
+
)
|
|
230
|
+
self.separable_conv_pointwise = nn.Conv2d(
|
|
231
|
+
in_channels=self.filter_2,
|
|
232
|
+
out_channels=self.filter_2,
|
|
233
|
+
kernel_size=(1, 1),
|
|
234
|
+
bias=False,
|
|
235
|
+
)
|
|
236
|
+
self.bn3 = nn.BatchNorm2d(self.filter_2)
|
|
237
|
+
self.pool2 = nn.AvgPool2d(kernel_size=(self.filter_1, 1))
|
|
238
|
+
self.drop2 = nn.Dropout(p=drop_prob)
|
|
239
|
+
|
|
240
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
241
|
+
# x shape: (batch_size, 1, n_times, n_chans)
|
|
242
|
+
x = self.conv1(x)
|
|
243
|
+
x = self.bn1(x)
|
|
244
|
+
x = self.activation(x)
|
|
245
|
+
|
|
246
|
+
x = self.depthwise_conv(x)
|
|
247
|
+
x = self.bn2(x)
|
|
248
|
+
x = self.activation(x)
|
|
249
|
+
x = self.pool1(x)
|
|
250
|
+
x = self.drop1(x)
|
|
251
|
+
|
|
252
|
+
x = self.separable_conv_depthwise(x)
|
|
253
|
+
x = self.separable_conv_pointwise(x)
|
|
254
|
+
x = self.bn3(x)
|
|
255
|
+
x = self.activation(x)
|
|
256
|
+
x = self.pool2(x)
|
|
257
|
+
x = self.drop2(x)
|
|
258
|
+
|
|
259
|
+
return x # Shape: (batch_size, F2, reduced_time, 1)
|
|
260
|
+
|
|
261
|
+
|
|
262
|
+
class _TCNBlock(nn.Module):
|
|
263
|
+
r"""
|
|
264
|
+
Many differences from our Temporal Block (braindecode) implementation.
|
|
265
|
+
Because of this, we decided to keep this implementation in a future issue;
|
|
266
|
+
we will re-evaluate if it is necessary to maintain this separate
|
|
267
|
+
implementation.
|
|
268
|
+
|
|
269
|
+
|
|
270
|
+
"""
|
|
271
|
+
|
|
272
|
+
def __init__(
|
|
273
|
+
self,
|
|
274
|
+
input_dimension: int,
|
|
275
|
+
depth: int,
|
|
276
|
+
kernel_size: int,
|
|
277
|
+
filters: int,
|
|
278
|
+
drop_prob: float,
|
|
279
|
+
activation: type[nn.Module] = nn.ELU,
|
|
280
|
+
):
|
|
281
|
+
super().__init__()
|
|
282
|
+
self.activation = activation()
|
|
283
|
+
self.drop_prob = drop_prob
|
|
284
|
+
self.depth = depth
|
|
285
|
+
self.filters = filters
|
|
286
|
+
self.kernel_size = kernel_size
|
|
287
|
+
|
|
288
|
+
self.layers = nn.ModuleList()
|
|
289
|
+
self.downsample = (
|
|
290
|
+
nn.Conv1d(input_dimension, filters, kernel_size=1, bias=False)
|
|
291
|
+
if input_dimension != filters
|
|
292
|
+
else None
|
|
293
|
+
)
|
|
294
|
+
|
|
295
|
+
for i in range(depth):
|
|
296
|
+
dilation = 2**i
|
|
297
|
+
padding = (kernel_size - 1) * dilation
|
|
298
|
+
conv_block = nn.Sequential(
|
|
299
|
+
nn.Conv1d(
|
|
300
|
+
in_channels=input_dimension if i == 0 else filters,
|
|
301
|
+
out_channels=filters,
|
|
302
|
+
kernel_size=kernel_size,
|
|
303
|
+
dilation=dilation,
|
|
304
|
+
padding=padding,
|
|
305
|
+
bias=False,
|
|
306
|
+
),
|
|
307
|
+
Chomp1d(padding),
|
|
308
|
+
self.activation,
|
|
309
|
+
nn.Dropout(self.drop_prob),
|
|
310
|
+
nn.Conv1d(
|
|
311
|
+
in_channels=filters,
|
|
312
|
+
out_channels=filters,
|
|
313
|
+
kernel_size=kernel_size,
|
|
314
|
+
dilation=dilation,
|
|
315
|
+
padding=padding,
|
|
316
|
+
bias=False,
|
|
317
|
+
),
|
|
318
|
+
Chomp1d(padding),
|
|
319
|
+
self.activation,
|
|
320
|
+
nn.Dropout(self.drop_prob),
|
|
321
|
+
)
|
|
322
|
+
self.layers.append(conv_block)
|
|
323
|
+
|
|
324
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
325
|
+
# x shape: (batch_size, time_steps, input_dimension)
|
|
326
|
+
x = x.permute(0, 2, 1) # (batch_size, input_dimension, time_steps)
|
|
327
|
+
|
|
328
|
+
res = x if self.downsample is None else self.downsample(x)
|
|
329
|
+
for layer in self.layers:
|
|
330
|
+
out = layer(x)
|
|
331
|
+
out = out + res
|
|
332
|
+
out = self.activation(out)
|
|
333
|
+
res = out # Update residual
|
|
334
|
+
x = out # Update input for next layer
|
|
335
|
+
|
|
336
|
+
out = out.permute(0, 2, 1) # (batch_size, time_steps, filters)
|
|
337
|
+
return out
|
|
@@ -0,0 +1,225 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import Any
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
from einops.layers.torch import Rearrange
|
|
7
|
+
from mne.utils import warn
|
|
8
|
+
from torch import nn
|
|
9
|
+
|
|
10
|
+
from braindecode.models.base import EEGModuleMixin
|
|
11
|
+
from braindecode.modules import (
|
|
12
|
+
Conv2dWithConstraint,
|
|
13
|
+
FilterBankLayer,
|
|
14
|
+
LinearWithConstraint,
|
|
15
|
+
LogVarLayer,
|
|
16
|
+
MaxLayer,
|
|
17
|
+
MeanLayer,
|
|
18
|
+
StdLayer,
|
|
19
|
+
VarLayer,
|
|
20
|
+
)
|
|
21
|
+
|
|
22
|
+
_valid_layers = {
|
|
23
|
+
"VarLayer": VarLayer,
|
|
24
|
+
"StdLayer": StdLayer,
|
|
25
|
+
"LogVarLayer": LogVarLayer,
|
|
26
|
+
"MeanLayer": MeanLayer,
|
|
27
|
+
"MaxLayer": MaxLayer,
|
|
28
|
+
}
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class FBCNet(EEGModuleMixin, nn.Module):
|
|
32
|
+
r"""FBCNet from Mane, R et al (2021) [fbcnet2021]_.
|
|
33
|
+
|
|
34
|
+
:bdg-success:`Convolution` :bdg-primary:`Filterbank`
|
|
35
|
+
|
|
36
|
+
.. figure:: https://raw.githubusercontent.com/ravikiran-mane/FBCNet/refs/heads/master/FBCNet-V2.png
|
|
37
|
+
:align: center
|
|
38
|
+
:alt: FBCNet Architecture
|
|
39
|
+
|
|
40
|
+
The FBCNet model applies spatial convolution and variance calculation along
|
|
41
|
+
the time axis, inspired by the Filter Bank Common Spatial Pattern (FBCSP)
|
|
42
|
+
algorithm.
|
|
43
|
+
|
|
44
|
+
Notes
|
|
45
|
+
-----
|
|
46
|
+
This implementation is not guaranteed to be correct and has not been checked
|
|
47
|
+
by the original authors; it has only been reimplemented from the paper
|
|
48
|
+
description and source code [fbcnetcode2021]_. There is a difference in the
|
|
49
|
+
activation function; in the paper, the ELU is used as the activation function,
|
|
50
|
+
but in the original code, SiLU is used. We followed the code.
|
|
51
|
+
|
|
52
|
+
Parameters
|
|
53
|
+
----------
|
|
54
|
+
n_bands : int or None or list[tuple[int, int]]], default=9
|
|
55
|
+
Number of frequency bands. Could
|
|
56
|
+
n_filters_spat : int, default=32
|
|
57
|
+
Number of spatial filters for the first convolution.
|
|
58
|
+
n_dim: int, default=3
|
|
59
|
+
Number of dimensions for the temporal reductor
|
|
60
|
+
temporal_layer : str, default='LogVarLayer'
|
|
61
|
+
Type of temporal aggregator layer. Options: 'VarLayer', 'StdLayer',
|
|
62
|
+
'LogVarLayer', 'MeanLayer', 'MaxLayer'.
|
|
63
|
+
stride_factor : int, default=4
|
|
64
|
+
Stride factor for reshaping.
|
|
65
|
+
activation : nn.Module, default=nn.SiLU
|
|
66
|
+
Activation function class to apply in Spatial Convolution Block.
|
|
67
|
+
cnn_max_norm : float, default=2.0
|
|
68
|
+
Maximum norm for the spatial convolution layer.
|
|
69
|
+
linear_max_norm : float, default=0.5
|
|
70
|
+
Maximum norm for the final linear layer.
|
|
71
|
+
filter_parameters: dict, default None
|
|
72
|
+
Dictionary of parameters to use for the FilterBankLayer.
|
|
73
|
+
If None, a default Chebyshev Type II filter with transition bandwidth of
|
|
74
|
+
2 Hz and stop-band ripple of 30 dB will be used.
|
|
75
|
+
|
|
76
|
+
References
|
|
77
|
+
----------
|
|
78
|
+
.. [fbcnet2021] Mane, R., Chew, E., Chua, K., Ang, K. K., Robinson, N.,
|
|
79
|
+
Vinod, A. P., ... & Guan, C. (2021). FBCNet: A multi-view convolutional
|
|
80
|
+
neural network for brain-computer interface. preprint arXiv:2104.01233.
|
|
81
|
+
.. [fbcnetcode2021] Link to source-code:
|
|
82
|
+
https://github.com/ravikiran-mane/FBCNet
|
|
83
|
+
"""
|
|
84
|
+
|
|
85
|
+
def __init__(
|
|
86
|
+
self,
|
|
87
|
+
# Braindecode parameters
|
|
88
|
+
n_chans=None,
|
|
89
|
+
n_outputs=None,
|
|
90
|
+
chs_info=None,
|
|
91
|
+
n_times=None,
|
|
92
|
+
input_window_seconds=None,
|
|
93
|
+
sfreq=None,
|
|
94
|
+
# models parameters
|
|
95
|
+
n_bands=9,
|
|
96
|
+
n_filters_spat: int = 32,
|
|
97
|
+
temporal_layer: str = "LogVarLayer",
|
|
98
|
+
n_dim: int = 3,
|
|
99
|
+
stride_factor: int = 4,
|
|
100
|
+
activation: type[nn.Module] = nn.SiLU,
|
|
101
|
+
linear_max_norm: float = 0.5,
|
|
102
|
+
cnn_max_norm: float = 2.0,
|
|
103
|
+
filter_parameters: dict[Any, Any] | None = None,
|
|
104
|
+
):
|
|
105
|
+
super().__init__(
|
|
106
|
+
n_chans=n_chans,
|
|
107
|
+
n_outputs=n_outputs,
|
|
108
|
+
chs_info=chs_info,
|
|
109
|
+
n_times=n_times,
|
|
110
|
+
input_window_seconds=input_window_seconds,
|
|
111
|
+
sfreq=sfreq,
|
|
112
|
+
)
|
|
113
|
+
del n_outputs, n_chans, chs_info, n_times, input_window_seconds, sfreq
|
|
114
|
+
|
|
115
|
+
# Parameters
|
|
116
|
+
self.n_bands = n_bands
|
|
117
|
+
self.n_filters_spat = n_filters_spat
|
|
118
|
+
self.n_dim = n_dim
|
|
119
|
+
self.stride_factor = stride_factor
|
|
120
|
+
self.activation = activation
|
|
121
|
+
self.filter_parameters = filter_parameters or {}
|
|
122
|
+
|
|
123
|
+
# Checkers
|
|
124
|
+
if temporal_layer not in _valid_layers:
|
|
125
|
+
raise NotImplementedError(
|
|
126
|
+
f"Temporal layer '{temporal_layer}' is not implemented."
|
|
127
|
+
)
|
|
128
|
+
|
|
129
|
+
if self.n_times % self.stride_factor != 0:
|
|
130
|
+
warn(
|
|
131
|
+
f"Time dimension ({self.n_times}) is not divisible by"
|
|
132
|
+
f" stride_factor ({self.stride_factor}). Input will be padded.",
|
|
133
|
+
UserWarning,
|
|
134
|
+
)
|
|
135
|
+
|
|
136
|
+
# Layers
|
|
137
|
+
# Following paper nomenclature
|
|
138
|
+
self.spectral_filtering = FilterBankLayer(
|
|
139
|
+
n_chans=self.n_chans,
|
|
140
|
+
sfreq=self.sfreq,
|
|
141
|
+
band_filters=self.n_bands,
|
|
142
|
+
verbose=False,
|
|
143
|
+
**self.filter_parameters,
|
|
144
|
+
)
|
|
145
|
+
# As we have an internal process to create the bands,
|
|
146
|
+
# we get the values from the filterbank
|
|
147
|
+
self.n_bands = self.spectral_filtering.n_bands
|
|
148
|
+
|
|
149
|
+
# Spatial Convolution Block (SCB)
|
|
150
|
+
self.spatial_conv = nn.Sequential(
|
|
151
|
+
Conv2dWithConstraint(
|
|
152
|
+
in_channels=self.n_bands,
|
|
153
|
+
out_channels=self.n_filters_spat * self.n_bands,
|
|
154
|
+
kernel_size=(self.n_chans, 1),
|
|
155
|
+
groups=self.n_bands,
|
|
156
|
+
max_norm=cnn_max_norm,
|
|
157
|
+
padding=0,
|
|
158
|
+
),
|
|
159
|
+
nn.BatchNorm2d(self.n_filters_spat * self.n_bands),
|
|
160
|
+
self.activation(),
|
|
161
|
+
)
|
|
162
|
+
|
|
163
|
+
# Padding layer
|
|
164
|
+
if self.n_times % self.stride_factor != 0:
|
|
165
|
+
self.padding_size = stride_factor - (self.n_times % stride_factor)
|
|
166
|
+
self.n_times_padded = self.n_times + self.padding_size
|
|
167
|
+
self.padding_layer = nn.ConstantPad1d((0, self.padding_size), 0.0)
|
|
168
|
+
else:
|
|
169
|
+
self.padding_layer = nn.Identity()
|
|
170
|
+
self.n_times_padded = self.n_times
|
|
171
|
+
|
|
172
|
+
# Temporal aggregator
|
|
173
|
+
self.temporal_layer = _valid_layers[temporal_layer](dim=self.n_dim) # type: ignore
|
|
174
|
+
|
|
175
|
+
# Flatten layer
|
|
176
|
+
self.flatten_layer = Rearrange("batch ... -> batch (...)")
|
|
177
|
+
|
|
178
|
+
# Final fully connected layer
|
|
179
|
+
self.final_layer = LinearWithConstraint(
|
|
180
|
+
in_features=self.n_filters_spat * self.n_bands * self.stride_factor,
|
|
181
|
+
out_features=self.n_outputs,
|
|
182
|
+
max_norm=linear_max_norm,
|
|
183
|
+
)
|
|
184
|
+
|
|
185
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
186
|
+
"""
|
|
187
|
+
Forward pass of the FBCNet model.
|
|
188
|
+
|
|
189
|
+
Parameters
|
|
190
|
+
----------
|
|
191
|
+
x : torch.Tensor
|
|
192
|
+
Input tensor with shape (batch_size, n_chans, n_times).
|
|
193
|
+
|
|
194
|
+
Returns
|
|
195
|
+
-------
|
|
196
|
+
torch.Tensor
|
|
197
|
+
Output tensor with shape (batch_size, n_outputs).
|
|
198
|
+
"""
|
|
199
|
+
# output: (batch_size, n_chans, n_times)
|
|
200
|
+
x = self.spectral_filtering(x)
|
|
201
|
+
|
|
202
|
+
# output: (batch_size, n_bands, n_chans, n_times)
|
|
203
|
+
x = self.spatial_conv(x)
|
|
204
|
+
batch_size, channels, _, _ = x.shape
|
|
205
|
+
|
|
206
|
+
# shape: (batch_size, n_filters_spat * n_bands, 1, n_times)
|
|
207
|
+
x = self.padding_layer(x)
|
|
208
|
+
|
|
209
|
+
# shape: (batch_size, n_filters_spat * n_bands, 1, n_times_padded)
|
|
210
|
+
x = x.view(
|
|
211
|
+
batch_size,
|
|
212
|
+
channels,
|
|
213
|
+
self.stride_factor,
|
|
214
|
+
self.n_times_padded // self.stride_factor,
|
|
215
|
+
)
|
|
216
|
+
# shape: batch_size, n_filters_spat * n_bands, stride, n_times_padded/stride
|
|
217
|
+
x = self.temporal_layer(x) # type: ignore[operator]
|
|
218
|
+
|
|
219
|
+
# shape: batch_size, n_filters_spat * n_bands, stride, 1
|
|
220
|
+
x = self.flatten_layer(x)
|
|
221
|
+
|
|
222
|
+
# shape: batch_size, n_filters_spat * n_bands * stride
|
|
223
|
+
x = self.final_layer(x)
|
|
224
|
+
# shape: batch_size, n_outputs
|
|
225
|
+
return x
|