braindecode 0.8__py3-none-any.whl → 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of braindecode might be problematic. Click here for more details.

Files changed (102) hide show
  1. braindecode/__init__.py +1 -2
  2. braindecode/augmentation/__init__.py +50 -0
  3. braindecode/augmentation/base.py +222 -0
  4. braindecode/augmentation/functional.py +1096 -0
  5. braindecode/augmentation/transforms.py +1274 -0
  6. braindecode/classifier.py +26 -24
  7. braindecode/datasets/__init__.py +34 -0
  8. braindecode/datasets/base.py +840 -0
  9. braindecode/datasets/bbci.py +694 -0
  10. braindecode/datasets/bcicomp.py +194 -0
  11. braindecode/datasets/bids.py +245 -0
  12. braindecode/datasets/mne.py +172 -0
  13. braindecode/datasets/moabb.py +209 -0
  14. braindecode/datasets/nmt.py +311 -0
  15. braindecode/datasets/sleep_physio_challe_18.py +412 -0
  16. braindecode/datasets/sleep_physionet.py +125 -0
  17. braindecode/datasets/tuh.py +588 -0
  18. braindecode/datasets/xy.py +95 -0
  19. braindecode/datautil/__init__.py +49 -0
  20. braindecode/datautil/serialization.py +342 -0
  21. braindecode/datautil/util.py +41 -0
  22. braindecode/eegneuralnet.py +63 -47
  23. braindecode/functional/__init__.py +10 -0
  24. braindecode/functional/functions.py +251 -0
  25. braindecode/functional/initialization.py +47 -0
  26. braindecode/models/__init__.py +52 -0
  27. braindecode/models/atcnet.py +652 -0
  28. braindecode/models/attentionbasenet.py +550 -0
  29. braindecode/models/base.py +296 -0
  30. braindecode/models/biot.py +483 -0
  31. braindecode/models/contrawr.py +296 -0
  32. braindecode/models/ctnet.py +450 -0
  33. braindecode/models/deep4.py +322 -0
  34. braindecode/models/deepsleepnet.py +295 -0
  35. braindecode/models/eegconformer.py +372 -0
  36. braindecode/models/eeginception_erp.py +304 -0
  37. braindecode/models/eeginception_mi.py +371 -0
  38. braindecode/models/eegitnet.py +301 -0
  39. braindecode/models/eegminer.py +255 -0
  40. braindecode/models/eegnet.py +473 -0
  41. braindecode/models/eegnex.py +247 -0
  42. braindecode/models/eegresnet.py +362 -0
  43. braindecode/models/eegsimpleconv.py +199 -0
  44. braindecode/models/eegtcnet.py +335 -0
  45. braindecode/models/fbcnet.py +221 -0
  46. braindecode/models/fblightconvnet.py +313 -0
  47. braindecode/models/fbmsnet.py +325 -0
  48. braindecode/models/hybrid.py +126 -0
  49. braindecode/models/ifnet.py +441 -0
  50. braindecode/models/labram.py +1166 -0
  51. braindecode/models/msvtnet.py +375 -0
  52. braindecode/models/sccnet.py +182 -0
  53. braindecode/models/shallow_fbcsp.py +208 -0
  54. braindecode/models/signal_jepa.py +1012 -0
  55. braindecode/models/sinc_shallow.py +337 -0
  56. braindecode/models/sleep_stager_blanco_2020.py +167 -0
  57. braindecode/models/sleep_stager_chambon_2018.py +157 -0
  58. braindecode/models/sleep_stager_eldele_2021.py +536 -0
  59. braindecode/models/sparcnet.py +378 -0
  60. braindecode/models/summary.csv +41 -0
  61. braindecode/models/syncnet.py +232 -0
  62. braindecode/models/tcn.py +273 -0
  63. braindecode/models/tidnet.py +395 -0
  64. braindecode/models/tsinception.py +258 -0
  65. braindecode/models/usleep.py +340 -0
  66. braindecode/models/util.py +133 -0
  67. braindecode/modules/__init__.py +38 -0
  68. braindecode/modules/activation.py +60 -0
  69. braindecode/modules/attention.py +757 -0
  70. braindecode/modules/blocks.py +108 -0
  71. braindecode/modules/convolution.py +274 -0
  72. braindecode/modules/filter.py +632 -0
  73. braindecode/modules/layers.py +133 -0
  74. braindecode/modules/linear.py +50 -0
  75. braindecode/modules/parametrization.py +38 -0
  76. braindecode/modules/stats.py +77 -0
  77. braindecode/modules/util.py +77 -0
  78. braindecode/modules/wrapper.py +75 -0
  79. braindecode/preprocessing/__init__.py +37 -0
  80. braindecode/preprocessing/mne_preprocess.py +77 -0
  81. braindecode/preprocessing/preprocess.py +478 -0
  82. braindecode/preprocessing/windowers.py +1031 -0
  83. braindecode/regressor.py +23 -12
  84. braindecode/samplers/__init__.py +18 -0
  85. braindecode/samplers/base.py +401 -0
  86. braindecode/samplers/ssl.py +263 -0
  87. braindecode/training/__init__.py +23 -0
  88. braindecode/training/callbacks.py +23 -0
  89. braindecode/training/losses.py +105 -0
  90. braindecode/training/scoring.py +483 -0
  91. braindecode/util.py +55 -59
  92. braindecode/version.py +1 -1
  93. braindecode/visualization/__init__.py +8 -0
  94. braindecode/visualization/confusion_matrices.py +289 -0
  95. braindecode/visualization/gradients.py +57 -0
  96. {braindecode-0.8.dist-info → braindecode-1.0.0.dist-info}/METADATA +39 -55
  97. braindecode-1.0.0.dist-info/RECORD +101 -0
  98. {braindecode-0.8.dist-info → braindecode-1.0.0.dist-info}/WHEEL +1 -1
  99. {braindecode-0.8.dist-info → braindecode-1.0.0.dist-info/licenses}/LICENSE.txt +1 -1
  100. braindecode-1.0.0.dist-info/licenses/NOTICE.txt +20 -0
  101. braindecode-0.8.dist-info/RECORD +0 -11
  102. {braindecode-0.8.dist-info → braindecode-1.0.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,757 @@
1
+ """
2
+ Attention modules used in the AttentionBaseNet from Martin Wimpff (2023).
3
+
4
+ Here, we implement some popular attention modules that can be used in the
5
+ AttentionBaseNet class.
6
+
7
+ """
8
+
9
+ # Authors: Martin Wimpff <martin.wimpff@iss.uni-stuttgart.de>
10
+ # Bruno Aristimunha <b.aristimunha@gmail.com>
11
+ #
12
+ # License: BSD (3-clause)
13
+
14
+ import math
15
+ from typing import Optional
16
+
17
+ import torch
18
+ import torch.nn.functional as F
19
+ from einops import rearrange
20
+ from einops.layers.torch import Rearrange
21
+ from torch import Tensor, nn
22
+
23
+ from braindecode.functional import _get_gaussian_kernel1d
24
+
25
+
26
+ class SqueezeAndExcitation(nn.Module):
27
+ """Squeeze-and-Excitation Networks from [Hu2018]_.
28
+
29
+ Parameters
30
+ ----------
31
+ in_channels : int,
32
+ number of input feature channels.
33
+ reduction_rate : int,
34
+ reduction ratio of the fully-connected layers.
35
+ bias: bool, default=False
36
+ if True, adds a learnable bias will be used in the convolution.
37
+
38
+ References
39
+ ----------
40
+ .. [Hu2018] Hu, J., Albanie, S., Sun, G., Wu, E., 2018.
41
+ Squeeze-and-Excitation Networks. CVPR 2018.
42
+ """
43
+
44
+ def __init__(self, in_channels: int, reduction_rate: int, bias: bool = False):
45
+ super(SqueezeAndExcitation, self).__init__()
46
+ sq_channels = int(in_channels // reduction_rate)
47
+ self.gap = nn.AdaptiveAvgPool2d(1)
48
+ self.fc1 = nn.Conv2d(
49
+ in_channels=in_channels, out_channels=sq_channels, kernel_size=1, bias=bias
50
+ )
51
+ self.nonlinearity = nn.ReLU()
52
+ self.fc2 = nn.Conv2d(
53
+ in_channels=reduction_rate,
54
+ out_channels=in_channels,
55
+ kernel_size=1,
56
+ bias=bias,
57
+ )
58
+ self.sigmoid = nn.Sigmoid()
59
+
60
+ def forward(self, x):
61
+ """
62
+ Apply the Squeeze-and-Excitation block to the input tensor.
63
+
64
+ Parameters
65
+ ----------
66
+ x: Pytorch.Tensor
67
+
68
+ Returns
69
+ -------
70
+ scale*x: Pytorch.Tensor
71
+ """
72
+ scale = self.gap(x)
73
+ scale = self.fc1(scale)
74
+ scale = self.nonlinearity(scale)
75
+ scale = self.fc2(scale)
76
+ scale = self.sigmoid(scale)
77
+ return scale * x
78
+
79
+
80
+ class GSoP(nn.Module):
81
+ """
82
+ Global Second-order Pooling Convolutional Networks from [Gao2018]_.
83
+
84
+ Parameters
85
+ ----------
86
+ in_channels : int,
87
+ number of input feature channels
88
+ reduction_rate : int,
89
+ reduction ratio of the fully-connected layers
90
+ bias: bool, default=False
91
+ if True, adds a learnable bias will be used in the convolution.
92
+
93
+ References
94
+ ----------
95
+ .. [Gao2018] Gao, Z., Jiangtao, X., Wang, Q., Li, P., 2018.
96
+ Global Second-order Pooling Convolutional Networks. CVPR 2018.
97
+ """
98
+
99
+ def __init__(self, in_channels: int, reduction_rate: int, bias: bool = True):
100
+ super(GSoP, self).__init__()
101
+ sq_channels = int(in_channels // reduction_rate)
102
+ self.pw_conv1 = nn.Conv2d(in_channels, sq_channels, 1, bias=bias)
103
+ self.bn = nn.BatchNorm2d(sq_channels)
104
+ self.rw_conv = nn.Conv2d(
105
+ sq_channels,
106
+ sq_channels * 4,
107
+ (sq_channels, 1),
108
+ groups=sq_channels,
109
+ bias=bias,
110
+ )
111
+ self.pw_conv2 = nn.Conv2d(sq_channels * 4, in_channels, 1, bias=bias)
112
+
113
+ def forward(self, x):
114
+ """
115
+ Apply the Global Second-order Pooling Convolutional Networks block.
116
+
117
+ Parameters
118
+ ----------
119
+ x: Pytorch.Tensor
120
+
121
+ Returns
122
+ -------
123
+ Pytorch.Tensor
124
+ """
125
+ scale = self.pw_conv1(x).squeeze(-2) # b x c x t
126
+ scale_zero_mean = scale - scale.mean(-1, keepdim=True)
127
+ t = scale_zero_mean.shape[-1]
128
+ cov = torch.bmm(scale_zero_mean, scale_zero_mean.transpose(1, 2)) / (t - 1)
129
+ cov = cov.unsqueeze(-1) # b x c x c x 1
130
+ cov = self.bn(cov)
131
+ scale = self.rw_conv(cov) # b x c x 1 x 1
132
+ scale = self.pw_conv2(scale)
133
+ return scale * x
134
+
135
+
136
+ class FCA(nn.Module):
137
+ """
138
+ Frequency Channel Attention Networks from [Qin2021]_.
139
+
140
+ Parameters
141
+ ----------
142
+ in_channels : int
143
+ Number of input feature channels
144
+ seq_len : int
145
+ Sequence length along temporal dimension, default=62
146
+ reduction_rate : int, default=4
147
+ Reduction ratio of the fully-connected layers.
148
+
149
+ References
150
+ ----------
151
+ .. [Qin2021] Qin, Z., Zhang, P., Wu, F., Li, X., 2021.
152
+ FcaNet: Frequency Channel Attention Networks. ICCV 2021.
153
+ """
154
+
155
+ def __init__(
156
+ self, in_channels, seq_len: int = 62, reduction_rate: int = 4, freq_idx: int = 0
157
+ ):
158
+ super(FCA, self).__init__()
159
+ mapper_y = [freq_idx]
160
+ assert in_channels % len(mapper_y) == 0
161
+
162
+ self.weight = nn.Parameter(
163
+ self.get_dct_filter(seq_len, mapper_y, in_channels), requires_grad=False
164
+ )
165
+ self.fc = nn.Sequential(
166
+ nn.Linear(in_channels, in_channels // reduction_rate, bias=False),
167
+ nn.ReLU(inplace=True),
168
+ nn.Linear(in_channels // reduction_rate, in_channels, bias=False),
169
+ nn.Sigmoid(),
170
+ )
171
+
172
+ def forward(self, x):
173
+ """
174
+ Apply the Frequency Channel Attention Networks block to the input.
175
+
176
+ Parameters
177
+ ----------
178
+ x: Pytorch.Tensor
179
+
180
+ Returns
181
+ -------
182
+ Pytorch.Tensor
183
+ """
184
+ scale = x.squeeze(-2) * self.weight
185
+ scale = torch.sum(scale, dim=-1)
186
+ scale = rearrange(self.fc(scale), "b c -> b c 1 1")
187
+ return x * scale.expand_as(x)
188
+
189
+ @staticmethod
190
+ def get_dct_filter(seq_len: int, mapper_y: list, in_channels: int):
191
+ """
192
+ Util function to get the DCT filter.
193
+
194
+ Parameters
195
+ ----------
196
+ seq_len: int
197
+ Size of the sequence
198
+ mapper_y:
199
+ List of frequencies
200
+ in_channels:
201
+ Number of input channels.
202
+
203
+ Returns
204
+ -------
205
+ torch.Tensor
206
+ """
207
+ dct_filter = torch.zeros(in_channels, seq_len)
208
+
209
+ c_part = in_channels // len(mapper_y)
210
+
211
+ for i, v_y in enumerate(mapper_y):
212
+ for t_y in range(seq_len):
213
+ filter = math.cos(math.pi * v_y * (t_y + 0.5) / seq_len) / math.sqrt(
214
+ seq_len
215
+ )
216
+ filter = filter * math.sqrt(2) if v_y != 0 else filter
217
+ dct_filter[i * c_part : (i + 1) * c_part, t_y] = filter
218
+ return dct_filter
219
+
220
+
221
+ class EncNet(nn.Module):
222
+ """
223
+ Context Encoding for Semantic Segmentation from [Zhang2018]_.
224
+
225
+ Parameters
226
+ ----------
227
+ in_channels : int
228
+ number of input feature channels
229
+ n_codewords : int
230
+ number of codewords
231
+
232
+ References
233
+ ----------
234
+ .. [Zhang2018] Zhang, H. et al. 2018.
235
+ Context Encoding for Semantic Segmentation. CVPR 2018.
236
+ """
237
+
238
+ def __init__(self, in_channels: int, n_codewords: int):
239
+ super(EncNet, self).__init__()
240
+ self.n_codewords = n_codewords
241
+ self.codewords = nn.Parameter(torch.empty(n_codewords, in_channels))
242
+ self.smoothing = nn.Parameter(torch.empty(n_codewords))
243
+ std = 1 / ((n_codewords * in_channels) ** (1 / 2))
244
+ nn.init.uniform_(self.codewords.data, -std, std)
245
+ nn.init.uniform_(self.smoothing, -1, 0)
246
+ self.bn = nn.BatchNorm1d(n_codewords)
247
+ self.fc = nn.Linear(in_channels, in_channels)
248
+
249
+ def forward(self, x):
250
+ """
251
+ Apply attention from the Context Encoding for Semantic Segmentation.
252
+
253
+ Parameters
254
+ ----------
255
+ x: Pytorch.Tensor
256
+
257
+ Returns
258
+ -------
259
+ Pytorch.Tensor
260
+ """
261
+ b, c, _, seq = x.shape
262
+ # b x c x 1 x t -> b x t x k x c
263
+ x_ = rearrange(x, pattern="b c 1 seq -> b seq 1 c")
264
+ x_ = x_.expand(b, seq, self.n_codewords, c)
265
+ cw_ = self.codewords.unsqueeze(0).unsqueeze(0) # 1 x 1 x k x c
266
+ a = self.smoothing.unsqueeze(0).unsqueeze(0) * (x_ - cw_).pow(2).sum(3)
267
+ a = torch.softmax(a, dim=2) # b x t x k
268
+
269
+ # aggregate
270
+ e = (a.unsqueeze(3) * (x_ - cw_)).sum(1) # b x k x c
271
+ e_norm = torch.relu(self.bn(e)).mean(1) # b x c
272
+
273
+ scale = torch.sigmoid(self.fc(e_norm))
274
+ return x * scale.unsqueeze(2).unsqueeze(3)
275
+
276
+
277
+ class ECA(nn.Module):
278
+ """
279
+ Efficient Channel Attention [Wang2021]_.
280
+
281
+ Parameters
282
+ ----------
283
+ in_channels : int
284
+ number of input feature channels
285
+ kernel_size : int
286
+ kernel size of convolutional layer, determines degree of channel
287
+ interaction, must be odd.
288
+
289
+ References
290
+ ----------
291
+ .. [Wang2021] Wang, Q. et al., 2021. ECA-Net: Efficient Channel Attention
292
+ for Deep Convolutional Neural Networks. CVPR 2021.
293
+ """
294
+
295
+ def __init__(self, in_channels: int, kernel_size: int):
296
+ super(ECA, self).__init__()
297
+ self.gap = nn.AdaptiveAvgPool2d(1)
298
+ assert kernel_size % 2 == 1, "kernel size must be odd for same padding"
299
+ self.conv = nn.Conv1d(
300
+ 1, 1, kernel_size=kernel_size, padding=kernel_size // 2, bias=False
301
+ )
302
+
303
+ def forward(self, x):
304
+ """
305
+ Apply the Efficient Channel Attention block to the input tensor.
306
+
307
+ Parameters
308
+ ----------
309
+ x: Pytorch.Tensor
310
+
311
+ Returns
312
+ -------
313
+ Pytorch.Tensor
314
+ """
315
+ scale = self.gap(x)
316
+ scale = rearrange(scale, "b c 1 1 -> b 1 c")
317
+ scale = self.conv(scale)
318
+ scale = torch.sigmoid(rearrange(scale, "b 1 c -> b c 1 1"))
319
+ return x * scale
320
+
321
+
322
+ class GatherExcite(nn.Module):
323
+ """
324
+ Gather-Excite Networks from [Hu2018b]_.
325
+
326
+ Parameters
327
+ ----------
328
+ in_channels : int
329
+ number of input feature channels
330
+ seq_len : int, default=62
331
+ sequence length along temporal dimension
332
+ extra_params : bool, default=False
333
+ whether to use a convolutional layer as a gather module
334
+ use_mlp : bool, default=False
335
+ whether to use an excite block with fully-connected layers
336
+ reduction_rate : int, default=4
337
+ reduction ratio of the excite block (if used)
338
+
339
+ References
340
+ ----------
341
+ .. [Hu2018b] Hu, J., Albanie, S., Sun, G., Vedaldi, A., 2018.
342
+ Gather-Excite: Exploiting Feature Context in Convolutional Neural Networks.
343
+ NeurIPS 2018.
344
+ """
345
+
346
+ def __init__(
347
+ self,
348
+ in_channels: int,
349
+ seq_len: int = 62,
350
+ extra_params: bool = False,
351
+ use_mlp: bool = False,
352
+ reduction_rate: int = 4,
353
+ ):
354
+ super(GatherExcite, self).__init__()
355
+ if extra_params:
356
+ self.gather = nn.Sequential(
357
+ nn.Conv2d(
358
+ in_channels,
359
+ in_channels,
360
+ (1, seq_len),
361
+ groups=in_channels,
362
+ bias=False,
363
+ ),
364
+ nn.BatchNorm2d(in_channels),
365
+ )
366
+ else:
367
+ self.gather = nn.AdaptiveAvgPool2d(1)
368
+
369
+ if use_mlp:
370
+ self.mlp = nn.Sequential(
371
+ nn.Conv2d(
372
+ in_channels, int(in_channels // reduction_rate), 1, bias=False
373
+ ),
374
+ nn.ReLU(),
375
+ nn.Conv2d(
376
+ int(in_channels // reduction_rate), in_channels, 1, bias=False
377
+ ),
378
+ )
379
+ else:
380
+ self.mlp = nn.Identity()
381
+
382
+ def forward(self, x):
383
+ """
384
+ Apply the Gather-Excite Networks block to the input tensor.
385
+
386
+ Parameters
387
+ ----------
388
+ x: Pytorch.Tensor
389
+
390
+ Returns
391
+ -------
392
+ Pytorch.Tensor
393
+ """
394
+ scale = self.gather(x)
395
+ scale = torch.sigmoid(self.mlp(scale))
396
+ return scale * x
397
+
398
+
399
+ class GCT(nn.Module):
400
+ """
401
+ Gated Channel Transformation from [Yang2020]_.
402
+
403
+ Parameters
404
+ ----------
405
+ in_channels : int
406
+ number of input feature channels
407
+
408
+ References
409
+ ----------
410
+ .. [Yang2020] Yang, Z. Linchao, Z., Wu, Y., Yang, Y., 2020.
411
+ Gated Channel Transformation for Visual Recognition. CVPR 2020.
412
+ """
413
+
414
+ def __init__(self, in_channels: int):
415
+ super(GCT, self).__init__()
416
+ self.alpha = nn.Parameter(torch.ones(1, in_channels, 1, 1))
417
+ self.beta = nn.Parameter(torch.zeros(1, in_channels, 1, 1))
418
+ self.gamma = nn.Parameter(torch.zeros(1, in_channels, 1, 1))
419
+
420
+ def forward(self, x, eps: float = 1e-5):
421
+ """
422
+ Apply the Gated Channel Transformation block to the input tensor.
423
+
424
+ Parameters
425
+ ----------
426
+ x: Pytorch.Tensor
427
+ eps: float, default=1e-5
428
+
429
+ Returns
430
+ -------
431
+ Pytorch.Tensor
432
+ the original tensor x multiplied by the gate.
433
+ """
434
+ embedding = (x.pow(2).sum((2, 3), keepdim=True) + eps).pow(0.5) * self.alpha
435
+ norm = self.gamma / (embedding.pow(2).mean(dim=1, keepdim=True) + eps).pow(0.5)
436
+ gate = 1.0 + torch.tanh(embedding * norm + self.beta)
437
+ return x * gate
438
+
439
+
440
+ class SRM(nn.Module):
441
+ """
442
+ Attention module from [Lee2019]_.
443
+
444
+ Parameters
445
+ ----------
446
+ in_channels : int
447
+ number of input feature channels
448
+ use_mlp : bool, default=False
449
+ whether to use fully-connected layers instead of a convolutional layer,
450
+ reduction_rate : int, default=4
451
+ reduction ratio of the fully-connected layers (if used),
452
+
453
+ References
454
+ ----------
455
+ .. [Lee2019] Lee, H., Kim, H., Nam, H., 2019. SRM: A Style-based
456
+ Recalibration Module for Convolutional Neural Networks. ICCV 2019.
457
+ """
458
+
459
+ def __init__(
460
+ self,
461
+ in_channels: int,
462
+ use_mlp: bool = False,
463
+ reduction_rate: int = 4,
464
+ bias: bool = False,
465
+ ):
466
+ super(SRM, self).__init__()
467
+ self.gap = nn.AdaptiveAvgPool2d(1)
468
+ if use_mlp:
469
+ self.style_integration = nn.Sequential(
470
+ Rearrange("b c n_metrics -> b (c n_metrics)"),
471
+ nn.Linear(
472
+ in_channels * 2, in_channels * 2 // reduction_rate, bias=bias
473
+ ),
474
+ nn.ReLU(),
475
+ nn.Linear(in_channels * 2 // reduction_rate, in_channels, bias=bias),
476
+ Rearrange("b c -> b c 1"),
477
+ )
478
+ else:
479
+ self.style_integration = nn.Conv1d(
480
+ in_channels, in_channels, 2, groups=in_channels, bias=bias
481
+ )
482
+ self.bn = nn.BatchNorm1d(in_channels)
483
+
484
+ def forward(self, x):
485
+ """
486
+ Apply the Style-based Recalibration Module to the input tensor.
487
+
488
+ Parameters
489
+ ----------
490
+ x: Pytorch.Tensor
491
+
492
+ Returns
493
+ -------
494
+ Pytorch.Tensor
495
+ """
496
+ mu = self.gap(x).squeeze(-1) # b x c x 1
497
+ std = x.std(dim=(-2, -1), keepdim=True).squeeze(-1) # b x c x 1
498
+ t = torch.cat([mu, std], dim=2) # b x c x 2
499
+ z = self.style_integration(t) # b x c x 1
500
+ z = self.bn(z)
501
+ scale = nn.functional.sigmoid(z).unsqueeze(-1)
502
+ return scale * x
503
+
504
+
505
+ class CBAM(nn.Module):
506
+ """
507
+ Convolutional Block Attention Module from [Woo2018]_.
508
+
509
+ Parameters
510
+ ----------
511
+ in_channels : int
512
+ number of input feature channels
513
+ reduction_rate : int
514
+ reduction ratio of the fully-connected layers
515
+ kernel_size : int
516
+ kernel size of the convolutional layer
517
+
518
+ References
519
+ ----------
520
+ .. [Woo2018] Woo, S., Park, J., Lee, J., Kweon, I., 2018.
521
+ CBAM: Convolutional Block Attention Module. ECCV 2018.
522
+ """
523
+
524
+ def __init__(self, in_channels: int, reduction_rate: int, kernel_size: int):
525
+ super(CBAM, self).__init__()
526
+ self.avg_pool = nn.AdaptiveAvgPool2d(1)
527
+ self.max_pool = nn.AdaptiveMaxPool2d(1)
528
+ self.fc = nn.Sequential(
529
+ nn.Conv2d(in_channels, in_channels // reduction_rate, 1, bias=False),
530
+ nn.ReLU(),
531
+ nn.Conv2d(in_channels // reduction_rate, in_channels, 1, bias=False),
532
+ )
533
+ assert kernel_size % 2 == 1, "kernel size must be odd for same padding"
534
+ self.conv = nn.Conv2d(2, 1, (1, kernel_size), padding=(0, kernel_size // 2))
535
+
536
+ def forward(self, x):
537
+ """
538
+ Apply the Convolutional Block Attention Module to the input tensor.
539
+
540
+ Parameters
541
+ ----------
542
+ x: Pytorch.Tensor
543
+
544
+ Returns
545
+ -------
546
+ Pytorch.Tensor
547
+ """
548
+ channel_attention = torch.sigmoid(
549
+ self.fc(self.avg_pool(x)) + self.fc(self.max_pool(x))
550
+ )
551
+ x = x * channel_attention
552
+ spat_input = torch.cat(
553
+ [torch.mean(x, dim=1, keepdim=True), torch.max(x, dim=1, keepdim=True)[0]],
554
+ dim=1,
555
+ )
556
+ spatial_attention = torch.sigmoid(self.conv(spat_input))
557
+ return x * spatial_attention
558
+
559
+
560
+ class CAT(nn.Module):
561
+ """
562
+ Attention Mechanism from [Wu2023]_.
563
+
564
+ Parameters
565
+ ----------
566
+ in_channels : int
567
+ number of input feature channels
568
+ reduction_rate : int
569
+ reduction ratio of the fully-connected layers
570
+ kernel_size : int
571
+ kernel size of the convolutional layer
572
+ bias : bool, default=False
573
+ if True, adds a learnable bias will be used in the convolution,
574
+
575
+ References
576
+ ----------
577
+ .. [Wu2023] Wu, Z. et al., 2023
578
+ CAT: Learning to Collaborate Channel and Spatial Attention from
579
+ Multi-Information Fusion. IET Computer Vision 2023.
580
+ """
581
+
582
+ def __init__(
583
+ self, in_channels: int, reduction_rate: int, kernel_size: int, bias=False
584
+ ):
585
+ super(CAT, self).__init__()
586
+ self.gauss_filter = nn.Conv2d(1, 1, (1, 5), padding=(0, 2), bias=False)
587
+ self.gauss_filter.weight = nn.Parameter(
588
+ _get_gaussian_kernel1d(5, 1.0)[None, None, None, :], requires_grad=False
589
+ )
590
+ self.mlp = nn.Sequential(
591
+ nn.Conv2d(in_channels, in_channels // reduction_rate, 1, bias=bias),
592
+ nn.ReLU(),
593
+ nn.Conv2d(in_channels // reduction_rate, in_channels, 1, bias=bias),
594
+ )
595
+ self.conv = nn.Conv2d(
596
+ in_channels,
597
+ in_channels,
598
+ kernel_size=(1, kernel_size),
599
+ padding=(0, kernel_size // 2),
600
+ bias=bias,
601
+ )
602
+
603
+ self.c_alpha = nn.Parameter(torch.zeros(1))
604
+ self.c_beta = nn.Parameter(torch.zeros(1))
605
+ self.c_gamma = nn.Parameter(torch.zeros(1))
606
+ self.s_alpha = nn.Parameter(torch.zeros(1))
607
+ self.s_beta = nn.Parameter(torch.zeros(1))
608
+ self.s_gamma = nn.Parameter(torch.zeros(1))
609
+ self.c_w = nn.Parameter(torch.zeros(1))
610
+ self.s_w = nn.Parameter(torch.zeros(1))
611
+
612
+ def forward(self, x):
613
+ """
614
+ Apply the CAT block to the input tensor.
615
+
616
+ Parameters
617
+ ----------
618
+ x: Pytorch.Tensor
619
+
620
+ Returns
621
+ -------
622
+ Pytorch.Tensor
623
+ """
624
+ b, c, h, w = x.shape
625
+ x_blurred = self.gauss_filter(x.transpose(1, 2)).transpose(1, 2)
626
+
627
+ c_gap = self.mlp(x.mean(dim=(-2, -1), keepdim=True))
628
+ c_gmp = self.mlp(torch.amax(x_blurred, dim=(-2, -1), keepdim=True))
629
+ pi = torch.softmax(x, dim=-1)
630
+ c_gep = -1 * (pi * torch.log(pi)).sum(dim=(-2, -1), keepdim=True)
631
+ c_gep_min = torch.amin(c_gep, dim=(-3, -2, -1), keepdim=True)
632
+ c_gep_max = torch.amax(c_gep, dim=(-3, -2, -1), keepdim=True)
633
+ c_gep = self.mlp((c_gep - c_gep_min) / (c_gep_max - c_gep_min))
634
+ channel_score = torch.sigmoid(
635
+ c_gap * self.c_alpha + c_gmp * self.c_beta + c_gep * self.c_gamma
636
+ )
637
+ channel_score = channel_score.expand(b, c, h, w)
638
+
639
+ s_gap = x.mean(dim=1, keepdim=True)
640
+ s_gmp = torch.amax(x_blurred, dim=(-2, -1), keepdim=True)
641
+ pi = torch.softmax(x, dim=1)
642
+ s_gep = -1 * (pi * torch.log(pi)).sum(dim=1, keepdim=True)
643
+ s_gep_min = torch.amin(s_gep, dim=(-2, -1), keepdim=True)
644
+ s_gep_max = torch.amax(s_gep, dim=(-2, -1), keepdim=True)
645
+ s_gep = (s_gep - s_gep_min) / (s_gep_max - s_gep_min)
646
+ spatial_score = (
647
+ -s_gap * self.s_alpha + s_gmp * self.s_beta + s_gep * self.s_gamma
648
+ )
649
+ spatial_score = torch.sigmoid(self.conv(spatial_score)).expand(b, c, h, w)
650
+
651
+ c_w = torch.exp(self.c_w) / (torch.exp(self.c_w) + torch.exp(self.s_w))
652
+ s_w = torch.exp(self.s_w) / (torch.exp(self.c_w) + torch.exp(self.s_w))
653
+
654
+ scale = channel_score * c_w + spatial_score * s_w
655
+ return scale * x
656
+
657
+
658
+ class CATLite(nn.Module):
659
+ """
660
+ Modification of CAT without the convolutional layer from [Wu2023]_.
661
+
662
+ Parameters
663
+ ----------
664
+ in_channels : int
665
+ number of input feature channels
666
+ reduction_rate : int
667
+ reduction ratio of the fully-connected layers
668
+ bias : bool, default=True
669
+ if True, adds a learnable bias will be used in the convolution,
670
+
671
+ References
672
+ ----------
673
+ .. [Wu2023] Wu, Z. et al., 2023 CAT: Learning to Collaborate Channel and
674
+ Spatial Attention from Multi-Information Fusion. IET Computer Vision 2023.
675
+ """
676
+
677
+ def __init__(self, in_channels: int, reduction_rate: int, bias: bool = True):
678
+ super(CATLite, self).__init__()
679
+ self.gauss_filter = nn.Conv2d(1, 1, (1, 5), padding=(0, 2), bias=False)
680
+ self.gauss_filter.weight = nn.Parameter(
681
+ _get_gaussian_kernel1d(5, 1.0)[None, None, None, :], requires_grad=False
682
+ )
683
+ self.mlp = nn.Sequential(
684
+ nn.Conv2d(in_channels, int(in_channels // reduction_rate), 1, bias=bias),
685
+ nn.ReLU(),
686
+ nn.Conv2d(int(in_channels // reduction_rate), in_channels, 1, bias=bias),
687
+ )
688
+
689
+ self.c_alpha = nn.Parameter(torch.zeros(1))
690
+ self.c_beta = nn.Parameter(torch.zeros(1))
691
+ self.c_gamma = nn.Parameter(torch.zeros(1))
692
+
693
+ def forward(self, x):
694
+ """
695
+ Apply the CATLite block to the input tensor.
696
+
697
+ Parameters
698
+ ----------
699
+ x: Pytorch.Tensor
700
+
701
+ Returns
702
+ -------
703
+ Pytorch.Tensor
704
+ """
705
+ b, c, h, w = x.shape
706
+ x_blurred = self.gauss_filter(x.transpose(1, 2)).transpose(1, 2)
707
+
708
+ c_gap = self.mlp(x.mean(dim=(-2, -1), keepdim=True))
709
+ c_gmp = self.mlp(torch.amax(x_blurred, dim=(-2, -1), keepdim=True))
710
+ pi = torch.softmax(x, dim=-1)
711
+ c_gep = -1 * (pi * torch.log(pi)).sum(dim=(-2, -1), keepdim=True)
712
+ c_gep_min = torch.amin(c_gep, dim=(-3, -2, -1), keepdim=True)
713
+ c_gep_max = torch.amax(c_gep, dim=(-3, -2, -1), keepdim=True)
714
+ c_gep = self.mlp((c_gep - c_gep_min) / (c_gep_max - c_gep_min))
715
+ channel_score = torch.sigmoid(
716
+ c_gap * self.c_alpha + c_gmp * self.c_beta + c_gep * self.c_gamma
717
+ )
718
+ channel_score = channel_score.expand(b, c, h, w)
719
+
720
+ return channel_score * x
721
+
722
+
723
+ class MultiHeadAttention(nn.Module):
724
+ def __init__(self, emb_size, num_heads, dropout):
725
+ super().__init__()
726
+ self.emb_size = emb_size
727
+ self.num_heads = num_heads
728
+ self.keys = nn.Linear(emb_size, emb_size)
729
+ self.queries = nn.Linear(emb_size, emb_size)
730
+ self.values = nn.Linear(emb_size, emb_size)
731
+ self.att_drop = nn.Dropout(dropout)
732
+ self.projection = nn.Linear(emb_size, emb_size)
733
+
734
+ self.rearrange_stack = Rearrange(
735
+ "b n (h d) -> b h n d",
736
+ h=num_heads,
737
+ )
738
+ self.rearrange_unstack = Rearrange(
739
+ "b h n d -> b n (h d)",
740
+ )
741
+
742
+ def forward(self, x: Tensor, mask: Optional[Tensor] = None) -> Tensor:
743
+ queries = self.rearrange_stack(self.queries(x))
744
+ keys = self.rearrange_stack(self.keys(x))
745
+ values = self.rearrange_stack(self.values(x))
746
+ energy = torch.einsum("bhqd, bhkd -> bhqk", queries, keys)
747
+ if mask is not None:
748
+ fill_value = float("-inf")
749
+ energy = energy.masked_fill(~mask, fill_value)
750
+
751
+ scaling = self.emb_size ** (1 / 2)
752
+ att = F.softmax(energy / scaling, dim=-1)
753
+ att = self.att_drop(att)
754
+ out = torch.einsum("bhal, bhlv -> bhav ", att, values)
755
+ out = self.rearrange_unstack(out)
756
+ out = self.projection(out)
757
+ return out