braindecode 0.8.1__py3-none-any.whl → 1.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of braindecode might be problematic. Click here for more details.

Files changed (108) hide show
  1. braindecode/__init__.py +1 -2
  2. braindecode/augmentation/__init__.py +39 -19
  3. braindecode/augmentation/base.py +25 -28
  4. braindecode/augmentation/functional.py +237 -100
  5. braindecode/augmentation/transforms.py +325 -158
  6. braindecode/classifier.py +26 -24
  7. braindecode/datasets/__init__.py +28 -10
  8. braindecode/datasets/base.py +220 -134
  9. braindecode/datasets/bbci.py +43 -52
  10. braindecode/datasets/bcicomp.py +47 -32
  11. braindecode/datasets/bids.py +245 -0
  12. braindecode/datasets/mne.py +45 -24
  13. braindecode/datasets/moabb.py +87 -27
  14. braindecode/datasets/nmt.py +311 -0
  15. braindecode/datasets/sleep_physio_challe_18.py +412 -0
  16. braindecode/datasets/sleep_physionet.py +43 -26
  17. braindecode/datasets/tuh.py +324 -140
  18. braindecode/datasets/xy.py +27 -12
  19. braindecode/datautil/__init__.py +37 -18
  20. braindecode/datautil/serialization.py +110 -72
  21. braindecode/eegneuralnet.py +63 -47
  22. braindecode/functional/__init__.py +22 -0
  23. braindecode/functional/functions.py +250 -0
  24. braindecode/functional/initialization.py +47 -0
  25. braindecode/models/__init__.py +84 -14
  26. braindecode/models/atcnet.py +193 -164
  27. braindecode/models/attentionbasenet.py +599 -0
  28. braindecode/models/base.py +86 -102
  29. braindecode/models/biot.py +504 -0
  30. braindecode/models/contrawr.py +317 -0
  31. braindecode/models/ctnet.py +536 -0
  32. braindecode/models/deep4.py +116 -77
  33. braindecode/models/deepsleepnet.py +149 -119
  34. braindecode/models/eegconformer.py +112 -173
  35. braindecode/models/eeginception_erp.py +109 -118
  36. braindecode/models/eeginception_mi.py +161 -97
  37. braindecode/models/eegitnet.py +215 -152
  38. braindecode/models/eegminer.py +254 -0
  39. braindecode/models/eegnet.py +228 -161
  40. braindecode/models/eegnex.py +247 -0
  41. braindecode/models/eegresnet.py +234 -152
  42. braindecode/models/eegsimpleconv.py +199 -0
  43. braindecode/models/eegtcnet.py +335 -0
  44. braindecode/models/fbcnet.py +221 -0
  45. braindecode/models/fblightconvnet.py +313 -0
  46. braindecode/models/fbmsnet.py +324 -0
  47. braindecode/models/hybrid.py +52 -71
  48. braindecode/models/ifnet.py +441 -0
  49. braindecode/models/labram.py +1186 -0
  50. braindecode/models/msvtnet.py +375 -0
  51. braindecode/models/sccnet.py +207 -0
  52. braindecode/models/shallow_fbcsp.py +50 -56
  53. braindecode/models/signal_jepa.py +1011 -0
  54. braindecode/models/sinc_shallow.py +337 -0
  55. braindecode/models/sleep_stager_blanco_2020.py +55 -46
  56. braindecode/models/sleep_stager_chambon_2018.py +54 -53
  57. braindecode/models/sleep_stager_eldele_2021.py +247 -141
  58. braindecode/models/sparcnet.py +424 -0
  59. braindecode/models/summary.csv +41 -0
  60. braindecode/models/syncnet.py +232 -0
  61. braindecode/models/tcn.py +158 -88
  62. braindecode/models/tidnet.py +280 -167
  63. braindecode/models/tsinception.py +283 -0
  64. braindecode/models/usleep.py +190 -177
  65. braindecode/models/util.py +109 -145
  66. braindecode/modules/__init__.py +84 -0
  67. braindecode/modules/activation.py +60 -0
  68. braindecode/modules/attention.py +757 -0
  69. braindecode/modules/blocks.py +108 -0
  70. braindecode/modules/convolution.py +274 -0
  71. braindecode/modules/filter.py +628 -0
  72. braindecode/modules/layers.py +131 -0
  73. braindecode/modules/linear.py +49 -0
  74. braindecode/modules/parametrization.py +38 -0
  75. braindecode/modules/stats.py +77 -0
  76. braindecode/modules/util.py +76 -0
  77. braindecode/modules/wrapper.py +73 -0
  78. braindecode/preprocessing/__init__.py +36 -11
  79. braindecode/preprocessing/mne_preprocess.py +13 -7
  80. braindecode/preprocessing/preprocess.py +139 -75
  81. braindecode/preprocessing/windowers.py +576 -187
  82. braindecode/regressor.py +23 -12
  83. braindecode/samplers/__init__.py +16 -8
  84. braindecode/samplers/base.py +146 -32
  85. braindecode/samplers/ssl.py +162 -17
  86. braindecode/training/__init__.py +18 -10
  87. braindecode/training/callbacks.py +2 -4
  88. braindecode/training/losses.py +3 -8
  89. braindecode/training/scoring.py +76 -68
  90. braindecode/util.py +55 -59
  91. braindecode/version.py +1 -1
  92. braindecode/visualization/__init__.py +2 -3
  93. braindecode/visualization/confusion_matrices.py +117 -73
  94. braindecode/visualization/gradients.py +14 -10
  95. {braindecode-0.8.1.dist-info → braindecode-1.1.0.dist-info}/METADATA +42 -58
  96. braindecode-1.1.0.dist-info/RECORD +101 -0
  97. {braindecode-0.8.1.dist-info → braindecode-1.1.0.dist-info}/WHEEL +1 -1
  98. {braindecode-0.8.1.dist-info → braindecode-1.1.0.dist-info/licenses}/LICENSE.txt +1 -1
  99. braindecode-1.1.0.dist-info/licenses/NOTICE.txt +20 -0
  100. braindecode/datautil/mne.py +0 -9
  101. braindecode/datautil/preprocess.py +0 -12
  102. braindecode/datautil/windowers.py +0 -6
  103. braindecode/datautil/xy.py +0 -9
  104. braindecode/models/eeginception.py +0 -317
  105. braindecode/models/functions.py +0 -47
  106. braindecode/models/modules.py +0 -358
  107. braindecode-0.8.1.dist-info/RECORD +0 -68
  108. {braindecode-0.8.1.dist-info → braindecode-1.1.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,599 @@
1
+ from __future__ import annotations
2
+
3
+ import math
4
+
5
+ from einops.layers.torch import Rearrange
6
+ from mne.utils import warn
7
+ from torch import nn
8
+
9
+ from braindecode.models.base import EEGModuleMixin
10
+ from braindecode.modules import Ensure4d
11
+ from braindecode.modules.attention import (
12
+ CAT,
13
+ CBAM,
14
+ ECA,
15
+ FCA,
16
+ GCT,
17
+ SRM,
18
+ CATLite,
19
+ EncNet,
20
+ GatherExcite,
21
+ GSoP,
22
+ SqueezeAndExcitation,
23
+ )
24
+
25
+
26
+ class AttentionBaseNet(EEGModuleMixin, nn.Module):
27
+ """AttentionBaseNet from Wimpff M et al. (2023) [Martin2023]_.
28
+
29
+ .. figure:: https://content.cld.iop.org/journals/1741-2552/21/3/036020/revision2/jnead48b9f2_hr.jpg
30
+ :align: center
31
+ :alt: Attention Base Net
32
+
33
+ Neural Network from the paper: EEG motor imagery decoding:
34
+ A framework for comparative analysis with channel attention
35
+ mechanisms
36
+
37
+ The paper and original code with more details about the methodological
38
+ choices are available at the [Martin2023]_ and [MartinCode]_.
39
+
40
+ The AttentionBaseNet architecture is composed of four modules:
41
+ - Input Block that performs a temporal convolution and a spatial
42
+ convolution.
43
+ - Channel Expansion that modifies the number of channels.
44
+ - An attention block that performs channel attention with several
45
+ options
46
+ - ClassificationHead
47
+
48
+ .. versionadded:: 0.9
49
+
50
+ Parameters
51
+ ----------
52
+ n_temporal_filters : int, optional
53
+ Number of temporal convolutional filters in the first layer. This defines
54
+ the number of output channels after the temporal convolution.
55
+ Default is 40.
56
+ temp_filter_length : int, default=15
57
+ The length of the temporal filters in the convolutional layers.
58
+ spatial_expansion : int, optional
59
+ Multiplicative factor to expand the spatial dimensions. Used to increase
60
+ the capacity of the model by expanding spatial features. Default is 1.
61
+ pool_length_inp : int, optional
62
+ Length of the pooling window in the input layer. Determines how much
63
+ temporal information is aggregated during pooling. Default is 75.
64
+ pool_stride_inp : int, optional
65
+ Stride of the pooling operation in the input layer. Controls the
66
+ downsampling factor in the temporal dimension. Default is 15.
67
+ drop_prob_inp : float, optional
68
+ Dropout rate applied after the input layer. This is the probability of
69
+ zeroing out elements during training to prevent overfitting.
70
+ Default is 0.5.
71
+ ch_dim : int, optional
72
+ Number of channels in the subsequent convolutional layers. This controls
73
+ the depth of the network after the initial layer. Default is 16.
74
+ attention_mode : str, optional
75
+ The type of attention mechanism to apply. If `None`, no attention is applied.
76
+ - "se" for Squeeze-and-excitation network
77
+ - "gsop" for Global Second-Order Pooling
78
+ - "fca" for Frequency Channel Attention Network
79
+ - "encnet" for context encoding module
80
+ - "eca" for Efficient channel attention for deep convolutional neural networks
81
+ - "ge" for Gather-Excite
82
+ - "gct" for Gated Channel Transformation
83
+ - "srm" for Style-based Recalibration Module
84
+ - "cbam" for Convolutional Block Attention Module
85
+ - "cat" for Learning to collaborate channel and temporal attention
86
+ from multi-information fusion
87
+ - "catlite" for Learning to collaborate channel attention
88
+ from multi-information fusion (lite version, cat w/o temporal attention)
89
+ pool_length : int, default=8
90
+ The length of the window for the average pooling operation.
91
+ pool_stride : int, default=8
92
+ The stride of the average pooling operation.
93
+ drop_prob_attn : float, default=0.5
94
+ The dropout rate for regularization for the attention layer. Values should be between 0 and 1.
95
+ reduction_rate : int, default=4
96
+ The reduction rate used in the attention mechanism to reduce dimensionality
97
+ and computational complexity.
98
+ use_mlp : bool, default=False
99
+ Flag to indicate whether an MLP (Multi-Layer Perceptron) should be used within
100
+ the attention mechanism for further processing.
101
+ freq_idx : int, default=0
102
+ DCT index used in fca attention mechanism.
103
+ n_codewords : int, default=4
104
+ The number of codewords (clusters) used in attention mechanisms that employ
105
+ quantization or clustering strategies.
106
+ kernel_size : int, default=9
107
+ The kernel size used in certain types of attention mechanisms for convolution
108
+ operations.
109
+ activation: nn.Module, default=nn.ELU
110
+ Activation function class to apply. Should be a PyTorch activation
111
+ module class like ``nn.ReLU`` or ``nn.ELU``. Default is ``nn.ELU``.
112
+ extra_params : bool, default=False
113
+ Flag to indicate whether additional, custom parameters should be passed to
114
+ the attention mechanism.
115
+
116
+ References
117
+ ----------
118
+ .. [Martin2023] Wimpff, M., Gizzi, L., Zerfowski, J. and Yang, B., 2023.
119
+ EEG motor imagery decoding: A framework for comparative analysis with
120
+ channel attention mechanisms. arXiv preprint arXiv:2310.11198.
121
+ .. [MartinCode] Wimpff, M., Gizzi, L., Zerfowski, J. and Yang, B.
122
+ GitHub https://github.com/martinwimpff/channel-attention (accessed 2024-03-28)
123
+ """
124
+
125
+ def __init__(
126
+ self,
127
+ n_times=None,
128
+ n_chans=None,
129
+ n_outputs=None,
130
+ chs_info=None,
131
+ sfreq=None,
132
+ input_window_seconds=None,
133
+ # Module parameters
134
+ n_temporal_filters: int = 40,
135
+ temp_filter_length_inp: int = 25,
136
+ spatial_expansion: int = 1,
137
+ pool_length_inp: int = 75,
138
+ pool_stride_inp: int = 15,
139
+ drop_prob_inp: float = 0.5,
140
+ ch_dim: int = 16,
141
+ temp_filter_length: int = 15,
142
+ pool_length: int = 8,
143
+ pool_stride: int = 8,
144
+ drop_prob_attn: float = 0.5,
145
+ attention_mode: str | None = None,
146
+ reduction_rate: int = 4,
147
+ use_mlp: bool = False,
148
+ freq_idx: int = 0,
149
+ n_codewords: int = 4,
150
+ kernel_size: int = 9,
151
+ activation: nn.Module = nn.ELU,
152
+ extra_params: bool = False,
153
+ ):
154
+ super(AttentionBaseNet, self).__init__()
155
+
156
+ super().__init__(
157
+ n_outputs=n_outputs,
158
+ n_chans=n_chans,
159
+ chs_info=chs_info,
160
+ n_times=n_times,
161
+ sfreq=sfreq,
162
+ input_window_seconds=input_window_seconds,
163
+ )
164
+ del n_outputs, n_chans, chs_info, n_times, sfreq, input_window_seconds
165
+
166
+ min_n_times_required = self._get_min_n_times(
167
+ pool_length_inp,
168
+ pool_stride_inp,
169
+ pool_length,
170
+ )
171
+
172
+ if self.n_times < min_n_times_required:
173
+ scaling_factor = self.n_times / min_n_times_required
174
+ warn(
175
+ f"n_times ({self.n_times}) is smaller than the minimum required "
176
+ f"({min_n_times_required}) for the current model parameters configuration. "
177
+ "Adjusting parameters to ensure compatibility."
178
+ "Reducing the kernel, pooling, and stride sizes accordingly.\n"
179
+ "Scaling factor: {:.2f}".format(scaling_factor),
180
+ UserWarning,
181
+ )
182
+ # 3. Scale down all temporal parameters proportionally
183
+ # Use max(1, ...) to ensure parameters remain valid
184
+ temp_filter_length_inp = max(
185
+ 1, int(temp_filter_length_inp * scaling_factor)
186
+ )
187
+ pool_length_inp = max(1, int(pool_length_inp * scaling_factor))
188
+ pool_stride_inp = max(1, int(pool_stride_inp * scaling_factor))
189
+ temp_filter_length = max(1, int(temp_filter_length * scaling_factor))
190
+ pool_length = max(1, int(pool_length * scaling_factor))
191
+ pool_stride = max(1, int(pool_stride * scaling_factor))
192
+
193
+ self.input_block = _FeatureExtractor(
194
+ n_chans=self.n_chans,
195
+ n_temporal_filters=n_temporal_filters,
196
+ temporal_filter_length=temp_filter_length_inp,
197
+ spatial_expansion=spatial_expansion,
198
+ pool_length=pool_length_inp,
199
+ pool_stride=pool_stride_inp,
200
+ drop_prob=drop_prob_inp,
201
+ activation=activation,
202
+ )
203
+
204
+ self.channel_expansion = nn.Sequential(
205
+ nn.Conv2d(
206
+ n_temporal_filters * spatial_expansion, ch_dim, (1, 1), bias=False
207
+ ),
208
+ nn.BatchNorm2d(ch_dim),
209
+ activation(),
210
+ )
211
+
212
+ seq_lengths = self._calculate_sequence_lengths(
213
+ self.n_times,
214
+ [temp_filter_length_inp, temp_filter_length],
215
+ [pool_length_inp, pool_length],
216
+ [pool_stride_inp, pool_stride],
217
+ )
218
+
219
+ self.channel_attention_block = _ChannelAttentionBlock(
220
+ attention_mode=attention_mode,
221
+ in_channels=ch_dim,
222
+ temp_filter_length=temp_filter_length,
223
+ pool_length=pool_length,
224
+ pool_stride=pool_stride,
225
+ drop_prob=drop_prob_attn,
226
+ reduction_rate=reduction_rate,
227
+ use_mlp=use_mlp,
228
+ seq_len=seq_lengths[0],
229
+ freq_idx=freq_idx,
230
+ n_codewords=n_codewords,
231
+ kernel_size=kernel_size,
232
+ extra_params=extra_params,
233
+ activation=activation,
234
+ )
235
+
236
+ self.final_layer = nn.Sequential(
237
+ nn.Flatten(), nn.Linear(seq_lengths[-1] * ch_dim, self.n_outputs)
238
+ )
239
+
240
+ def forward(self, x):
241
+ x = self.input_block(x)
242
+ x = self.channel_expansion(x)
243
+ x = self.channel_attention_block(x)
244
+ x = self.final_layer(x)
245
+ return x
246
+
247
+ @staticmethod
248
+ def _calculate_sequence_lengths(
249
+ input_window_samples: int,
250
+ kernel_lengths: list,
251
+ pool_lengths: list,
252
+ pool_strides: list,
253
+ ):
254
+ seq_lengths = []
255
+ out = input_window_samples
256
+ for k, pl, ps in zip(kernel_lengths, pool_lengths, pool_strides):
257
+ out = math.floor(out + 2 * (k // 2) - k + 1)
258
+ out = math.floor((out - pl) / ps + 1)
259
+ seq_lengths.append(int(out))
260
+ return seq_lengths
261
+
262
+ @staticmethod
263
+ def _get_min_n_times(
264
+ pool_length_inp: int,
265
+ pool_stride_inp: int,
266
+ pool_length: int,
267
+ ) -> int:
268
+ """
269
+ Calculates the minimum n_times required for the model to work
270
+ with the given parameters.
271
+
272
+ The calculation is based on reversing the pooling operations to
273
+ ensure the input to each is valid.
274
+ """
275
+ # The input to the second pooling layer must be at least its kernel size.
276
+ min_len_for_second_pool = pool_length
277
+
278
+ # Reverse the first pooling operation to find the required input size.
279
+ # Formula: min_L_in = Stride * (min_L_out - 1) + Kernel
280
+ min_len = pool_stride_inp * (min_len_for_second_pool - 1) + pool_length_inp
281
+ return min_len
282
+
283
+
284
+ class _FeatureExtractor(nn.Module):
285
+ """
286
+ A module for feature extraction of the data with temporal and spatial
287
+ transformations.
288
+
289
+ This module sequentially processes the input through a series of layers:
290
+ rearrangement, temporal convolution, batch normalization, spatial convolution,
291
+ another batch normalization, an ELU non-linearity, average pooling, and dropout.
292
+
293
+
294
+ Parameters
295
+ ----------
296
+ n_chans : int
297
+ The number of channels in the input data.
298
+ n_temporal_filters : int, optional
299
+ The number of filters to use in the temporal convolution layer. Default is 40.
300
+ temporal_filter_length : int, optional
301
+ The size of each filter in the temporal convolution layer. Default is 25.
302
+ spatial_expansion : int, optional
303
+ The expansion factor of the spatial convolution layer, determining the number
304
+ of output channels relative to the number of temporal filters. Default is 1.
305
+ pool_length : int, optional
306
+ The size of the window for the average pooling operation. Default is 75.
307
+ pool_stride : int, optional
308
+ The stride of the average pooling operation. Default is 15.
309
+ drop_prob : float, optional
310
+ The dropout rate for regularization. Default is 0.5.
311
+ activation: nn.Module, default=nn.ELU
312
+ Activation function class to apply. Should be a PyTorch activation
313
+ module class like ``nn.ReLU`` or ``nn.ELU``. Default is ``nn.ELU``.
314
+ """
315
+
316
+ def __init__(
317
+ self,
318
+ n_chans: int,
319
+ n_temporal_filters: int = 40,
320
+ temporal_filter_length: int = 25,
321
+ spatial_expansion: int = 1,
322
+ pool_length: int = 75,
323
+ pool_stride: int = 15,
324
+ drop_prob: float = 0.5,
325
+ activation: nn.Module = nn.ELU,
326
+ ):
327
+ super().__init__()
328
+
329
+ self.ensure4d = Ensure4d()
330
+ self.rearrange_input = Rearrange("b c t 1 -> b 1 c t")
331
+ self.temporal_conv = nn.Conv2d(
332
+ 1,
333
+ n_temporal_filters,
334
+ kernel_size=(1, temporal_filter_length),
335
+ padding=(0, temporal_filter_length // 2),
336
+ bias=False,
337
+ )
338
+ self.intermediate_bn = nn.BatchNorm2d(n_temporal_filters)
339
+ self.spatial_conv = nn.Conv2d(
340
+ n_temporal_filters,
341
+ n_temporal_filters * spatial_expansion,
342
+ kernel_size=(n_chans, 1),
343
+ groups=n_temporal_filters,
344
+ bias=False,
345
+ )
346
+ self.bn = nn.BatchNorm2d(n_temporal_filters * spatial_expansion)
347
+ self.nonlinearity = activation()
348
+ self.pool = nn.AvgPool2d((1, pool_length), stride=(1, pool_stride))
349
+ self.dropout = nn.Dropout(drop_prob)
350
+
351
+ def forward(self, x):
352
+ x = self.ensure4d(x)
353
+ x = self.rearrange_input(x)
354
+ x = self.temporal_conv(x)
355
+ x = self.intermediate_bn(x)
356
+ x = self.spatial_conv(x)
357
+ x = self.bn(x)
358
+ x = self.nonlinearity(x)
359
+ x = self.pool(x)
360
+ x = self.dropout(x)
361
+ return x
362
+
363
+
364
+ class _ChannelAttentionBlock(nn.Module):
365
+ """
366
+ A neural network module implementing channel-wise attention mechanisms to enhance
367
+ feature representations by selectively emphasizing important channels and suppressing
368
+ less useful ones. This block integrates convolutional layers, pooling, dropout, and
369
+ an optional attention mechanism that can be customized based on the given mode.
370
+
371
+ Parameters
372
+ ----------
373
+ attention_mode : str, optional
374
+ The type of attention mechanism to apply. If `None`, no attention is applied.
375
+ - "se" for Squeeze-and-excitation network
376
+ - "gsop" for Global Second-Order Pooling
377
+ - "fca" for Frequency Channel Attention Network
378
+ - "encnet" for context encoding module
379
+ - "eca" for Efficient channel attention for deep convolutional neural networks
380
+ - "ge" for Gather-Excite
381
+ - "gct" for Gated Channel Transformation
382
+ - "srm" for Style-based Recalibration Module
383
+ - "cbam" for Convolutional Block Attention Module
384
+ - "cat" for Learning to collaborate channel and temporal attention
385
+ from multi-information fusion
386
+ - "catlite" for Learning to collaborate channel attention
387
+ from multi-information fusion (lite version, cat w/o temporal attention)
388
+
389
+ in_channels : int, default=16
390
+ The number of input channels to the block.
391
+ temp_filter_length : int, default=15
392
+ The length of the temporal filters in the convolutional layers.
393
+ pool_length : int, default=8
394
+ The length of the window for the average pooling operation.
395
+ pool_stride : int, default=8
396
+ The stride of the average pooling operation.
397
+ drop_prob : float, default=0.5
398
+ The dropout rate for regularization. Values should be between 0 and 1.
399
+ reduction_rate : int, default=4
400
+ The reduction rate used in the attention mechanism to reduce dimensionality
401
+ and computational complexity.
402
+ use_mlp : bool, default=False
403
+ Flag to indicate whether an MLP (Multi-Layer Perceptron) should be used within
404
+ the attention mechanism for further processing.
405
+ seq_len : int, default=62
406
+ The sequence length, used in certain types of attention mechanisms to process
407
+ temporal dimensions.
408
+ freq_idx : int, default=0
409
+ DCT index used in fca attention mechanism.
410
+ n_codewords : int, default=4
411
+ The number of codewords (clusters) used in attention mechanisms that employ
412
+ quantization or clustering strategies.
413
+ kernel_size : int, default=9
414
+ The kernel size used in certain types of attention mechanisms for convolution
415
+ operations.
416
+ extra_params : bool, default=False
417
+ Flag to indicate whether additional, custom parameters should be passed to
418
+ the attention mechanism.
419
+ activation: nn.Module, default=nn.ELU
420
+ Activation function class to apply. Should be a PyTorch activation
421
+ module class like ``nn.ReLU`` or ``nn.ELU``. Default is ``nn.ELU``.
422
+
423
+ Attributes
424
+ ----------
425
+ conv : torch.nn.Sequential
426
+ Sequential model of convolutional layers, batch normalization, and ELU
427
+ activation, designed to process input features.
428
+ pool : torch.nn.AvgPool2d
429
+ Average pooling layer to reduce the dimensionality of the feature maps.
430
+ dropout : torch.nn.Dropout
431
+ Dropout layer for regularization.
432
+ attention_block : torch.nn.Module or None
433
+ The attention mechanism applied to the output of the convolutional layers,
434
+ if `attention_mode` is not None. Otherwise, it's set to None.
435
+ activation: nn.Module, default=nn.ELU
436
+ Activation function class to apply. Should be a PyTorch activation
437
+ module class like ``nn.ReLU`` or ``nn.ELU``. Default is ``nn.ELU``.
438
+
439
+ Examples
440
+ --------
441
+ >>> channel_attention_block = _ChannelAttentionBlock(attention_mode='cbam', in_channels=16, reduction_rate=4, kernel_size=7)
442
+ >>> x = torch.randn(1, 16, 64, 64) # Example input tensor
443
+ >>> output = channel_attention_block(x)
444
+ The output tensor then can be further processed or used as input to another block.
445
+
446
+ """
447
+
448
+ def __init__(
449
+ self,
450
+ attention_mode: str | None = None,
451
+ in_channels: int = 16,
452
+ temp_filter_length: int = 15,
453
+ pool_length: int = 8,
454
+ pool_stride: int = 8,
455
+ drop_prob: float = 0.5,
456
+ reduction_rate: int = 4,
457
+ use_mlp: bool = False,
458
+ seq_len: int = 62,
459
+ freq_idx: int = 0,
460
+ n_codewords: int = 4,
461
+ kernel_size: int = 9,
462
+ extra_params: bool = False,
463
+ activation: nn.Module = nn.ELU,
464
+ ):
465
+ super().__init__()
466
+ self.conv = nn.Sequential(
467
+ nn.Conv2d(
468
+ in_channels,
469
+ in_channels,
470
+ (1, temp_filter_length),
471
+ padding=(0, temp_filter_length // 2),
472
+ bias=False,
473
+ groups=in_channels,
474
+ ),
475
+ nn.Conv2d(in_channels, in_channels, (1, 1), bias=False),
476
+ nn.BatchNorm2d(in_channels),
477
+ activation(),
478
+ )
479
+
480
+ self.pool = nn.AvgPool2d((1, pool_length), stride=(1, pool_stride))
481
+ self.dropout = nn.Dropout(drop_prob)
482
+
483
+ if attention_mode is not None:
484
+ self.attention_block = get_attention_block(
485
+ attention_mode,
486
+ ch_dim=in_channels,
487
+ reduction_rate=reduction_rate,
488
+ use_mlp=use_mlp,
489
+ seq_len=seq_len,
490
+ freq_idx=freq_idx,
491
+ n_codewords=n_codewords,
492
+ kernel_size=kernel_size,
493
+ extra_params=extra_params,
494
+ )
495
+ else:
496
+ self.attention_block = None
497
+
498
+ def forward(self, x):
499
+ out = self.conv(x)
500
+ if self.attention_block is not None:
501
+ out = self.attention_block(out)
502
+ out = self.pool(out)
503
+ out = self.dropout(out)
504
+ return out
505
+
506
+
507
+ def get_attention_block(
508
+ attention_mode: str,
509
+ ch_dim: int = 16,
510
+ reduction_rate: int = 4,
511
+ use_mlp: bool = False,
512
+ seq_len: int | None = None,
513
+ freq_idx: int = 0,
514
+ n_codewords: int = 4,
515
+ kernel_size: int = 9,
516
+ extra_params: bool = False,
517
+ ):
518
+ """
519
+ Util function to the attention block based on the attention mode.
520
+
521
+ Parameters
522
+ ----------
523
+ attention_mode: str
524
+ The type of attention mechanism to apply.
525
+ ch_dim: int
526
+ The number of input channels to the block.
527
+ reduction_rate: int
528
+ The reduction rate used in the attention mechanism to reduce
529
+ dimensionality and computational complexity.
530
+ Used in all the methods, except for the
531
+ encnet and eca.
532
+ use_mlp: bool
533
+ Flag to indicate whether an MLP (Multi-Layer Perceptron) should be used
534
+ within the attention mechanism for further processing. Used in the ge
535
+ and srm attention mechanism.
536
+ seq_len: int
537
+ The sequence length, used in certain types of attention mechanisms to
538
+ process temporal dimensions. Used in the ge or fca attention mechanism.
539
+ freq_idx: int
540
+ DCT index used in fca attention mechanism.
541
+ n_codewords: int
542
+ The number of codewords (clusters) used in attention mechanisms
543
+ that employ quantization or clustering strategies, encnet.
544
+ kernel_size: int
545
+ The kernel size used in certain types of attention mechanisms for convolution
546
+ operations, used in the cbam, eca, and cat attention mechanisms.
547
+ extra_params: bool
548
+ Parameter to pass additional parameters to the GatherExcite mechanism.
549
+
550
+ Returns
551
+ -------
552
+ nn.Module
553
+ The attention block based on the attention mode.
554
+ """
555
+ if attention_mode == "se":
556
+ return SqueezeAndExcitation(in_channels=ch_dim, reduction_rate=reduction_rate)
557
+ # improving the squeeze module
558
+ elif attention_mode == "gsop":
559
+ return GSoP(in_channels=ch_dim, reduction_rate=reduction_rate)
560
+ elif attention_mode == "fca":
561
+ assert seq_len is not None
562
+ return FCA(
563
+ in_channels=ch_dim,
564
+ seq_len=seq_len,
565
+ reduction_rate=reduction_rate,
566
+ freq_idx=freq_idx,
567
+ )
568
+ elif attention_mode == "encnet":
569
+ return EncNet(in_channels=ch_dim, n_codewords=n_codewords)
570
+ # improving the excitation module
571
+ elif attention_mode == "eca":
572
+ return ECA(in_channels=ch_dim, kernel_size=kernel_size)
573
+ # improving the squeeze and the excitation module
574
+ elif attention_mode == "ge":
575
+ assert seq_len is not None
576
+ return GatherExcite(
577
+ in_channels=ch_dim,
578
+ seq_len=seq_len,
579
+ extra_params=extra_params,
580
+ use_mlp=use_mlp,
581
+ reduction_rate=reduction_rate,
582
+ )
583
+ elif attention_mode == "gct":
584
+ return GCT(in_channels=ch_dim)
585
+ elif attention_mode == "srm":
586
+ return SRM(in_channels=ch_dim, use_mlp=use_mlp, reduction_rate=reduction_rate)
587
+ # temporal and channel attention
588
+ elif attention_mode == "cbam":
589
+ return CBAM(
590
+ in_channels=ch_dim, reduction_rate=reduction_rate, kernel_size=kernel_size
591
+ )
592
+ elif attention_mode == "cat":
593
+ return CAT(
594
+ in_channels=ch_dim, reduction_rate=reduction_rate, kernel_size=kernel_size
595
+ )
596
+ elif attention_mode == "catlite":
597
+ return CATLite(ch_dim, reduction_rate=reduction_rate)
598
+ else:
599
+ raise NotImplementedError