braindecode 1.3.0.dev177069446__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- braindecode/__init__.py +9 -0
- braindecode/augmentation/__init__.py +52 -0
- braindecode/augmentation/base.py +225 -0
- braindecode/augmentation/functional.py +1300 -0
- braindecode/augmentation/transforms.py +1356 -0
- braindecode/classifier.py +258 -0
- braindecode/datasets/__init__.py +44 -0
- braindecode/datasets/base.py +823 -0
- braindecode/datasets/bbci.py +693 -0
- braindecode/datasets/bcicomp.py +193 -0
- braindecode/datasets/bids/__init__.py +54 -0
- braindecode/datasets/bids/datasets.py +239 -0
- braindecode/datasets/bids/format.py +717 -0
- braindecode/datasets/bids/hub.py +987 -0
- braindecode/datasets/bids/hub_format.py +717 -0
- braindecode/datasets/bids/hub_io.py +197 -0
- braindecode/datasets/bids/hub_validation.py +114 -0
- braindecode/datasets/bids/iterable.py +220 -0
- braindecode/datasets/chb_mit.py +163 -0
- braindecode/datasets/mne.py +170 -0
- braindecode/datasets/moabb.py +219 -0
- braindecode/datasets/nmt.py +313 -0
- braindecode/datasets/registry.py +120 -0
- braindecode/datasets/siena.py +162 -0
- braindecode/datasets/sleep_physio_challe_18.py +411 -0
- braindecode/datasets/sleep_physionet.py +125 -0
- braindecode/datasets/tuh.py +591 -0
- braindecode/datasets/utils.py +67 -0
- braindecode/datasets/xy.py +96 -0
- braindecode/datautil/__init__.py +62 -0
- braindecode/datautil/channel_utils.py +114 -0
- braindecode/datautil/hub_formats.py +180 -0
- braindecode/datautil/serialization.py +359 -0
- braindecode/datautil/util.py +154 -0
- braindecode/eegneuralnet.py +372 -0
- braindecode/functional/__init__.py +22 -0
- braindecode/functional/functions.py +251 -0
- braindecode/functional/initialization.py +47 -0
- braindecode/models/__init__.py +117 -0
- braindecode/models/atcnet.py +830 -0
- braindecode/models/attentionbasenet.py +727 -0
- braindecode/models/attn_sleep.py +549 -0
- braindecode/models/base.py +574 -0
- braindecode/models/bendr.py +493 -0
- braindecode/models/biot.py +537 -0
- braindecode/models/brainmodule.py +845 -0
- braindecode/models/config.py +233 -0
- braindecode/models/contrawr.py +319 -0
- braindecode/models/ctnet.py +541 -0
- braindecode/models/deep4.py +376 -0
- braindecode/models/deepsleepnet.py +417 -0
- braindecode/models/eegconformer.py +475 -0
- braindecode/models/eeginception_erp.py +379 -0
- braindecode/models/eeginception_mi.py +379 -0
- braindecode/models/eegitnet.py +302 -0
- braindecode/models/eegminer.py +256 -0
- braindecode/models/eegnet.py +359 -0
- braindecode/models/eegnex.py +354 -0
- braindecode/models/eegsimpleconv.py +201 -0
- braindecode/models/eegsym.py +917 -0
- braindecode/models/eegtcnet.py +337 -0
- braindecode/models/fbcnet.py +225 -0
- braindecode/models/fblightconvnet.py +315 -0
- braindecode/models/fbmsnet.py +338 -0
- braindecode/models/hybrid.py +126 -0
- braindecode/models/ifnet.py +443 -0
- braindecode/models/labram.py +1316 -0
- braindecode/models/luna.py +891 -0
- braindecode/models/medformer.py +760 -0
- braindecode/models/msvtnet.py +377 -0
- braindecode/models/patchedtransformer.py +640 -0
- braindecode/models/reve.py +843 -0
- braindecode/models/sccnet.py +280 -0
- braindecode/models/shallow_fbcsp.py +212 -0
- braindecode/models/signal_jepa.py +1122 -0
- braindecode/models/sinc_shallow.py +339 -0
- braindecode/models/sleep_stager_blanco_2020.py +169 -0
- braindecode/models/sleep_stager_chambon_2018.py +159 -0
- braindecode/models/sparcnet.py +426 -0
- braindecode/models/sstdpn.py +869 -0
- braindecode/models/summary.csv +47 -0
- braindecode/models/syncnet.py +234 -0
- braindecode/models/tcn.py +275 -0
- braindecode/models/tidnet.py +397 -0
- braindecode/models/tsinception.py +295 -0
- braindecode/models/usleep.py +439 -0
- braindecode/models/util.py +369 -0
- braindecode/modules/__init__.py +92 -0
- braindecode/modules/activation.py +86 -0
- braindecode/modules/attention.py +883 -0
- braindecode/modules/blocks.py +160 -0
- braindecode/modules/convolution.py +330 -0
- braindecode/modules/filter.py +654 -0
- braindecode/modules/layers.py +216 -0
- braindecode/modules/linear.py +70 -0
- braindecode/modules/parametrization.py +38 -0
- braindecode/modules/stats.py +87 -0
- braindecode/modules/util.py +85 -0
- braindecode/modules/wrapper.py +90 -0
- braindecode/preprocessing/__init__.py +271 -0
- braindecode/preprocessing/eegprep_preprocess.py +1317 -0
- braindecode/preprocessing/mne_preprocess.py +240 -0
- braindecode/preprocessing/preprocess.py +579 -0
- braindecode/preprocessing/util.py +177 -0
- braindecode/preprocessing/windowers.py +1037 -0
- braindecode/regressor.py +234 -0
- braindecode/samplers/__init__.py +18 -0
- braindecode/samplers/base.py +399 -0
- braindecode/samplers/ssl.py +263 -0
- braindecode/training/__init__.py +23 -0
- braindecode/training/callbacks.py +23 -0
- braindecode/training/losses.py +105 -0
- braindecode/training/scoring.py +477 -0
- braindecode/util.py +419 -0
- braindecode/version.py +1 -0
- braindecode/visualization/__init__.py +8 -0
- braindecode/visualization/confusion_matrices.py +289 -0
- braindecode/visualization/gradients.py +62 -0
- braindecode-1.3.0.dev177069446.dist-info/METADATA +230 -0
- braindecode-1.3.0.dev177069446.dist-info/RECORD +124 -0
- braindecode-1.3.0.dev177069446.dist-info/WHEEL +5 -0
- braindecode-1.3.0.dev177069446.dist-info/licenses/LICENSE.txt +31 -0
- braindecode-1.3.0.dev177069446.dist-info/licenses/NOTICE.txt +20 -0
- braindecode-1.3.0.dev177069446.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,883 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Attention modules used in the AttentionBaseNet from Martin Wimpff (2023).
|
|
3
|
+
|
|
4
|
+
Here, we implement some popular attention modules that can be used in the
|
|
5
|
+
AttentionBaseNet class.
|
|
6
|
+
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
# Authors: Martin Wimpff <martin.wimpff@iss.uni-stuttgart.de>
|
|
10
|
+
# Bruno Aristimunha <b.aristimunha@gmail.com>
|
|
11
|
+
#
|
|
12
|
+
# License: BSD (3-clause)
|
|
13
|
+
|
|
14
|
+
import math
|
|
15
|
+
from typing import Optional
|
|
16
|
+
|
|
17
|
+
import torch
|
|
18
|
+
import torch.nn.functional as F
|
|
19
|
+
from einops import rearrange
|
|
20
|
+
from einops.layers.torch import Rearrange
|
|
21
|
+
from torch import Tensor, nn
|
|
22
|
+
|
|
23
|
+
from braindecode.functional import _get_gaussian_kernel1d
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class SqueezeAndExcitation(nn.Module):
|
|
27
|
+
"""Squeeze-and-Excitation Networks from [Hu2018]_.
|
|
28
|
+
|
|
29
|
+
Parameters
|
|
30
|
+
----------
|
|
31
|
+
in_channels : int,
|
|
32
|
+
number of input feature channels.
|
|
33
|
+
reduction_rate : int,
|
|
34
|
+
reduction ratio of the fully-connected layers.
|
|
35
|
+
bias: bool, default=False
|
|
36
|
+
if True, adds a learnable bias will be used in the convolution.
|
|
37
|
+
|
|
38
|
+
Examples
|
|
39
|
+
--------
|
|
40
|
+
>>> import torch
|
|
41
|
+
>>> from braindecode.modules import SqueezeAndExcitation
|
|
42
|
+
>>> module = SqueezeAndExcitation(in_channels=16, reduction_rate=4)
|
|
43
|
+
>>> inputs = torch.randn(2, 16, 1, 64)
|
|
44
|
+
>>> outputs = module(inputs)
|
|
45
|
+
>>> outputs.shape
|
|
46
|
+
torch.Size([2, 16, 1, 64])
|
|
47
|
+
|
|
48
|
+
References
|
|
49
|
+
----------
|
|
50
|
+
.. [Hu2018] Hu, J., Albanie, S., Sun, G., Wu, E., 2018.
|
|
51
|
+
Squeeze-and-Excitation Networks. CVPR 2018.
|
|
52
|
+
"""
|
|
53
|
+
|
|
54
|
+
def __init__(self, in_channels: int, reduction_rate: int, bias: bool = False):
|
|
55
|
+
super(SqueezeAndExcitation, self).__init__()
|
|
56
|
+
sq_channels = int(in_channels // reduction_rate)
|
|
57
|
+
self.gap = nn.AdaptiveAvgPool2d(1)
|
|
58
|
+
self.fc1 = nn.Conv2d(
|
|
59
|
+
in_channels=in_channels, out_channels=sq_channels, kernel_size=1, bias=bias
|
|
60
|
+
)
|
|
61
|
+
self.nonlinearity = nn.ReLU()
|
|
62
|
+
self.fc2 = nn.Conv2d(
|
|
63
|
+
in_channels=reduction_rate,
|
|
64
|
+
out_channels=in_channels,
|
|
65
|
+
kernel_size=1,
|
|
66
|
+
bias=bias,
|
|
67
|
+
)
|
|
68
|
+
self.sigmoid = nn.Sigmoid()
|
|
69
|
+
|
|
70
|
+
def forward(self, x):
|
|
71
|
+
"""
|
|
72
|
+
Apply the Squeeze-and-Excitation block to the input tensor.
|
|
73
|
+
|
|
74
|
+
Parameters
|
|
75
|
+
----------
|
|
76
|
+
x: Pytorch.Tensor
|
|
77
|
+
|
|
78
|
+
Returns
|
|
79
|
+
-------
|
|
80
|
+
scale*x: Pytorch.Tensor
|
|
81
|
+
"""
|
|
82
|
+
scale = self.gap(x)
|
|
83
|
+
scale = self.fc1(scale)
|
|
84
|
+
scale = self.nonlinearity(scale)
|
|
85
|
+
scale = self.fc2(scale)
|
|
86
|
+
scale = self.sigmoid(scale)
|
|
87
|
+
return scale * x
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
class GSoP(nn.Module):
|
|
91
|
+
"""
|
|
92
|
+
Global Second-order Pooling Convolutional Networks from [Gao2018]_.
|
|
93
|
+
|
|
94
|
+
Parameters
|
|
95
|
+
----------
|
|
96
|
+
in_channels : int,
|
|
97
|
+
number of input feature channels
|
|
98
|
+
reduction_rate : int,
|
|
99
|
+
reduction ratio of the fully-connected layers
|
|
100
|
+
bias: bool, default=False
|
|
101
|
+
if True, adds a learnable bias will be used in the convolution.
|
|
102
|
+
|
|
103
|
+
Examples
|
|
104
|
+
--------
|
|
105
|
+
>>> import torch
|
|
106
|
+
>>> from braindecode.modules import GSoP
|
|
107
|
+
>>> module = GSoP(in_channels=16, reduction_rate=4)
|
|
108
|
+
>>> inputs = torch.randn(2, 16, 1, 64)
|
|
109
|
+
>>> outputs = module(inputs)
|
|
110
|
+
>>> outputs.shape
|
|
111
|
+
torch.Size([2, 16, 1, 64])
|
|
112
|
+
|
|
113
|
+
References
|
|
114
|
+
----------
|
|
115
|
+
.. [Gao2018] Gao, Z., Jiangtao, X., Wang, Q., Li, P., 2018.
|
|
116
|
+
Global Second-order Pooling Convolutional Networks. CVPR 2018.
|
|
117
|
+
"""
|
|
118
|
+
|
|
119
|
+
def __init__(self, in_channels: int, reduction_rate: int, bias: bool = True):
|
|
120
|
+
super(GSoP, self).__init__()
|
|
121
|
+
sq_channels = int(in_channels // reduction_rate)
|
|
122
|
+
self.pw_conv1 = nn.Conv2d(in_channels, sq_channels, 1, bias=bias)
|
|
123
|
+
self.bn = nn.BatchNorm2d(sq_channels)
|
|
124
|
+
self.rw_conv = nn.Conv2d(
|
|
125
|
+
sq_channels,
|
|
126
|
+
sq_channels * 4,
|
|
127
|
+
(sq_channels, 1),
|
|
128
|
+
groups=sq_channels,
|
|
129
|
+
bias=bias,
|
|
130
|
+
)
|
|
131
|
+
self.pw_conv2 = nn.Conv2d(sq_channels * 4, in_channels, 1, bias=bias)
|
|
132
|
+
|
|
133
|
+
def forward(self, x):
|
|
134
|
+
"""
|
|
135
|
+
Apply the Global Second-order Pooling Convolutional Networks block.
|
|
136
|
+
|
|
137
|
+
Parameters
|
|
138
|
+
----------
|
|
139
|
+
x: Pytorch.Tensor
|
|
140
|
+
|
|
141
|
+
Returns
|
|
142
|
+
-------
|
|
143
|
+
Pytorch.Tensor
|
|
144
|
+
"""
|
|
145
|
+
scale = self.pw_conv1(x).squeeze(-2) # b x c x t
|
|
146
|
+
scale_zero_mean = scale - scale.mean(-1, keepdim=True)
|
|
147
|
+
t = scale_zero_mean.shape[-1]
|
|
148
|
+
cov = torch.bmm(scale_zero_mean, scale_zero_mean.transpose(1, 2)) / (t - 1)
|
|
149
|
+
cov = cov.unsqueeze(-1) # b x c x c x 1
|
|
150
|
+
cov = self.bn(cov)
|
|
151
|
+
scale = self.rw_conv(cov) # b x c x 1 x 1
|
|
152
|
+
scale = self.pw_conv2(scale)
|
|
153
|
+
return scale * x
|
|
154
|
+
|
|
155
|
+
|
|
156
|
+
class FCA(nn.Module):
|
|
157
|
+
"""
|
|
158
|
+
Frequency Channel Attention Networks from [Qin2021]_.
|
|
159
|
+
|
|
160
|
+
Parameters
|
|
161
|
+
----------
|
|
162
|
+
in_channels : int
|
|
163
|
+
Number of input feature channels
|
|
164
|
+
seq_len : int
|
|
165
|
+
Sequence length along temporal dimension, default=62
|
|
166
|
+
reduction_rate : int, default=4
|
|
167
|
+
Reduction ratio of the fully-connected layers.
|
|
168
|
+
|
|
169
|
+
Examples
|
|
170
|
+
--------
|
|
171
|
+
>>> import torch
|
|
172
|
+
>>> from braindecode.modules import FCA
|
|
173
|
+
>>> module = FCA(in_channels=16, seq_len=64, reduction_rate=4, freq_idx=0)
|
|
174
|
+
>>> inputs = torch.randn(2, 16, 1, 64)
|
|
175
|
+
>>> outputs = module(inputs)
|
|
176
|
+
>>> outputs.shape
|
|
177
|
+
torch.Size([2, 16, 1, 64])
|
|
178
|
+
|
|
179
|
+
References
|
|
180
|
+
----------
|
|
181
|
+
.. [Qin2021] Qin, Z., Zhang, P., Wu, F., Li, X., 2021.
|
|
182
|
+
FcaNet: Frequency Channel Attention Networks. ICCV 2021.
|
|
183
|
+
"""
|
|
184
|
+
|
|
185
|
+
def __init__(
|
|
186
|
+
self, in_channels, seq_len: int = 62, reduction_rate: int = 4, freq_idx: int = 0
|
|
187
|
+
):
|
|
188
|
+
super(FCA, self).__init__()
|
|
189
|
+
mapper_y = [freq_idx]
|
|
190
|
+
if in_channels % len(mapper_y) != 0:
|
|
191
|
+
raise ValueError("in_channels must be divisible by number of DCT filters")
|
|
192
|
+
|
|
193
|
+
self.weight = nn.Parameter(
|
|
194
|
+
self.get_dct_filter(seq_len, mapper_y, in_channels), requires_grad=False
|
|
195
|
+
)
|
|
196
|
+
self.fc = nn.Sequential(
|
|
197
|
+
nn.Linear(in_channels, in_channels // reduction_rate, bias=False),
|
|
198
|
+
nn.ReLU(inplace=True),
|
|
199
|
+
nn.Linear(in_channels // reduction_rate, in_channels, bias=False),
|
|
200
|
+
nn.Sigmoid(),
|
|
201
|
+
)
|
|
202
|
+
|
|
203
|
+
def forward(self, x):
|
|
204
|
+
"""
|
|
205
|
+
Apply the Frequency Channel Attention Networks block to the input.
|
|
206
|
+
|
|
207
|
+
Parameters
|
|
208
|
+
----------
|
|
209
|
+
x: Pytorch.Tensor
|
|
210
|
+
|
|
211
|
+
Returns
|
|
212
|
+
-------
|
|
213
|
+
Pytorch.Tensor
|
|
214
|
+
"""
|
|
215
|
+
scale = x.squeeze(-2) * self.weight
|
|
216
|
+
scale = torch.sum(scale, dim=-1)
|
|
217
|
+
scale = rearrange(self.fc(scale), "b c -> b c 1 1")
|
|
218
|
+
return x * scale.expand_as(x)
|
|
219
|
+
|
|
220
|
+
@staticmethod
|
|
221
|
+
def get_dct_filter(seq_len: int, mapper_y: list, in_channels: int):
|
|
222
|
+
"""
|
|
223
|
+
Util function to get the DCT filter.
|
|
224
|
+
|
|
225
|
+
Parameters
|
|
226
|
+
----------
|
|
227
|
+
seq_len: int
|
|
228
|
+
Size of the sequence
|
|
229
|
+
mapper_y:
|
|
230
|
+
List of frequencies
|
|
231
|
+
in_channels:
|
|
232
|
+
Number of input channels.
|
|
233
|
+
|
|
234
|
+
Returns
|
|
235
|
+
-------
|
|
236
|
+
torch.Tensor
|
|
237
|
+
"""
|
|
238
|
+
dct_filter = torch.zeros(in_channels, seq_len)
|
|
239
|
+
|
|
240
|
+
c_part = in_channels // len(mapper_y)
|
|
241
|
+
|
|
242
|
+
for i, v_y in enumerate(mapper_y):
|
|
243
|
+
for t_y in range(seq_len):
|
|
244
|
+
filter = math.cos(math.pi * v_y * (t_y + 0.5) / seq_len) / math.sqrt(
|
|
245
|
+
seq_len
|
|
246
|
+
)
|
|
247
|
+
filter = filter * math.sqrt(2) if v_y != 0 else filter
|
|
248
|
+
dct_filter[i * c_part : (i + 1) * c_part, t_y] = filter
|
|
249
|
+
return dct_filter
|
|
250
|
+
|
|
251
|
+
|
|
252
|
+
class EncNet(nn.Module):
|
|
253
|
+
"""
|
|
254
|
+
Context Encoding for Semantic Segmentation from [Zhang2018]_.
|
|
255
|
+
|
|
256
|
+
Parameters
|
|
257
|
+
----------
|
|
258
|
+
in_channels : int
|
|
259
|
+
number of input feature channels
|
|
260
|
+
n_codewords : int
|
|
261
|
+
number of codewords
|
|
262
|
+
|
|
263
|
+
Examples
|
|
264
|
+
--------
|
|
265
|
+
>>> import torch
|
|
266
|
+
>>> from braindecode.modules import EncNet
|
|
267
|
+
>>> module = EncNet(in_channels=16, n_codewords=8)
|
|
268
|
+
>>> inputs = torch.randn(2, 16, 1, 64)
|
|
269
|
+
>>> outputs = module(inputs)
|
|
270
|
+
>>> outputs.shape
|
|
271
|
+
torch.Size([2, 16, 1, 64])
|
|
272
|
+
|
|
273
|
+
References
|
|
274
|
+
----------
|
|
275
|
+
.. [Zhang2018] Zhang, H. et al. 2018.
|
|
276
|
+
Context Encoding for Semantic Segmentation. CVPR 2018.
|
|
277
|
+
"""
|
|
278
|
+
|
|
279
|
+
def __init__(self, in_channels: int, n_codewords: int):
|
|
280
|
+
super(EncNet, self).__init__()
|
|
281
|
+
self.n_codewords = n_codewords
|
|
282
|
+
self.codewords = nn.Parameter(torch.empty(n_codewords, in_channels))
|
|
283
|
+
self.smoothing = nn.Parameter(torch.empty(n_codewords))
|
|
284
|
+
std = 1 / ((n_codewords * in_channels) ** (1 / 2))
|
|
285
|
+
nn.init.uniform_(self.codewords.data, -std, std)
|
|
286
|
+
nn.init.uniform_(self.smoothing, -1, 0)
|
|
287
|
+
self.bn = nn.BatchNorm1d(n_codewords)
|
|
288
|
+
self.fc = nn.Linear(in_channels, in_channels)
|
|
289
|
+
|
|
290
|
+
def forward(self, x):
|
|
291
|
+
"""
|
|
292
|
+
Apply attention from the Context Encoding for Semantic Segmentation.
|
|
293
|
+
|
|
294
|
+
Parameters
|
|
295
|
+
----------
|
|
296
|
+
x: Pytorch.Tensor
|
|
297
|
+
|
|
298
|
+
Returns
|
|
299
|
+
-------
|
|
300
|
+
Pytorch.Tensor
|
|
301
|
+
"""
|
|
302
|
+
b, c, _, seq = x.shape
|
|
303
|
+
# b x c x 1 x t -> b x t x k x c
|
|
304
|
+
x_ = rearrange(x, pattern="b c 1 seq -> b seq 1 c")
|
|
305
|
+
x_ = x_.expand(b, seq, self.n_codewords, c)
|
|
306
|
+
cw_ = self.codewords.unsqueeze(0).unsqueeze(0) # 1 x 1 x k x c
|
|
307
|
+
a = self.smoothing.unsqueeze(0).unsqueeze(0) * (x_ - cw_).pow(2).sum(3)
|
|
308
|
+
a = torch.softmax(a, dim=2) # b x t x k
|
|
309
|
+
|
|
310
|
+
# aggregate
|
|
311
|
+
e = (a.unsqueeze(3) * (x_ - cw_)).sum(1) # b x k x c
|
|
312
|
+
e_norm = torch.relu(self.bn(e)).mean(1) # b x c
|
|
313
|
+
|
|
314
|
+
scale = torch.sigmoid(self.fc(e_norm))
|
|
315
|
+
return x * scale.unsqueeze(2).unsqueeze(3)
|
|
316
|
+
|
|
317
|
+
|
|
318
|
+
class ECA(nn.Module):
|
|
319
|
+
"""
|
|
320
|
+
Efficient Channel Attention [Wang2021]_.
|
|
321
|
+
|
|
322
|
+
Parameters
|
|
323
|
+
----------
|
|
324
|
+
in_channels : int
|
|
325
|
+
number of input feature channels
|
|
326
|
+
kernel_size : int
|
|
327
|
+
kernel size of convolutional layer, determines degree of channel
|
|
328
|
+
interaction, must be odd.
|
|
329
|
+
|
|
330
|
+
Examples
|
|
331
|
+
--------
|
|
332
|
+
>>> import torch
|
|
333
|
+
>>> from braindecode.modules import ECA
|
|
334
|
+
>>> module = ECA(in_channels=16, kernel_size=3)
|
|
335
|
+
>>> inputs = torch.randn(2, 16, 1, 64)
|
|
336
|
+
>>> outputs = module(inputs)
|
|
337
|
+
>>> outputs.shape
|
|
338
|
+
torch.Size([2, 16, 1, 64])
|
|
339
|
+
|
|
340
|
+
References
|
|
341
|
+
----------
|
|
342
|
+
.. [Wang2021] Wang, Q. et al., 2021. ECA-Net: Efficient Channel Attention
|
|
343
|
+
for Deep Convolutional Neural Networks. CVPR 2021.
|
|
344
|
+
"""
|
|
345
|
+
|
|
346
|
+
def __init__(self, in_channels: int, kernel_size: int):
|
|
347
|
+
super(ECA, self).__init__()
|
|
348
|
+
self.gap = nn.AdaptiveAvgPool2d(1)
|
|
349
|
+
if kernel_size % 2 != 1:
|
|
350
|
+
raise ValueError("kernel size must be odd for same padding")
|
|
351
|
+
self.conv = nn.Conv1d(
|
|
352
|
+
1, 1, kernel_size=kernel_size, padding=kernel_size // 2, bias=False
|
|
353
|
+
)
|
|
354
|
+
|
|
355
|
+
def forward(self, x):
|
|
356
|
+
"""
|
|
357
|
+
Apply the Efficient Channel Attention block to the input tensor.
|
|
358
|
+
|
|
359
|
+
Parameters
|
|
360
|
+
----------
|
|
361
|
+
x: Pytorch.Tensor
|
|
362
|
+
|
|
363
|
+
Returns
|
|
364
|
+
-------
|
|
365
|
+
Pytorch.Tensor
|
|
366
|
+
"""
|
|
367
|
+
scale = self.gap(x)
|
|
368
|
+
scale = rearrange(scale, "b c 1 1 -> b 1 c")
|
|
369
|
+
scale = self.conv(scale)
|
|
370
|
+
scale = torch.sigmoid(rearrange(scale, "b 1 c -> b c 1 1"))
|
|
371
|
+
return x * scale
|
|
372
|
+
|
|
373
|
+
|
|
374
|
+
class GatherExcite(nn.Module):
|
|
375
|
+
"""
|
|
376
|
+
Gather-Excite Networks from [Hu2018b]_.
|
|
377
|
+
|
|
378
|
+
Parameters
|
|
379
|
+
----------
|
|
380
|
+
in_channels : int
|
|
381
|
+
number of input feature channels
|
|
382
|
+
seq_len : int, default=62
|
|
383
|
+
sequence length along temporal dimension
|
|
384
|
+
extra_params : bool, default=False
|
|
385
|
+
whether to use a convolutional layer as a gather module
|
|
386
|
+
use_mlp : bool, default=False
|
|
387
|
+
whether to use an excite block with fully-connected layers
|
|
388
|
+
reduction_rate : int, default=4
|
|
389
|
+
reduction ratio of the excite block (if used)
|
|
390
|
+
|
|
391
|
+
Examples
|
|
392
|
+
--------
|
|
393
|
+
>>> import torch
|
|
394
|
+
>>> from braindecode.modules import GatherExcite
|
|
395
|
+
>>> module = GatherExcite(in_channels=16, seq_len=64, extra_params=False, use_mlp=True)
|
|
396
|
+
>>> inputs = torch.randn(2, 16, 1, 64)
|
|
397
|
+
>>> outputs = module(inputs)
|
|
398
|
+
>>> outputs.shape
|
|
399
|
+
torch.Size([2, 16, 1, 64])
|
|
400
|
+
|
|
401
|
+
References
|
|
402
|
+
----------
|
|
403
|
+
.. [Hu2018b] Hu, J., Albanie, S., Sun, G., Vedaldi, A., 2018.
|
|
404
|
+
Gather-Excite: Exploiting Feature Context in Convolutional Neural Networks.
|
|
405
|
+
NeurIPS 2018.
|
|
406
|
+
"""
|
|
407
|
+
|
|
408
|
+
def __init__(
|
|
409
|
+
self,
|
|
410
|
+
in_channels: int,
|
|
411
|
+
seq_len: int = 62,
|
|
412
|
+
extra_params: bool = False,
|
|
413
|
+
use_mlp: bool = False,
|
|
414
|
+
reduction_rate: int = 4,
|
|
415
|
+
):
|
|
416
|
+
super(GatherExcite, self).__init__()
|
|
417
|
+
if extra_params:
|
|
418
|
+
self.gather = nn.Sequential(
|
|
419
|
+
nn.Conv2d(
|
|
420
|
+
in_channels,
|
|
421
|
+
in_channels,
|
|
422
|
+
(1, seq_len),
|
|
423
|
+
groups=in_channels,
|
|
424
|
+
bias=False,
|
|
425
|
+
),
|
|
426
|
+
nn.BatchNorm2d(in_channels),
|
|
427
|
+
)
|
|
428
|
+
else:
|
|
429
|
+
self.gather = nn.AdaptiveAvgPool2d(1)
|
|
430
|
+
|
|
431
|
+
if use_mlp:
|
|
432
|
+
self.mlp = nn.Sequential(
|
|
433
|
+
nn.Conv2d(
|
|
434
|
+
in_channels, int(in_channels // reduction_rate), 1, bias=False
|
|
435
|
+
),
|
|
436
|
+
nn.ReLU(),
|
|
437
|
+
nn.Conv2d(
|
|
438
|
+
int(in_channels // reduction_rate), in_channels, 1, bias=False
|
|
439
|
+
),
|
|
440
|
+
)
|
|
441
|
+
else:
|
|
442
|
+
self.mlp = nn.Identity()
|
|
443
|
+
|
|
444
|
+
def forward(self, x):
|
|
445
|
+
"""
|
|
446
|
+
Apply the Gather-Excite Networks block to the input tensor.
|
|
447
|
+
|
|
448
|
+
Parameters
|
|
449
|
+
----------
|
|
450
|
+
x: Pytorch.Tensor
|
|
451
|
+
|
|
452
|
+
Returns
|
|
453
|
+
-------
|
|
454
|
+
Pytorch.Tensor
|
|
455
|
+
"""
|
|
456
|
+
scale = self.gather(x)
|
|
457
|
+
scale = torch.sigmoid(self.mlp(scale))
|
|
458
|
+
return scale * x
|
|
459
|
+
|
|
460
|
+
|
|
461
|
+
class GCT(nn.Module):
|
|
462
|
+
"""
|
|
463
|
+
Gated Channel Transformation from [Yang2020]_.
|
|
464
|
+
|
|
465
|
+
Parameters
|
|
466
|
+
----------
|
|
467
|
+
in_channels : int
|
|
468
|
+
number of input feature channels
|
|
469
|
+
|
|
470
|
+
Examples
|
|
471
|
+
--------
|
|
472
|
+
>>> import torch
|
|
473
|
+
>>> from braindecode.modules import GCT
|
|
474
|
+
>>> module = GCT(in_channels=16)
|
|
475
|
+
>>> inputs = torch.randn(2, 16, 1, 64)
|
|
476
|
+
>>> outputs = module(inputs)
|
|
477
|
+
>>> outputs.shape
|
|
478
|
+
torch.Size([2, 16, 1, 64])
|
|
479
|
+
|
|
480
|
+
References
|
|
481
|
+
----------
|
|
482
|
+
.. [Yang2020] Yang, Z. Linchao, Z., Wu, Y., Yang, Y., 2020.
|
|
483
|
+
Gated Channel Transformation for Visual Recognition. CVPR 2020.
|
|
484
|
+
"""
|
|
485
|
+
|
|
486
|
+
def __init__(self, in_channels: int):
|
|
487
|
+
super(GCT, self).__init__()
|
|
488
|
+
self.alpha = nn.Parameter(torch.ones(1, in_channels, 1, 1))
|
|
489
|
+
self.beta = nn.Parameter(torch.zeros(1, in_channels, 1, 1))
|
|
490
|
+
self.gamma = nn.Parameter(torch.zeros(1, in_channels, 1, 1))
|
|
491
|
+
|
|
492
|
+
def forward(self, x, eps: float = 1e-5):
|
|
493
|
+
"""
|
|
494
|
+
Apply the Gated Channel Transformation block to the input tensor.
|
|
495
|
+
|
|
496
|
+
Parameters
|
|
497
|
+
----------
|
|
498
|
+
x: Pytorch.Tensor
|
|
499
|
+
eps: float, default=1e-5
|
|
500
|
+
|
|
501
|
+
Returns
|
|
502
|
+
-------
|
|
503
|
+
Pytorch.Tensor
|
|
504
|
+
the original tensor x multiplied by the gate.
|
|
505
|
+
"""
|
|
506
|
+
embedding = (x.pow(2).sum((2, 3), keepdim=True) + eps).pow(0.5) * self.alpha
|
|
507
|
+
norm = self.gamma / (embedding.pow(2).mean(dim=1, keepdim=True) + eps).pow(0.5)
|
|
508
|
+
gate = 1.0 + torch.tanh(embedding * norm + self.beta)
|
|
509
|
+
return x * gate
|
|
510
|
+
|
|
511
|
+
|
|
512
|
+
class SRM(nn.Module):
|
|
513
|
+
"""
|
|
514
|
+
Attention module from [Lee2019]_.
|
|
515
|
+
|
|
516
|
+
Parameters
|
|
517
|
+
----------
|
|
518
|
+
in_channels : int
|
|
519
|
+
number of input feature channels
|
|
520
|
+
use_mlp : bool, default=False
|
|
521
|
+
whether to use fully-connected layers instead of a convolutional layer,
|
|
522
|
+
reduction_rate : int, default=4
|
|
523
|
+
reduction ratio of the fully-connected layers (if used),
|
|
524
|
+
|
|
525
|
+
Examples
|
|
526
|
+
--------
|
|
527
|
+
>>> import torch
|
|
528
|
+
>>> from braindecode.modules import SRM
|
|
529
|
+
>>> module = SRM(in_channels=16, use_mlp=False)
|
|
530
|
+
>>> inputs = torch.randn(2, 16, 1, 64)
|
|
531
|
+
>>> outputs = module(inputs)
|
|
532
|
+
>>> outputs.shape
|
|
533
|
+
torch.Size([2, 16, 1, 64])
|
|
534
|
+
|
|
535
|
+
References
|
|
536
|
+
----------
|
|
537
|
+
.. [Lee2019] Lee, H., Kim, H., Nam, H., 2019. SRM: A Style-based
|
|
538
|
+
Recalibration Module for Convolutional Neural Networks. ICCV 2019.
|
|
539
|
+
"""
|
|
540
|
+
|
|
541
|
+
def __init__(
|
|
542
|
+
self,
|
|
543
|
+
in_channels: int,
|
|
544
|
+
use_mlp: bool = False,
|
|
545
|
+
reduction_rate: int = 4,
|
|
546
|
+
bias: bool = False,
|
|
547
|
+
):
|
|
548
|
+
super(SRM, self).__init__()
|
|
549
|
+
self.gap = nn.AdaptiveAvgPool2d(1)
|
|
550
|
+
if use_mlp:
|
|
551
|
+
self.style_integration = nn.Sequential(
|
|
552
|
+
Rearrange("b c n_metrics -> b (c n_metrics)"),
|
|
553
|
+
nn.Linear(
|
|
554
|
+
in_channels * 2, in_channels * 2 // reduction_rate, bias=bias
|
|
555
|
+
),
|
|
556
|
+
nn.ReLU(),
|
|
557
|
+
nn.Linear(in_channels * 2 // reduction_rate, in_channels, bias=bias),
|
|
558
|
+
Rearrange("b c -> b c 1"),
|
|
559
|
+
)
|
|
560
|
+
else:
|
|
561
|
+
self.style_integration = nn.Conv1d(
|
|
562
|
+
in_channels, in_channels, 2, groups=in_channels, bias=bias
|
|
563
|
+
)
|
|
564
|
+
self.bn = nn.BatchNorm1d(in_channels)
|
|
565
|
+
|
|
566
|
+
def forward(self, x):
|
|
567
|
+
"""
|
|
568
|
+
Apply the Style-based Recalibration Module to the input tensor.
|
|
569
|
+
|
|
570
|
+
Parameters
|
|
571
|
+
----------
|
|
572
|
+
x: Pytorch.Tensor
|
|
573
|
+
|
|
574
|
+
Returns
|
|
575
|
+
-------
|
|
576
|
+
Pytorch.Tensor
|
|
577
|
+
"""
|
|
578
|
+
mu = self.gap(x).squeeze(-1) # b x c x 1
|
|
579
|
+
std = x.std(dim=(-2, -1), keepdim=True).squeeze(-1) # b x c x 1
|
|
580
|
+
t = torch.cat([mu, std], dim=2) # b x c x 2
|
|
581
|
+
z = self.style_integration(t) # b x c x 1
|
|
582
|
+
z = self.bn(z)
|
|
583
|
+
scale = nn.functional.sigmoid(z).unsqueeze(-1)
|
|
584
|
+
return scale * x
|
|
585
|
+
|
|
586
|
+
|
|
587
|
+
class CBAM(nn.Module):
|
|
588
|
+
"""
|
|
589
|
+
Convolutional Block Attention Module from [Woo2018]_.
|
|
590
|
+
|
|
591
|
+
Parameters
|
|
592
|
+
----------
|
|
593
|
+
in_channels : int
|
|
594
|
+
number of input feature channels
|
|
595
|
+
reduction_rate : int
|
|
596
|
+
reduction ratio of the fully-connected layers
|
|
597
|
+
kernel_size : int
|
|
598
|
+
kernel size of the convolutional layer
|
|
599
|
+
|
|
600
|
+
Examples
|
|
601
|
+
--------
|
|
602
|
+
>>> import torch
|
|
603
|
+
>>> from braindecode.modules import CBAM
|
|
604
|
+
>>> module = CBAM(in_channels=16, reduction_rate=4, kernel_size=3)
|
|
605
|
+
>>> inputs = torch.randn(2, 16, 1, 64)
|
|
606
|
+
>>> outputs = module(inputs)
|
|
607
|
+
>>> outputs.shape
|
|
608
|
+
torch.Size([2, 16, 1, 64])
|
|
609
|
+
|
|
610
|
+
References
|
|
611
|
+
----------
|
|
612
|
+
.. [Woo2018] Woo, S., Park, J., Lee, J., Kweon, I., 2018.
|
|
613
|
+
CBAM: Convolutional Block Attention Module. ECCV 2018.
|
|
614
|
+
"""
|
|
615
|
+
|
|
616
|
+
def __init__(self, in_channels: int, reduction_rate: int, kernel_size: int):
|
|
617
|
+
super(CBAM, self).__init__()
|
|
618
|
+
self.avg_pool = nn.AdaptiveAvgPool2d(1)
|
|
619
|
+
self.max_pool = nn.AdaptiveMaxPool2d(1)
|
|
620
|
+
self.fc = nn.Sequential(
|
|
621
|
+
nn.Conv2d(in_channels, in_channels // reduction_rate, 1, bias=False),
|
|
622
|
+
nn.ReLU(),
|
|
623
|
+
nn.Conv2d(in_channels // reduction_rate, in_channels, 1, bias=False),
|
|
624
|
+
)
|
|
625
|
+
if kernel_size % 2 != 1:
|
|
626
|
+
raise ValueError("kernel size must be odd for same padding")
|
|
627
|
+
self.conv = nn.Conv2d(2, 1, (1, kernel_size), padding=(0, kernel_size // 2))
|
|
628
|
+
|
|
629
|
+
def forward(self, x):
|
|
630
|
+
"""
|
|
631
|
+
Apply the Convolutional Block Attention Module to the input tensor.
|
|
632
|
+
|
|
633
|
+
Parameters
|
|
634
|
+
----------
|
|
635
|
+
x: Pytorch.Tensor
|
|
636
|
+
|
|
637
|
+
Returns
|
|
638
|
+
-------
|
|
639
|
+
Pytorch.Tensor
|
|
640
|
+
"""
|
|
641
|
+
channel_attention = torch.sigmoid(
|
|
642
|
+
self.fc(self.avg_pool(x)) + self.fc(self.max_pool(x))
|
|
643
|
+
)
|
|
644
|
+
x = x * channel_attention
|
|
645
|
+
spat_input = torch.cat(
|
|
646
|
+
[torch.mean(x, dim=1, keepdim=True), torch.max(x, dim=1, keepdim=True)[0]],
|
|
647
|
+
dim=1,
|
|
648
|
+
)
|
|
649
|
+
spatial_attention = torch.sigmoid(self.conv(spat_input))
|
|
650
|
+
return x * spatial_attention
|
|
651
|
+
|
|
652
|
+
|
|
653
|
+
class CAT(nn.Module):
|
|
654
|
+
"""
|
|
655
|
+
Attention Mechanism from [Wu2023]_.
|
|
656
|
+
|
|
657
|
+
Parameters
|
|
658
|
+
----------
|
|
659
|
+
in_channels : int
|
|
660
|
+
number of input feature channels
|
|
661
|
+
reduction_rate : int
|
|
662
|
+
reduction ratio of the fully-connected layers
|
|
663
|
+
kernel_size : int
|
|
664
|
+
kernel size of the convolutional layer
|
|
665
|
+
bias : bool, default=False
|
|
666
|
+
if True, adds a learnable bias will be used in the convolution,
|
|
667
|
+
|
|
668
|
+
Examples
|
|
669
|
+
--------
|
|
670
|
+
>>> import torch
|
|
671
|
+
>>> from braindecode.modules import CAT
|
|
672
|
+
>>> module = CAT(in_channels=16, reduction_rate=4, kernel_size=3)
|
|
673
|
+
>>> inputs = torch.randn(2, 16, 1, 64)
|
|
674
|
+
>>> outputs = module(inputs)
|
|
675
|
+
>>> outputs.shape
|
|
676
|
+
torch.Size([2, 16, 1, 64])
|
|
677
|
+
|
|
678
|
+
References
|
|
679
|
+
----------
|
|
680
|
+
.. [Wu2023] Wu, Z. et al., 2023
|
|
681
|
+
CAT: Learning to Collaborate Channel and Spatial Attention from
|
|
682
|
+
Multi-Information Fusion. IET Computer Vision 2023.
|
|
683
|
+
"""
|
|
684
|
+
|
|
685
|
+
def __init__(
|
|
686
|
+
self, in_channels: int, reduction_rate: int, kernel_size: int, bias=False
|
|
687
|
+
):
|
|
688
|
+
super(CAT, self).__init__()
|
|
689
|
+
self.gauss_filter = nn.Conv2d(1, 1, (1, 5), padding=(0, 2), bias=False)
|
|
690
|
+
self.gauss_filter.weight = nn.Parameter(
|
|
691
|
+
_get_gaussian_kernel1d(5, 1.0)[None, None, None, :], requires_grad=False
|
|
692
|
+
)
|
|
693
|
+
self.mlp = nn.Sequential(
|
|
694
|
+
nn.Conv2d(in_channels, in_channels // reduction_rate, 1, bias=bias),
|
|
695
|
+
nn.ReLU(),
|
|
696
|
+
nn.Conv2d(in_channels // reduction_rate, in_channels, 1, bias=bias),
|
|
697
|
+
)
|
|
698
|
+
self.conv = nn.Conv2d(
|
|
699
|
+
in_channels,
|
|
700
|
+
in_channels,
|
|
701
|
+
kernel_size=(1, kernel_size),
|
|
702
|
+
padding=(0, kernel_size // 2),
|
|
703
|
+
bias=bias,
|
|
704
|
+
)
|
|
705
|
+
|
|
706
|
+
self.c_alpha = nn.Parameter(torch.zeros(1))
|
|
707
|
+
self.c_beta = nn.Parameter(torch.zeros(1))
|
|
708
|
+
self.c_gamma = nn.Parameter(torch.zeros(1))
|
|
709
|
+
self.s_alpha = nn.Parameter(torch.zeros(1))
|
|
710
|
+
self.s_beta = nn.Parameter(torch.zeros(1))
|
|
711
|
+
self.s_gamma = nn.Parameter(torch.zeros(1))
|
|
712
|
+
self.c_w = nn.Parameter(torch.zeros(1))
|
|
713
|
+
self.s_w = nn.Parameter(torch.zeros(1))
|
|
714
|
+
|
|
715
|
+
def forward(self, x):
|
|
716
|
+
"""
|
|
717
|
+
Apply the CAT block to the input tensor.
|
|
718
|
+
|
|
719
|
+
Parameters
|
|
720
|
+
----------
|
|
721
|
+
x: Pytorch.Tensor
|
|
722
|
+
|
|
723
|
+
Returns
|
|
724
|
+
-------
|
|
725
|
+
Pytorch.Tensor
|
|
726
|
+
"""
|
|
727
|
+
b, c, h, w = x.shape
|
|
728
|
+
x_blurred = self.gauss_filter(x.transpose(1, 2)).transpose(1, 2)
|
|
729
|
+
|
|
730
|
+
c_gap = self.mlp(x.mean(dim=(-2, -1), keepdim=True))
|
|
731
|
+
c_gmp = self.mlp(torch.amax(x_blurred, dim=(-2, -1), keepdim=True))
|
|
732
|
+
pi = torch.softmax(x, dim=-1)
|
|
733
|
+
c_gep = -1 * (pi * torch.log(pi)).sum(dim=(-2, -1), keepdim=True)
|
|
734
|
+
c_gep_min = torch.amin(c_gep, dim=(-3, -2, -1), keepdim=True)
|
|
735
|
+
c_gep_max = torch.amax(c_gep, dim=(-3, -2, -1), keepdim=True)
|
|
736
|
+
c_gep = self.mlp((c_gep - c_gep_min) / (c_gep_max - c_gep_min))
|
|
737
|
+
channel_score = torch.sigmoid(
|
|
738
|
+
c_gap * self.c_alpha + c_gmp * self.c_beta + c_gep * self.c_gamma
|
|
739
|
+
)
|
|
740
|
+
channel_score = channel_score.expand(b, c, h, w)
|
|
741
|
+
|
|
742
|
+
s_gap = x.mean(dim=1, keepdim=True)
|
|
743
|
+
s_gmp = torch.amax(x_blurred, dim=(-2, -1), keepdim=True)
|
|
744
|
+
pi = torch.softmax(x, dim=1)
|
|
745
|
+
s_gep = -1 * (pi * torch.log(pi)).sum(dim=1, keepdim=True)
|
|
746
|
+
s_gep_min = torch.amin(s_gep, dim=(-2, -1), keepdim=True)
|
|
747
|
+
s_gep_max = torch.amax(s_gep, dim=(-2, -1), keepdim=True)
|
|
748
|
+
s_gep = (s_gep - s_gep_min) / (s_gep_max - s_gep_min)
|
|
749
|
+
spatial_score = (
|
|
750
|
+
-s_gap * self.s_alpha + s_gmp * self.s_beta + s_gep * self.s_gamma
|
|
751
|
+
)
|
|
752
|
+
spatial_score = torch.sigmoid(self.conv(spatial_score)).expand(b, c, h, w)
|
|
753
|
+
|
|
754
|
+
c_w = torch.exp(self.c_w) / (torch.exp(self.c_w) + torch.exp(self.s_w))
|
|
755
|
+
s_w = torch.exp(self.s_w) / (torch.exp(self.c_w) + torch.exp(self.s_w))
|
|
756
|
+
|
|
757
|
+
scale = channel_score * c_w + spatial_score * s_w
|
|
758
|
+
return scale * x
|
|
759
|
+
|
|
760
|
+
|
|
761
|
+
class CATLite(nn.Module):
|
|
762
|
+
"""
|
|
763
|
+
Modification of CAT without the convolutional layer from [Wu2023]_.
|
|
764
|
+
|
|
765
|
+
Parameters
|
|
766
|
+
----------
|
|
767
|
+
in_channels : int
|
|
768
|
+
number of input feature channels
|
|
769
|
+
reduction_rate : int
|
|
770
|
+
reduction ratio of the fully-connected layers
|
|
771
|
+
bias : bool, default=True
|
|
772
|
+
if True, adds a learnable bias will be used in the convolution,
|
|
773
|
+
|
|
774
|
+
Examples
|
|
775
|
+
--------
|
|
776
|
+
>>> import torch
|
|
777
|
+
>>> from braindecode.modules import CATLite
|
|
778
|
+
>>> module = CATLite(in_channels=16, reduction_rate=4)
|
|
779
|
+
>>> inputs = torch.randn(2, 16, 1, 64)
|
|
780
|
+
>>> outputs = module(inputs)
|
|
781
|
+
>>> outputs.shape
|
|
782
|
+
torch.Size([2, 16, 1, 64])
|
|
783
|
+
|
|
784
|
+
References
|
|
785
|
+
----------
|
|
786
|
+
.. [Wu2023] Wu, Z. et al., 2023 CAT: Learning to Collaborate Channel and
|
|
787
|
+
Spatial Attention from Multi-Information Fusion. IET Computer Vision 2023.
|
|
788
|
+
"""
|
|
789
|
+
|
|
790
|
+
def __init__(self, in_channels: int, reduction_rate: int, bias: bool = True):
|
|
791
|
+
super(CATLite, self).__init__()
|
|
792
|
+
self.gauss_filter = nn.Conv2d(1, 1, (1, 5), padding=(0, 2), bias=False)
|
|
793
|
+
self.gauss_filter.weight = nn.Parameter(
|
|
794
|
+
_get_gaussian_kernel1d(5, 1.0)[None, None, None, :], requires_grad=False
|
|
795
|
+
)
|
|
796
|
+
self.mlp = nn.Sequential(
|
|
797
|
+
nn.Conv2d(in_channels, int(in_channels // reduction_rate), 1, bias=bias),
|
|
798
|
+
nn.ReLU(),
|
|
799
|
+
nn.Conv2d(int(in_channels // reduction_rate), in_channels, 1, bias=bias),
|
|
800
|
+
)
|
|
801
|
+
|
|
802
|
+
self.c_alpha = nn.Parameter(torch.zeros(1))
|
|
803
|
+
self.c_beta = nn.Parameter(torch.zeros(1))
|
|
804
|
+
self.c_gamma = nn.Parameter(torch.zeros(1))
|
|
805
|
+
|
|
806
|
+
def forward(self, x):
|
|
807
|
+
"""
|
|
808
|
+
Apply the CATLite block to the input tensor.
|
|
809
|
+
|
|
810
|
+
Parameters
|
|
811
|
+
----------
|
|
812
|
+
x: Pytorch.Tensor
|
|
813
|
+
|
|
814
|
+
Returns
|
|
815
|
+
-------
|
|
816
|
+
Pytorch.Tensor
|
|
817
|
+
"""
|
|
818
|
+
b, c, h, w = x.shape
|
|
819
|
+
x_blurred = self.gauss_filter(x.transpose(1, 2)).transpose(1, 2)
|
|
820
|
+
|
|
821
|
+
c_gap = self.mlp(x.mean(dim=(-2, -1), keepdim=True))
|
|
822
|
+
c_gmp = self.mlp(torch.amax(x_blurred, dim=(-2, -1), keepdim=True))
|
|
823
|
+
pi = torch.softmax(x, dim=-1)
|
|
824
|
+
c_gep = -1 * (pi * torch.log(pi)).sum(dim=(-2, -1), keepdim=True)
|
|
825
|
+
c_gep_min = torch.amin(c_gep, dim=(-3, -2, -1), keepdim=True)
|
|
826
|
+
c_gep_max = torch.amax(c_gep, dim=(-3, -2, -1), keepdim=True)
|
|
827
|
+
c_gep = self.mlp((c_gep - c_gep_min) / (c_gep_max - c_gep_min))
|
|
828
|
+
channel_score = torch.sigmoid(
|
|
829
|
+
c_gap * self.c_alpha + c_gmp * self.c_beta + c_gep * self.c_gamma
|
|
830
|
+
)
|
|
831
|
+
channel_score = channel_score.expand(b, c, h, w)
|
|
832
|
+
|
|
833
|
+
return channel_score * x
|
|
834
|
+
|
|
835
|
+
|
|
836
|
+
class MultiHeadAttention(nn.Module):
|
|
837
|
+
"""Multi-head self-attention block.
|
|
838
|
+
|
|
839
|
+
Examples
|
|
840
|
+
--------
|
|
841
|
+
>>> import torch
|
|
842
|
+
>>> from braindecode.modules import MultiHeadAttention
|
|
843
|
+
>>> module = MultiHeadAttention(emb_size=32, num_heads=4, dropout=0.1)
|
|
844
|
+
>>> inputs = torch.randn(2, 10, 32)
|
|
845
|
+
>>> outputs = module(inputs)
|
|
846
|
+
>>> outputs.shape
|
|
847
|
+
torch.Size([2, 10, 32])
|
|
848
|
+
"""
|
|
849
|
+
|
|
850
|
+
def __init__(self, emb_size, num_heads, dropout):
|
|
851
|
+
super().__init__()
|
|
852
|
+
self.emb_size = emb_size
|
|
853
|
+
self.num_heads = num_heads
|
|
854
|
+
self.keys = nn.Linear(emb_size, emb_size)
|
|
855
|
+
self.queries = nn.Linear(emb_size, emb_size)
|
|
856
|
+
self.values = nn.Linear(emb_size, emb_size)
|
|
857
|
+
self.att_drop = nn.Dropout(dropout)
|
|
858
|
+
self.projection = nn.Linear(emb_size, emb_size)
|
|
859
|
+
|
|
860
|
+
self.rearrange_stack = Rearrange(
|
|
861
|
+
"b n (h d) -> b h n d",
|
|
862
|
+
h=num_heads,
|
|
863
|
+
)
|
|
864
|
+
self.rearrange_unstack = Rearrange(
|
|
865
|
+
"b h n d -> b n (h d)",
|
|
866
|
+
)
|
|
867
|
+
|
|
868
|
+
def forward(self, x: Tensor, mask: Optional[Tensor] = None) -> Tensor:
|
|
869
|
+
queries = self.rearrange_stack(self.queries(x))
|
|
870
|
+
keys = self.rearrange_stack(self.keys(x))
|
|
871
|
+
values = self.rearrange_stack(self.values(x))
|
|
872
|
+
energy = torch.einsum("bhqd, bhkd -> bhqk", queries, keys)
|
|
873
|
+
if mask is not None:
|
|
874
|
+
fill_value = float("-inf")
|
|
875
|
+
energy = energy.masked_fill(~mask, fill_value)
|
|
876
|
+
|
|
877
|
+
scaling = self.emb_size ** (1 / 2)
|
|
878
|
+
att = F.softmax(energy / scaling, dim=-1)
|
|
879
|
+
att = self.att_drop(att)
|
|
880
|
+
out = torch.einsum("bhal, bhlv -> bhav ", att, values)
|
|
881
|
+
out = self.rearrange_unstack(out)
|
|
882
|
+
out = self.projection(out)
|
|
883
|
+
return out
|