braindecode 0.8.1__py3-none-any.whl → 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of braindecode might be problematic. Click here for more details.
- braindecode/__init__.py +1 -2
- braindecode/augmentation/__init__.py +39 -19
- braindecode/augmentation/base.py +25 -28
- braindecode/augmentation/functional.py +237 -100
- braindecode/augmentation/transforms.py +326 -158
- braindecode/classifier.py +26 -24
- braindecode/datasets/__init__.py +28 -10
- braindecode/datasets/base.py +220 -134
- braindecode/datasets/bbci.py +43 -52
- braindecode/datasets/bcicomp.py +47 -32
- braindecode/datasets/bids.py +245 -0
- braindecode/datasets/mne.py +45 -24
- braindecode/datasets/moabb.py +87 -27
- braindecode/datasets/nmt.py +311 -0
- braindecode/datasets/sleep_physio_challe_18.py +412 -0
- braindecode/datasets/sleep_physionet.py +43 -26
- braindecode/datasets/tuh.py +324 -140
- braindecode/datasets/xy.py +27 -12
- braindecode/datautil/__init__.py +34 -18
- braindecode/datautil/serialization.py +98 -71
- braindecode/eegneuralnet.py +63 -47
- braindecode/functional/__init__.py +10 -0
- braindecode/functional/functions.py +251 -0
- braindecode/functional/initialization.py +47 -0
- braindecode/models/__init__.py +36 -14
- braindecode/models/atcnet.py +153 -159
- braindecode/models/attentionbasenet.py +550 -0
- braindecode/models/base.py +86 -102
- braindecode/models/biot.py +483 -0
- braindecode/models/contrawr.py +296 -0
- braindecode/models/ctnet.py +450 -0
- braindecode/models/deep4.py +64 -75
- braindecode/models/deepsleepnet.py +149 -119
- braindecode/models/eegconformer.py +111 -171
- braindecode/models/eeginception_erp.py +109 -118
- braindecode/models/eeginception_mi.py +155 -97
- braindecode/models/eegitnet.py +215 -151
- braindecode/models/eegminer.py +255 -0
- braindecode/models/eegnet.py +229 -161
- braindecode/models/eegnex.py +247 -0
- braindecode/models/eegresnet.py +234 -152
- braindecode/models/eegsimpleconv.py +199 -0
- braindecode/models/eegtcnet.py +335 -0
- braindecode/models/fbcnet.py +221 -0
- braindecode/models/fblightconvnet.py +313 -0
- braindecode/models/fbmsnet.py +325 -0
- braindecode/models/hybrid.py +52 -71
- braindecode/models/ifnet.py +441 -0
- braindecode/models/labram.py +1166 -0
- braindecode/models/msvtnet.py +375 -0
- braindecode/models/sccnet.py +182 -0
- braindecode/models/shallow_fbcsp.py +50 -56
- braindecode/models/signal_jepa.py +1012 -0
- braindecode/models/sinc_shallow.py +337 -0
- braindecode/models/sleep_stager_blanco_2020.py +55 -46
- braindecode/models/sleep_stager_chambon_2018.py +54 -53
- braindecode/models/sleep_stager_eldele_2021.py +248 -141
- braindecode/models/sparcnet.py +378 -0
- braindecode/models/summary.csv +41 -0
- braindecode/models/syncnet.py +232 -0
- braindecode/models/tcn.py +158 -88
- braindecode/models/tidnet.py +280 -167
- braindecode/models/tsinception.py +258 -0
- braindecode/models/usleep.py +190 -177
- braindecode/models/util.py +109 -141
- braindecode/modules/__init__.py +38 -0
- braindecode/modules/activation.py +60 -0
- braindecode/modules/attention.py +757 -0
- braindecode/modules/blocks.py +108 -0
- braindecode/modules/convolution.py +274 -0
- braindecode/modules/filter.py +632 -0
- braindecode/modules/layers.py +133 -0
- braindecode/modules/linear.py +50 -0
- braindecode/modules/parametrization.py +38 -0
- braindecode/modules/stats.py +77 -0
- braindecode/modules/util.py +77 -0
- braindecode/modules/wrapper.py +75 -0
- braindecode/preprocessing/__init__.py +36 -11
- braindecode/preprocessing/mne_preprocess.py +13 -7
- braindecode/preprocessing/preprocess.py +139 -75
- braindecode/preprocessing/windowers.py +576 -187
- braindecode/regressor.py +23 -12
- braindecode/samplers/__init__.py +16 -8
- braindecode/samplers/base.py +148 -32
- braindecode/samplers/ssl.py +162 -17
- braindecode/training/__init__.py +18 -10
- braindecode/training/callbacks.py +2 -4
- braindecode/training/losses.py +3 -8
- braindecode/training/scoring.py +76 -68
- braindecode/util.py +55 -59
- braindecode/version.py +1 -1
- braindecode/visualization/__init__.py +2 -3
- braindecode/visualization/confusion_matrices.py +117 -73
- braindecode/visualization/gradients.py +14 -10
- {braindecode-0.8.1.dist-info → braindecode-1.0.0.dist-info}/METADATA +39 -55
- braindecode-1.0.0.dist-info/RECORD +101 -0
- {braindecode-0.8.1.dist-info → braindecode-1.0.0.dist-info}/WHEEL +1 -1
- {braindecode-0.8.1.dist-info → braindecode-1.0.0.dist-info/licenses}/LICENSE.txt +1 -1
- braindecode-1.0.0.dist-info/licenses/NOTICE.txt +20 -0
- braindecode/datautil/mne.py +0 -9
- braindecode/datautil/preprocess.py +0 -12
- braindecode/datautil/windowers.py +0 -6
- braindecode/datautil/xy.py +0 -9
- braindecode/models/eeginception.py +0 -317
- braindecode/models/functions.py +0 -47
- braindecode/models/modules.py +0 -358
- braindecode-0.8.1.dist-info/RECORD +0 -68
- {braindecode-0.8.1.dist-info → braindecode-1.0.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,550 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import math
|
|
4
|
+
|
|
5
|
+
from einops.layers.torch import Rearrange
|
|
6
|
+
from torch import nn
|
|
7
|
+
|
|
8
|
+
from braindecode.models.base import EEGModuleMixin
|
|
9
|
+
from braindecode.modules import Ensure4d
|
|
10
|
+
from braindecode.modules.attention import (
|
|
11
|
+
CAT,
|
|
12
|
+
CBAM,
|
|
13
|
+
ECA,
|
|
14
|
+
FCA,
|
|
15
|
+
GCT,
|
|
16
|
+
SRM,
|
|
17
|
+
CATLite,
|
|
18
|
+
EncNet,
|
|
19
|
+
GatherExcite,
|
|
20
|
+
GSoP,
|
|
21
|
+
SqueezeAndExcitation,
|
|
22
|
+
)
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class AttentionBaseNet(EEGModuleMixin, nn.Module):
|
|
26
|
+
"""AttentionBaseNet from Wimpff M et al. (2023) [Martin2023]_.
|
|
27
|
+
|
|
28
|
+
.. figure:: https://content.cld.iop.org/journals/1741-2552/21/3/036020/revision2/jnead48b9f2_hr.jpg
|
|
29
|
+
:align: center
|
|
30
|
+
:alt: Attention Base Net
|
|
31
|
+
|
|
32
|
+
Neural Network from the paper: EEG motor imagery decoding:
|
|
33
|
+
A framework for comparative analysis with channel attention
|
|
34
|
+
mechanisms
|
|
35
|
+
|
|
36
|
+
The paper and original code with more details about the methodological
|
|
37
|
+
choices are available at the [Martin2023]_ and [MartinCode]_.
|
|
38
|
+
|
|
39
|
+
The AttentionBaseNet architecture is composed of four modules:
|
|
40
|
+
- Input Block that performs a temporal convolution and a spatial
|
|
41
|
+
convolution.
|
|
42
|
+
- Channel Expansion that modifies the number of channels.
|
|
43
|
+
- An attention block that performs channel attention with several
|
|
44
|
+
options
|
|
45
|
+
- ClassificationHead
|
|
46
|
+
|
|
47
|
+
.. versionadded:: 0.9
|
|
48
|
+
|
|
49
|
+
Parameters
|
|
50
|
+
----------
|
|
51
|
+
n_temporal_filters : int, optional
|
|
52
|
+
Number of temporal convolutional filters in the first layer. This defines
|
|
53
|
+
the number of output channels after the temporal convolution.
|
|
54
|
+
Default is 40.
|
|
55
|
+
temp_filter_length : int, default=15
|
|
56
|
+
The length of the temporal filters in the convolutional layers.
|
|
57
|
+
spatial_expansion : int, optional
|
|
58
|
+
Multiplicative factor to expand the spatial dimensions. Used to increase
|
|
59
|
+
the capacity of the model by expanding spatial features. Default is 1.
|
|
60
|
+
pool_length_inp : int, optional
|
|
61
|
+
Length of the pooling window in the input layer. Determines how much
|
|
62
|
+
temporal information is aggregated during pooling. Default is 75.
|
|
63
|
+
pool_stride_inp : int, optional
|
|
64
|
+
Stride of the pooling operation in the input layer. Controls the
|
|
65
|
+
downsampling factor in the temporal dimension. Default is 15.
|
|
66
|
+
drop_prob_inp : float, optional
|
|
67
|
+
Dropout rate applied after the input layer. This is the probability of
|
|
68
|
+
zeroing out elements during training to prevent overfitting.
|
|
69
|
+
Default is 0.5.
|
|
70
|
+
ch_dim : int, optional
|
|
71
|
+
Number of channels in the subsequent convolutional layers. This controls
|
|
72
|
+
the depth of the network after the initial layer. Default is 16.
|
|
73
|
+
attention_mode : str, optional
|
|
74
|
+
The type of attention mechanism to apply. If `None`, no attention is applied.
|
|
75
|
+
- "se" for Squeeze-and-excitation network
|
|
76
|
+
- "gsop" for Global Second-Order Pooling
|
|
77
|
+
- "fca" for Frequency Channel Attention Network
|
|
78
|
+
- "encnet" for context encoding module
|
|
79
|
+
- "eca" for Efficient channel attention for deep convolutional neural networks
|
|
80
|
+
- "ge" for Gather-Excite
|
|
81
|
+
- "gct" for Gated Channel Transformation
|
|
82
|
+
- "srm" for Style-based Recalibration Module
|
|
83
|
+
- "cbam" for Convolutional Block Attention Module
|
|
84
|
+
- "cat" for Learning to collaborate channel and temporal attention
|
|
85
|
+
from multi-information fusion
|
|
86
|
+
- "catlite" for Learning to collaborate channel attention
|
|
87
|
+
from multi-information fusion (lite version, cat w/o temporal attention)
|
|
88
|
+
pool_length : int, default=8
|
|
89
|
+
The length of the window for the average pooling operation.
|
|
90
|
+
pool_stride : int, default=8
|
|
91
|
+
The stride of the average pooling operation.
|
|
92
|
+
drop_prob_attn : float, default=0.5
|
|
93
|
+
The dropout rate for regularization for the attention layer. Values should be between 0 and 1.
|
|
94
|
+
reduction_rate : int, default=4
|
|
95
|
+
The reduction rate used in the attention mechanism to reduce dimensionality
|
|
96
|
+
and computational complexity.
|
|
97
|
+
use_mlp : bool, default=False
|
|
98
|
+
Flag to indicate whether an MLP (Multi-Layer Perceptron) should be used within
|
|
99
|
+
the attention mechanism for further processing.
|
|
100
|
+
freq_idx : int, default=0
|
|
101
|
+
DCT index used in fca attention mechanism.
|
|
102
|
+
n_codewords : int, default=4
|
|
103
|
+
The number of codewords (clusters) used in attention mechanisms that employ
|
|
104
|
+
quantization or clustering strategies.
|
|
105
|
+
kernel_size : int, default=9
|
|
106
|
+
The kernel size used in certain types of attention mechanisms for convolution
|
|
107
|
+
operations.
|
|
108
|
+
activation: nn.Module, default=nn.ELU
|
|
109
|
+
Activation function class to apply. Should be a PyTorch activation
|
|
110
|
+
module class like ``nn.ReLU`` or ``nn.ELU``. Default is ``nn.ELU``.
|
|
111
|
+
extra_params : bool, default=False
|
|
112
|
+
Flag to indicate whether additional, custom parameters should be passed to
|
|
113
|
+
the attention mechanism.
|
|
114
|
+
|
|
115
|
+
References
|
|
116
|
+
----------
|
|
117
|
+
.. [Martin2023] Wimpff, M., Gizzi, L., Zerfowski, J. and Yang, B., 2023.
|
|
118
|
+
EEG motor imagery decoding: A framework for comparative analysis with
|
|
119
|
+
channel attention mechanisms. arXiv preprint arXiv:2310.11198.
|
|
120
|
+
.. [MartinCode] Wimpff, M., Gizzi, L., Zerfowski, J. and Yang, B.
|
|
121
|
+
GitHub https://github.com/martinwimpff/channel-attention (accessed 2024-03-28)
|
|
122
|
+
"""
|
|
123
|
+
|
|
124
|
+
def __init__(
|
|
125
|
+
self,
|
|
126
|
+
n_times=None,
|
|
127
|
+
n_chans=None,
|
|
128
|
+
n_outputs=None,
|
|
129
|
+
chs_info=None,
|
|
130
|
+
sfreq=None,
|
|
131
|
+
input_window_seconds=None,
|
|
132
|
+
# Module parameters
|
|
133
|
+
n_temporal_filters: int = 40,
|
|
134
|
+
temp_filter_length_inp: int = 25,
|
|
135
|
+
spatial_expansion: int = 1,
|
|
136
|
+
pool_length_inp: int = 75,
|
|
137
|
+
pool_stride_inp: int = 15,
|
|
138
|
+
drop_prob_inp: float = 0.5,
|
|
139
|
+
ch_dim: int = 16,
|
|
140
|
+
temp_filter_length: int = 15,
|
|
141
|
+
pool_length: int = 8,
|
|
142
|
+
pool_stride: int = 8,
|
|
143
|
+
drop_prob_attn: float = 0.5,
|
|
144
|
+
attention_mode: str | None = None,
|
|
145
|
+
reduction_rate: int = 4,
|
|
146
|
+
use_mlp: bool = False,
|
|
147
|
+
freq_idx: int = 0,
|
|
148
|
+
n_codewords: int = 4,
|
|
149
|
+
kernel_size: int = 9,
|
|
150
|
+
activation: nn.Module = nn.ELU,
|
|
151
|
+
extra_params: bool = False,
|
|
152
|
+
):
|
|
153
|
+
super(AttentionBaseNet, self).__init__()
|
|
154
|
+
|
|
155
|
+
super().__init__(
|
|
156
|
+
n_outputs=n_outputs,
|
|
157
|
+
n_chans=n_chans,
|
|
158
|
+
chs_info=chs_info,
|
|
159
|
+
n_times=n_times,
|
|
160
|
+
sfreq=sfreq,
|
|
161
|
+
input_window_seconds=input_window_seconds,
|
|
162
|
+
)
|
|
163
|
+
del n_outputs, n_chans, chs_info, n_times, sfreq, input_window_seconds
|
|
164
|
+
|
|
165
|
+
self.input_block = _FeatureExtractor(
|
|
166
|
+
n_chans=self.n_chans,
|
|
167
|
+
n_temporal_filters=n_temporal_filters,
|
|
168
|
+
temporal_filter_length=temp_filter_length_inp,
|
|
169
|
+
spatial_expansion=spatial_expansion,
|
|
170
|
+
pool_length=pool_length_inp,
|
|
171
|
+
pool_stride=pool_stride_inp,
|
|
172
|
+
drop_prob=drop_prob_inp,
|
|
173
|
+
activation=activation,
|
|
174
|
+
)
|
|
175
|
+
|
|
176
|
+
self.channel_expansion = nn.Sequential(
|
|
177
|
+
nn.Conv2d(
|
|
178
|
+
n_temporal_filters * spatial_expansion, ch_dim, (1, 1), bias=False
|
|
179
|
+
),
|
|
180
|
+
nn.BatchNorm2d(ch_dim),
|
|
181
|
+
activation(),
|
|
182
|
+
)
|
|
183
|
+
|
|
184
|
+
seq_lengths = self._calculate_sequence_lengths(
|
|
185
|
+
self.n_times,
|
|
186
|
+
[temp_filter_length_inp, temp_filter_length],
|
|
187
|
+
[pool_length_inp, pool_length],
|
|
188
|
+
[pool_stride_inp, pool_stride],
|
|
189
|
+
)
|
|
190
|
+
|
|
191
|
+
self.channel_attention_block = _ChannelAttentionBlock(
|
|
192
|
+
attention_mode=attention_mode,
|
|
193
|
+
in_channels=ch_dim,
|
|
194
|
+
temp_filter_length=temp_filter_length,
|
|
195
|
+
pool_length=pool_length,
|
|
196
|
+
pool_stride=pool_stride,
|
|
197
|
+
drop_prob=drop_prob_attn,
|
|
198
|
+
reduction_rate=reduction_rate,
|
|
199
|
+
use_mlp=use_mlp,
|
|
200
|
+
seq_len=seq_lengths[0],
|
|
201
|
+
freq_idx=freq_idx,
|
|
202
|
+
n_codewords=n_codewords,
|
|
203
|
+
kernel_size=kernel_size,
|
|
204
|
+
extra_params=extra_params,
|
|
205
|
+
activation=activation,
|
|
206
|
+
)
|
|
207
|
+
|
|
208
|
+
self.final_layer = nn.Sequential(
|
|
209
|
+
nn.Flatten(), nn.Linear(seq_lengths[-1] * ch_dim, self.n_outputs)
|
|
210
|
+
)
|
|
211
|
+
|
|
212
|
+
def forward(self, x):
|
|
213
|
+
x = self.input_block(x)
|
|
214
|
+
x = self.channel_expansion(x)
|
|
215
|
+
x = self.channel_attention_block(x)
|
|
216
|
+
x = self.final_layer(x)
|
|
217
|
+
return x
|
|
218
|
+
|
|
219
|
+
@staticmethod
|
|
220
|
+
def _calculate_sequence_lengths(
|
|
221
|
+
input_window_samples: int,
|
|
222
|
+
kernel_lengths: list,
|
|
223
|
+
pool_lengths: list,
|
|
224
|
+
pool_strides: list,
|
|
225
|
+
):
|
|
226
|
+
seq_lengths = []
|
|
227
|
+
out = input_window_samples
|
|
228
|
+
for k, pl, ps in zip(kernel_lengths, pool_lengths, pool_strides):
|
|
229
|
+
out = math.floor(out + 2 * (k // 2) - k + 1)
|
|
230
|
+
out = math.floor((out - pl) / ps + 1)
|
|
231
|
+
seq_lengths.append(int(out))
|
|
232
|
+
return seq_lengths
|
|
233
|
+
|
|
234
|
+
|
|
235
|
+
class _FeatureExtractor(nn.Module):
|
|
236
|
+
"""
|
|
237
|
+
A module for feature extraction of the data with temporal and spatial
|
|
238
|
+
transformations.
|
|
239
|
+
|
|
240
|
+
This module sequentially processes the input through a series of layers:
|
|
241
|
+
rearrangement, temporal convolution, batch normalization, spatial convolution,
|
|
242
|
+
another batch normalization, an ELU non-linearity, average pooling, and dropout.
|
|
243
|
+
|
|
244
|
+
|
|
245
|
+
Parameters
|
|
246
|
+
----------
|
|
247
|
+
n_chans : int
|
|
248
|
+
The number of channels in the input data.
|
|
249
|
+
n_temporal_filters : int, optional
|
|
250
|
+
The number of filters to use in the temporal convolution layer. Default is 40.
|
|
251
|
+
temporal_filter_length : int, optional
|
|
252
|
+
The size of each filter in the temporal convolution layer. Default is 25.
|
|
253
|
+
spatial_expansion : int, optional
|
|
254
|
+
The expansion factor of the spatial convolution layer, determining the number
|
|
255
|
+
of output channels relative to the number of temporal filters. Default is 1.
|
|
256
|
+
pool_length : int, optional
|
|
257
|
+
The size of the window for the average pooling operation. Default is 75.
|
|
258
|
+
pool_stride : int, optional
|
|
259
|
+
The stride of the average pooling operation. Default is 15.
|
|
260
|
+
drop_prob : float, optional
|
|
261
|
+
The dropout rate for regularization. Default is 0.5.
|
|
262
|
+
activation: nn.Module, default=nn.ELU
|
|
263
|
+
Activation function class to apply. Should be a PyTorch activation
|
|
264
|
+
module class like ``nn.ReLU`` or ``nn.ELU``. Default is ``nn.ELU``.
|
|
265
|
+
"""
|
|
266
|
+
|
|
267
|
+
def __init__(
|
|
268
|
+
self,
|
|
269
|
+
n_chans: int,
|
|
270
|
+
n_temporal_filters: int = 40,
|
|
271
|
+
temporal_filter_length: int = 25,
|
|
272
|
+
spatial_expansion: int = 1,
|
|
273
|
+
pool_length: int = 75,
|
|
274
|
+
pool_stride: int = 15,
|
|
275
|
+
drop_prob: float = 0.5,
|
|
276
|
+
activation: nn.Module = nn.ELU,
|
|
277
|
+
):
|
|
278
|
+
super().__init__()
|
|
279
|
+
|
|
280
|
+
self.ensure4d = Ensure4d()
|
|
281
|
+
self.rearrange_input = Rearrange("b c t 1 -> b 1 c t")
|
|
282
|
+
self.temporal_conv = nn.Conv2d(
|
|
283
|
+
1,
|
|
284
|
+
n_temporal_filters,
|
|
285
|
+
kernel_size=(1, temporal_filter_length),
|
|
286
|
+
padding=(0, temporal_filter_length // 2),
|
|
287
|
+
bias=False,
|
|
288
|
+
)
|
|
289
|
+
self.intermediate_bn = nn.BatchNorm2d(n_temporal_filters)
|
|
290
|
+
self.spatial_conv = nn.Conv2d(
|
|
291
|
+
n_temporal_filters,
|
|
292
|
+
n_temporal_filters * spatial_expansion,
|
|
293
|
+
kernel_size=(n_chans, 1),
|
|
294
|
+
groups=n_temporal_filters,
|
|
295
|
+
bias=False,
|
|
296
|
+
)
|
|
297
|
+
self.bn = nn.BatchNorm2d(n_temporal_filters * spatial_expansion)
|
|
298
|
+
self.nonlinearity = activation()
|
|
299
|
+
self.pool = nn.AvgPool2d((1, pool_length), stride=(1, pool_stride))
|
|
300
|
+
self.dropout = nn.Dropout(drop_prob)
|
|
301
|
+
|
|
302
|
+
def forward(self, x):
|
|
303
|
+
x = self.ensure4d(x)
|
|
304
|
+
x = self.rearrange_input(x)
|
|
305
|
+
x = self.temporal_conv(x)
|
|
306
|
+
x = self.intermediate_bn(x)
|
|
307
|
+
x = self.spatial_conv(x)
|
|
308
|
+
x = self.bn(x)
|
|
309
|
+
x = self.nonlinearity(x)
|
|
310
|
+
x = self.pool(x)
|
|
311
|
+
x = self.dropout(x)
|
|
312
|
+
return x
|
|
313
|
+
|
|
314
|
+
|
|
315
|
+
class _ChannelAttentionBlock(nn.Module):
|
|
316
|
+
"""
|
|
317
|
+
A neural network module implementing channel-wise attention mechanisms to enhance
|
|
318
|
+
feature representations by selectively emphasizing important channels and suppressing
|
|
319
|
+
less useful ones. This block integrates convolutional layers, pooling, dropout, and
|
|
320
|
+
an optional attention mechanism that can be customized based on the given mode.
|
|
321
|
+
|
|
322
|
+
Parameters
|
|
323
|
+
----------
|
|
324
|
+
attention_mode : str, optional
|
|
325
|
+
The type of attention mechanism to apply. If `None`, no attention is applied.
|
|
326
|
+
- "se" for Squeeze-and-excitation network
|
|
327
|
+
- "gsop" for Global Second-Order Pooling
|
|
328
|
+
- "fca" for Frequency Channel Attention Network
|
|
329
|
+
- "encnet" for context encoding module
|
|
330
|
+
- "eca" for Efficient channel attention for deep convolutional neural networks
|
|
331
|
+
- "ge" for Gather-Excite
|
|
332
|
+
- "gct" for Gated Channel Transformation
|
|
333
|
+
- "srm" for Style-based Recalibration Module
|
|
334
|
+
- "cbam" for Convolutional Block Attention Module
|
|
335
|
+
- "cat" for Learning to collaborate channel and temporal attention
|
|
336
|
+
from multi-information fusion
|
|
337
|
+
- "catlite" for Learning to collaborate channel attention
|
|
338
|
+
from multi-information fusion (lite version, cat w/o temporal attention)
|
|
339
|
+
|
|
340
|
+
in_channels : int, default=16
|
|
341
|
+
The number of input channels to the block.
|
|
342
|
+
temp_filter_length : int, default=15
|
|
343
|
+
The length of the temporal filters in the convolutional layers.
|
|
344
|
+
pool_length : int, default=8
|
|
345
|
+
The length of the window for the average pooling operation.
|
|
346
|
+
pool_stride : int, default=8
|
|
347
|
+
The stride of the average pooling operation.
|
|
348
|
+
drop_prob : float, default=0.5
|
|
349
|
+
The dropout rate for regularization. Values should be between 0 and 1.
|
|
350
|
+
reduction_rate : int, default=4
|
|
351
|
+
The reduction rate used in the attention mechanism to reduce dimensionality
|
|
352
|
+
and computational complexity.
|
|
353
|
+
use_mlp : bool, default=False
|
|
354
|
+
Flag to indicate whether an MLP (Multi-Layer Perceptron) should be used within
|
|
355
|
+
the attention mechanism for further processing.
|
|
356
|
+
seq_len : int, default=62
|
|
357
|
+
The sequence length, used in certain types of attention mechanisms to process
|
|
358
|
+
temporal dimensions.
|
|
359
|
+
freq_idx : int, default=0
|
|
360
|
+
DCT index used in fca attention mechanism.
|
|
361
|
+
n_codewords : int, default=4
|
|
362
|
+
The number of codewords (clusters) used in attention mechanisms that employ
|
|
363
|
+
quantization or clustering strategies.
|
|
364
|
+
kernel_size : int, default=9
|
|
365
|
+
The kernel size used in certain types of attention mechanisms for convolution
|
|
366
|
+
operations.
|
|
367
|
+
extra_params : bool, default=False
|
|
368
|
+
Flag to indicate whether additional, custom parameters should be passed to
|
|
369
|
+
the attention mechanism.
|
|
370
|
+
activation: nn.Module, default=nn.ELU
|
|
371
|
+
Activation function class to apply. Should be a PyTorch activation
|
|
372
|
+
module class like ``nn.ReLU`` or ``nn.ELU``. Default is ``nn.ELU``.
|
|
373
|
+
|
|
374
|
+
Attributes
|
|
375
|
+
----------
|
|
376
|
+
conv : torch.nn.Sequential
|
|
377
|
+
Sequential model of convolutional layers, batch normalization, and ELU
|
|
378
|
+
activation, designed to process input features.
|
|
379
|
+
pool : torch.nn.AvgPool2d
|
|
380
|
+
Average pooling layer to reduce the dimensionality of the feature maps.
|
|
381
|
+
dropout : torch.nn.Dropout
|
|
382
|
+
Dropout layer for regularization.
|
|
383
|
+
attention_block : torch.nn.Module or None
|
|
384
|
+
The attention mechanism applied to the output of the convolutional layers,
|
|
385
|
+
if `attention_mode` is not None. Otherwise, it's set to None.
|
|
386
|
+
activation: nn.Module, default=nn.ELU
|
|
387
|
+
Activation function class to apply. Should be a PyTorch activation
|
|
388
|
+
module class like ``nn.ReLU`` or ``nn.ELU``. Default is ``nn.ELU``.
|
|
389
|
+
|
|
390
|
+
Examples
|
|
391
|
+
--------
|
|
392
|
+
>>> channel_attention_block = _ChannelAttentionBlock(attention_mode='cbam', in_channels=16, reduction_rate=4, kernel_size=7)
|
|
393
|
+
>>> x = torch.randn(1, 16, 64, 64) # Example input tensor
|
|
394
|
+
>>> output = channel_attention_block(x)
|
|
395
|
+
The output tensor then can be further processed or used as input to another block.
|
|
396
|
+
|
|
397
|
+
"""
|
|
398
|
+
|
|
399
|
+
def __init__(
|
|
400
|
+
self,
|
|
401
|
+
attention_mode: str | None = None,
|
|
402
|
+
in_channels: int = 16,
|
|
403
|
+
temp_filter_length: int = 15,
|
|
404
|
+
pool_length: int = 8,
|
|
405
|
+
pool_stride: int = 8,
|
|
406
|
+
drop_prob: float = 0.5,
|
|
407
|
+
reduction_rate: int = 4,
|
|
408
|
+
use_mlp: bool = False,
|
|
409
|
+
seq_len: int = 62,
|
|
410
|
+
freq_idx: int = 0,
|
|
411
|
+
n_codewords: int = 4,
|
|
412
|
+
kernel_size: int = 9,
|
|
413
|
+
extra_params: bool = False,
|
|
414
|
+
activation: nn.Module = nn.ELU,
|
|
415
|
+
):
|
|
416
|
+
super().__init__()
|
|
417
|
+
self.conv = nn.Sequential(
|
|
418
|
+
nn.Conv2d(
|
|
419
|
+
in_channels,
|
|
420
|
+
in_channels,
|
|
421
|
+
(1, temp_filter_length),
|
|
422
|
+
padding=(0, temp_filter_length // 2),
|
|
423
|
+
bias=False,
|
|
424
|
+
groups=in_channels,
|
|
425
|
+
),
|
|
426
|
+
nn.Conv2d(in_channels, in_channels, (1, 1), bias=False),
|
|
427
|
+
nn.BatchNorm2d(in_channels),
|
|
428
|
+
activation(),
|
|
429
|
+
)
|
|
430
|
+
|
|
431
|
+
self.pool = nn.AvgPool2d((1, pool_length), stride=(1, pool_stride))
|
|
432
|
+
self.dropout = nn.Dropout(drop_prob)
|
|
433
|
+
|
|
434
|
+
if attention_mode is not None:
|
|
435
|
+
self.attention_block = get_attention_block(
|
|
436
|
+
attention_mode,
|
|
437
|
+
ch_dim=in_channels,
|
|
438
|
+
reduction_rate=reduction_rate,
|
|
439
|
+
use_mlp=use_mlp,
|
|
440
|
+
seq_len=seq_len,
|
|
441
|
+
freq_idx=freq_idx,
|
|
442
|
+
n_codewords=n_codewords,
|
|
443
|
+
kernel_size=kernel_size,
|
|
444
|
+
extra_params=extra_params,
|
|
445
|
+
)
|
|
446
|
+
else:
|
|
447
|
+
self.attention_block = None
|
|
448
|
+
|
|
449
|
+
def forward(self, x):
|
|
450
|
+
out = self.conv(x)
|
|
451
|
+
if self.attention_block is not None:
|
|
452
|
+
out = self.attention_block(out)
|
|
453
|
+
out = self.pool(out)
|
|
454
|
+
out = self.dropout(out)
|
|
455
|
+
return out
|
|
456
|
+
|
|
457
|
+
|
|
458
|
+
def get_attention_block(
|
|
459
|
+
attention_mode: str,
|
|
460
|
+
ch_dim: int = 16,
|
|
461
|
+
reduction_rate: int = 4,
|
|
462
|
+
use_mlp: bool = False,
|
|
463
|
+
seq_len: int | None = None,
|
|
464
|
+
freq_idx: int = 0,
|
|
465
|
+
n_codewords: int = 4,
|
|
466
|
+
kernel_size: int = 9,
|
|
467
|
+
extra_params: bool = False,
|
|
468
|
+
):
|
|
469
|
+
"""
|
|
470
|
+
Util function to the attention block based on the attention mode.
|
|
471
|
+
|
|
472
|
+
Parameters
|
|
473
|
+
----------
|
|
474
|
+
attention_mode: str
|
|
475
|
+
The type of attention mechanism to apply.
|
|
476
|
+
ch_dim: int
|
|
477
|
+
The number of input channels to the block.
|
|
478
|
+
reduction_rate: int
|
|
479
|
+
The reduction rate used in the attention mechanism to reduce
|
|
480
|
+
dimensionality and computational complexity.
|
|
481
|
+
Used in all the methods, except for the
|
|
482
|
+
encnet and eca.
|
|
483
|
+
use_mlp: bool
|
|
484
|
+
Flag to indicate whether an MLP (Multi-Layer Perceptron) should be used
|
|
485
|
+
within the attention mechanism for further processing. Used in the ge
|
|
486
|
+
and srm attention mechanism.
|
|
487
|
+
seq_len: int
|
|
488
|
+
The sequence length, used in certain types of attention mechanisms to
|
|
489
|
+
process temporal dimensions. Used in the ge or fca attention mechanism.
|
|
490
|
+
freq_idx: int
|
|
491
|
+
DCT index used in fca attention mechanism.
|
|
492
|
+
n_codewords: int
|
|
493
|
+
The number of codewords (clusters) used in attention mechanisms
|
|
494
|
+
that employ quantization or clustering strategies, encnet.
|
|
495
|
+
kernel_size: int
|
|
496
|
+
The kernel size used in certain types of attention mechanisms for convolution
|
|
497
|
+
operations, used in the cbam, eca, and cat attention mechanisms.
|
|
498
|
+
extra_params: bool
|
|
499
|
+
Parameter to pass additional parameters to the GatherExcite mechanism.
|
|
500
|
+
|
|
501
|
+
Returns
|
|
502
|
+
-------
|
|
503
|
+
nn.Module
|
|
504
|
+
The attention block based on the attention mode.
|
|
505
|
+
"""
|
|
506
|
+
if attention_mode == "se":
|
|
507
|
+
return SqueezeAndExcitation(in_channels=ch_dim, reduction_rate=reduction_rate)
|
|
508
|
+
# improving the squeeze module
|
|
509
|
+
elif attention_mode == "gsop":
|
|
510
|
+
return GSoP(in_channels=ch_dim, reduction_rate=reduction_rate)
|
|
511
|
+
elif attention_mode == "fca":
|
|
512
|
+
assert seq_len is not None
|
|
513
|
+
return FCA(
|
|
514
|
+
in_channels=ch_dim,
|
|
515
|
+
seq_len=seq_len,
|
|
516
|
+
reduction_rate=reduction_rate,
|
|
517
|
+
freq_idx=freq_idx,
|
|
518
|
+
)
|
|
519
|
+
elif attention_mode == "encnet":
|
|
520
|
+
return EncNet(in_channels=ch_dim, n_codewords=n_codewords)
|
|
521
|
+
# improving the excitation module
|
|
522
|
+
elif attention_mode == "eca":
|
|
523
|
+
return ECA(in_channels=ch_dim, kernel_size=kernel_size)
|
|
524
|
+
# improving the squeeze and the excitation module
|
|
525
|
+
elif attention_mode == "ge":
|
|
526
|
+
assert seq_len is not None
|
|
527
|
+
return GatherExcite(
|
|
528
|
+
in_channels=ch_dim,
|
|
529
|
+
seq_len=seq_len,
|
|
530
|
+
extra_params=extra_params,
|
|
531
|
+
use_mlp=use_mlp,
|
|
532
|
+
reduction_rate=reduction_rate,
|
|
533
|
+
)
|
|
534
|
+
elif attention_mode == "gct":
|
|
535
|
+
return GCT(in_channels=ch_dim)
|
|
536
|
+
elif attention_mode == "srm":
|
|
537
|
+
return SRM(in_channels=ch_dim, use_mlp=use_mlp, reduction_rate=reduction_rate)
|
|
538
|
+
# temporal and channel attention
|
|
539
|
+
elif attention_mode == "cbam":
|
|
540
|
+
return CBAM(
|
|
541
|
+
in_channels=ch_dim, reduction_rate=reduction_rate, kernel_size=kernel_size
|
|
542
|
+
)
|
|
543
|
+
elif attention_mode == "cat":
|
|
544
|
+
return CAT(
|
|
545
|
+
in_channels=ch_dim, reduction_rate=reduction_rate, kernel_size=kernel_size
|
|
546
|
+
)
|
|
547
|
+
elif attention_mode == "catlite":
|
|
548
|
+
return CATLite(ch_dim, reduction_rate=reduction_rate)
|
|
549
|
+
else:
|
|
550
|
+
raise NotImplementedError
|