braindecode 0.8__py3-none-any.whl → 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of braindecode might be problematic. Click here for more details.

Files changed (102) hide show
  1. braindecode/__init__.py +1 -2
  2. braindecode/augmentation/__init__.py +50 -0
  3. braindecode/augmentation/base.py +222 -0
  4. braindecode/augmentation/functional.py +1096 -0
  5. braindecode/augmentation/transforms.py +1274 -0
  6. braindecode/classifier.py +26 -24
  7. braindecode/datasets/__init__.py +34 -0
  8. braindecode/datasets/base.py +840 -0
  9. braindecode/datasets/bbci.py +694 -0
  10. braindecode/datasets/bcicomp.py +194 -0
  11. braindecode/datasets/bids.py +245 -0
  12. braindecode/datasets/mne.py +172 -0
  13. braindecode/datasets/moabb.py +209 -0
  14. braindecode/datasets/nmt.py +311 -0
  15. braindecode/datasets/sleep_physio_challe_18.py +412 -0
  16. braindecode/datasets/sleep_physionet.py +125 -0
  17. braindecode/datasets/tuh.py +588 -0
  18. braindecode/datasets/xy.py +95 -0
  19. braindecode/datautil/__init__.py +49 -0
  20. braindecode/datautil/serialization.py +342 -0
  21. braindecode/datautil/util.py +41 -0
  22. braindecode/eegneuralnet.py +63 -47
  23. braindecode/functional/__init__.py +10 -0
  24. braindecode/functional/functions.py +251 -0
  25. braindecode/functional/initialization.py +47 -0
  26. braindecode/models/__init__.py +52 -0
  27. braindecode/models/atcnet.py +652 -0
  28. braindecode/models/attentionbasenet.py +550 -0
  29. braindecode/models/base.py +296 -0
  30. braindecode/models/biot.py +483 -0
  31. braindecode/models/contrawr.py +296 -0
  32. braindecode/models/ctnet.py +450 -0
  33. braindecode/models/deep4.py +322 -0
  34. braindecode/models/deepsleepnet.py +295 -0
  35. braindecode/models/eegconformer.py +372 -0
  36. braindecode/models/eeginception_erp.py +304 -0
  37. braindecode/models/eeginception_mi.py +371 -0
  38. braindecode/models/eegitnet.py +301 -0
  39. braindecode/models/eegminer.py +255 -0
  40. braindecode/models/eegnet.py +473 -0
  41. braindecode/models/eegnex.py +247 -0
  42. braindecode/models/eegresnet.py +362 -0
  43. braindecode/models/eegsimpleconv.py +199 -0
  44. braindecode/models/eegtcnet.py +335 -0
  45. braindecode/models/fbcnet.py +221 -0
  46. braindecode/models/fblightconvnet.py +313 -0
  47. braindecode/models/fbmsnet.py +325 -0
  48. braindecode/models/hybrid.py +126 -0
  49. braindecode/models/ifnet.py +441 -0
  50. braindecode/models/labram.py +1166 -0
  51. braindecode/models/msvtnet.py +375 -0
  52. braindecode/models/sccnet.py +182 -0
  53. braindecode/models/shallow_fbcsp.py +208 -0
  54. braindecode/models/signal_jepa.py +1012 -0
  55. braindecode/models/sinc_shallow.py +337 -0
  56. braindecode/models/sleep_stager_blanco_2020.py +167 -0
  57. braindecode/models/sleep_stager_chambon_2018.py +157 -0
  58. braindecode/models/sleep_stager_eldele_2021.py +536 -0
  59. braindecode/models/sparcnet.py +378 -0
  60. braindecode/models/summary.csv +41 -0
  61. braindecode/models/syncnet.py +232 -0
  62. braindecode/models/tcn.py +273 -0
  63. braindecode/models/tidnet.py +395 -0
  64. braindecode/models/tsinception.py +258 -0
  65. braindecode/models/usleep.py +340 -0
  66. braindecode/models/util.py +133 -0
  67. braindecode/modules/__init__.py +38 -0
  68. braindecode/modules/activation.py +60 -0
  69. braindecode/modules/attention.py +757 -0
  70. braindecode/modules/blocks.py +108 -0
  71. braindecode/modules/convolution.py +274 -0
  72. braindecode/modules/filter.py +632 -0
  73. braindecode/modules/layers.py +133 -0
  74. braindecode/modules/linear.py +50 -0
  75. braindecode/modules/parametrization.py +38 -0
  76. braindecode/modules/stats.py +77 -0
  77. braindecode/modules/util.py +77 -0
  78. braindecode/modules/wrapper.py +75 -0
  79. braindecode/preprocessing/__init__.py +37 -0
  80. braindecode/preprocessing/mne_preprocess.py +77 -0
  81. braindecode/preprocessing/preprocess.py +478 -0
  82. braindecode/preprocessing/windowers.py +1031 -0
  83. braindecode/regressor.py +23 -12
  84. braindecode/samplers/__init__.py +18 -0
  85. braindecode/samplers/base.py +401 -0
  86. braindecode/samplers/ssl.py +263 -0
  87. braindecode/training/__init__.py +23 -0
  88. braindecode/training/callbacks.py +23 -0
  89. braindecode/training/losses.py +105 -0
  90. braindecode/training/scoring.py +483 -0
  91. braindecode/util.py +55 -59
  92. braindecode/version.py +1 -1
  93. braindecode/visualization/__init__.py +8 -0
  94. braindecode/visualization/confusion_matrices.py +289 -0
  95. braindecode/visualization/gradients.py +57 -0
  96. {braindecode-0.8.dist-info → braindecode-1.0.0.dist-info}/METADATA +39 -55
  97. braindecode-1.0.0.dist-info/RECORD +101 -0
  98. {braindecode-0.8.dist-info → braindecode-1.0.0.dist-info}/WHEEL +1 -1
  99. {braindecode-0.8.dist-info → braindecode-1.0.0.dist-info/licenses}/LICENSE.txt +1 -1
  100. braindecode-1.0.0.dist-info/licenses/NOTICE.txt +20 -0
  101. braindecode-0.8.dist-info/RECORD +0 -11
  102. {braindecode-0.8.dist-info → braindecode-1.0.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,550 @@
1
+ from __future__ import annotations
2
+
3
+ import math
4
+
5
+ from einops.layers.torch import Rearrange
6
+ from torch import nn
7
+
8
+ from braindecode.models.base import EEGModuleMixin
9
+ from braindecode.modules import Ensure4d
10
+ from braindecode.modules.attention import (
11
+ CAT,
12
+ CBAM,
13
+ ECA,
14
+ FCA,
15
+ GCT,
16
+ SRM,
17
+ CATLite,
18
+ EncNet,
19
+ GatherExcite,
20
+ GSoP,
21
+ SqueezeAndExcitation,
22
+ )
23
+
24
+
25
+ class AttentionBaseNet(EEGModuleMixin, nn.Module):
26
+ """AttentionBaseNet from Wimpff M et al. (2023) [Martin2023]_.
27
+
28
+ .. figure:: https://content.cld.iop.org/journals/1741-2552/21/3/036020/revision2/jnead48b9f2_hr.jpg
29
+ :align: center
30
+ :alt: Attention Base Net
31
+
32
+ Neural Network from the paper: EEG motor imagery decoding:
33
+ A framework for comparative analysis with channel attention
34
+ mechanisms
35
+
36
+ The paper and original code with more details about the methodological
37
+ choices are available at the [Martin2023]_ and [MartinCode]_.
38
+
39
+ The AttentionBaseNet architecture is composed of four modules:
40
+ - Input Block that performs a temporal convolution and a spatial
41
+ convolution.
42
+ - Channel Expansion that modifies the number of channels.
43
+ - An attention block that performs channel attention with several
44
+ options
45
+ - ClassificationHead
46
+
47
+ .. versionadded:: 0.9
48
+
49
+ Parameters
50
+ ----------
51
+ n_temporal_filters : int, optional
52
+ Number of temporal convolutional filters in the first layer. This defines
53
+ the number of output channels after the temporal convolution.
54
+ Default is 40.
55
+ temp_filter_length : int, default=15
56
+ The length of the temporal filters in the convolutional layers.
57
+ spatial_expansion : int, optional
58
+ Multiplicative factor to expand the spatial dimensions. Used to increase
59
+ the capacity of the model by expanding spatial features. Default is 1.
60
+ pool_length_inp : int, optional
61
+ Length of the pooling window in the input layer. Determines how much
62
+ temporal information is aggregated during pooling. Default is 75.
63
+ pool_stride_inp : int, optional
64
+ Stride of the pooling operation in the input layer. Controls the
65
+ downsampling factor in the temporal dimension. Default is 15.
66
+ drop_prob_inp : float, optional
67
+ Dropout rate applied after the input layer. This is the probability of
68
+ zeroing out elements during training to prevent overfitting.
69
+ Default is 0.5.
70
+ ch_dim : int, optional
71
+ Number of channels in the subsequent convolutional layers. This controls
72
+ the depth of the network after the initial layer. Default is 16.
73
+ attention_mode : str, optional
74
+ The type of attention mechanism to apply. If `None`, no attention is applied.
75
+ - "se" for Squeeze-and-excitation network
76
+ - "gsop" for Global Second-Order Pooling
77
+ - "fca" for Frequency Channel Attention Network
78
+ - "encnet" for context encoding module
79
+ - "eca" for Efficient channel attention for deep convolutional neural networks
80
+ - "ge" for Gather-Excite
81
+ - "gct" for Gated Channel Transformation
82
+ - "srm" for Style-based Recalibration Module
83
+ - "cbam" for Convolutional Block Attention Module
84
+ - "cat" for Learning to collaborate channel and temporal attention
85
+ from multi-information fusion
86
+ - "catlite" for Learning to collaborate channel attention
87
+ from multi-information fusion (lite version, cat w/o temporal attention)
88
+ pool_length : int, default=8
89
+ The length of the window for the average pooling operation.
90
+ pool_stride : int, default=8
91
+ The stride of the average pooling operation.
92
+ drop_prob_attn : float, default=0.5
93
+ The dropout rate for regularization for the attention layer. Values should be between 0 and 1.
94
+ reduction_rate : int, default=4
95
+ The reduction rate used in the attention mechanism to reduce dimensionality
96
+ and computational complexity.
97
+ use_mlp : bool, default=False
98
+ Flag to indicate whether an MLP (Multi-Layer Perceptron) should be used within
99
+ the attention mechanism for further processing.
100
+ freq_idx : int, default=0
101
+ DCT index used in fca attention mechanism.
102
+ n_codewords : int, default=4
103
+ The number of codewords (clusters) used in attention mechanisms that employ
104
+ quantization or clustering strategies.
105
+ kernel_size : int, default=9
106
+ The kernel size used in certain types of attention mechanisms for convolution
107
+ operations.
108
+ activation: nn.Module, default=nn.ELU
109
+ Activation function class to apply. Should be a PyTorch activation
110
+ module class like ``nn.ReLU`` or ``nn.ELU``. Default is ``nn.ELU``.
111
+ extra_params : bool, default=False
112
+ Flag to indicate whether additional, custom parameters should be passed to
113
+ the attention mechanism.
114
+
115
+ References
116
+ ----------
117
+ .. [Martin2023] Wimpff, M., Gizzi, L., Zerfowski, J. and Yang, B., 2023.
118
+ EEG motor imagery decoding: A framework for comparative analysis with
119
+ channel attention mechanisms. arXiv preprint arXiv:2310.11198.
120
+ .. [MartinCode] Wimpff, M., Gizzi, L., Zerfowski, J. and Yang, B.
121
+ GitHub https://github.com/martinwimpff/channel-attention (accessed 2024-03-28)
122
+ """
123
+
124
+ def __init__(
125
+ self,
126
+ n_times=None,
127
+ n_chans=None,
128
+ n_outputs=None,
129
+ chs_info=None,
130
+ sfreq=None,
131
+ input_window_seconds=None,
132
+ # Module parameters
133
+ n_temporal_filters: int = 40,
134
+ temp_filter_length_inp: int = 25,
135
+ spatial_expansion: int = 1,
136
+ pool_length_inp: int = 75,
137
+ pool_stride_inp: int = 15,
138
+ drop_prob_inp: float = 0.5,
139
+ ch_dim: int = 16,
140
+ temp_filter_length: int = 15,
141
+ pool_length: int = 8,
142
+ pool_stride: int = 8,
143
+ drop_prob_attn: float = 0.5,
144
+ attention_mode: str | None = None,
145
+ reduction_rate: int = 4,
146
+ use_mlp: bool = False,
147
+ freq_idx: int = 0,
148
+ n_codewords: int = 4,
149
+ kernel_size: int = 9,
150
+ activation: nn.Module = nn.ELU,
151
+ extra_params: bool = False,
152
+ ):
153
+ super(AttentionBaseNet, self).__init__()
154
+
155
+ super().__init__(
156
+ n_outputs=n_outputs,
157
+ n_chans=n_chans,
158
+ chs_info=chs_info,
159
+ n_times=n_times,
160
+ sfreq=sfreq,
161
+ input_window_seconds=input_window_seconds,
162
+ )
163
+ del n_outputs, n_chans, chs_info, n_times, sfreq, input_window_seconds
164
+
165
+ self.input_block = _FeatureExtractor(
166
+ n_chans=self.n_chans,
167
+ n_temporal_filters=n_temporal_filters,
168
+ temporal_filter_length=temp_filter_length_inp,
169
+ spatial_expansion=spatial_expansion,
170
+ pool_length=pool_length_inp,
171
+ pool_stride=pool_stride_inp,
172
+ drop_prob=drop_prob_inp,
173
+ activation=activation,
174
+ )
175
+
176
+ self.channel_expansion = nn.Sequential(
177
+ nn.Conv2d(
178
+ n_temporal_filters * spatial_expansion, ch_dim, (1, 1), bias=False
179
+ ),
180
+ nn.BatchNorm2d(ch_dim),
181
+ activation(),
182
+ )
183
+
184
+ seq_lengths = self._calculate_sequence_lengths(
185
+ self.n_times,
186
+ [temp_filter_length_inp, temp_filter_length],
187
+ [pool_length_inp, pool_length],
188
+ [pool_stride_inp, pool_stride],
189
+ )
190
+
191
+ self.channel_attention_block = _ChannelAttentionBlock(
192
+ attention_mode=attention_mode,
193
+ in_channels=ch_dim,
194
+ temp_filter_length=temp_filter_length,
195
+ pool_length=pool_length,
196
+ pool_stride=pool_stride,
197
+ drop_prob=drop_prob_attn,
198
+ reduction_rate=reduction_rate,
199
+ use_mlp=use_mlp,
200
+ seq_len=seq_lengths[0],
201
+ freq_idx=freq_idx,
202
+ n_codewords=n_codewords,
203
+ kernel_size=kernel_size,
204
+ extra_params=extra_params,
205
+ activation=activation,
206
+ )
207
+
208
+ self.final_layer = nn.Sequential(
209
+ nn.Flatten(), nn.Linear(seq_lengths[-1] * ch_dim, self.n_outputs)
210
+ )
211
+
212
+ def forward(self, x):
213
+ x = self.input_block(x)
214
+ x = self.channel_expansion(x)
215
+ x = self.channel_attention_block(x)
216
+ x = self.final_layer(x)
217
+ return x
218
+
219
+ @staticmethod
220
+ def _calculate_sequence_lengths(
221
+ input_window_samples: int,
222
+ kernel_lengths: list,
223
+ pool_lengths: list,
224
+ pool_strides: list,
225
+ ):
226
+ seq_lengths = []
227
+ out = input_window_samples
228
+ for k, pl, ps in zip(kernel_lengths, pool_lengths, pool_strides):
229
+ out = math.floor(out + 2 * (k // 2) - k + 1)
230
+ out = math.floor((out - pl) / ps + 1)
231
+ seq_lengths.append(int(out))
232
+ return seq_lengths
233
+
234
+
235
+ class _FeatureExtractor(nn.Module):
236
+ """
237
+ A module for feature extraction of the data with temporal and spatial
238
+ transformations.
239
+
240
+ This module sequentially processes the input through a series of layers:
241
+ rearrangement, temporal convolution, batch normalization, spatial convolution,
242
+ another batch normalization, an ELU non-linearity, average pooling, and dropout.
243
+
244
+
245
+ Parameters
246
+ ----------
247
+ n_chans : int
248
+ The number of channels in the input data.
249
+ n_temporal_filters : int, optional
250
+ The number of filters to use in the temporal convolution layer. Default is 40.
251
+ temporal_filter_length : int, optional
252
+ The size of each filter in the temporal convolution layer. Default is 25.
253
+ spatial_expansion : int, optional
254
+ The expansion factor of the spatial convolution layer, determining the number
255
+ of output channels relative to the number of temporal filters. Default is 1.
256
+ pool_length : int, optional
257
+ The size of the window for the average pooling operation. Default is 75.
258
+ pool_stride : int, optional
259
+ The stride of the average pooling operation. Default is 15.
260
+ drop_prob : float, optional
261
+ The dropout rate for regularization. Default is 0.5.
262
+ activation: nn.Module, default=nn.ELU
263
+ Activation function class to apply. Should be a PyTorch activation
264
+ module class like ``nn.ReLU`` or ``nn.ELU``. Default is ``nn.ELU``.
265
+ """
266
+
267
+ def __init__(
268
+ self,
269
+ n_chans: int,
270
+ n_temporal_filters: int = 40,
271
+ temporal_filter_length: int = 25,
272
+ spatial_expansion: int = 1,
273
+ pool_length: int = 75,
274
+ pool_stride: int = 15,
275
+ drop_prob: float = 0.5,
276
+ activation: nn.Module = nn.ELU,
277
+ ):
278
+ super().__init__()
279
+
280
+ self.ensure4d = Ensure4d()
281
+ self.rearrange_input = Rearrange("b c t 1 -> b 1 c t")
282
+ self.temporal_conv = nn.Conv2d(
283
+ 1,
284
+ n_temporal_filters,
285
+ kernel_size=(1, temporal_filter_length),
286
+ padding=(0, temporal_filter_length // 2),
287
+ bias=False,
288
+ )
289
+ self.intermediate_bn = nn.BatchNorm2d(n_temporal_filters)
290
+ self.spatial_conv = nn.Conv2d(
291
+ n_temporal_filters,
292
+ n_temporal_filters * spatial_expansion,
293
+ kernel_size=(n_chans, 1),
294
+ groups=n_temporal_filters,
295
+ bias=False,
296
+ )
297
+ self.bn = nn.BatchNorm2d(n_temporal_filters * spatial_expansion)
298
+ self.nonlinearity = activation()
299
+ self.pool = nn.AvgPool2d((1, pool_length), stride=(1, pool_stride))
300
+ self.dropout = nn.Dropout(drop_prob)
301
+
302
+ def forward(self, x):
303
+ x = self.ensure4d(x)
304
+ x = self.rearrange_input(x)
305
+ x = self.temporal_conv(x)
306
+ x = self.intermediate_bn(x)
307
+ x = self.spatial_conv(x)
308
+ x = self.bn(x)
309
+ x = self.nonlinearity(x)
310
+ x = self.pool(x)
311
+ x = self.dropout(x)
312
+ return x
313
+
314
+
315
+ class _ChannelAttentionBlock(nn.Module):
316
+ """
317
+ A neural network module implementing channel-wise attention mechanisms to enhance
318
+ feature representations by selectively emphasizing important channels and suppressing
319
+ less useful ones. This block integrates convolutional layers, pooling, dropout, and
320
+ an optional attention mechanism that can be customized based on the given mode.
321
+
322
+ Parameters
323
+ ----------
324
+ attention_mode : str, optional
325
+ The type of attention mechanism to apply. If `None`, no attention is applied.
326
+ - "se" for Squeeze-and-excitation network
327
+ - "gsop" for Global Second-Order Pooling
328
+ - "fca" for Frequency Channel Attention Network
329
+ - "encnet" for context encoding module
330
+ - "eca" for Efficient channel attention for deep convolutional neural networks
331
+ - "ge" for Gather-Excite
332
+ - "gct" for Gated Channel Transformation
333
+ - "srm" for Style-based Recalibration Module
334
+ - "cbam" for Convolutional Block Attention Module
335
+ - "cat" for Learning to collaborate channel and temporal attention
336
+ from multi-information fusion
337
+ - "catlite" for Learning to collaborate channel attention
338
+ from multi-information fusion (lite version, cat w/o temporal attention)
339
+
340
+ in_channels : int, default=16
341
+ The number of input channels to the block.
342
+ temp_filter_length : int, default=15
343
+ The length of the temporal filters in the convolutional layers.
344
+ pool_length : int, default=8
345
+ The length of the window for the average pooling operation.
346
+ pool_stride : int, default=8
347
+ The stride of the average pooling operation.
348
+ drop_prob : float, default=0.5
349
+ The dropout rate for regularization. Values should be between 0 and 1.
350
+ reduction_rate : int, default=4
351
+ The reduction rate used in the attention mechanism to reduce dimensionality
352
+ and computational complexity.
353
+ use_mlp : bool, default=False
354
+ Flag to indicate whether an MLP (Multi-Layer Perceptron) should be used within
355
+ the attention mechanism for further processing.
356
+ seq_len : int, default=62
357
+ The sequence length, used in certain types of attention mechanisms to process
358
+ temporal dimensions.
359
+ freq_idx : int, default=0
360
+ DCT index used in fca attention mechanism.
361
+ n_codewords : int, default=4
362
+ The number of codewords (clusters) used in attention mechanisms that employ
363
+ quantization or clustering strategies.
364
+ kernel_size : int, default=9
365
+ The kernel size used in certain types of attention mechanisms for convolution
366
+ operations.
367
+ extra_params : bool, default=False
368
+ Flag to indicate whether additional, custom parameters should be passed to
369
+ the attention mechanism.
370
+ activation: nn.Module, default=nn.ELU
371
+ Activation function class to apply. Should be a PyTorch activation
372
+ module class like ``nn.ReLU`` or ``nn.ELU``. Default is ``nn.ELU``.
373
+
374
+ Attributes
375
+ ----------
376
+ conv : torch.nn.Sequential
377
+ Sequential model of convolutional layers, batch normalization, and ELU
378
+ activation, designed to process input features.
379
+ pool : torch.nn.AvgPool2d
380
+ Average pooling layer to reduce the dimensionality of the feature maps.
381
+ dropout : torch.nn.Dropout
382
+ Dropout layer for regularization.
383
+ attention_block : torch.nn.Module or None
384
+ The attention mechanism applied to the output of the convolutional layers,
385
+ if `attention_mode` is not None. Otherwise, it's set to None.
386
+ activation: nn.Module, default=nn.ELU
387
+ Activation function class to apply. Should be a PyTorch activation
388
+ module class like ``nn.ReLU`` or ``nn.ELU``. Default is ``nn.ELU``.
389
+
390
+ Examples
391
+ --------
392
+ >>> channel_attention_block = _ChannelAttentionBlock(attention_mode='cbam', in_channels=16, reduction_rate=4, kernel_size=7)
393
+ >>> x = torch.randn(1, 16, 64, 64) # Example input tensor
394
+ >>> output = channel_attention_block(x)
395
+ The output tensor then can be further processed or used as input to another block.
396
+
397
+ """
398
+
399
+ def __init__(
400
+ self,
401
+ attention_mode: str | None = None,
402
+ in_channels: int = 16,
403
+ temp_filter_length: int = 15,
404
+ pool_length: int = 8,
405
+ pool_stride: int = 8,
406
+ drop_prob: float = 0.5,
407
+ reduction_rate: int = 4,
408
+ use_mlp: bool = False,
409
+ seq_len: int = 62,
410
+ freq_idx: int = 0,
411
+ n_codewords: int = 4,
412
+ kernel_size: int = 9,
413
+ extra_params: bool = False,
414
+ activation: nn.Module = nn.ELU,
415
+ ):
416
+ super().__init__()
417
+ self.conv = nn.Sequential(
418
+ nn.Conv2d(
419
+ in_channels,
420
+ in_channels,
421
+ (1, temp_filter_length),
422
+ padding=(0, temp_filter_length // 2),
423
+ bias=False,
424
+ groups=in_channels,
425
+ ),
426
+ nn.Conv2d(in_channels, in_channels, (1, 1), bias=False),
427
+ nn.BatchNorm2d(in_channels),
428
+ activation(),
429
+ )
430
+
431
+ self.pool = nn.AvgPool2d((1, pool_length), stride=(1, pool_stride))
432
+ self.dropout = nn.Dropout(drop_prob)
433
+
434
+ if attention_mode is not None:
435
+ self.attention_block = get_attention_block(
436
+ attention_mode,
437
+ ch_dim=in_channels,
438
+ reduction_rate=reduction_rate,
439
+ use_mlp=use_mlp,
440
+ seq_len=seq_len,
441
+ freq_idx=freq_idx,
442
+ n_codewords=n_codewords,
443
+ kernel_size=kernel_size,
444
+ extra_params=extra_params,
445
+ )
446
+ else:
447
+ self.attention_block = None
448
+
449
+ def forward(self, x):
450
+ out = self.conv(x)
451
+ if self.attention_block is not None:
452
+ out = self.attention_block(out)
453
+ out = self.pool(out)
454
+ out = self.dropout(out)
455
+ return out
456
+
457
+
458
+ def get_attention_block(
459
+ attention_mode: str,
460
+ ch_dim: int = 16,
461
+ reduction_rate: int = 4,
462
+ use_mlp: bool = False,
463
+ seq_len: int | None = None,
464
+ freq_idx: int = 0,
465
+ n_codewords: int = 4,
466
+ kernel_size: int = 9,
467
+ extra_params: bool = False,
468
+ ):
469
+ """
470
+ Util function to the attention block based on the attention mode.
471
+
472
+ Parameters
473
+ ----------
474
+ attention_mode: str
475
+ The type of attention mechanism to apply.
476
+ ch_dim: int
477
+ The number of input channels to the block.
478
+ reduction_rate: int
479
+ The reduction rate used in the attention mechanism to reduce
480
+ dimensionality and computational complexity.
481
+ Used in all the methods, except for the
482
+ encnet and eca.
483
+ use_mlp: bool
484
+ Flag to indicate whether an MLP (Multi-Layer Perceptron) should be used
485
+ within the attention mechanism for further processing. Used in the ge
486
+ and srm attention mechanism.
487
+ seq_len: int
488
+ The sequence length, used in certain types of attention mechanisms to
489
+ process temporal dimensions. Used in the ge or fca attention mechanism.
490
+ freq_idx: int
491
+ DCT index used in fca attention mechanism.
492
+ n_codewords: int
493
+ The number of codewords (clusters) used in attention mechanisms
494
+ that employ quantization or clustering strategies, encnet.
495
+ kernel_size: int
496
+ The kernel size used in certain types of attention mechanisms for convolution
497
+ operations, used in the cbam, eca, and cat attention mechanisms.
498
+ extra_params: bool
499
+ Parameter to pass additional parameters to the GatherExcite mechanism.
500
+
501
+ Returns
502
+ -------
503
+ nn.Module
504
+ The attention block based on the attention mode.
505
+ """
506
+ if attention_mode == "se":
507
+ return SqueezeAndExcitation(in_channels=ch_dim, reduction_rate=reduction_rate)
508
+ # improving the squeeze module
509
+ elif attention_mode == "gsop":
510
+ return GSoP(in_channels=ch_dim, reduction_rate=reduction_rate)
511
+ elif attention_mode == "fca":
512
+ assert seq_len is not None
513
+ return FCA(
514
+ in_channels=ch_dim,
515
+ seq_len=seq_len,
516
+ reduction_rate=reduction_rate,
517
+ freq_idx=freq_idx,
518
+ )
519
+ elif attention_mode == "encnet":
520
+ return EncNet(in_channels=ch_dim, n_codewords=n_codewords)
521
+ # improving the excitation module
522
+ elif attention_mode == "eca":
523
+ return ECA(in_channels=ch_dim, kernel_size=kernel_size)
524
+ # improving the squeeze and the excitation module
525
+ elif attention_mode == "ge":
526
+ assert seq_len is not None
527
+ return GatherExcite(
528
+ in_channels=ch_dim,
529
+ seq_len=seq_len,
530
+ extra_params=extra_params,
531
+ use_mlp=use_mlp,
532
+ reduction_rate=reduction_rate,
533
+ )
534
+ elif attention_mode == "gct":
535
+ return GCT(in_channels=ch_dim)
536
+ elif attention_mode == "srm":
537
+ return SRM(in_channels=ch_dim, use_mlp=use_mlp, reduction_rate=reduction_rate)
538
+ # temporal and channel attention
539
+ elif attention_mode == "cbam":
540
+ return CBAM(
541
+ in_channels=ch_dim, reduction_rate=reduction_rate, kernel_size=kernel_size
542
+ )
543
+ elif attention_mode == "cat":
544
+ return CAT(
545
+ in_channels=ch_dim, reduction_rate=reduction_rate, kernel_size=kernel_size
546
+ )
547
+ elif attention_mode == "catlite":
548
+ return CATLite(ch_dim, reduction_rate=reduction_rate)
549
+ else:
550
+ raise NotImplementedError