braindecode 0.8__py3-none-any.whl → 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of braindecode might be problematic. Click here for more details.
- braindecode/__init__.py +1 -2
- braindecode/augmentation/__init__.py +50 -0
- braindecode/augmentation/base.py +222 -0
- braindecode/augmentation/functional.py +1096 -0
- braindecode/augmentation/transforms.py +1274 -0
- braindecode/classifier.py +26 -24
- braindecode/datasets/__init__.py +34 -0
- braindecode/datasets/base.py +840 -0
- braindecode/datasets/bbci.py +694 -0
- braindecode/datasets/bcicomp.py +194 -0
- braindecode/datasets/bids.py +245 -0
- braindecode/datasets/mne.py +172 -0
- braindecode/datasets/moabb.py +209 -0
- braindecode/datasets/nmt.py +311 -0
- braindecode/datasets/sleep_physio_challe_18.py +412 -0
- braindecode/datasets/sleep_physionet.py +125 -0
- braindecode/datasets/tuh.py +588 -0
- braindecode/datasets/xy.py +95 -0
- braindecode/datautil/__init__.py +49 -0
- braindecode/datautil/serialization.py +342 -0
- braindecode/datautil/util.py +41 -0
- braindecode/eegneuralnet.py +63 -47
- braindecode/functional/__init__.py +10 -0
- braindecode/functional/functions.py +251 -0
- braindecode/functional/initialization.py +47 -0
- braindecode/models/__init__.py +52 -0
- braindecode/models/atcnet.py +652 -0
- braindecode/models/attentionbasenet.py +550 -0
- braindecode/models/base.py +296 -0
- braindecode/models/biot.py +483 -0
- braindecode/models/contrawr.py +296 -0
- braindecode/models/ctnet.py +450 -0
- braindecode/models/deep4.py +322 -0
- braindecode/models/deepsleepnet.py +295 -0
- braindecode/models/eegconformer.py +372 -0
- braindecode/models/eeginception_erp.py +304 -0
- braindecode/models/eeginception_mi.py +371 -0
- braindecode/models/eegitnet.py +301 -0
- braindecode/models/eegminer.py +255 -0
- braindecode/models/eegnet.py +473 -0
- braindecode/models/eegnex.py +247 -0
- braindecode/models/eegresnet.py +362 -0
- braindecode/models/eegsimpleconv.py +199 -0
- braindecode/models/eegtcnet.py +335 -0
- braindecode/models/fbcnet.py +221 -0
- braindecode/models/fblightconvnet.py +313 -0
- braindecode/models/fbmsnet.py +325 -0
- braindecode/models/hybrid.py +126 -0
- braindecode/models/ifnet.py +441 -0
- braindecode/models/labram.py +1166 -0
- braindecode/models/msvtnet.py +375 -0
- braindecode/models/sccnet.py +182 -0
- braindecode/models/shallow_fbcsp.py +208 -0
- braindecode/models/signal_jepa.py +1012 -0
- braindecode/models/sinc_shallow.py +337 -0
- braindecode/models/sleep_stager_blanco_2020.py +167 -0
- braindecode/models/sleep_stager_chambon_2018.py +157 -0
- braindecode/models/sleep_stager_eldele_2021.py +536 -0
- braindecode/models/sparcnet.py +378 -0
- braindecode/models/summary.csv +41 -0
- braindecode/models/syncnet.py +232 -0
- braindecode/models/tcn.py +273 -0
- braindecode/models/tidnet.py +395 -0
- braindecode/models/tsinception.py +258 -0
- braindecode/models/usleep.py +340 -0
- braindecode/models/util.py +133 -0
- braindecode/modules/__init__.py +38 -0
- braindecode/modules/activation.py +60 -0
- braindecode/modules/attention.py +757 -0
- braindecode/modules/blocks.py +108 -0
- braindecode/modules/convolution.py +274 -0
- braindecode/modules/filter.py +632 -0
- braindecode/modules/layers.py +133 -0
- braindecode/modules/linear.py +50 -0
- braindecode/modules/parametrization.py +38 -0
- braindecode/modules/stats.py +77 -0
- braindecode/modules/util.py +77 -0
- braindecode/modules/wrapper.py +75 -0
- braindecode/preprocessing/__init__.py +37 -0
- braindecode/preprocessing/mne_preprocess.py +77 -0
- braindecode/preprocessing/preprocess.py +478 -0
- braindecode/preprocessing/windowers.py +1031 -0
- braindecode/regressor.py +23 -12
- braindecode/samplers/__init__.py +18 -0
- braindecode/samplers/base.py +401 -0
- braindecode/samplers/ssl.py +263 -0
- braindecode/training/__init__.py +23 -0
- braindecode/training/callbacks.py +23 -0
- braindecode/training/losses.py +105 -0
- braindecode/training/scoring.py +483 -0
- braindecode/util.py +55 -59
- braindecode/version.py +1 -1
- braindecode/visualization/__init__.py +8 -0
- braindecode/visualization/confusion_matrices.py +289 -0
- braindecode/visualization/gradients.py +57 -0
- {braindecode-0.8.dist-info → braindecode-1.0.0.dist-info}/METADATA +39 -55
- braindecode-1.0.0.dist-info/RECORD +101 -0
- {braindecode-0.8.dist-info → braindecode-1.0.0.dist-info}/WHEEL +1 -1
- {braindecode-0.8.dist-info → braindecode-1.0.0.dist-info/licenses}/LICENSE.txt +1 -1
- braindecode-1.0.0.dist-info/licenses/NOTICE.txt +20 -0
- braindecode-0.8.dist-info/RECORD +0 -11
- {braindecode-0.8.dist-info → braindecode-1.0.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,652 @@
|
|
|
1
|
+
# Authors: Cedric Rommel <cedric.rommel@inria.fr>
|
|
2
|
+
#
|
|
3
|
+
# License: BSD (3-clause)
|
|
4
|
+
import math
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
from einops.layers.torch import Rearrange
|
|
8
|
+
from torch import nn
|
|
9
|
+
|
|
10
|
+
from braindecode.models.base import EEGModuleMixin
|
|
11
|
+
from braindecode.modules import CausalConv1d, Ensure4d, MaxNormLinear
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class ATCNet(EEGModuleMixin, nn.Module):
|
|
15
|
+
"""ATCNet model from Altaheri et al. (2022) [1]_
|
|
16
|
+
|
|
17
|
+
Pytorch implementation based on official tensorflow code [2]_.
|
|
18
|
+
|
|
19
|
+
.. figure:: https://user-images.githubusercontent.com/25565236/185449791-e8539453-d4fa-41e1-865a-2cf7e91f60ef.png
|
|
20
|
+
:align: center
|
|
21
|
+
:alt: ATCNet Architecture
|
|
22
|
+
|
|
23
|
+
Parameters
|
|
24
|
+
----------
|
|
25
|
+
input_window_seconds : float, optional
|
|
26
|
+
Time length of inputs, in seconds. Defaults to 4.5 s, as in BCI-IV 2a
|
|
27
|
+
dataset.
|
|
28
|
+
sfreq : int, optional
|
|
29
|
+
Sampling frequency of the inputs, in Hz. Default to 250 Hz, as in
|
|
30
|
+
BCI-IV 2a dataset.
|
|
31
|
+
conv_block_n_filters : int
|
|
32
|
+
Number temporal filters in the first convolutional layer of the
|
|
33
|
+
convolutional block, denoted F1 in figure 2 of the paper [1]_. Defaults
|
|
34
|
+
to 16 as in [1]_.
|
|
35
|
+
conv_block_kernel_length_1 : int
|
|
36
|
+
Length of temporal filters in the first convolutional layer of the
|
|
37
|
+
convolutional block, denoted Kc in table 1 of the paper [1]_. Defaults
|
|
38
|
+
to 64 as in [1]_.
|
|
39
|
+
conv_block_kernel_length_2 : int
|
|
40
|
+
Length of temporal filters in the last convolutional layer of the
|
|
41
|
+
convolutional block. Defaults to 16 as in [1]_.
|
|
42
|
+
conv_block_pool_size_1 : int
|
|
43
|
+
Length of first average pooling kernel in the convolutional block.
|
|
44
|
+
Defaults to 8 as in [1]_.
|
|
45
|
+
conv_block_pool_size_2 : int
|
|
46
|
+
Length of first average pooling kernel in the convolutional block,
|
|
47
|
+
denoted P2 in table 1 of the paper [1]_. Defaults to 7 as in [1]_.
|
|
48
|
+
conv_block_depth_mult : int
|
|
49
|
+
Depth multiplier of depthwise convolution in the convolutional block,
|
|
50
|
+
denoted D in table 1 of the paper [1]_. Defaults to 2 as in [1]_.
|
|
51
|
+
conv_block_dropout : float
|
|
52
|
+
Dropout probability used in the convolution block, denoted pc in
|
|
53
|
+
table 1 of the paper [1]_. Defaults to 0.3 as in [1]_.
|
|
54
|
+
n_windows : int
|
|
55
|
+
Number of sliding windows, denoted n in [1]_. Defaults to 5 as in [1]_.
|
|
56
|
+
att_head_dim : int
|
|
57
|
+
Embedding dimension used in each self-attention head, denoted dh in
|
|
58
|
+
table 1 of the paper [1]_. Defaults to 8 as in [1]_.
|
|
59
|
+
att_num_heads : int
|
|
60
|
+
Number of attention heads, denoted H in table 1 of the paper [1]_.
|
|
61
|
+
Defaults to 2 as in [1]_.
|
|
62
|
+
att_dropout : float
|
|
63
|
+
Dropout probability used in the attention block, denoted pa in table 1
|
|
64
|
+
of the paper [1]_. Defaults to 0.5 as in [1]_.
|
|
65
|
+
tcn_depth : int
|
|
66
|
+
Depth of Temporal Convolutional Network block (i.e. number of TCN
|
|
67
|
+
Residual blocks), denoted L in table 1 of the paper [1]_. Defaults to 2
|
|
68
|
+
as in [1]_.
|
|
69
|
+
tcn_kernel_size : int
|
|
70
|
+
Temporal kernel size used in TCN block, denoted Kt in table 1 of the
|
|
71
|
+
paper [1]_. Defaults to 4 as in [1]_.
|
|
72
|
+
tcn_n_filters : int
|
|
73
|
+
Number of filters used in TCN convolutional layers (Ft). Defaults to
|
|
74
|
+
32 as in [1]_.
|
|
75
|
+
tcn_dropout : float
|
|
76
|
+
Dropout probability used in the TCN block, denoted pt in table 1
|
|
77
|
+
of the paper [1]_. Defaults to 0.3 as in [1]_.
|
|
78
|
+
tcn_activation : torch.nn.Module
|
|
79
|
+
Nonlinear activation to use. Defaults to nn.ELU().
|
|
80
|
+
concat : bool
|
|
81
|
+
When ``True``, concatenates each slidding window embedding before
|
|
82
|
+
feeding it to a fully-connected layer, as done in [1]_. When ``False``,
|
|
83
|
+
maps each slidding window to `n_outputs` logits and average them.
|
|
84
|
+
Defaults to ``False`` contrary to what is reported in [1]_, but
|
|
85
|
+
matching what the official code does [2]_.
|
|
86
|
+
max_norm_const : float
|
|
87
|
+
Maximum L2-norm constraint imposed on weights of the last
|
|
88
|
+
fully-connected layer. Defaults to 0.25.
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
References
|
|
92
|
+
----------
|
|
93
|
+
.. [1] H. Altaheri, G. Muhammad and M. Alsulaiman,
|
|
94
|
+
Physics-informed attention temporal convolutional network for EEG-based
|
|
95
|
+
motor imagery classification in IEEE Transactions on Industrial Informatics,
|
|
96
|
+
2022, doi: 10.1109/TII.2022.3197419.
|
|
97
|
+
.. [2] EEE-ATCNet implementation.
|
|
98
|
+
https://github.com/Altaheri/EEG-ATCNet/blob/main/models.py
|
|
99
|
+
"""
|
|
100
|
+
|
|
101
|
+
def __init__(
|
|
102
|
+
self,
|
|
103
|
+
n_chans=None,
|
|
104
|
+
n_outputs=None,
|
|
105
|
+
input_window_seconds=None,
|
|
106
|
+
sfreq=250.0,
|
|
107
|
+
conv_block_n_filters=16,
|
|
108
|
+
conv_block_kernel_length_1=64,
|
|
109
|
+
conv_block_kernel_length_2=16,
|
|
110
|
+
conv_block_pool_size_1=8,
|
|
111
|
+
conv_block_pool_size_2=7,
|
|
112
|
+
conv_block_depth_mult=2,
|
|
113
|
+
conv_block_dropout=0.3,
|
|
114
|
+
n_windows=5,
|
|
115
|
+
att_head_dim=8,
|
|
116
|
+
att_num_heads=2,
|
|
117
|
+
att_drop_prob=0.5,
|
|
118
|
+
tcn_depth=2,
|
|
119
|
+
tcn_kernel_size=4,
|
|
120
|
+
tcn_n_filters=32,
|
|
121
|
+
tcn_drop_prob=0.3,
|
|
122
|
+
tcn_activation: nn.Module = nn.ELU,
|
|
123
|
+
concat=False,
|
|
124
|
+
max_norm_const=0.25,
|
|
125
|
+
chs_info=None,
|
|
126
|
+
n_times=None,
|
|
127
|
+
):
|
|
128
|
+
super().__init__(
|
|
129
|
+
n_outputs=n_outputs,
|
|
130
|
+
n_chans=n_chans,
|
|
131
|
+
chs_info=chs_info,
|
|
132
|
+
n_times=n_times,
|
|
133
|
+
input_window_seconds=input_window_seconds,
|
|
134
|
+
sfreq=sfreq,
|
|
135
|
+
)
|
|
136
|
+
del n_outputs, n_chans, chs_info, n_times, input_window_seconds, sfreq
|
|
137
|
+
self.conv_block_n_filters = conv_block_n_filters
|
|
138
|
+
self.conv_block_kernel_length_1 = conv_block_kernel_length_1
|
|
139
|
+
self.conv_block_kernel_length_2 = conv_block_kernel_length_2
|
|
140
|
+
self.conv_block_pool_size_1 = conv_block_pool_size_1
|
|
141
|
+
self.conv_block_pool_size_2 = conv_block_pool_size_2
|
|
142
|
+
self.conv_block_depth_mult = conv_block_depth_mult
|
|
143
|
+
self.conv_block_dropout = conv_block_dropout
|
|
144
|
+
self.n_windows = n_windows
|
|
145
|
+
self.att_head_dim = att_head_dim
|
|
146
|
+
self.att_num_heads = att_num_heads
|
|
147
|
+
self.att_dropout = att_drop_prob
|
|
148
|
+
self.tcn_depth = tcn_depth
|
|
149
|
+
self.tcn_kernel_size = tcn_kernel_size
|
|
150
|
+
self.tcn_n_filters = tcn_n_filters
|
|
151
|
+
self.tcn_dropout = tcn_drop_prob
|
|
152
|
+
self.tcn_activation = tcn_activation
|
|
153
|
+
self.concat = concat
|
|
154
|
+
self.max_norm_const = max_norm_const
|
|
155
|
+
|
|
156
|
+
map = dict()
|
|
157
|
+
for w in range(self.n_windows):
|
|
158
|
+
map[f"max_norm_linears.[{w}].weight"] = f"final_layer.[{w}].weight"
|
|
159
|
+
map[f"max_norm_linears.[{w}].bias"] = f"final_layer.[{w}].bias"
|
|
160
|
+
self.mapping = map
|
|
161
|
+
|
|
162
|
+
# Check later if we want to keep the Ensure4d. Not sure if we can
|
|
163
|
+
# remove it or replace it with eipsum.
|
|
164
|
+
self.ensuredims = Ensure4d()
|
|
165
|
+
self.dimshuffle = Rearrange("batch C T 1 -> batch 1 T C")
|
|
166
|
+
|
|
167
|
+
self.conv_block = _ConvBlock(
|
|
168
|
+
n_channels=self.n_chans, # input shape: (batch_size, 1, T, C)
|
|
169
|
+
n_filters=conv_block_n_filters,
|
|
170
|
+
kernel_length_1=conv_block_kernel_length_1,
|
|
171
|
+
kernel_length_2=conv_block_kernel_length_2,
|
|
172
|
+
pool_size_1=conv_block_pool_size_1,
|
|
173
|
+
pool_size_2=conv_block_pool_size_2,
|
|
174
|
+
depth_mult=conv_block_depth_mult,
|
|
175
|
+
dropout=conv_block_dropout,
|
|
176
|
+
)
|
|
177
|
+
|
|
178
|
+
self.F2 = int(conv_block_depth_mult * conv_block_n_filters)
|
|
179
|
+
self.Tc = int(self.n_times / (conv_block_pool_size_1 * conv_block_pool_size_2))
|
|
180
|
+
self.Tw = self.Tc - self.n_windows + 1
|
|
181
|
+
|
|
182
|
+
self.attention_blocks = nn.ModuleList(
|
|
183
|
+
[
|
|
184
|
+
_AttentionBlock(
|
|
185
|
+
in_shape=self.F2,
|
|
186
|
+
head_dim=self.att_head_dim,
|
|
187
|
+
num_heads=att_num_heads,
|
|
188
|
+
dropout=att_drop_prob,
|
|
189
|
+
)
|
|
190
|
+
for _ in range(self.n_windows)
|
|
191
|
+
]
|
|
192
|
+
)
|
|
193
|
+
|
|
194
|
+
self.temporal_conv_nets = nn.ModuleList(
|
|
195
|
+
[
|
|
196
|
+
nn.Sequential(
|
|
197
|
+
*[
|
|
198
|
+
_TCNResidualBlock(
|
|
199
|
+
in_channels=self.F2,
|
|
200
|
+
kernel_size=tcn_kernel_size,
|
|
201
|
+
n_filters=tcn_n_filters,
|
|
202
|
+
dropout=tcn_drop_prob,
|
|
203
|
+
activation=tcn_activation,
|
|
204
|
+
dilation=2**i,
|
|
205
|
+
)
|
|
206
|
+
for i in range(tcn_depth)
|
|
207
|
+
]
|
|
208
|
+
)
|
|
209
|
+
for _ in range(self.n_windows)
|
|
210
|
+
]
|
|
211
|
+
)
|
|
212
|
+
|
|
213
|
+
if self.concat:
|
|
214
|
+
self.final_layer = nn.ModuleList(
|
|
215
|
+
[
|
|
216
|
+
MaxNormLinear(
|
|
217
|
+
in_features=self.F2 * self.n_windows,
|
|
218
|
+
out_features=self.n_outputs,
|
|
219
|
+
max_norm_val=self.max_norm_const,
|
|
220
|
+
)
|
|
221
|
+
]
|
|
222
|
+
)
|
|
223
|
+
else:
|
|
224
|
+
self.final_layer = nn.ModuleList(
|
|
225
|
+
[
|
|
226
|
+
MaxNormLinear(
|
|
227
|
+
in_features=self.F2,
|
|
228
|
+
out_features=self.n_outputs,
|
|
229
|
+
max_norm_val=self.max_norm_const,
|
|
230
|
+
)
|
|
231
|
+
for _ in range(self.n_windows)
|
|
232
|
+
]
|
|
233
|
+
)
|
|
234
|
+
|
|
235
|
+
self.out_fun = nn.Identity()
|
|
236
|
+
|
|
237
|
+
def forward(self, X):
|
|
238
|
+
# Dimension: (batch_size, C, T)
|
|
239
|
+
X = self.ensuredims(X)
|
|
240
|
+
# Dimension: (batch_size, C, T, 1)
|
|
241
|
+
X = self.dimshuffle(X)
|
|
242
|
+
# Dimension: (batch_size, 1, T, C)
|
|
243
|
+
|
|
244
|
+
# ----- Sliding window -----
|
|
245
|
+
conv_feat = self.conv_block(X)
|
|
246
|
+
# Dimension: (batch_size, F2, Tc, 1)
|
|
247
|
+
conv_feat = conv_feat.view(-1, self.F2, self.Tc)
|
|
248
|
+
# Dimension: (batch_size, F2, Tc)
|
|
249
|
+
|
|
250
|
+
# ----- Sliding window -----
|
|
251
|
+
sw_concat: list[torch.Tensor] = [] # to store sliding window outputs
|
|
252
|
+
# for w in range(self.n_windows):
|
|
253
|
+
for idx, (attention, tcn_module, final_layer) in enumerate(
|
|
254
|
+
zip(self.attention_blocks, self.temporal_conv_nets, self.final_layer)
|
|
255
|
+
):
|
|
256
|
+
conv_feat_w = conv_feat[..., idx : idx + self.Tw]
|
|
257
|
+
# Dimension: (batch_size, F2, Tw)
|
|
258
|
+
|
|
259
|
+
# ----- Attention block -----
|
|
260
|
+
att_feat = attention(conv_feat_w)
|
|
261
|
+
# Dimension: (batch_size, F2, Tw)
|
|
262
|
+
|
|
263
|
+
# ----- Temporal convolutional network (TCN) -----
|
|
264
|
+
tcn_feat = tcn_module(att_feat)[..., -1]
|
|
265
|
+
# Dimension: (batch_size, F2)
|
|
266
|
+
|
|
267
|
+
# Outputs of sliding window can be either averaged after being
|
|
268
|
+
# mapped by dense layer or concatenated then mapped by a dense
|
|
269
|
+
# layer
|
|
270
|
+
if not self.concat:
|
|
271
|
+
tcn_feat = final_layer(tcn_feat)
|
|
272
|
+
|
|
273
|
+
sw_concat.append(tcn_feat)
|
|
274
|
+
|
|
275
|
+
# ----- Aggregation and prediction -----
|
|
276
|
+
if self.concat:
|
|
277
|
+
sw_concat_agg = torch.cat(sw_concat, dim=1)
|
|
278
|
+
sw_concat_agg = self.final_layer[0](sw_concat_agg)
|
|
279
|
+
else:
|
|
280
|
+
if len(sw_concat) > 1: # more than one window
|
|
281
|
+
sw_concat_agg = torch.stack(sw_concat, dim=0)
|
|
282
|
+
sw_concat_agg = torch.mean(sw_concat_agg, dim=0)
|
|
283
|
+
else: # one window (# windows = 1)
|
|
284
|
+
sw_concat_agg = sw_concat[0]
|
|
285
|
+
|
|
286
|
+
return self.out_fun(sw_concat_agg)
|
|
287
|
+
|
|
288
|
+
|
|
289
|
+
class _ConvBlock(nn.Module):
|
|
290
|
+
"""Convolutional block proposed in ATCNet [1]_, inspired by the EEGNet
|
|
291
|
+
architecture [2]_.
|
|
292
|
+
|
|
293
|
+
References
|
|
294
|
+
----------
|
|
295
|
+
.. [1] H. Altaheri, G. Muhammad and M. Alsulaiman, "Physics-informed
|
|
296
|
+
attention temporal convolutional network for EEG-based motor imagery
|
|
297
|
+
classification," in IEEE Transactions on Industrial Informatics,
|
|
298
|
+
2022, doi: 10.1109/TII.2022.3197419.
|
|
299
|
+
.. [2] Lawhern, V. J., Solon, A. J., Waytowich, N. R., Gordon,
|
|
300
|
+
S. M., Hung, C. P., & Lance, B. J. (2018).
|
|
301
|
+
EEGNet: A Compact Convolutional Network for EEG-based
|
|
302
|
+
Brain-Computer Interfaces.
|
|
303
|
+
arXiv preprint arXiv:1611.08024.
|
|
304
|
+
"""
|
|
305
|
+
|
|
306
|
+
def __init__(
|
|
307
|
+
self,
|
|
308
|
+
n_channels,
|
|
309
|
+
n_filters=16,
|
|
310
|
+
kernel_length_1=64,
|
|
311
|
+
kernel_length_2=16,
|
|
312
|
+
pool_size_1=8,
|
|
313
|
+
pool_size_2=7,
|
|
314
|
+
depth_mult=2,
|
|
315
|
+
dropout=0.3,
|
|
316
|
+
):
|
|
317
|
+
super().__init__()
|
|
318
|
+
|
|
319
|
+
self.conv1 = nn.Conv2d(
|
|
320
|
+
in_channels=1,
|
|
321
|
+
out_channels=n_filters,
|
|
322
|
+
kernel_size=(kernel_length_1, 1),
|
|
323
|
+
padding="same",
|
|
324
|
+
bias=False,
|
|
325
|
+
)
|
|
326
|
+
|
|
327
|
+
self.bn1 = nn.BatchNorm2d(num_features=n_filters, eps=1e-4)
|
|
328
|
+
|
|
329
|
+
n_depth_kernels = n_filters * depth_mult
|
|
330
|
+
self.conv2 = nn.Conv2d(
|
|
331
|
+
in_channels=n_filters,
|
|
332
|
+
out_channels=n_depth_kernels,
|
|
333
|
+
groups=n_filters,
|
|
334
|
+
kernel_size=(1, n_channels),
|
|
335
|
+
padding="valid",
|
|
336
|
+
bias=False,
|
|
337
|
+
)
|
|
338
|
+
|
|
339
|
+
self.bn2 = nn.BatchNorm2d(num_features=n_depth_kernels, eps=1e-4)
|
|
340
|
+
|
|
341
|
+
self.activation2 = nn.ELU()
|
|
342
|
+
|
|
343
|
+
self.pool2 = nn.AvgPool2d(kernel_size=(pool_size_1, 1))
|
|
344
|
+
|
|
345
|
+
self.drop2 = nn.Dropout2d(dropout)
|
|
346
|
+
|
|
347
|
+
self.conv3 = nn.Conv2d(
|
|
348
|
+
in_channels=n_depth_kernels,
|
|
349
|
+
out_channels=n_depth_kernels,
|
|
350
|
+
kernel_size=(kernel_length_2, 1),
|
|
351
|
+
padding="same",
|
|
352
|
+
bias=False,
|
|
353
|
+
)
|
|
354
|
+
|
|
355
|
+
self.bn3 = nn.BatchNorm2d(num_features=n_depth_kernels, eps=1e-4)
|
|
356
|
+
|
|
357
|
+
self.activation3 = nn.ELU()
|
|
358
|
+
|
|
359
|
+
self.pool3 = nn.AvgPool2d(kernel_size=(pool_size_2, 1))
|
|
360
|
+
|
|
361
|
+
self.drop3 = nn.Dropout2d(dropout)
|
|
362
|
+
|
|
363
|
+
def forward(self, X):
|
|
364
|
+
# ----- Temporal convolution -----
|
|
365
|
+
# Dimension: (batch_size, 1, T, C)
|
|
366
|
+
X = self.conv1(X)
|
|
367
|
+
X = self.bn1(X)
|
|
368
|
+
# Dimension: (batch_size, F1, T, C)
|
|
369
|
+
|
|
370
|
+
# ----- Depthwise channels convolution -----
|
|
371
|
+
X = self.conv2(X)
|
|
372
|
+
X = self.bn2(X)
|
|
373
|
+
X = self.activation2(X)
|
|
374
|
+
# Dimension: (batch_size, F1*D, T, 1)
|
|
375
|
+
X = self.pool2(X)
|
|
376
|
+
X = self.drop2(X)
|
|
377
|
+
# Dimension: (batch_size, F1*D, T/P1, 1)
|
|
378
|
+
|
|
379
|
+
# ----- "Spatial" convolution -----
|
|
380
|
+
X = self.conv3(X)
|
|
381
|
+
X = self.bn3(X)
|
|
382
|
+
X = self.activation3(X)
|
|
383
|
+
# Dimension: (batch_size, F1*D, T/P1, 1)
|
|
384
|
+
X = self.pool3(X)
|
|
385
|
+
X = self.drop3(X)
|
|
386
|
+
# Dimension: (batch_size, F1*D, T/(P1*P2), 1)
|
|
387
|
+
|
|
388
|
+
return X
|
|
389
|
+
|
|
390
|
+
|
|
391
|
+
class _AttentionBlock(nn.Module):
|
|
392
|
+
"""Multi Head self Attention (MHA) block used in ATCNet [1]_, inspired from
|
|
393
|
+
[2]_.
|
|
394
|
+
|
|
395
|
+
References
|
|
396
|
+
----------
|
|
397
|
+
.. [1] H. Altaheri, G. Muhammad and M. Alsulaiman, "Physics-informed
|
|
398
|
+
attention temporal convolutional network for EEG-based motor imagery
|
|
399
|
+
classification," in IEEE Transactions on Industrial Informatics,
|
|
400
|
+
2022, doi: 10.1109/TII.2022.3197419.
|
|
401
|
+
.. [2] Vaswani, A. et al., "Attention is all you need",
|
|
402
|
+
in Advances in neural information processing systems, 2017.
|
|
403
|
+
"""
|
|
404
|
+
|
|
405
|
+
def __init__(
|
|
406
|
+
self,
|
|
407
|
+
in_shape=32,
|
|
408
|
+
head_dim=8,
|
|
409
|
+
num_heads=2,
|
|
410
|
+
dropout=0.5,
|
|
411
|
+
):
|
|
412
|
+
super().__init__()
|
|
413
|
+
self.in_shape = in_shape
|
|
414
|
+
self.head_dim = head_dim
|
|
415
|
+
self.num_heads = num_heads
|
|
416
|
+
|
|
417
|
+
# Puts time dimension at -2 and feature dim at -1
|
|
418
|
+
self.dimshuffle = Rearrange("batch C T -> batch T C")
|
|
419
|
+
|
|
420
|
+
# Layer normalization
|
|
421
|
+
self.ln = nn.LayerNorm(normalized_shape=in_shape, eps=1e-6)
|
|
422
|
+
|
|
423
|
+
# Multi-head self-attention layer
|
|
424
|
+
# (We had to reimplement it since the original code is in tensorflow,
|
|
425
|
+
# where it is possible to have an embedding dimension different than
|
|
426
|
+
# the input and output dimensions, which is not possible in pytorch.)
|
|
427
|
+
self.mha = _MHA(
|
|
428
|
+
input_dim=in_shape,
|
|
429
|
+
head_dim=head_dim,
|
|
430
|
+
output_dim=in_shape,
|
|
431
|
+
num_heads=num_heads,
|
|
432
|
+
dropout=dropout,
|
|
433
|
+
)
|
|
434
|
+
|
|
435
|
+
# XXX: This line in the official code is weird, as there is already
|
|
436
|
+
# dropout in the MultiheadAttention layer. They also don't mention
|
|
437
|
+
# any additional dropout between the attention block and TCN in the
|
|
438
|
+
# paper. We are adding it here however to follo so we are removing this
|
|
439
|
+
# for now.
|
|
440
|
+
self.drop = nn.Dropout(0.3)
|
|
441
|
+
|
|
442
|
+
def forward(self, X):
|
|
443
|
+
# Dimension: (batch_size, F2, Tw)
|
|
444
|
+
X = self.dimshuffle(X)
|
|
445
|
+
# Dimension: (batch_size, Tw, F2)
|
|
446
|
+
|
|
447
|
+
# ----- Layer norm -----
|
|
448
|
+
out = self.ln(X)
|
|
449
|
+
|
|
450
|
+
# ----- Self-Attention -----
|
|
451
|
+
out = self.mha(out, out, out)
|
|
452
|
+
# Dimension: (batch_size, Tw, F2)
|
|
453
|
+
|
|
454
|
+
# XXX In the paper fig. 1, it is drawn that layer normalization is
|
|
455
|
+
# performed before the skip connection, while it is done afterwards
|
|
456
|
+
# in the official code. Here we follow the code.
|
|
457
|
+
|
|
458
|
+
# ----- Skip connection -----
|
|
459
|
+
out = X + self.drop(out)
|
|
460
|
+
|
|
461
|
+
# Move back to shape (batch_size, F2, Tw) from the beginning
|
|
462
|
+
return self.dimshuffle(out)
|
|
463
|
+
|
|
464
|
+
|
|
465
|
+
class _TCNResidualBlock(nn.Module):
|
|
466
|
+
"""Modified TCN Residual block as proposed in [1]_. Inspired from
|
|
467
|
+
Temporal Convolutional Networks (TCN) [2]_.
|
|
468
|
+
|
|
469
|
+
References
|
|
470
|
+
----------
|
|
471
|
+
.. [1] H. Altaheri, G. Muhammad and M. Alsulaiman, "Physics-informed
|
|
472
|
+
attention temporal convolutional network for EEG-based motor imagery
|
|
473
|
+
classification," in IEEE Transactions on Industrial Informatics,
|
|
474
|
+
2022, doi: 10.1109/TII.2022.3197419.
|
|
475
|
+
.. [2] Bai, S., Kolter, J. Z., & Koltun, V.
|
|
476
|
+
"An empirical evaluation of generic convolutional and recurrent
|
|
477
|
+
networks for sequence modeling", 2018.
|
|
478
|
+
"""
|
|
479
|
+
|
|
480
|
+
def __init__(
|
|
481
|
+
self,
|
|
482
|
+
in_channels,
|
|
483
|
+
kernel_size=4,
|
|
484
|
+
n_filters=32,
|
|
485
|
+
dropout=0.3,
|
|
486
|
+
activation: nn.Module = nn.ELU,
|
|
487
|
+
dilation=1,
|
|
488
|
+
):
|
|
489
|
+
super().__init__()
|
|
490
|
+
self.activation = activation()
|
|
491
|
+
self.dilation = dilation
|
|
492
|
+
self.dropout = dropout
|
|
493
|
+
self.n_filters = n_filters
|
|
494
|
+
self.kernel_size = kernel_size
|
|
495
|
+
self.in_channels = in_channels
|
|
496
|
+
|
|
497
|
+
self.conv1 = CausalConv1d(
|
|
498
|
+
in_channels=in_channels,
|
|
499
|
+
out_channels=n_filters,
|
|
500
|
+
kernel_size=kernel_size,
|
|
501
|
+
dilation=dilation,
|
|
502
|
+
)
|
|
503
|
+
nn.init.kaiming_uniform_(self.conv1.weight)
|
|
504
|
+
|
|
505
|
+
self.bn1 = nn.BatchNorm1d(n_filters)
|
|
506
|
+
|
|
507
|
+
self.drop1 = nn.Dropout(dropout)
|
|
508
|
+
|
|
509
|
+
self.conv2 = CausalConv1d(
|
|
510
|
+
in_channels=n_filters,
|
|
511
|
+
out_channels=n_filters,
|
|
512
|
+
kernel_size=kernel_size,
|
|
513
|
+
dilation=dilation,
|
|
514
|
+
)
|
|
515
|
+
nn.init.kaiming_uniform_(self.conv2.weight)
|
|
516
|
+
|
|
517
|
+
self.bn2 = nn.BatchNorm1d(n_filters)
|
|
518
|
+
|
|
519
|
+
self.drop2 = nn.Dropout(dropout)
|
|
520
|
+
|
|
521
|
+
# Reshape the input for the residual connection when necessary
|
|
522
|
+
if in_channels != n_filters:
|
|
523
|
+
self.reshaping_conv = nn.Conv1d(
|
|
524
|
+
n_filters,
|
|
525
|
+
kernel_size=1,
|
|
526
|
+
padding="same",
|
|
527
|
+
)
|
|
528
|
+
else:
|
|
529
|
+
self.reshaping_conv = nn.Identity()
|
|
530
|
+
|
|
531
|
+
def forward(self, X):
|
|
532
|
+
# Dimension: (batch_size, F2, Tw)
|
|
533
|
+
# ----- Double dilated convolutions -----
|
|
534
|
+
out = self.conv1(X)
|
|
535
|
+
out = self.bn1(out)
|
|
536
|
+
out = self.activation(out)
|
|
537
|
+
out = self.drop1(out)
|
|
538
|
+
|
|
539
|
+
out = self.conv2(out)
|
|
540
|
+
out = self.bn2(out)
|
|
541
|
+
out = self.activation(out)
|
|
542
|
+
out = self.drop2(out)
|
|
543
|
+
|
|
544
|
+
out = self.reshaping_conv(out)
|
|
545
|
+
|
|
546
|
+
# ----- Residual connection -----
|
|
547
|
+
out = X + out
|
|
548
|
+
|
|
549
|
+
return self.activation(out)
|
|
550
|
+
|
|
551
|
+
|
|
552
|
+
class _MHA(nn.Module):
|
|
553
|
+
def __init__(
|
|
554
|
+
self,
|
|
555
|
+
input_dim: int,
|
|
556
|
+
head_dim: int,
|
|
557
|
+
output_dim: int,
|
|
558
|
+
num_heads: int,
|
|
559
|
+
dropout: float = 0.0,
|
|
560
|
+
):
|
|
561
|
+
"""Multi-head Attention
|
|
562
|
+
|
|
563
|
+
The difference between this module and torch.nn.MultiheadAttention is
|
|
564
|
+
that this module supports embedding dimensions different then input
|
|
565
|
+
and output ones. It also does not support sequences of different
|
|
566
|
+
length.
|
|
567
|
+
|
|
568
|
+
Parameters
|
|
569
|
+
----------
|
|
570
|
+
input_dim : int
|
|
571
|
+
Dimension of query, key and value inputs.
|
|
572
|
+
head_dim : int
|
|
573
|
+
Dimension of embed query, key and value in each head,
|
|
574
|
+
before computing attention.
|
|
575
|
+
output_dim : int
|
|
576
|
+
Output dimension.
|
|
577
|
+
num_heads : int
|
|
578
|
+
Number of heads in the multi-head architecture.
|
|
579
|
+
dropout : float, optional
|
|
580
|
+
Dropout probability on output weights. Default: 0.0 (no dropout).
|
|
581
|
+
"""
|
|
582
|
+
|
|
583
|
+
super(_MHA, self).__init__()
|
|
584
|
+
|
|
585
|
+
self.input_dim = input_dim
|
|
586
|
+
self.head_dim = head_dim
|
|
587
|
+
# typical choice for the split dimension of the heads
|
|
588
|
+
self.embed_dim = head_dim * num_heads
|
|
589
|
+
|
|
590
|
+
# embeddings for multi-head projections
|
|
591
|
+
self.fc_q = nn.Linear(input_dim, self.embed_dim)
|
|
592
|
+
self.fc_k = nn.Linear(input_dim, self.embed_dim)
|
|
593
|
+
self.fc_v = nn.Linear(input_dim, self.embed_dim)
|
|
594
|
+
|
|
595
|
+
# output mapping
|
|
596
|
+
self.fc_o = nn.Linear(self.embed_dim, output_dim)
|
|
597
|
+
|
|
598
|
+
# dropout
|
|
599
|
+
self.dropout = nn.Dropout(dropout)
|
|
600
|
+
|
|
601
|
+
def forward(
|
|
602
|
+
self, Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor
|
|
603
|
+
) -> torch.Tensor:
|
|
604
|
+
"""Compute MHA(Q, K, V)
|
|
605
|
+
|
|
606
|
+
Parameters
|
|
607
|
+
----------
|
|
608
|
+
Q: torch.Tensor of size (batch_size, seq_len, input_dim)
|
|
609
|
+
Input query (Q) sequence.
|
|
610
|
+
K: torch.Tensor of size (batch_size, seq_len, input_dim)
|
|
611
|
+
Input key (K) sequence.
|
|
612
|
+
V: torch.Tensor of size (batch_size, seq_len, input_dim)
|
|
613
|
+
Input value (V) sequence.
|
|
614
|
+
|
|
615
|
+
Returns
|
|
616
|
+
-------
|
|
617
|
+
O: torch.Tensor of size (batch_size, seq_len, output_dim)
|
|
618
|
+
Output MHA(Q, K, V)
|
|
619
|
+
"""
|
|
620
|
+
assert Q.shape[-1] == K.shape[-1] == V.shape[-1] == self.input_dim
|
|
621
|
+
|
|
622
|
+
batch_size, _, _ = Q.shape
|
|
623
|
+
|
|
624
|
+
# embedding for multi-head projections (masked or not)
|
|
625
|
+
Q = self.fc_q(Q) # (B, S, D)
|
|
626
|
+
K, V = self.fc_k(K), self.fc_v(V) # (B, S, D)
|
|
627
|
+
|
|
628
|
+
# Split into num_head vectors (num_heads * batch_size, n/m, head_dim)
|
|
629
|
+
Q_ = torch.cat(Q.split(self.head_dim, -1), 0) # (B', S, D')
|
|
630
|
+
K_ = torch.cat(K.split(self.head_dim, -1), 0) # (B', S, D')
|
|
631
|
+
V_ = torch.cat(V.split(self.head_dim, -1), 0) # (B', S, D')
|
|
632
|
+
|
|
633
|
+
# Attention weights of size (num_heads * batch_size, n, m):
|
|
634
|
+
# measures how similar each pair of Q and K is.
|
|
635
|
+
W = torch.softmax(
|
|
636
|
+
Q_.bmm(K_.transpose(-2, -1)) / math.sqrt(self.head_dim),
|
|
637
|
+
-1, # (B', D', S)
|
|
638
|
+
) # (B', N, M)
|
|
639
|
+
|
|
640
|
+
# Multihead output (batch_size, seq_len, dim):
|
|
641
|
+
# weighted sum of V where a value gets more weight if its corresponding
|
|
642
|
+
# key has larger dot product with the query.
|
|
643
|
+
H = torch.cat(
|
|
644
|
+
(W.bmm(V_)).split( # (B', S, S) # (B', S, D')
|
|
645
|
+
batch_size, 0
|
|
646
|
+
), # [(B, S, D')] * num_heads
|
|
647
|
+
-1,
|
|
648
|
+
) # (B, S, D)
|
|
649
|
+
|
|
650
|
+
out = self.fc_o(H)
|
|
651
|
+
|
|
652
|
+
return self.dropout(out)
|