braindecode 1.3.0.dev177069446__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (124) hide show
  1. braindecode/__init__.py +9 -0
  2. braindecode/augmentation/__init__.py +52 -0
  3. braindecode/augmentation/base.py +225 -0
  4. braindecode/augmentation/functional.py +1300 -0
  5. braindecode/augmentation/transforms.py +1356 -0
  6. braindecode/classifier.py +258 -0
  7. braindecode/datasets/__init__.py +44 -0
  8. braindecode/datasets/base.py +823 -0
  9. braindecode/datasets/bbci.py +693 -0
  10. braindecode/datasets/bcicomp.py +193 -0
  11. braindecode/datasets/bids/__init__.py +54 -0
  12. braindecode/datasets/bids/datasets.py +239 -0
  13. braindecode/datasets/bids/format.py +717 -0
  14. braindecode/datasets/bids/hub.py +987 -0
  15. braindecode/datasets/bids/hub_format.py +717 -0
  16. braindecode/datasets/bids/hub_io.py +197 -0
  17. braindecode/datasets/bids/hub_validation.py +114 -0
  18. braindecode/datasets/bids/iterable.py +220 -0
  19. braindecode/datasets/chb_mit.py +163 -0
  20. braindecode/datasets/mne.py +170 -0
  21. braindecode/datasets/moabb.py +219 -0
  22. braindecode/datasets/nmt.py +313 -0
  23. braindecode/datasets/registry.py +120 -0
  24. braindecode/datasets/siena.py +162 -0
  25. braindecode/datasets/sleep_physio_challe_18.py +411 -0
  26. braindecode/datasets/sleep_physionet.py +125 -0
  27. braindecode/datasets/tuh.py +591 -0
  28. braindecode/datasets/utils.py +67 -0
  29. braindecode/datasets/xy.py +96 -0
  30. braindecode/datautil/__init__.py +62 -0
  31. braindecode/datautil/channel_utils.py +114 -0
  32. braindecode/datautil/hub_formats.py +180 -0
  33. braindecode/datautil/serialization.py +359 -0
  34. braindecode/datautil/util.py +154 -0
  35. braindecode/eegneuralnet.py +372 -0
  36. braindecode/functional/__init__.py +22 -0
  37. braindecode/functional/functions.py +251 -0
  38. braindecode/functional/initialization.py +47 -0
  39. braindecode/models/__init__.py +117 -0
  40. braindecode/models/atcnet.py +830 -0
  41. braindecode/models/attentionbasenet.py +727 -0
  42. braindecode/models/attn_sleep.py +549 -0
  43. braindecode/models/base.py +574 -0
  44. braindecode/models/bendr.py +493 -0
  45. braindecode/models/biot.py +537 -0
  46. braindecode/models/brainmodule.py +845 -0
  47. braindecode/models/config.py +233 -0
  48. braindecode/models/contrawr.py +319 -0
  49. braindecode/models/ctnet.py +541 -0
  50. braindecode/models/deep4.py +376 -0
  51. braindecode/models/deepsleepnet.py +417 -0
  52. braindecode/models/eegconformer.py +475 -0
  53. braindecode/models/eeginception_erp.py +379 -0
  54. braindecode/models/eeginception_mi.py +379 -0
  55. braindecode/models/eegitnet.py +302 -0
  56. braindecode/models/eegminer.py +256 -0
  57. braindecode/models/eegnet.py +359 -0
  58. braindecode/models/eegnex.py +354 -0
  59. braindecode/models/eegsimpleconv.py +201 -0
  60. braindecode/models/eegsym.py +917 -0
  61. braindecode/models/eegtcnet.py +337 -0
  62. braindecode/models/fbcnet.py +225 -0
  63. braindecode/models/fblightconvnet.py +315 -0
  64. braindecode/models/fbmsnet.py +338 -0
  65. braindecode/models/hybrid.py +126 -0
  66. braindecode/models/ifnet.py +443 -0
  67. braindecode/models/labram.py +1316 -0
  68. braindecode/models/luna.py +891 -0
  69. braindecode/models/medformer.py +760 -0
  70. braindecode/models/msvtnet.py +377 -0
  71. braindecode/models/patchedtransformer.py +640 -0
  72. braindecode/models/reve.py +843 -0
  73. braindecode/models/sccnet.py +280 -0
  74. braindecode/models/shallow_fbcsp.py +212 -0
  75. braindecode/models/signal_jepa.py +1122 -0
  76. braindecode/models/sinc_shallow.py +339 -0
  77. braindecode/models/sleep_stager_blanco_2020.py +169 -0
  78. braindecode/models/sleep_stager_chambon_2018.py +159 -0
  79. braindecode/models/sparcnet.py +426 -0
  80. braindecode/models/sstdpn.py +869 -0
  81. braindecode/models/summary.csv +47 -0
  82. braindecode/models/syncnet.py +234 -0
  83. braindecode/models/tcn.py +275 -0
  84. braindecode/models/tidnet.py +397 -0
  85. braindecode/models/tsinception.py +295 -0
  86. braindecode/models/usleep.py +439 -0
  87. braindecode/models/util.py +369 -0
  88. braindecode/modules/__init__.py +92 -0
  89. braindecode/modules/activation.py +86 -0
  90. braindecode/modules/attention.py +883 -0
  91. braindecode/modules/blocks.py +160 -0
  92. braindecode/modules/convolution.py +330 -0
  93. braindecode/modules/filter.py +654 -0
  94. braindecode/modules/layers.py +216 -0
  95. braindecode/modules/linear.py +70 -0
  96. braindecode/modules/parametrization.py +38 -0
  97. braindecode/modules/stats.py +87 -0
  98. braindecode/modules/util.py +85 -0
  99. braindecode/modules/wrapper.py +90 -0
  100. braindecode/preprocessing/__init__.py +271 -0
  101. braindecode/preprocessing/eegprep_preprocess.py +1317 -0
  102. braindecode/preprocessing/mne_preprocess.py +240 -0
  103. braindecode/preprocessing/preprocess.py +579 -0
  104. braindecode/preprocessing/util.py +177 -0
  105. braindecode/preprocessing/windowers.py +1037 -0
  106. braindecode/regressor.py +234 -0
  107. braindecode/samplers/__init__.py +18 -0
  108. braindecode/samplers/base.py +399 -0
  109. braindecode/samplers/ssl.py +263 -0
  110. braindecode/training/__init__.py +23 -0
  111. braindecode/training/callbacks.py +23 -0
  112. braindecode/training/losses.py +105 -0
  113. braindecode/training/scoring.py +477 -0
  114. braindecode/util.py +419 -0
  115. braindecode/version.py +1 -0
  116. braindecode/visualization/__init__.py +8 -0
  117. braindecode/visualization/confusion_matrices.py +289 -0
  118. braindecode/visualization/gradients.py +62 -0
  119. braindecode-1.3.0.dev177069446.dist-info/METADATA +230 -0
  120. braindecode-1.3.0.dev177069446.dist-info/RECORD +124 -0
  121. braindecode-1.3.0.dev177069446.dist-info/WHEEL +5 -0
  122. braindecode-1.3.0.dev177069446.dist-info/licenses/LICENSE.txt +31 -0
  123. braindecode-1.3.0.dev177069446.dist-info/licenses/NOTICE.txt +20 -0
  124. braindecode-1.3.0.dev177069446.dist-info/top_level.txt +1 -0
@@ -0,0 +1,883 @@
1
+ """
2
+ Attention modules used in the AttentionBaseNet from Martin Wimpff (2023).
3
+
4
+ Here, we implement some popular attention modules that can be used in the
5
+ AttentionBaseNet class.
6
+
7
+ """
8
+
9
+ # Authors: Martin Wimpff <martin.wimpff@iss.uni-stuttgart.de>
10
+ # Bruno Aristimunha <b.aristimunha@gmail.com>
11
+ #
12
+ # License: BSD (3-clause)
13
+
14
+ import math
15
+ from typing import Optional
16
+
17
+ import torch
18
+ import torch.nn.functional as F
19
+ from einops import rearrange
20
+ from einops.layers.torch import Rearrange
21
+ from torch import Tensor, nn
22
+
23
+ from braindecode.functional import _get_gaussian_kernel1d
24
+
25
+
26
+ class SqueezeAndExcitation(nn.Module):
27
+ """Squeeze-and-Excitation Networks from [Hu2018]_.
28
+
29
+ Parameters
30
+ ----------
31
+ in_channels : int,
32
+ number of input feature channels.
33
+ reduction_rate : int,
34
+ reduction ratio of the fully-connected layers.
35
+ bias: bool, default=False
36
+ if True, adds a learnable bias will be used in the convolution.
37
+
38
+ Examples
39
+ --------
40
+ >>> import torch
41
+ >>> from braindecode.modules import SqueezeAndExcitation
42
+ >>> module = SqueezeAndExcitation(in_channels=16, reduction_rate=4)
43
+ >>> inputs = torch.randn(2, 16, 1, 64)
44
+ >>> outputs = module(inputs)
45
+ >>> outputs.shape
46
+ torch.Size([2, 16, 1, 64])
47
+
48
+ References
49
+ ----------
50
+ .. [Hu2018] Hu, J., Albanie, S., Sun, G., Wu, E., 2018.
51
+ Squeeze-and-Excitation Networks. CVPR 2018.
52
+ """
53
+
54
+ def __init__(self, in_channels: int, reduction_rate: int, bias: bool = False):
55
+ super(SqueezeAndExcitation, self).__init__()
56
+ sq_channels = int(in_channels // reduction_rate)
57
+ self.gap = nn.AdaptiveAvgPool2d(1)
58
+ self.fc1 = nn.Conv2d(
59
+ in_channels=in_channels, out_channels=sq_channels, kernel_size=1, bias=bias
60
+ )
61
+ self.nonlinearity = nn.ReLU()
62
+ self.fc2 = nn.Conv2d(
63
+ in_channels=reduction_rate,
64
+ out_channels=in_channels,
65
+ kernel_size=1,
66
+ bias=bias,
67
+ )
68
+ self.sigmoid = nn.Sigmoid()
69
+
70
+ def forward(self, x):
71
+ """
72
+ Apply the Squeeze-and-Excitation block to the input tensor.
73
+
74
+ Parameters
75
+ ----------
76
+ x: Pytorch.Tensor
77
+
78
+ Returns
79
+ -------
80
+ scale*x: Pytorch.Tensor
81
+ """
82
+ scale = self.gap(x)
83
+ scale = self.fc1(scale)
84
+ scale = self.nonlinearity(scale)
85
+ scale = self.fc2(scale)
86
+ scale = self.sigmoid(scale)
87
+ return scale * x
88
+
89
+
90
+ class GSoP(nn.Module):
91
+ """
92
+ Global Second-order Pooling Convolutional Networks from [Gao2018]_.
93
+
94
+ Parameters
95
+ ----------
96
+ in_channels : int,
97
+ number of input feature channels
98
+ reduction_rate : int,
99
+ reduction ratio of the fully-connected layers
100
+ bias: bool, default=False
101
+ if True, adds a learnable bias will be used in the convolution.
102
+
103
+ Examples
104
+ --------
105
+ >>> import torch
106
+ >>> from braindecode.modules import GSoP
107
+ >>> module = GSoP(in_channels=16, reduction_rate=4)
108
+ >>> inputs = torch.randn(2, 16, 1, 64)
109
+ >>> outputs = module(inputs)
110
+ >>> outputs.shape
111
+ torch.Size([2, 16, 1, 64])
112
+
113
+ References
114
+ ----------
115
+ .. [Gao2018] Gao, Z., Jiangtao, X., Wang, Q., Li, P., 2018.
116
+ Global Second-order Pooling Convolutional Networks. CVPR 2018.
117
+ """
118
+
119
+ def __init__(self, in_channels: int, reduction_rate: int, bias: bool = True):
120
+ super(GSoP, self).__init__()
121
+ sq_channels = int(in_channels // reduction_rate)
122
+ self.pw_conv1 = nn.Conv2d(in_channels, sq_channels, 1, bias=bias)
123
+ self.bn = nn.BatchNorm2d(sq_channels)
124
+ self.rw_conv = nn.Conv2d(
125
+ sq_channels,
126
+ sq_channels * 4,
127
+ (sq_channels, 1),
128
+ groups=sq_channels,
129
+ bias=bias,
130
+ )
131
+ self.pw_conv2 = nn.Conv2d(sq_channels * 4, in_channels, 1, bias=bias)
132
+
133
+ def forward(self, x):
134
+ """
135
+ Apply the Global Second-order Pooling Convolutional Networks block.
136
+
137
+ Parameters
138
+ ----------
139
+ x: Pytorch.Tensor
140
+
141
+ Returns
142
+ -------
143
+ Pytorch.Tensor
144
+ """
145
+ scale = self.pw_conv1(x).squeeze(-2) # b x c x t
146
+ scale_zero_mean = scale - scale.mean(-1, keepdim=True)
147
+ t = scale_zero_mean.shape[-1]
148
+ cov = torch.bmm(scale_zero_mean, scale_zero_mean.transpose(1, 2)) / (t - 1)
149
+ cov = cov.unsqueeze(-1) # b x c x c x 1
150
+ cov = self.bn(cov)
151
+ scale = self.rw_conv(cov) # b x c x 1 x 1
152
+ scale = self.pw_conv2(scale)
153
+ return scale * x
154
+
155
+
156
+ class FCA(nn.Module):
157
+ """
158
+ Frequency Channel Attention Networks from [Qin2021]_.
159
+
160
+ Parameters
161
+ ----------
162
+ in_channels : int
163
+ Number of input feature channels
164
+ seq_len : int
165
+ Sequence length along temporal dimension, default=62
166
+ reduction_rate : int, default=4
167
+ Reduction ratio of the fully-connected layers.
168
+
169
+ Examples
170
+ --------
171
+ >>> import torch
172
+ >>> from braindecode.modules import FCA
173
+ >>> module = FCA(in_channels=16, seq_len=64, reduction_rate=4, freq_idx=0)
174
+ >>> inputs = torch.randn(2, 16, 1, 64)
175
+ >>> outputs = module(inputs)
176
+ >>> outputs.shape
177
+ torch.Size([2, 16, 1, 64])
178
+
179
+ References
180
+ ----------
181
+ .. [Qin2021] Qin, Z., Zhang, P., Wu, F., Li, X., 2021.
182
+ FcaNet: Frequency Channel Attention Networks. ICCV 2021.
183
+ """
184
+
185
+ def __init__(
186
+ self, in_channels, seq_len: int = 62, reduction_rate: int = 4, freq_idx: int = 0
187
+ ):
188
+ super(FCA, self).__init__()
189
+ mapper_y = [freq_idx]
190
+ if in_channels % len(mapper_y) != 0:
191
+ raise ValueError("in_channels must be divisible by number of DCT filters")
192
+
193
+ self.weight = nn.Parameter(
194
+ self.get_dct_filter(seq_len, mapper_y, in_channels), requires_grad=False
195
+ )
196
+ self.fc = nn.Sequential(
197
+ nn.Linear(in_channels, in_channels // reduction_rate, bias=False),
198
+ nn.ReLU(inplace=True),
199
+ nn.Linear(in_channels // reduction_rate, in_channels, bias=False),
200
+ nn.Sigmoid(),
201
+ )
202
+
203
+ def forward(self, x):
204
+ """
205
+ Apply the Frequency Channel Attention Networks block to the input.
206
+
207
+ Parameters
208
+ ----------
209
+ x: Pytorch.Tensor
210
+
211
+ Returns
212
+ -------
213
+ Pytorch.Tensor
214
+ """
215
+ scale = x.squeeze(-2) * self.weight
216
+ scale = torch.sum(scale, dim=-1)
217
+ scale = rearrange(self.fc(scale), "b c -> b c 1 1")
218
+ return x * scale.expand_as(x)
219
+
220
+ @staticmethod
221
+ def get_dct_filter(seq_len: int, mapper_y: list, in_channels: int):
222
+ """
223
+ Util function to get the DCT filter.
224
+
225
+ Parameters
226
+ ----------
227
+ seq_len: int
228
+ Size of the sequence
229
+ mapper_y:
230
+ List of frequencies
231
+ in_channels:
232
+ Number of input channels.
233
+
234
+ Returns
235
+ -------
236
+ torch.Tensor
237
+ """
238
+ dct_filter = torch.zeros(in_channels, seq_len)
239
+
240
+ c_part = in_channels // len(mapper_y)
241
+
242
+ for i, v_y in enumerate(mapper_y):
243
+ for t_y in range(seq_len):
244
+ filter = math.cos(math.pi * v_y * (t_y + 0.5) / seq_len) / math.sqrt(
245
+ seq_len
246
+ )
247
+ filter = filter * math.sqrt(2) if v_y != 0 else filter
248
+ dct_filter[i * c_part : (i + 1) * c_part, t_y] = filter
249
+ return dct_filter
250
+
251
+
252
+ class EncNet(nn.Module):
253
+ """
254
+ Context Encoding for Semantic Segmentation from [Zhang2018]_.
255
+
256
+ Parameters
257
+ ----------
258
+ in_channels : int
259
+ number of input feature channels
260
+ n_codewords : int
261
+ number of codewords
262
+
263
+ Examples
264
+ --------
265
+ >>> import torch
266
+ >>> from braindecode.modules import EncNet
267
+ >>> module = EncNet(in_channels=16, n_codewords=8)
268
+ >>> inputs = torch.randn(2, 16, 1, 64)
269
+ >>> outputs = module(inputs)
270
+ >>> outputs.shape
271
+ torch.Size([2, 16, 1, 64])
272
+
273
+ References
274
+ ----------
275
+ .. [Zhang2018] Zhang, H. et al. 2018.
276
+ Context Encoding for Semantic Segmentation. CVPR 2018.
277
+ """
278
+
279
+ def __init__(self, in_channels: int, n_codewords: int):
280
+ super(EncNet, self).__init__()
281
+ self.n_codewords = n_codewords
282
+ self.codewords = nn.Parameter(torch.empty(n_codewords, in_channels))
283
+ self.smoothing = nn.Parameter(torch.empty(n_codewords))
284
+ std = 1 / ((n_codewords * in_channels) ** (1 / 2))
285
+ nn.init.uniform_(self.codewords.data, -std, std)
286
+ nn.init.uniform_(self.smoothing, -1, 0)
287
+ self.bn = nn.BatchNorm1d(n_codewords)
288
+ self.fc = nn.Linear(in_channels, in_channels)
289
+
290
+ def forward(self, x):
291
+ """
292
+ Apply attention from the Context Encoding for Semantic Segmentation.
293
+
294
+ Parameters
295
+ ----------
296
+ x: Pytorch.Tensor
297
+
298
+ Returns
299
+ -------
300
+ Pytorch.Tensor
301
+ """
302
+ b, c, _, seq = x.shape
303
+ # b x c x 1 x t -> b x t x k x c
304
+ x_ = rearrange(x, pattern="b c 1 seq -> b seq 1 c")
305
+ x_ = x_.expand(b, seq, self.n_codewords, c)
306
+ cw_ = self.codewords.unsqueeze(0).unsqueeze(0) # 1 x 1 x k x c
307
+ a = self.smoothing.unsqueeze(0).unsqueeze(0) * (x_ - cw_).pow(2).sum(3)
308
+ a = torch.softmax(a, dim=2) # b x t x k
309
+
310
+ # aggregate
311
+ e = (a.unsqueeze(3) * (x_ - cw_)).sum(1) # b x k x c
312
+ e_norm = torch.relu(self.bn(e)).mean(1) # b x c
313
+
314
+ scale = torch.sigmoid(self.fc(e_norm))
315
+ return x * scale.unsqueeze(2).unsqueeze(3)
316
+
317
+
318
+ class ECA(nn.Module):
319
+ """
320
+ Efficient Channel Attention [Wang2021]_.
321
+
322
+ Parameters
323
+ ----------
324
+ in_channels : int
325
+ number of input feature channels
326
+ kernel_size : int
327
+ kernel size of convolutional layer, determines degree of channel
328
+ interaction, must be odd.
329
+
330
+ Examples
331
+ --------
332
+ >>> import torch
333
+ >>> from braindecode.modules import ECA
334
+ >>> module = ECA(in_channels=16, kernel_size=3)
335
+ >>> inputs = torch.randn(2, 16, 1, 64)
336
+ >>> outputs = module(inputs)
337
+ >>> outputs.shape
338
+ torch.Size([2, 16, 1, 64])
339
+
340
+ References
341
+ ----------
342
+ .. [Wang2021] Wang, Q. et al., 2021. ECA-Net: Efficient Channel Attention
343
+ for Deep Convolutional Neural Networks. CVPR 2021.
344
+ """
345
+
346
+ def __init__(self, in_channels: int, kernel_size: int):
347
+ super(ECA, self).__init__()
348
+ self.gap = nn.AdaptiveAvgPool2d(1)
349
+ if kernel_size % 2 != 1:
350
+ raise ValueError("kernel size must be odd for same padding")
351
+ self.conv = nn.Conv1d(
352
+ 1, 1, kernel_size=kernel_size, padding=kernel_size // 2, bias=False
353
+ )
354
+
355
+ def forward(self, x):
356
+ """
357
+ Apply the Efficient Channel Attention block to the input tensor.
358
+
359
+ Parameters
360
+ ----------
361
+ x: Pytorch.Tensor
362
+
363
+ Returns
364
+ -------
365
+ Pytorch.Tensor
366
+ """
367
+ scale = self.gap(x)
368
+ scale = rearrange(scale, "b c 1 1 -> b 1 c")
369
+ scale = self.conv(scale)
370
+ scale = torch.sigmoid(rearrange(scale, "b 1 c -> b c 1 1"))
371
+ return x * scale
372
+
373
+
374
+ class GatherExcite(nn.Module):
375
+ """
376
+ Gather-Excite Networks from [Hu2018b]_.
377
+
378
+ Parameters
379
+ ----------
380
+ in_channels : int
381
+ number of input feature channels
382
+ seq_len : int, default=62
383
+ sequence length along temporal dimension
384
+ extra_params : bool, default=False
385
+ whether to use a convolutional layer as a gather module
386
+ use_mlp : bool, default=False
387
+ whether to use an excite block with fully-connected layers
388
+ reduction_rate : int, default=4
389
+ reduction ratio of the excite block (if used)
390
+
391
+ Examples
392
+ --------
393
+ >>> import torch
394
+ >>> from braindecode.modules import GatherExcite
395
+ >>> module = GatherExcite(in_channels=16, seq_len=64, extra_params=False, use_mlp=True)
396
+ >>> inputs = torch.randn(2, 16, 1, 64)
397
+ >>> outputs = module(inputs)
398
+ >>> outputs.shape
399
+ torch.Size([2, 16, 1, 64])
400
+
401
+ References
402
+ ----------
403
+ .. [Hu2018b] Hu, J., Albanie, S., Sun, G., Vedaldi, A., 2018.
404
+ Gather-Excite: Exploiting Feature Context in Convolutional Neural Networks.
405
+ NeurIPS 2018.
406
+ """
407
+
408
+ def __init__(
409
+ self,
410
+ in_channels: int,
411
+ seq_len: int = 62,
412
+ extra_params: bool = False,
413
+ use_mlp: bool = False,
414
+ reduction_rate: int = 4,
415
+ ):
416
+ super(GatherExcite, self).__init__()
417
+ if extra_params:
418
+ self.gather = nn.Sequential(
419
+ nn.Conv2d(
420
+ in_channels,
421
+ in_channels,
422
+ (1, seq_len),
423
+ groups=in_channels,
424
+ bias=False,
425
+ ),
426
+ nn.BatchNorm2d(in_channels),
427
+ )
428
+ else:
429
+ self.gather = nn.AdaptiveAvgPool2d(1)
430
+
431
+ if use_mlp:
432
+ self.mlp = nn.Sequential(
433
+ nn.Conv2d(
434
+ in_channels, int(in_channels // reduction_rate), 1, bias=False
435
+ ),
436
+ nn.ReLU(),
437
+ nn.Conv2d(
438
+ int(in_channels // reduction_rate), in_channels, 1, bias=False
439
+ ),
440
+ )
441
+ else:
442
+ self.mlp = nn.Identity()
443
+
444
+ def forward(self, x):
445
+ """
446
+ Apply the Gather-Excite Networks block to the input tensor.
447
+
448
+ Parameters
449
+ ----------
450
+ x: Pytorch.Tensor
451
+
452
+ Returns
453
+ -------
454
+ Pytorch.Tensor
455
+ """
456
+ scale = self.gather(x)
457
+ scale = torch.sigmoid(self.mlp(scale))
458
+ return scale * x
459
+
460
+
461
+ class GCT(nn.Module):
462
+ """
463
+ Gated Channel Transformation from [Yang2020]_.
464
+
465
+ Parameters
466
+ ----------
467
+ in_channels : int
468
+ number of input feature channels
469
+
470
+ Examples
471
+ --------
472
+ >>> import torch
473
+ >>> from braindecode.modules import GCT
474
+ >>> module = GCT(in_channels=16)
475
+ >>> inputs = torch.randn(2, 16, 1, 64)
476
+ >>> outputs = module(inputs)
477
+ >>> outputs.shape
478
+ torch.Size([2, 16, 1, 64])
479
+
480
+ References
481
+ ----------
482
+ .. [Yang2020] Yang, Z. Linchao, Z., Wu, Y., Yang, Y., 2020.
483
+ Gated Channel Transformation for Visual Recognition. CVPR 2020.
484
+ """
485
+
486
+ def __init__(self, in_channels: int):
487
+ super(GCT, self).__init__()
488
+ self.alpha = nn.Parameter(torch.ones(1, in_channels, 1, 1))
489
+ self.beta = nn.Parameter(torch.zeros(1, in_channels, 1, 1))
490
+ self.gamma = nn.Parameter(torch.zeros(1, in_channels, 1, 1))
491
+
492
+ def forward(self, x, eps: float = 1e-5):
493
+ """
494
+ Apply the Gated Channel Transformation block to the input tensor.
495
+
496
+ Parameters
497
+ ----------
498
+ x: Pytorch.Tensor
499
+ eps: float, default=1e-5
500
+
501
+ Returns
502
+ -------
503
+ Pytorch.Tensor
504
+ the original tensor x multiplied by the gate.
505
+ """
506
+ embedding = (x.pow(2).sum((2, 3), keepdim=True) + eps).pow(0.5) * self.alpha
507
+ norm = self.gamma / (embedding.pow(2).mean(dim=1, keepdim=True) + eps).pow(0.5)
508
+ gate = 1.0 + torch.tanh(embedding * norm + self.beta)
509
+ return x * gate
510
+
511
+
512
+ class SRM(nn.Module):
513
+ """
514
+ Attention module from [Lee2019]_.
515
+
516
+ Parameters
517
+ ----------
518
+ in_channels : int
519
+ number of input feature channels
520
+ use_mlp : bool, default=False
521
+ whether to use fully-connected layers instead of a convolutional layer,
522
+ reduction_rate : int, default=4
523
+ reduction ratio of the fully-connected layers (if used),
524
+
525
+ Examples
526
+ --------
527
+ >>> import torch
528
+ >>> from braindecode.modules import SRM
529
+ >>> module = SRM(in_channels=16, use_mlp=False)
530
+ >>> inputs = torch.randn(2, 16, 1, 64)
531
+ >>> outputs = module(inputs)
532
+ >>> outputs.shape
533
+ torch.Size([2, 16, 1, 64])
534
+
535
+ References
536
+ ----------
537
+ .. [Lee2019] Lee, H., Kim, H., Nam, H., 2019. SRM: A Style-based
538
+ Recalibration Module for Convolutional Neural Networks. ICCV 2019.
539
+ """
540
+
541
+ def __init__(
542
+ self,
543
+ in_channels: int,
544
+ use_mlp: bool = False,
545
+ reduction_rate: int = 4,
546
+ bias: bool = False,
547
+ ):
548
+ super(SRM, self).__init__()
549
+ self.gap = nn.AdaptiveAvgPool2d(1)
550
+ if use_mlp:
551
+ self.style_integration = nn.Sequential(
552
+ Rearrange("b c n_metrics -> b (c n_metrics)"),
553
+ nn.Linear(
554
+ in_channels * 2, in_channels * 2 // reduction_rate, bias=bias
555
+ ),
556
+ nn.ReLU(),
557
+ nn.Linear(in_channels * 2 // reduction_rate, in_channels, bias=bias),
558
+ Rearrange("b c -> b c 1"),
559
+ )
560
+ else:
561
+ self.style_integration = nn.Conv1d(
562
+ in_channels, in_channels, 2, groups=in_channels, bias=bias
563
+ )
564
+ self.bn = nn.BatchNorm1d(in_channels)
565
+
566
+ def forward(self, x):
567
+ """
568
+ Apply the Style-based Recalibration Module to the input tensor.
569
+
570
+ Parameters
571
+ ----------
572
+ x: Pytorch.Tensor
573
+
574
+ Returns
575
+ -------
576
+ Pytorch.Tensor
577
+ """
578
+ mu = self.gap(x).squeeze(-1) # b x c x 1
579
+ std = x.std(dim=(-2, -1), keepdim=True).squeeze(-1) # b x c x 1
580
+ t = torch.cat([mu, std], dim=2) # b x c x 2
581
+ z = self.style_integration(t) # b x c x 1
582
+ z = self.bn(z)
583
+ scale = nn.functional.sigmoid(z).unsqueeze(-1)
584
+ return scale * x
585
+
586
+
587
+ class CBAM(nn.Module):
588
+ """
589
+ Convolutional Block Attention Module from [Woo2018]_.
590
+
591
+ Parameters
592
+ ----------
593
+ in_channels : int
594
+ number of input feature channels
595
+ reduction_rate : int
596
+ reduction ratio of the fully-connected layers
597
+ kernel_size : int
598
+ kernel size of the convolutional layer
599
+
600
+ Examples
601
+ --------
602
+ >>> import torch
603
+ >>> from braindecode.modules import CBAM
604
+ >>> module = CBAM(in_channels=16, reduction_rate=4, kernel_size=3)
605
+ >>> inputs = torch.randn(2, 16, 1, 64)
606
+ >>> outputs = module(inputs)
607
+ >>> outputs.shape
608
+ torch.Size([2, 16, 1, 64])
609
+
610
+ References
611
+ ----------
612
+ .. [Woo2018] Woo, S., Park, J., Lee, J., Kweon, I., 2018.
613
+ CBAM: Convolutional Block Attention Module. ECCV 2018.
614
+ """
615
+
616
+ def __init__(self, in_channels: int, reduction_rate: int, kernel_size: int):
617
+ super(CBAM, self).__init__()
618
+ self.avg_pool = nn.AdaptiveAvgPool2d(1)
619
+ self.max_pool = nn.AdaptiveMaxPool2d(1)
620
+ self.fc = nn.Sequential(
621
+ nn.Conv2d(in_channels, in_channels // reduction_rate, 1, bias=False),
622
+ nn.ReLU(),
623
+ nn.Conv2d(in_channels // reduction_rate, in_channels, 1, bias=False),
624
+ )
625
+ if kernel_size % 2 != 1:
626
+ raise ValueError("kernel size must be odd for same padding")
627
+ self.conv = nn.Conv2d(2, 1, (1, kernel_size), padding=(0, kernel_size // 2))
628
+
629
+ def forward(self, x):
630
+ """
631
+ Apply the Convolutional Block Attention Module to the input tensor.
632
+
633
+ Parameters
634
+ ----------
635
+ x: Pytorch.Tensor
636
+
637
+ Returns
638
+ -------
639
+ Pytorch.Tensor
640
+ """
641
+ channel_attention = torch.sigmoid(
642
+ self.fc(self.avg_pool(x)) + self.fc(self.max_pool(x))
643
+ )
644
+ x = x * channel_attention
645
+ spat_input = torch.cat(
646
+ [torch.mean(x, dim=1, keepdim=True), torch.max(x, dim=1, keepdim=True)[0]],
647
+ dim=1,
648
+ )
649
+ spatial_attention = torch.sigmoid(self.conv(spat_input))
650
+ return x * spatial_attention
651
+
652
+
653
+ class CAT(nn.Module):
654
+ """
655
+ Attention Mechanism from [Wu2023]_.
656
+
657
+ Parameters
658
+ ----------
659
+ in_channels : int
660
+ number of input feature channels
661
+ reduction_rate : int
662
+ reduction ratio of the fully-connected layers
663
+ kernel_size : int
664
+ kernel size of the convolutional layer
665
+ bias : bool, default=False
666
+ if True, adds a learnable bias will be used in the convolution,
667
+
668
+ Examples
669
+ --------
670
+ >>> import torch
671
+ >>> from braindecode.modules import CAT
672
+ >>> module = CAT(in_channels=16, reduction_rate=4, kernel_size=3)
673
+ >>> inputs = torch.randn(2, 16, 1, 64)
674
+ >>> outputs = module(inputs)
675
+ >>> outputs.shape
676
+ torch.Size([2, 16, 1, 64])
677
+
678
+ References
679
+ ----------
680
+ .. [Wu2023] Wu, Z. et al., 2023
681
+ CAT: Learning to Collaborate Channel and Spatial Attention from
682
+ Multi-Information Fusion. IET Computer Vision 2023.
683
+ """
684
+
685
+ def __init__(
686
+ self, in_channels: int, reduction_rate: int, kernel_size: int, bias=False
687
+ ):
688
+ super(CAT, self).__init__()
689
+ self.gauss_filter = nn.Conv2d(1, 1, (1, 5), padding=(0, 2), bias=False)
690
+ self.gauss_filter.weight = nn.Parameter(
691
+ _get_gaussian_kernel1d(5, 1.0)[None, None, None, :], requires_grad=False
692
+ )
693
+ self.mlp = nn.Sequential(
694
+ nn.Conv2d(in_channels, in_channels // reduction_rate, 1, bias=bias),
695
+ nn.ReLU(),
696
+ nn.Conv2d(in_channels // reduction_rate, in_channels, 1, bias=bias),
697
+ )
698
+ self.conv = nn.Conv2d(
699
+ in_channels,
700
+ in_channels,
701
+ kernel_size=(1, kernel_size),
702
+ padding=(0, kernel_size // 2),
703
+ bias=bias,
704
+ )
705
+
706
+ self.c_alpha = nn.Parameter(torch.zeros(1))
707
+ self.c_beta = nn.Parameter(torch.zeros(1))
708
+ self.c_gamma = nn.Parameter(torch.zeros(1))
709
+ self.s_alpha = nn.Parameter(torch.zeros(1))
710
+ self.s_beta = nn.Parameter(torch.zeros(1))
711
+ self.s_gamma = nn.Parameter(torch.zeros(1))
712
+ self.c_w = nn.Parameter(torch.zeros(1))
713
+ self.s_w = nn.Parameter(torch.zeros(1))
714
+
715
+ def forward(self, x):
716
+ """
717
+ Apply the CAT block to the input tensor.
718
+
719
+ Parameters
720
+ ----------
721
+ x: Pytorch.Tensor
722
+
723
+ Returns
724
+ -------
725
+ Pytorch.Tensor
726
+ """
727
+ b, c, h, w = x.shape
728
+ x_blurred = self.gauss_filter(x.transpose(1, 2)).transpose(1, 2)
729
+
730
+ c_gap = self.mlp(x.mean(dim=(-2, -1), keepdim=True))
731
+ c_gmp = self.mlp(torch.amax(x_blurred, dim=(-2, -1), keepdim=True))
732
+ pi = torch.softmax(x, dim=-1)
733
+ c_gep = -1 * (pi * torch.log(pi)).sum(dim=(-2, -1), keepdim=True)
734
+ c_gep_min = torch.amin(c_gep, dim=(-3, -2, -1), keepdim=True)
735
+ c_gep_max = torch.amax(c_gep, dim=(-3, -2, -1), keepdim=True)
736
+ c_gep = self.mlp((c_gep - c_gep_min) / (c_gep_max - c_gep_min))
737
+ channel_score = torch.sigmoid(
738
+ c_gap * self.c_alpha + c_gmp * self.c_beta + c_gep * self.c_gamma
739
+ )
740
+ channel_score = channel_score.expand(b, c, h, w)
741
+
742
+ s_gap = x.mean(dim=1, keepdim=True)
743
+ s_gmp = torch.amax(x_blurred, dim=(-2, -1), keepdim=True)
744
+ pi = torch.softmax(x, dim=1)
745
+ s_gep = -1 * (pi * torch.log(pi)).sum(dim=1, keepdim=True)
746
+ s_gep_min = torch.amin(s_gep, dim=(-2, -1), keepdim=True)
747
+ s_gep_max = torch.amax(s_gep, dim=(-2, -1), keepdim=True)
748
+ s_gep = (s_gep - s_gep_min) / (s_gep_max - s_gep_min)
749
+ spatial_score = (
750
+ -s_gap * self.s_alpha + s_gmp * self.s_beta + s_gep * self.s_gamma
751
+ )
752
+ spatial_score = torch.sigmoid(self.conv(spatial_score)).expand(b, c, h, w)
753
+
754
+ c_w = torch.exp(self.c_w) / (torch.exp(self.c_w) + torch.exp(self.s_w))
755
+ s_w = torch.exp(self.s_w) / (torch.exp(self.c_w) + torch.exp(self.s_w))
756
+
757
+ scale = channel_score * c_w + spatial_score * s_w
758
+ return scale * x
759
+
760
+
761
+ class CATLite(nn.Module):
762
+ """
763
+ Modification of CAT without the convolutional layer from [Wu2023]_.
764
+
765
+ Parameters
766
+ ----------
767
+ in_channels : int
768
+ number of input feature channels
769
+ reduction_rate : int
770
+ reduction ratio of the fully-connected layers
771
+ bias : bool, default=True
772
+ if True, adds a learnable bias will be used in the convolution,
773
+
774
+ Examples
775
+ --------
776
+ >>> import torch
777
+ >>> from braindecode.modules import CATLite
778
+ >>> module = CATLite(in_channels=16, reduction_rate=4)
779
+ >>> inputs = torch.randn(2, 16, 1, 64)
780
+ >>> outputs = module(inputs)
781
+ >>> outputs.shape
782
+ torch.Size([2, 16, 1, 64])
783
+
784
+ References
785
+ ----------
786
+ .. [Wu2023] Wu, Z. et al., 2023 CAT: Learning to Collaborate Channel and
787
+ Spatial Attention from Multi-Information Fusion. IET Computer Vision 2023.
788
+ """
789
+
790
+ def __init__(self, in_channels: int, reduction_rate: int, bias: bool = True):
791
+ super(CATLite, self).__init__()
792
+ self.gauss_filter = nn.Conv2d(1, 1, (1, 5), padding=(0, 2), bias=False)
793
+ self.gauss_filter.weight = nn.Parameter(
794
+ _get_gaussian_kernel1d(5, 1.0)[None, None, None, :], requires_grad=False
795
+ )
796
+ self.mlp = nn.Sequential(
797
+ nn.Conv2d(in_channels, int(in_channels // reduction_rate), 1, bias=bias),
798
+ nn.ReLU(),
799
+ nn.Conv2d(int(in_channels // reduction_rate), in_channels, 1, bias=bias),
800
+ )
801
+
802
+ self.c_alpha = nn.Parameter(torch.zeros(1))
803
+ self.c_beta = nn.Parameter(torch.zeros(1))
804
+ self.c_gamma = nn.Parameter(torch.zeros(1))
805
+
806
+ def forward(self, x):
807
+ """
808
+ Apply the CATLite block to the input tensor.
809
+
810
+ Parameters
811
+ ----------
812
+ x: Pytorch.Tensor
813
+
814
+ Returns
815
+ -------
816
+ Pytorch.Tensor
817
+ """
818
+ b, c, h, w = x.shape
819
+ x_blurred = self.gauss_filter(x.transpose(1, 2)).transpose(1, 2)
820
+
821
+ c_gap = self.mlp(x.mean(dim=(-2, -1), keepdim=True))
822
+ c_gmp = self.mlp(torch.amax(x_blurred, dim=(-2, -1), keepdim=True))
823
+ pi = torch.softmax(x, dim=-1)
824
+ c_gep = -1 * (pi * torch.log(pi)).sum(dim=(-2, -1), keepdim=True)
825
+ c_gep_min = torch.amin(c_gep, dim=(-3, -2, -1), keepdim=True)
826
+ c_gep_max = torch.amax(c_gep, dim=(-3, -2, -1), keepdim=True)
827
+ c_gep = self.mlp((c_gep - c_gep_min) / (c_gep_max - c_gep_min))
828
+ channel_score = torch.sigmoid(
829
+ c_gap * self.c_alpha + c_gmp * self.c_beta + c_gep * self.c_gamma
830
+ )
831
+ channel_score = channel_score.expand(b, c, h, w)
832
+
833
+ return channel_score * x
834
+
835
+
836
+ class MultiHeadAttention(nn.Module):
837
+ """Multi-head self-attention block.
838
+
839
+ Examples
840
+ --------
841
+ >>> import torch
842
+ >>> from braindecode.modules import MultiHeadAttention
843
+ >>> module = MultiHeadAttention(emb_size=32, num_heads=4, dropout=0.1)
844
+ >>> inputs = torch.randn(2, 10, 32)
845
+ >>> outputs = module(inputs)
846
+ >>> outputs.shape
847
+ torch.Size([2, 10, 32])
848
+ """
849
+
850
+ def __init__(self, emb_size, num_heads, dropout):
851
+ super().__init__()
852
+ self.emb_size = emb_size
853
+ self.num_heads = num_heads
854
+ self.keys = nn.Linear(emb_size, emb_size)
855
+ self.queries = nn.Linear(emb_size, emb_size)
856
+ self.values = nn.Linear(emb_size, emb_size)
857
+ self.att_drop = nn.Dropout(dropout)
858
+ self.projection = nn.Linear(emb_size, emb_size)
859
+
860
+ self.rearrange_stack = Rearrange(
861
+ "b n (h d) -> b h n d",
862
+ h=num_heads,
863
+ )
864
+ self.rearrange_unstack = Rearrange(
865
+ "b h n d -> b n (h d)",
866
+ )
867
+
868
+ def forward(self, x: Tensor, mask: Optional[Tensor] = None) -> Tensor:
869
+ queries = self.rearrange_stack(self.queries(x))
870
+ keys = self.rearrange_stack(self.keys(x))
871
+ values = self.rearrange_stack(self.values(x))
872
+ energy = torch.einsum("bhqd, bhkd -> bhqk", queries, keys)
873
+ if mask is not None:
874
+ fill_value = float("-inf")
875
+ energy = energy.masked_fill(~mask, fill_value)
876
+
877
+ scaling = self.emb_size ** (1 / 2)
878
+ att = F.softmax(energy / scaling, dim=-1)
879
+ att = self.att_drop(att)
880
+ out = torch.einsum("bhal, bhlv -> bhav ", att, values)
881
+ out = self.rearrange_unstack(out)
882
+ out = self.projection(out)
883
+ return out