braindecode 0.8.1__py3-none-any.whl → 1.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of braindecode might be problematic. Click here for more details.
- braindecode/__init__.py +1 -2
- braindecode/augmentation/__init__.py +39 -19
- braindecode/augmentation/base.py +25 -28
- braindecode/augmentation/functional.py +237 -100
- braindecode/augmentation/transforms.py +325 -158
- braindecode/classifier.py +26 -24
- braindecode/datasets/__init__.py +28 -10
- braindecode/datasets/base.py +220 -134
- braindecode/datasets/bbci.py +43 -52
- braindecode/datasets/bcicomp.py +47 -32
- braindecode/datasets/bids.py +245 -0
- braindecode/datasets/mne.py +45 -24
- braindecode/datasets/moabb.py +87 -27
- braindecode/datasets/nmt.py +311 -0
- braindecode/datasets/sleep_physio_challe_18.py +412 -0
- braindecode/datasets/sleep_physionet.py +43 -26
- braindecode/datasets/tuh.py +324 -140
- braindecode/datasets/xy.py +27 -12
- braindecode/datautil/__init__.py +37 -18
- braindecode/datautil/serialization.py +110 -72
- braindecode/eegneuralnet.py +63 -47
- braindecode/functional/__init__.py +22 -0
- braindecode/functional/functions.py +250 -0
- braindecode/functional/initialization.py +47 -0
- braindecode/models/__init__.py +84 -14
- braindecode/models/atcnet.py +193 -164
- braindecode/models/attentionbasenet.py +599 -0
- braindecode/models/base.py +86 -102
- braindecode/models/biot.py +504 -0
- braindecode/models/contrawr.py +317 -0
- braindecode/models/ctnet.py +536 -0
- braindecode/models/deep4.py +116 -77
- braindecode/models/deepsleepnet.py +149 -119
- braindecode/models/eegconformer.py +112 -173
- braindecode/models/eeginception_erp.py +109 -118
- braindecode/models/eeginception_mi.py +161 -97
- braindecode/models/eegitnet.py +215 -152
- braindecode/models/eegminer.py +254 -0
- braindecode/models/eegnet.py +228 -161
- braindecode/models/eegnex.py +247 -0
- braindecode/models/eegresnet.py +234 -152
- braindecode/models/eegsimpleconv.py +199 -0
- braindecode/models/eegtcnet.py +335 -0
- braindecode/models/fbcnet.py +221 -0
- braindecode/models/fblightconvnet.py +313 -0
- braindecode/models/fbmsnet.py +324 -0
- braindecode/models/hybrid.py +52 -71
- braindecode/models/ifnet.py +441 -0
- braindecode/models/labram.py +1186 -0
- braindecode/models/msvtnet.py +375 -0
- braindecode/models/sccnet.py +207 -0
- braindecode/models/shallow_fbcsp.py +50 -56
- braindecode/models/signal_jepa.py +1011 -0
- braindecode/models/sinc_shallow.py +337 -0
- braindecode/models/sleep_stager_blanco_2020.py +55 -46
- braindecode/models/sleep_stager_chambon_2018.py +54 -53
- braindecode/models/sleep_stager_eldele_2021.py +247 -141
- braindecode/models/sparcnet.py +424 -0
- braindecode/models/summary.csv +41 -0
- braindecode/models/syncnet.py +232 -0
- braindecode/models/tcn.py +158 -88
- braindecode/models/tidnet.py +280 -167
- braindecode/models/tsinception.py +283 -0
- braindecode/models/usleep.py +190 -177
- braindecode/models/util.py +109 -145
- braindecode/modules/__init__.py +84 -0
- braindecode/modules/activation.py +60 -0
- braindecode/modules/attention.py +757 -0
- braindecode/modules/blocks.py +108 -0
- braindecode/modules/convolution.py +274 -0
- braindecode/modules/filter.py +628 -0
- braindecode/modules/layers.py +131 -0
- braindecode/modules/linear.py +49 -0
- braindecode/modules/parametrization.py +38 -0
- braindecode/modules/stats.py +77 -0
- braindecode/modules/util.py +76 -0
- braindecode/modules/wrapper.py +73 -0
- braindecode/preprocessing/__init__.py +36 -11
- braindecode/preprocessing/mne_preprocess.py +13 -7
- braindecode/preprocessing/preprocess.py +139 -75
- braindecode/preprocessing/windowers.py +576 -187
- braindecode/regressor.py +23 -12
- braindecode/samplers/__init__.py +16 -8
- braindecode/samplers/base.py +146 -32
- braindecode/samplers/ssl.py +162 -17
- braindecode/training/__init__.py +18 -10
- braindecode/training/callbacks.py +2 -4
- braindecode/training/losses.py +3 -8
- braindecode/training/scoring.py +76 -68
- braindecode/util.py +55 -59
- braindecode/version.py +1 -1
- braindecode/visualization/__init__.py +2 -3
- braindecode/visualization/confusion_matrices.py +117 -73
- braindecode/visualization/gradients.py +14 -10
- {braindecode-0.8.1.dist-info → braindecode-1.1.0.dist-info}/METADATA +42 -58
- braindecode-1.1.0.dist-info/RECORD +101 -0
- {braindecode-0.8.1.dist-info → braindecode-1.1.0.dist-info}/WHEEL +1 -1
- {braindecode-0.8.1.dist-info → braindecode-1.1.0.dist-info/licenses}/LICENSE.txt +1 -1
- braindecode-1.1.0.dist-info/licenses/NOTICE.txt +20 -0
- braindecode/datautil/mne.py +0 -9
- braindecode/datautil/preprocess.py +0 -12
- braindecode/datautil/windowers.py +0 -6
- braindecode/datautil/xy.py +0 -9
- braindecode/models/eeginception.py +0 -317
- braindecode/models/functions.py +0 -47
- braindecode/models/modules.py +0 -358
- braindecode-0.8.1.dist-info/RECORD +0 -68
- {braindecode-0.8.1.dist-info → braindecode-1.1.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,757 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Attention modules used in the AttentionBaseNet from Martin Wimpff (2023).
|
|
3
|
+
|
|
4
|
+
Here, we implement some popular attention modules that can be used in the
|
|
5
|
+
AttentionBaseNet class.
|
|
6
|
+
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
# Authors: Martin Wimpff <martin.wimpff@iss.uni-stuttgart.de>
|
|
10
|
+
# Bruno Aristimunha <b.aristimunha@gmail.com>
|
|
11
|
+
#
|
|
12
|
+
# License: BSD (3-clause)
|
|
13
|
+
|
|
14
|
+
import math
|
|
15
|
+
from typing import Optional
|
|
16
|
+
|
|
17
|
+
import torch
|
|
18
|
+
import torch.nn.functional as F
|
|
19
|
+
from einops import rearrange
|
|
20
|
+
from einops.layers.torch import Rearrange
|
|
21
|
+
from torch import Tensor, nn
|
|
22
|
+
|
|
23
|
+
from braindecode.functional import _get_gaussian_kernel1d
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class SqueezeAndExcitation(nn.Module):
|
|
27
|
+
"""Squeeze-and-Excitation Networks from [Hu2018]_.
|
|
28
|
+
|
|
29
|
+
Parameters
|
|
30
|
+
----------
|
|
31
|
+
in_channels : int,
|
|
32
|
+
number of input feature channels.
|
|
33
|
+
reduction_rate : int,
|
|
34
|
+
reduction ratio of the fully-connected layers.
|
|
35
|
+
bias: bool, default=False
|
|
36
|
+
if True, adds a learnable bias will be used in the convolution.
|
|
37
|
+
|
|
38
|
+
References
|
|
39
|
+
----------
|
|
40
|
+
.. [Hu2018] Hu, J., Albanie, S., Sun, G., Wu, E., 2018.
|
|
41
|
+
Squeeze-and-Excitation Networks. CVPR 2018.
|
|
42
|
+
"""
|
|
43
|
+
|
|
44
|
+
def __init__(self, in_channels: int, reduction_rate: int, bias: bool = False):
|
|
45
|
+
super(SqueezeAndExcitation, self).__init__()
|
|
46
|
+
sq_channels = int(in_channels // reduction_rate)
|
|
47
|
+
self.gap = nn.AdaptiveAvgPool2d(1)
|
|
48
|
+
self.fc1 = nn.Conv2d(
|
|
49
|
+
in_channels=in_channels, out_channels=sq_channels, kernel_size=1, bias=bias
|
|
50
|
+
)
|
|
51
|
+
self.nonlinearity = nn.ReLU()
|
|
52
|
+
self.fc2 = nn.Conv2d(
|
|
53
|
+
in_channels=reduction_rate,
|
|
54
|
+
out_channels=in_channels,
|
|
55
|
+
kernel_size=1,
|
|
56
|
+
bias=bias,
|
|
57
|
+
)
|
|
58
|
+
self.sigmoid = nn.Sigmoid()
|
|
59
|
+
|
|
60
|
+
def forward(self, x):
|
|
61
|
+
"""
|
|
62
|
+
Apply the Squeeze-and-Excitation block to the input tensor.
|
|
63
|
+
|
|
64
|
+
Parameters
|
|
65
|
+
----------
|
|
66
|
+
x: Pytorch.Tensor
|
|
67
|
+
|
|
68
|
+
Returns
|
|
69
|
+
-------
|
|
70
|
+
scale*x: Pytorch.Tensor
|
|
71
|
+
"""
|
|
72
|
+
scale = self.gap(x)
|
|
73
|
+
scale = self.fc1(scale)
|
|
74
|
+
scale = self.nonlinearity(scale)
|
|
75
|
+
scale = self.fc2(scale)
|
|
76
|
+
scale = self.sigmoid(scale)
|
|
77
|
+
return scale * x
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
class GSoP(nn.Module):
|
|
81
|
+
"""
|
|
82
|
+
Global Second-order Pooling Convolutional Networks from [Gao2018]_.
|
|
83
|
+
|
|
84
|
+
Parameters
|
|
85
|
+
----------
|
|
86
|
+
in_channels : int,
|
|
87
|
+
number of input feature channels
|
|
88
|
+
reduction_rate : int,
|
|
89
|
+
reduction ratio of the fully-connected layers
|
|
90
|
+
bias: bool, default=False
|
|
91
|
+
if True, adds a learnable bias will be used in the convolution.
|
|
92
|
+
|
|
93
|
+
References
|
|
94
|
+
----------
|
|
95
|
+
.. [Gao2018] Gao, Z., Jiangtao, X., Wang, Q., Li, P., 2018.
|
|
96
|
+
Global Second-order Pooling Convolutional Networks. CVPR 2018.
|
|
97
|
+
"""
|
|
98
|
+
|
|
99
|
+
def __init__(self, in_channels: int, reduction_rate: int, bias: bool = True):
|
|
100
|
+
super(GSoP, self).__init__()
|
|
101
|
+
sq_channels = int(in_channels // reduction_rate)
|
|
102
|
+
self.pw_conv1 = nn.Conv2d(in_channels, sq_channels, 1, bias=bias)
|
|
103
|
+
self.bn = nn.BatchNorm2d(sq_channels)
|
|
104
|
+
self.rw_conv = nn.Conv2d(
|
|
105
|
+
sq_channels,
|
|
106
|
+
sq_channels * 4,
|
|
107
|
+
(sq_channels, 1),
|
|
108
|
+
groups=sq_channels,
|
|
109
|
+
bias=bias,
|
|
110
|
+
)
|
|
111
|
+
self.pw_conv2 = nn.Conv2d(sq_channels * 4, in_channels, 1, bias=bias)
|
|
112
|
+
|
|
113
|
+
def forward(self, x):
|
|
114
|
+
"""
|
|
115
|
+
Apply the Global Second-order Pooling Convolutional Networks block.
|
|
116
|
+
|
|
117
|
+
Parameters
|
|
118
|
+
----------
|
|
119
|
+
x: Pytorch.Tensor
|
|
120
|
+
|
|
121
|
+
Returns
|
|
122
|
+
-------
|
|
123
|
+
Pytorch.Tensor
|
|
124
|
+
"""
|
|
125
|
+
scale = self.pw_conv1(x).squeeze(-2) # b x c x t
|
|
126
|
+
scale_zero_mean = scale - scale.mean(-1, keepdim=True)
|
|
127
|
+
t = scale_zero_mean.shape[-1]
|
|
128
|
+
cov = torch.bmm(scale_zero_mean, scale_zero_mean.transpose(1, 2)) / (t - 1)
|
|
129
|
+
cov = cov.unsqueeze(-1) # b x c x c x 1
|
|
130
|
+
cov = self.bn(cov)
|
|
131
|
+
scale = self.rw_conv(cov) # b x c x 1 x 1
|
|
132
|
+
scale = self.pw_conv2(scale)
|
|
133
|
+
return scale * x
|
|
134
|
+
|
|
135
|
+
|
|
136
|
+
class FCA(nn.Module):
|
|
137
|
+
"""
|
|
138
|
+
Frequency Channel Attention Networks from [Qin2021]_.
|
|
139
|
+
|
|
140
|
+
Parameters
|
|
141
|
+
----------
|
|
142
|
+
in_channels : int
|
|
143
|
+
Number of input feature channels
|
|
144
|
+
seq_len : int
|
|
145
|
+
Sequence length along temporal dimension, default=62
|
|
146
|
+
reduction_rate : int, default=4
|
|
147
|
+
Reduction ratio of the fully-connected layers.
|
|
148
|
+
|
|
149
|
+
References
|
|
150
|
+
----------
|
|
151
|
+
.. [Qin2021] Qin, Z., Zhang, P., Wu, F., Li, X., 2021.
|
|
152
|
+
FcaNet: Frequency Channel Attention Networks. ICCV 2021.
|
|
153
|
+
"""
|
|
154
|
+
|
|
155
|
+
def __init__(
|
|
156
|
+
self, in_channels, seq_len: int = 62, reduction_rate: int = 4, freq_idx: int = 0
|
|
157
|
+
):
|
|
158
|
+
super(FCA, self).__init__()
|
|
159
|
+
mapper_y = [freq_idx]
|
|
160
|
+
assert in_channels % len(mapper_y) == 0
|
|
161
|
+
|
|
162
|
+
self.weight = nn.Parameter(
|
|
163
|
+
self.get_dct_filter(seq_len, mapper_y, in_channels), requires_grad=False
|
|
164
|
+
)
|
|
165
|
+
self.fc = nn.Sequential(
|
|
166
|
+
nn.Linear(in_channels, in_channels // reduction_rate, bias=False),
|
|
167
|
+
nn.ReLU(inplace=True),
|
|
168
|
+
nn.Linear(in_channels // reduction_rate, in_channels, bias=False),
|
|
169
|
+
nn.Sigmoid(),
|
|
170
|
+
)
|
|
171
|
+
|
|
172
|
+
def forward(self, x):
|
|
173
|
+
"""
|
|
174
|
+
Apply the Frequency Channel Attention Networks block to the input.
|
|
175
|
+
|
|
176
|
+
Parameters
|
|
177
|
+
----------
|
|
178
|
+
x: Pytorch.Tensor
|
|
179
|
+
|
|
180
|
+
Returns
|
|
181
|
+
-------
|
|
182
|
+
Pytorch.Tensor
|
|
183
|
+
"""
|
|
184
|
+
scale = x.squeeze(-2) * self.weight
|
|
185
|
+
scale = torch.sum(scale, dim=-1)
|
|
186
|
+
scale = rearrange(self.fc(scale), "b c -> b c 1 1")
|
|
187
|
+
return x * scale.expand_as(x)
|
|
188
|
+
|
|
189
|
+
@staticmethod
|
|
190
|
+
def get_dct_filter(seq_len: int, mapper_y: list, in_channels: int):
|
|
191
|
+
"""
|
|
192
|
+
Util function to get the DCT filter.
|
|
193
|
+
|
|
194
|
+
Parameters
|
|
195
|
+
----------
|
|
196
|
+
seq_len: int
|
|
197
|
+
Size of the sequence
|
|
198
|
+
mapper_y:
|
|
199
|
+
List of frequencies
|
|
200
|
+
in_channels:
|
|
201
|
+
Number of input channels.
|
|
202
|
+
|
|
203
|
+
Returns
|
|
204
|
+
-------
|
|
205
|
+
torch.Tensor
|
|
206
|
+
"""
|
|
207
|
+
dct_filter = torch.zeros(in_channels, seq_len)
|
|
208
|
+
|
|
209
|
+
c_part = in_channels // len(mapper_y)
|
|
210
|
+
|
|
211
|
+
for i, v_y in enumerate(mapper_y):
|
|
212
|
+
for t_y in range(seq_len):
|
|
213
|
+
filter = math.cos(math.pi * v_y * (t_y + 0.5) / seq_len) / math.sqrt(
|
|
214
|
+
seq_len
|
|
215
|
+
)
|
|
216
|
+
filter = filter * math.sqrt(2) if v_y != 0 else filter
|
|
217
|
+
dct_filter[i * c_part : (i + 1) * c_part, t_y] = filter
|
|
218
|
+
return dct_filter
|
|
219
|
+
|
|
220
|
+
|
|
221
|
+
class EncNet(nn.Module):
|
|
222
|
+
"""
|
|
223
|
+
Context Encoding for Semantic Segmentation from [Zhang2018]_.
|
|
224
|
+
|
|
225
|
+
Parameters
|
|
226
|
+
----------
|
|
227
|
+
in_channels : int
|
|
228
|
+
number of input feature channels
|
|
229
|
+
n_codewords : int
|
|
230
|
+
number of codewords
|
|
231
|
+
|
|
232
|
+
References
|
|
233
|
+
----------
|
|
234
|
+
.. [Zhang2018] Zhang, H. et al. 2018.
|
|
235
|
+
Context Encoding for Semantic Segmentation. CVPR 2018.
|
|
236
|
+
"""
|
|
237
|
+
|
|
238
|
+
def __init__(self, in_channels: int, n_codewords: int):
|
|
239
|
+
super(EncNet, self).__init__()
|
|
240
|
+
self.n_codewords = n_codewords
|
|
241
|
+
self.codewords = nn.Parameter(torch.empty(n_codewords, in_channels))
|
|
242
|
+
self.smoothing = nn.Parameter(torch.empty(n_codewords))
|
|
243
|
+
std = 1 / ((n_codewords * in_channels) ** (1 / 2))
|
|
244
|
+
nn.init.uniform_(self.codewords.data, -std, std)
|
|
245
|
+
nn.init.uniform_(self.smoothing, -1, 0)
|
|
246
|
+
self.bn = nn.BatchNorm1d(n_codewords)
|
|
247
|
+
self.fc = nn.Linear(in_channels, in_channels)
|
|
248
|
+
|
|
249
|
+
def forward(self, x):
|
|
250
|
+
"""
|
|
251
|
+
Apply attention from the Context Encoding for Semantic Segmentation.
|
|
252
|
+
|
|
253
|
+
Parameters
|
|
254
|
+
----------
|
|
255
|
+
x: Pytorch.Tensor
|
|
256
|
+
|
|
257
|
+
Returns
|
|
258
|
+
-------
|
|
259
|
+
Pytorch.Tensor
|
|
260
|
+
"""
|
|
261
|
+
b, c, _, seq = x.shape
|
|
262
|
+
# b x c x 1 x t -> b x t x k x c
|
|
263
|
+
x_ = rearrange(x, pattern="b c 1 seq -> b seq 1 c")
|
|
264
|
+
x_ = x_.expand(b, seq, self.n_codewords, c)
|
|
265
|
+
cw_ = self.codewords.unsqueeze(0).unsqueeze(0) # 1 x 1 x k x c
|
|
266
|
+
a = self.smoothing.unsqueeze(0).unsqueeze(0) * (x_ - cw_).pow(2).sum(3)
|
|
267
|
+
a = torch.softmax(a, dim=2) # b x t x k
|
|
268
|
+
|
|
269
|
+
# aggregate
|
|
270
|
+
e = (a.unsqueeze(3) * (x_ - cw_)).sum(1) # b x k x c
|
|
271
|
+
e_norm = torch.relu(self.bn(e)).mean(1) # b x c
|
|
272
|
+
|
|
273
|
+
scale = torch.sigmoid(self.fc(e_norm))
|
|
274
|
+
return x * scale.unsqueeze(2).unsqueeze(3)
|
|
275
|
+
|
|
276
|
+
|
|
277
|
+
class ECA(nn.Module):
|
|
278
|
+
"""
|
|
279
|
+
Efficient Channel Attention [Wang2021]_.
|
|
280
|
+
|
|
281
|
+
Parameters
|
|
282
|
+
----------
|
|
283
|
+
in_channels : int
|
|
284
|
+
number of input feature channels
|
|
285
|
+
kernel_size : int
|
|
286
|
+
kernel size of convolutional layer, determines degree of channel
|
|
287
|
+
interaction, must be odd.
|
|
288
|
+
|
|
289
|
+
References
|
|
290
|
+
----------
|
|
291
|
+
.. [Wang2021] Wang, Q. et al., 2021. ECA-Net: Efficient Channel Attention
|
|
292
|
+
for Deep Convolutional Neural Networks. CVPR 2021.
|
|
293
|
+
"""
|
|
294
|
+
|
|
295
|
+
def __init__(self, in_channels: int, kernel_size: int):
|
|
296
|
+
super(ECA, self).__init__()
|
|
297
|
+
self.gap = nn.AdaptiveAvgPool2d(1)
|
|
298
|
+
assert kernel_size % 2 == 1, "kernel size must be odd for same padding"
|
|
299
|
+
self.conv = nn.Conv1d(
|
|
300
|
+
1, 1, kernel_size=kernel_size, padding=kernel_size // 2, bias=False
|
|
301
|
+
)
|
|
302
|
+
|
|
303
|
+
def forward(self, x):
|
|
304
|
+
"""
|
|
305
|
+
Apply the Efficient Channel Attention block to the input tensor.
|
|
306
|
+
|
|
307
|
+
Parameters
|
|
308
|
+
----------
|
|
309
|
+
x: Pytorch.Tensor
|
|
310
|
+
|
|
311
|
+
Returns
|
|
312
|
+
-------
|
|
313
|
+
Pytorch.Tensor
|
|
314
|
+
"""
|
|
315
|
+
scale = self.gap(x)
|
|
316
|
+
scale = rearrange(scale, "b c 1 1 -> b 1 c")
|
|
317
|
+
scale = self.conv(scale)
|
|
318
|
+
scale = torch.sigmoid(rearrange(scale, "b 1 c -> b c 1 1"))
|
|
319
|
+
return x * scale
|
|
320
|
+
|
|
321
|
+
|
|
322
|
+
class GatherExcite(nn.Module):
|
|
323
|
+
"""
|
|
324
|
+
Gather-Excite Networks from [Hu2018b]_.
|
|
325
|
+
|
|
326
|
+
Parameters
|
|
327
|
+
----------
|
|
328
|
+
in_channels : int
|
|
329
|
+
number of input feature channels
|
|
330
|
+
seq_len : int, default=62
|
|
331
|
+
sequence length along temporal dimension
|
|
332
|
+
extra_params : bool, default=False
|
|
333
|
+
whether to use a convolutional layer as a gather module
|
|
334
|
+
use_mlp : bool, default=False
|
|
335
|
+
whether to use an excite block with fully-connected layers
|
|
336
|
+
reduction_rate : int, default=4
|
|
337
|
+
reduction ratio of the excite block (if used)
|
|
338
|
+
|
|
339
|
+
References
|
|
340
|
+
----------
|
|
341
|
+
.. [Hu2018b] Hu, J., Albanie, S., Sun, G., Vedaldi, A., 2018.
|
|
342
|
+
Gather-Excite: Exploiting Feature Context in Convolutional Neural Networks.
|
|
343
|
+
NeurIPS 2018.
|
|
344
|
+
"""
|
|
345
|
+
|
|
346
|
+
def __init__(
|
|
347
|
+
self,
|
|
348
|
+
in_channels: int,
|
|
349
|
+
seq_len: int = 62,
|
|
350
|
+
extra_params: bool = False,
|
|
351
|
+
use_mlp: bool = False,
|
|
352
|
+
reduction_rate: int = 4,
|
|
353
|
+
):
|
|
354
|
+
super(GatherExcite, self).__init__()
|
|
355
|
+
if extra_params:
|
|
356
|
+
self.gather = nn.Sequential(
|
|
357
|
+
nn.Conv2d(
|
|
358
|
+
in_channels,
|
|
359
|
+
in_channels,
|
|
360
|
+
(1, seq_len),
|
|
361
|
+
groups=in_channels,
|
|
362
|
+
bias=False,
|
|
363
|
+
),
|
|
364
|
+
nn.BatchNorm2d(in_channels),
|
|
365
|
+
)
|
|
366
|
+
else:
|
|
367
|
+
self.gather = nn.AdaptiveAvgPool2d(1)
|
|
368
|
+
|
|
369
|
+
if use_mlp:
|
|
370
|
+
self.mlp = nn.Sequential(
|
|
371
|
+
nn.Conv2d(
|
|
372
|
+
in_channels, int(in_channels // reduction_rate), 1, bias=False
|
|
373
|
+
),
|
|
374
|
+
nn.ReLU(),
|
|
375
|
+
nn.Conv2d(
|
|
376
|
+
int(in_channels // reduction_rate), in_channels, 1, bias=False
|
|
377
|
+
),
|
|
378
|
+
)
|
|
379
|
+
else:
|
|
380
|
+
self.mlp = nn.Identity()
|
|
381
|
+
|
|
382
|
+
def forward(self, x):
|
|
383
|
+
"""
|
|
384
|
+
Apply the Gather-Excite Networks block to the input tensor.
|
|
385
|
+
|
|
386
|
+
Parameters
|
|
387
|
+
----------
|
|
388
|
+
x: Pytorch.Tensor
|
|
389
|
+
|
|
390
|
+
Returns
|
|
391
|
+
-------
|
|
392
|
+
Pytorch.Tensor
|
|
393
|
+
"""
|
|
394
|
+
scale = self.gather(x)
|
|
395
|
+
scale = torch.sigmoid(self.mlp(scale))
|
|
396
|
+
return scale * x
|
|
397
|
+
|
|
398
|
+
|
|
399
|
+
class GCT(nn.Module):
|
|
400
|
+
"""
|
|
401
|
+
Gated Channel Transformation from [Yang2020]_.
|
|
402
|
+
|
|
403
|
+
Parameters
|
|
404
|
+
----------
|
|
405
|
+
in_channels : int
|
|
406
|
+
number of input feature channels
|
|
407
|
+
|
|
408
|
+
References
|
|
409
|
+
----------
|
|
410
|
+
.. [Yang2020] Yang, Z. Linchao, Z., Wu, Y., Yang, Y., 2020.
|
|
411
|
+
Gated Channel Transformation for Visual Recognition. CVPR 2020.
|
|
412
|
+
"""
|
|
413
|
+
|
|
414
|
+
def __init__(self, in_channels: int):
|
|
415
|
+
super(GCT, self).__init__()
|
|
416
|
+
self.alpha = nn.Parameter(torch.ones(1, in_channels, 1, 1))
|
|
417
|
+
self.beta = nn.Parameter(torch.zeros(1, in_channels, 1, 1))
|
|
418
|
+
self.gamma = nn.Parameter(torch.zeros(1, in_channels, 1, 1))
|
|
419
|
+
|
|
420
|
+
def forward(self, x, eps: float = 1e-5):
|
|
421
|
+
"""
|
|
422
|
+
Apply the Gated Channel Transformation block to the input tensor.
|
|
423
|
+
|
|
424
|
+
Parameters
|
|
425
|
+
----------
|
|
426
|
+
x: Pytorch.Tensor
|
|
427
|
+
eps: float, default=1e-5
|
|
428
|
+
|
|
429
|
+
Returns
|
|
430
|
+
-------
|
|
431
|
+
Pytorch.Tensor
|
|
432
|
+
the original tensor x multiplied by the gate.
|
|
433
|
+
"""
|
|
434
|
+
embedding = (x.pow(2).sum((2, 3), keepdim=True) + eps).pow(0.5) * self.alpha
|
|
435
|
+
norm = self.gamma / (embedding.pow(2).mean(dim=1, keepdim=True) + eps).pow(0.5)
|
|
436
|
+
gate = 1.0 + torch.tanh(embedding * norm + self.beta)
|
|
437
|
+
return x * gate
|
|
438
|
+
|
|
439
|
+
|
|
440
|
+
class SRM(nn.Module):
|
|
441
|
+
"""
|
|
442
|
+
Attention module from [Lee2019]_.
|
|
443
|
+
|
|
444
|
+
Parameters
|
|
445
|
+
----------
|
|
446
|
+
in_channels : int
|
|
447
|
+
number of input feature channels
|
|
448
|
+
use_mlp : bool, default=False
|
|
449
|
+
whether to use fully-connected layers instead of a convolutional layer,
|
|
450
|
+
reduction_rate : int, default=4
|
|
451
|
+
reduction ratio of the fully-connected layers (if used),
|
|
452
|
+
|
|
453
|
+
References
|
|
454
|
+
----------
|
|
455
|
+
.. [Lee2019] Lee, H., Kim, H., Nam, H., 2019. SRM: A Style-based
|
|
456
|
+
Recalibration Module for Convolutional Neural Networks. ICCV 2019.
|
|
457
|
+
"""
|
|
458
|
+
|
|
459
|
+
def __init__(
|
|
460
|
+
self,
|
|
461
|
+
in_channels: int,
|
|
462
|
+
use_mlp: bool = False,
|
|
463
|
+
reduction_rate: int = 4,
|
|
464
|
+
bias: bool = False,
|
|
465
|
+
):
|
|
466
|
+
super(SRM, self).__init__()
|
|
467
|
+
self.gap = nn.AdaptiveAvgPool2d(1)
|
|
468
|
+
if use_mlp:
|
|
469
|
+
self.style_integration = nn.Sequential(
|
|
470
|
+
Rearrange("b c n_metrics -> b (c n_metrics)"),
|
|
471
|
+
nn.Linear(
|
|
472
|
+
in_channels * 2, in_channels * 2 // reduction_rate, bias=bias
|
|
473
|
+
),
|
|
474
|
+
nn.ReLU(),
|
|
475
|
+
nn.Linear(in_channels * 2 // reduction_rate, in_channels, bias=bias),
|
|
476
|
+
Rearrange("b c -> b c 1"),
|
|
477
|
+
)
|
|
478
|
+
else:
|
|
479
|
+
self.style_integration = nn.Conv1d(
|
|
480
|
+
in_channels, in_channels, 2, groups=in_channels, bias=bias
|
|
481
|
+
)
|
|
482
|
+
self.bn = nn.BatchNorm1d(in_channels)
|
|
483
|
+
|
|
484
|
+
def forward(self, x):
|
|
485
|
+
"""
|
|
486
|
+
Apply the Style-based Recalibration Module to the input tensor.
|
|
487
|
+
|
|
488
|
+
Parameters
|
|
489
|
+
----------
|
|
490
|
+
x: Pytorch.Tensor
|
|
491
|
+
|
|
492
|
+
Returns
|
|
493
|
+
-------
|
|
494
|
+
Pytorch.Tensor
|
|
495
|
+
"""
|
|
496
|
+
mu = self.gap(x).squeeze(-1) # b x c x 1
|
|
497
|
+
std = x.std(dim=(-2, -1), keepdim=True).squeeze(-1) # b x c x 1
|
|
498
|
+
t = torch.cat([mu, std], dim=2) # b x c x 2
|
|
499
|
+
z = self.style_integration(t) # b x c x 1
|
|
500
|
+
z = self.bn(z)
|
|
501
|
+
scale = nn.functional.sigmoid(z).unsqueeze(-1)
|
|
502
|
+
return scale * x
|
|
503
|
+
|
|
504
|
+
|
|
505
|
+
class CBAM(nn.Module):
|
|
506
|
+
"""
|
|
507
|
+
Convolutional Block Attention Module from [Woo2018]_.
|
|
508
|
+
|
|
509
|
+
Parameters
|
|
510
|
+
----------
|
|
511
|
+
in_channels : int
|
|
512
|
+
number of input feature channels
|
|
513
|
+
reduction_rate : int
|
|
514
|
+
reduction ratio of the fully-connected layers
|
|
515
|
+
kernel_size : int
|
|
516
|
+
kernel size of the convolutional layer
|
|
517
|
+
|
|
518
|
+
References
|
|
519
|
+
----------
|
|
520
|
+
.. [Woo2018] Woo, S., Park, J., Lee, J., Kweon, I., 2018.
|
|
521
|
+
CBAM: Convolutional Block Attention Module. ECCV 2018.
|
|
522
|
+
"""
|
|
523
|
+
|
|
524
|
+
def __init__(self, in_channels: int, reduction_rate: int, kernel_size: int):
|
|
525
|
+
super(CBAM, self).__init__()
|
|
526
|
+
self.avg_pool = nn.AdaptiveAvgPool2d(1)
|
|
527
|
+
self.max_pool = nn.AdaptiveMaxPool2d(1)
|
|
528
|
+
self.fc = nn.Sequential(
|
|
529
|
+
nn.Conv2d(in_channels, in_channels // reduction_rate, 1, bias=False),
|
|
530
|
+
nn.ReLU(),
|
|
531
|
+
nn.Conv2d(in_channels // reduction_rate, in_channels, 1, bias=False),
|
|
532
|
+
)
|
|
533
|
+
assert kernel_size % 2 == 1, "kernel size must be odd for same padding"
|
|
534
|
+
self.conv = nn.Conv2d(2, 1, (1, kernel_size), padding=(0, kernel_size // 2))
|
|
535
|
+
|
|
536
|
+
def forward(self, x):
|
|
537
|
+
"""
|
|
538
|
+
Apply the Convolutional Block Attention Module to the input tensor.
|
|
539
|
+
|
|
540
|
+
Parameters
|
|
541
|
+
----------
|
|
542
|
+
x: Pytorch.Tensor
|
|
543
|
+
|
|
544
|
+
Returns
|
|
545
|
+
-------
|
|
546
|
+
Pytorch.Tensor
|
|
547
|
+
"""
|
|
548
|
+
channel_attention = torch.sigmoid(
|
|
549
|
+
self.fc(self.avg_pool(x)) + self.fc(self.max_pool(x))
|
|
550
|
+
)
|
|
551
|
+
x = x * channel_attention
|
|
552
|
+
spat_input = torch.cat(
|
|
553
|
+
[torch.mean(x, dim=1, keepdim=True), torch.max(x, dim=1, keepdim=True)[0]],
|
|
554
|
+
dim=1,
|
|
555
|
+
)
|
|
556
|
+
spatial_attention = torch.sigmoid(self.conv(spat_input))
|
|
557
|
+
return x * spatial_attention
|
|
558
|
+
|
|
559
|
+
|
|
560
|
+
class CAT(nn.Module):
|
|
561
|
+
"""
|
|
562
|
+
Attention Mechanism from [Wu2023]_.
|
|
563
|
+
|
|
564
|
+
Parameters
|
|
565
|
+
----------
|
|
566
|
+
in_channels : int
|
|
567
|
+
number of input feature channels
|
|
568
|
+
reduction_rate : int
|
|
569
|
+
reduction ratio of the fully-connected layers
|
|
570
|
+
kernel_size : int
|
|
571
|
+
kernel size of the convolutional layer
|
|
572
|
+
bias : bool, default=False
|
|
573
|
+
if True, adds a learnable bias will be used in the convolution,
|
|
574
|
+
|
|
575
|
+
References
|
|
576
|
+
----------
|
|
577
|
+
.. [Wu2023] Wu, Z. et al., 2023
|
|
578
|
+
CAT: Learning to Collaborate Channel and Spatial Attention from
|
|
579
|
+
Multi-Information Fusion. IET Computer Vision 2023.
|
|
580
|
+
"""
|
|
581
|
+
|
|
582
|
+
def __init__(
|
|
583
|
+
self, in_channels: int, reduction_rate: int, kernel_size: int, bias=False
|
|
584
|
+
):
|
|
585
|
+
super(CAT, self).__init__()
|
|
586
|
+
self.gauss_filter = nn.Conv2d(1, 1, (1, 5), padding=(0, 2), bias=False)
|
|
587
|
+
self.gauss_filter.weight = nn.Parameter(
|
|
588
|
+
_get_gaussian_kernel1d(5, 1.0)[None, None, None, :], requires_grad=False
|
|
589
|
+
)
|
|
590
|
+
self.mlp = nn.Sequential(
|
|
591
|
+
nn.Conv2d(in_channels, in_channels // reduction_rate, 1, bias=bias),
|
|
592
|
+
nn.ReLU(),
|
|
593
|
+
nn.Conv2d(in_channels // reduction_rate, in_channels, 1, bias=bias),
|
|
594
|
+
)
|
|
595
|
+
self.conv = nn.Conv2d(
|
|
596
|
+
in_channels,
|
|
597
|
+
in_channels,
|
|
598
|
+
kernel_size=(1, kernel_size),
|
|
599
|
+
padding=(0, kernel_size // 2),
|
|
600
|
+
bias=bias,
|
|
601
|
+
)
|
|
602
|
+
|
|
603
|
+
self.c_alpha = nn.Parameter(torch.zeros(1))
|
|
604
|
+
self.c_beta = nn.Parameter(torch.zeros(1))
|
|
605
|
+
self.c_gamma = nn.Parameter(torch.zeros(1))
|
|
606
|
+
self.s_alpha = nn.Parameter(torch.zeros(1))
|
|
607
|
+
self.s_beta = nn.Parameter(torch.zeros(1))
|
|
608
|
+
self.s_gamma = nn.Parameter(torch.zeros(1))
|
|
609
|
+
self.c_w = nn.Parameter(torch.zeros(1))
|
|
610
|
+
self.s_w = nn.Parameter(torch.zeros(1))
|
|
611
|
+
|
|
612
|
+
def forward(self, x):
|
|
613
|
+
"""
|
|
614
|
+
Apply the CAT block to the input tensor.
|
|
615
|
+
|
|
616
|
+
Parameters
|
|
617
|
+
----------
|
|
618
|
+
x: Pytorch.Tensor
|
|
619
|
+
|
|
620
|
+
Returns
|
|
621
|
+
-------
|
|
622
|
+
Pytorch.Tensor
|
|
623
|
+
"""
|
|
624
|
+
b, c, h, w = x.shape
|
|
625
|
+
x_blurred = self.gauss_filter(x.transpose(1, 2)).transpose(1, 2)
|
|
626
|
+
|
|
627
|
+
c_gap = self.mlp(x.mean(dim=(-2, -1), keepdim=True))
|
|
628
|
+
c_gmp = self.mlp(torch.amax(x_blurred, dim=(-2, -1), keepdim=True))
|
|
629
|
+
pi = torch.softmax(x, dim=-1)
|
|
630
|
+
c_gep = -1 * (pi * torch.log(pi)).sum(dim=(-2, -1), keepdim=True)
|
|
631
|
+
c_gep_min = torch.amin(c_gep, dim=(-3, -2, -1), keepdim=True)
|
|
632
|
+
c_gep_max = torch.amax(c_gep, dim=(-3, -2, -1), keepdim=True)
|
|
633
|
+
c_gep = self.mlp((c_gep - c_gep_min) / (c_gep_max - c_gep_min))
|
|
634
|
+
channel_score = torch.sigmoid(
|
|
635
|
+
c_gap * self.c_alpha + c_gmp * self.c_beta + c_gep * self.c_gamma
|
|
636
|
+
)
|
|
637
|
+
channel_score = channel_score.expand(b, c, h, w)
|
|
638
|
+
|
|
639
|
+
s_gap = x.mean(dim=1, keepdim=True)
|
|
640
|
+
s_gmp = torch.amax(x_blurred, dim=(-2, -1), keepdim=True)
|
|
641
|
+
pi = torch.softmax(x, dim=1)
|
|
642
|
+
s_gep = -1 * (pi * torch.log(pi)).sum(dim=1, keepdim=True)
|
|
643
|
+
s_gep_min = torch.amin(s_gep, dim=(-2, -1), keepdim=True)
|
|
644
|
+
s_gep_max = torch.amax(s_gep, dim=(-2, -1), keepdim=True)
|
|
645
|
+
s_gep = (s_gep - s_gep_min) / (s_gep_max - s_gep_min)
|
|
646
|
+
spatial_score = (
|
|
647
|
+
-s_gap * self.s_alpha + s_gmp * self.s_beta + s_gep * self.s_gamma
|
|
648
|
+
)
|
|
649
|
+
spatial_score = torch.sigmoid(self.conv(spatial_score)).expand(b, c, h, w)
|
|
650
|
+
|
|
651
|
+
c_w = torch.exp(self.c_w) / (torch.exp(self.c_w) + torch.exp(self.s_w))
|
|
652
|
+
s_w = torch.exp(self.s_w) / (torch.exp(self.c_w) + torch.exp(self.s_w))
|
|
653
|
+
|
|
654
|
+
scale = channel_score * c_w + spatial_score * s_w
|
|
655
|
+
return scale * x
|
|
656
|
+
|
|
657
|
+
|
|
658
|
+
class CATLite(nn.Module):
|
|
659
|
+
"""
|
|
660
|
+
Modification of CAT without the convolutional layer from [Wu2023]_.
|
|
661
|
+
|
|
662
|
+
Parameters
|
|
663
|
+
----------
|
|
664
|
+
in_channels : int
|
|
665
|
+
number of input feature channels
|
|
666
|
+
reduction_rate : int
|
|
667
|
+
reduction ratio of the fully-connected layers
|
|
668
|
+
bias : bool, default=True
|
|
669
|
+
if True, adds a learnable bias will be used in the convolution,
|
|
670
|
+
|
|
671
|
+
References
|
|
672
|
+
----------
|
|
673
|
+
.. [Wu2023] Wu, Z. et al., 2023 CAT: Learning to Collaborate Channel and
|
|
674
|
+
Spatial Attention from Multi-Information Fusion. IET Computer Vision 2023.
|
|
675
|
+
"""
|
|
676
|
+
|
|
677
|
+
def __init__(self, in_channels: int, reduction_rate: int, bias: bool = True):
|
|
678
|
+
super(CATLite, self).__init__()
|
|
679
|
+
self.gauss_filter = nn.Conv2d(1, 1, (1, 5), padding=(0, 2), bias=False)
|
|
680
|
+
self.gauss_filter.weight = nn.Parameter(
|
|
681
|
+
_get_gaussian_kernel1d(5, 1.0)[None, None, None, :], requires_grad=False
|
|
682
|
+
)
|
|
683
|
+
self.mlp = nn.Sequential(
|
|
684
|
+
nn.Conv2d(in_channels, int(in_channels // reduction_rate), 1, bias=bias),
|
|
685
|
+
nn.ReLU(),
|
|
686
|
+
nn.Conv2d(int(in_channels // reduction_rate), in_channels, 1, bias=bias),
|
|
687
|
+
)
|
|
688
|
+
|
|
689
|
+
self.c_alpha = nn.Parameter(torch.zeros(1))
|
|
690
|
+
self.c_beta = nn.Parameter(torch.zeros(1))
|
|
691
|
+
self.c_gamma = nn.Parameter(torch.zeros(1))
|
|
692
|
+
|
|
693
|
+
def forward(self, x):
|
|
694
|
+
"""
|
|
695
|
+
Apply the CATLite block to the input tensor.
|
|
696
|
+
|
|
697
|
+
Parameters
|
|
698
|
+
----------
|
|
699
|
+
x: Pytorch.Tensor
|
|
700
|
+
|
|
701
|
+
Returns
|
|
702
|
+
-------
|
|
703
|
+
Pytorch.Tensor
|
|
704
|
+
"""
|
|
705
|
+
b, c, h, w = x.shape
|
|
706
|
+
x_blurred = self.gauss_filter(x.transpose(1, 2)).transpose(1, 2)
|
|
707
|
+
|
|
708
|
+
c_gap = self.mlp(x.mean(dim=(-2, -1), keepdim=True))
|
|
709
|
+
c_gmp = self.mlp(torch.amax(x_blurred, dim=(-2, -1), keepdim=True))
|
|
710
|
+
pi = torch.softmax(x, dim=-1)
|
|
711
|
+
c_gep = -1 * (pi * torch.log(pi)).sum(dim=(-2, -1), keepdim=True)
|
|
712
|
+
c_gep_min = torch.amin(c_gep, dim=(-3, -2, -1), keepdim=True)
|
|
713
|
+
c_gep_max = torch.amax(c_gep, dim=(-3, -2, -1), keepdim=True)
|
|
714
|
+
c_gep = self.mlp((c_gep - c_gep_min) / (c_gep_max - c_gep_min))
|
|
715
|
+
channel_score = torch.sigmoid(
|
|
716
|
+
c_gap * self.c_alpha + c_gmp * self.c_beta + c_gep * self.c_gamma
|
|
717
|
+
)
|
|
718
|
+
channel_score = channel_score.expand(b, c, h, w)
|
|
719
|
+
|
|
720
|
+
return channel_score * x
|
|
721
|
+
|
|
722
|
+
|
|
723
|
+
class MultiHeadAttention(nn.Module):
|
|
724
|
+
def __init__(self, emb_size, num_heads, dropout):
|
|
725
|
+
super().__init__()
|
|
726
|
+
self.emb_size = emb_size
|
|
727
|
+
self.num_heads = num_heads
|
|
728
|
+
self.keys = nn.Linear(emb_size, emb_size)
|
|
729
|
+
self.queries = nn.Linear(emb_size, emb_size)
|
|
730
|
+
self.values = nn.Linear(emb_size, emb_size)
|
|
731
|
+
self.att_drop = nn.Dropout(dropout)
|
|
732
|
+
self.projection = nn.Linear(emb_size, emb_size)
|
|
733
|
+
|
|
734
|
+
self.rearrange_stack = Rearrange(
|
|
735
|
+
"b n (h d) -> b h n d",
|
|
736
|
+
h=num_heads,
|
|
737
|
+
)
|
|
738
|
+
self.rearrange_unstack = Rearrange(
|
|
739
|
+
"b h n d -> b n (h d)",
|
|
740
|
+
)
|
|
741
|
+
|
|
742
|
+
def forward(self, x: Tensor, mask: Optional[Tensor] = None) -> Tensor:
|
|
743
|
+
queries = self.rearrange_stack(self.queries(x))
|
|
744
|
+
keys = self.rearrange_stack(self.keys(x))
|
|
745
|
+
values = self.rearrange_stack(self.values(x))
|
|
746
|
+
energy = torch.einsum("bhqd, bhkd -> bhqk", queries, keys)
|
|
747
|
+
if mask is not None:
|
|
748
|
+
fill_value = float("-inf")
|
|
749
|
+
energy = energy.masked_fill(~mask, fill_value)
|
|
750
|
+
|
|
751
|
+
scaling = self.emb_size ** (1 / 2)
|
|
752
|
+
att = F.softmax(energy / scaling, dim=-1)
|
|
753
|
+
att = self.att_drop(att)
|
|
754
|
+
out = torch.einsum("bhal, bhlv -> bhav ", att, values)
|
|
755
|
+
out = self.rearrange_unstack(out)
|
|
756
|
+
out = self.projection(out)
|
|
757
|
+
return out
|