braindecode 0.8.1__py3-none-any.whl → 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of braindecode might be problematic. Click here for more details.
- braindecode/__init__.py +1 -2
- braindecode/augmentation/__init__.py +39 -19
- braindecode/augmentation/base.py +25 -28
- braindecode/augmentation/functional.py +237 -100
- braindecode/augmentation/transforms.py +326 -158
- braindecode/classifier.py +26 -24
- braindecode/datasets/__init__.py +28 -10
- braindecode/datasets/base.py +220 -134
- braindecode/datasets/bbci.py +43 -52
- braindecode/datasets/bcicomp.py +47 -32
- braindecode/datasets/bids.py +245 -0
- braindecode/datasets/mne.py +45 -24
- braindecode/datasets/moabb.py +87 -27
- braindecode/datasets/nmt.py +311 -0
- braindecode/datasets/sleep_physio_challe_18.py +412 -0
- braindecode/datasets/sleep_physionet.py +43 -26
- braindecode/datasets/tuh.py +324 -140
- braindecode/datasets/xy.py +27 -12
- braindecode/datautil/__init__.py +34 -18
- braindecode/datautil/serialization.py +98 -71
- braindecode/eegneuralnet.py +63 -47
- braindecode/functional/__init__.py +10 -0
- braindecode/functional/functions.py +251 -0
- braindecode/functional/initialization.py +47 -0
- braindecode/models/__init__.py +36 -14
- braindecode/models/atcnet.py +153 -159
- braindecode/models/attentionbasenet.py +550 -0
- braindecode/models/base.py +86 -102
- braindecode/models/biot.py +483 -0
- braindecode/models/contrawr.py +296 -0
- braindecode/models/ctnet.py +450 -0
- braindecode/models/deep4.py +64 -75
- braindecode/models/deepsleepnet.py +149 -119
- braindecode/models/eegconformer.py +111 -171
- braindecode/models/eeginception_erp.py +109 -118
- braindecode/models/eeginception_mi.py +155 -97
- braindecode/models/eegitnet.py +215 -151
- braindecode/models/eegminer.py +255 -0
- braindecode/models/eegnet.py +229 -161
- braindecode/models/eegnex.py +247 -0
- braindecode/models/eegresnet.py +234 -152
- braindecode/models/eegsimpleconv.py +199 -0
- braindecode/models/eegtcnet.py +335 -0
- braindecode/models/fbcnet.py +221 -0
- braindecode/models/fblightconvnet.py +313 -0
- braindecode/models/fbmsnet.py +325 -0
- braindecode/models/hybrid.py +52 -71
- braindecode/models/ifnet.py +441 -0
- braindecode/models/labram.py +1166 -0
- braindecode/models/msvtnet.py +375 -0
- braindecode/models/sccnet.py +182 -0
- braindecode/models/shallow_fbcsp.py +50 -56
- braindecode/models/signal_jepa.py +1012 -0
- braindecode/models/sinc_shallow.py +337 -0
- braindecode/models/sleep_stager_blanco_2020.py +55 -46
- braindecode/models/sleep_stager_chambon_2018.py +54 -53
- braindecode/models/sleep_stager_eldele_2021.py +248 -141
- braindecode/models/sparcnet.py +378 -0
- braindecode/models/summary.csv +41 -0
- braindecode/models/syncnet.py +232 -0
- braindecode/models/tcn.py +158 -88
- braindecode/models/tidnet.py +280 -167
- braindecode/models/tsinception.py +258 -0
- braindecode/models/usleep.py +190 -177
- braindecode/models/util.py +109 -141
- braindecode/modules/__init__.py +38 -0
- braindecode/modules/activation.py +60 -0
- braindecode/modules/attention.py +757 -0
- braindecode/modules/blocks.py +108 -0
- braindecode/modules/convolution.py +274 -0
- braindecode/modules/filter.py +632 -0
- braindecode/modules/layers.py +133 -0
- braindecode/modules/linear.py +50 -0
- braindecode/modules/parametrization.py +38 -0
- braindecode/modules/stats.py +77 -0
- braindecode/modules/util.py +77 -0
- braindecode/modules/wrapper.py +75 -0
- braindecode/preprocessing/__init__.py +36 -11
- braindecode/preprocessing/mne_preprocess.py +13 -7
- braindecode/preprocessing/preprocess.py +139 -75
- braindecode/preprocessing/windowers.py +576 -187
- braindecode/regressor.py +23 -12
- braindecode/samplers/__init__.py +16 -8
- braindecode/samplers/base.py +148 -32
- braindecode/samplers/ssl.py +162 -17
- braindecode/training/__init__.py +18 -10
- braindecode/training/callbacks.py +2 -4
- braindecode/training/losses.py +3 -8
- braindecode/training/scoring.py +76 -68
- braindecode/util.py +55 -59
- braindecode/version.py +1 -1
- braindecode/visualization/__init__.py +2 -3
- braindecode/visualization/confusion_matrices.py +117 -73
- braindecode/visualization/gradients.py +14 -10
- {braindecode-0.8.1.dist-info → braindecode-1.0.0.dist-info}/METADATA +39 -55
- braindecode-1.0.0.dist-info/RECORD +101 -0
- {braindecode-0.8.1.dist-info → braindecode-1.0.0.dist-info}/WHEEL +1 -1
- {braindecode-0.8.1.dist-info → braindecode-1.0.0.dist-info/licenses}/LICENSE.txt +1 -1
- braindecode-1.0.0.dist-info/licenses/NOTICE.txt +20 -0
- braindecode/datautil/mne.py +0 -9
- braindecode/datautil/preprocess.py +0 -12
- braindecode/datautil/windowers.py +0 -6
- braindecode/datautil/xy.py +0 -9
- braindecode/models/eeginception.py +0 -317
- braindecode/models/functions.py +0 -47
- braindecode/models/modules.py +0 -358
- braindecode-0.8.1.dist-info/RECORD +0 -68
- {braindecode-0.8.1.dist-info → braindecode-1.0.0.dist-info}/top_level.txt +0 -0
braindecode/models/atcnet.py
CHANGED
|
@@ -1,21 +1,25 @@
|
|
|
1
1
|
# Authors: Cedric Rommel <cedric.rommel@inria.fr>
|
|
2
2
|
#
|
|
3
3
|
# License: BSD (3-clause)
|
|
4
|
-
import
|
|
4
|
+
import math
|
|
5
5
|
|
|
6
6
|
import torch
|
|
7
|
-
from torch import nn
|
|
8
7
|
from einops.layers.torch import Rearrange
|
|
8
|
+
from torch import nn
|
|
9
9
|
|
|
10
|
-
from .
|
|
11
|
-
from .
|
|
10
|
+
from braindecode.models.base import EEGModuleMixin
|
|
11
|
+
from braindecode.modules import CausalConv1d, Ensure4d, MaxNormLinear
|
|
12
12
|
|
|
13
13
|
|
|
14
14
|
class ATCNet(EEGModuleMixin, nn.Module):
|
|
15
|
-
"""ATCNet model from [1]_
|
|
15
|
+
"""ATCNet model from Altaheri et al. (2022) [1]_
|
|
16
16
|
|
|
17
17
|
Pytorch implementation based on official tensorflow code [2]_.
|
|
18
18
|
|
|
19
|
+
.. figure:: https://user-images.githubusercontent.com/25565236/185449791-e8539453-d4fa-41e1-865a-2cf7e91f60ef.png
|
|
20
|
+
:align: center
|
|
21
|
+
:alt: ATCNet Architecture
|
|
22
|
+
|
|
19
23
|
Parameters
|
|
20
24
|
----------
|
|
21
25
|
input_window_seconds : float, optional
|
|
@@ -54,7 +58,7 @@ class ATCNet(EEGModuleMixin, nn.Module):
|
|
|
54
58
|
table 1 of the paper [1]_. Defaults to 8 as in [1]_.
|
|
55
59
|
att_num_heads : int
|
|
56
60
|
Number of attention heads, denoted H in table 1 of the paper [1]_.
|
|
57
|
-
Defaults to 2 as in [
|
|
61
|
+
Defaults to 2 as in [1]_.
|
|
58
62
|
att_dropout : float
|
|
59
63
|
Dropout probability used in the attention block, denoted pa in table 1
|
|
60
64
|
of the paper [1]_. Defaults to 0.5 as in [1]_.
|
|
@@ -82,59 +86,45 @@ class ATCNet(EEGModuleMixin, nn.Module):
|
|
|
82
86
|
max_norm_const : float
|
|
83
87
|
Maximum L2-norm constraint imposed on weights of the last
|
|
84
88
|
fully-connected layer. Defaults to 0.25.
|
|
85
|
-
|
|
86
|
-
Alias for n_chans.
|
|
87
|
-
n_classes:
|
|
88
|
-
Alias for n_outputs.
|
|
89
|
-
input_size_s:
|
|
90
|
-
Alias for input_window_seconds.
|
|
89
|
+
|
|
91
90
|
|
|
92
91
|
References
|
|
93
92
|
----------
|
|
94
|
-
.. [1] H. Altaheri, G. Muhammad and M. Alsulaiman,
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
.. [2]
|
|
93
|
+
.. [1] H. Altaheri, G. Muhammad and M. Alsulaiman,
|
|
94
|
+
Physics-informed attention temporal convolutional network for EEG-based
|
|
95
|
+
motor imagery classification in IEEE Transactions on Industrial Informatics,
|
|
96
|
+
2022, doi: 10.1109/TII.2022.3197419.
|
|
97
|
+
.. [2] EEE-ATCNet implementation.
|
|
98
|
+
https://github.com/Altaheri/EEG-ATCNet/blob/main/models.py
|
|
99
99
|
"""
|
|
100
100
|
|
|
101
101
|
def __init__(
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
n_channels=None,
|
|
128
|
-
n_classes=None,
|
|
129
|
-
input_size_s=None,
|
|
130
|
-
add_log_softmax=True,
|
|
102
|
+
self,
|
|
103
|
+
n_chans=None,
|
|
104
|
+
n_outputs=None,
|
|
105
|
+
input_window_seconds=None,
|
|
106
|
+
sfreq=250.0,
|
|
107
|
+
conv_block_n_filters=16,
|
|
108
|
+
conv_block_kernel_length_1=64,
|
|
109
|
+
conv_block_kernel_length_2=16,
|
|
110
|
+
conv_block_pool_size_1=8,
|
|
111
|
+
conv_block_pool_size_2=7,
|
|
112
|
+
conv_block_depth_mult=2,
|
|
113
|
+
conv_block_dropout=0.3,
|
|
114
|
+
n_windows=5,
|
|
115
|
+
att_head_dim=8,
|
|
116
|
+
att_num_heads=2,
|
|
117
|
+
att_drop_prob=0.5,
|
|
118
|
+
tcn_depth=2,
|
|
119
|
+
tcn_kernel_size=4,
|
|
120
|
+
tcn_n_filters=32,
|
|
121
|
+
tcn_drop_prob=0.3,
|
|
122
|
+
tcn_activation: nn.Module = nn.ELU,
|
|
123
|
+
concat=False,
|
|
124
|
+
max_norm_const=0.25,
|
|
125
|
+
chs_info=None,
|
|
126
|
+
n_times=None,
|
|
131
127
|
):
|
|
132
|
-
n_chans, n_outputs, input_window_seconds = deprecated_args(
|
|
133
|
-
self,
|
|
134
|
-
('n_channels', 'n_chans', n_channels, n_chans),
|
|
135
|
-
('n_classes', 'n_outputs', n_classes, n_outputs),
|
|
136
|
-
('input_size_s', 'input_window_seconds', input_size_s, input_window_seconds),
|
|
137
|
-
)
|
|
138
128
|
super().__init__(
|
|
139
129
|
n_outputs=n_outputs,
|
|
140
130
|
n_chans=n_chans,
|
|
@@ -142,10 +132,8 @@ class ATCNet(EEGModuleMixin, nn.Module):
|
|
|
142
132
|
n_times=n_times,
|
|
143
133
|
input_window_seconds=input_window_seconds,
|
|
144
134
|
sfreq=sfreq,
|
|
145
|
-
add_log_softmax=add_log_softmax,
|
|
146
135
|
)
|
|
147
136
|
del n_outputs, n_chans, chs_info, n_times, input_window_seconds, sfreq
|
|
148
|
-
del n_channels, n_classes, input_size_s
|
|
149
137
|
self.conv_block_n_filters = conv_block_n_filters
|
|
150
138
|
self.conv_block_kernel_length_1 = conv_block_kernel_length_1
|
|
151
139
|
self.conv_block_kernel_length_2 = conv_block_kernel_length_2
|
|
@@ -156,19 +144,19 @@ class ATCNet(EEGModuleMixin, nn.Module):
|
|
|
156
144
|
self.n_windows = n_windows
|
|
157
145
|
self.att_head_dim = att_head_dim
|
|
158
146
|
self.att_num_heads = att_num_heads
|
|
159
|
-
self.att_dropout =
|
|
147
|
+
self.att_dropout = att_drop_prob
|
|
160
148
|
self.tcn_depth = tcn_depth
|
|
161
149
|
self.tcn_kernel_size = tcn_kernel_size
|
|
162
150
|
self.tcn_n_filters = tcn_n_filters
|
|
163
|
-
self.tcn_dropout =
|
|
151
|
+
self.tcn_dropout = tcn_drop_prob
|
|
164
152
|
self.tcn_activation = tcn_activation
|
|
165
153
|
self.concat = concat
|
|
166
154
|
self.max_norm_const = max_norm_const
|
|
167
155
|
|
|
168
156
|
map = dict()
|
|
169
157
|
for w in range(self.n_windows):
|
|
170
|
-
map[f
|
|
171
|
-
map[f
|
|
158
|
+
map[f"max_norm_linears.[{w}].weight"] = f"final_layer.[{w}].weight"
|
|
159
|
+
map[f"max_norm_linears.[{w}].bias"] = f"final_layer.[{w}].bias"
|
|
172
160
|
self.mapping = map
|
|
173
161
|
|
|
174
162
|
# Check later if we want to keep the Ensure4d. Not sure if we can
|
|
@@ -184,57 +172,67 @@ class ATCNet(EEGModuleMixin, nn.Module):
|
|
|
184
172
|
pool_size_1=conv_block_pool_size_1,
|
|
185
173
|
pool_size_2=conv_block_pool_size_2,
|
|
186
174
|
depth_mult=conv_block_depth_mult,
|
|
187
|
-
dropout=conv_block_dropout
|
|
175
|
+
dropout=conv_block_dropout,
|
|
188
176
|
)
|
|
189
177
|
|
|
190
178
|
self.F2 = int(conv_block_depth_mult * conv_block_n_filters)
|
|
191
|
-
self.Tc = int(self.
|
|
192
|
-
conv_block_pool_size_1 * conv_block_pool_size_2))
|
|
179
|
+
self.Tc = int(self.n_times / (conv_block_pool_size_1 * conv_block_pool_size_2))
|
|
193
180
|
self.Tw = self.Tc - self.n_windows + 1
|
|
194
181
|
|
|
195
|
-
self.attention_blocks = nn.ModuleList(
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
*[_TCNResidualBlock(
|
|
207
|
-
in_channels=self.F2,
|
|
208
|
-
kernel_size=tcn_kernel_size,
|
|
209
|
-
n_filters=tcn_n_filters,
|
|
210
|
-
dropout=tcn_dropout,
|
|
211
|
-
activation=tcn_activation,
|
|
212
|
-
dilation=2 ** i
|
|
213
|
-
) for i in range(tcn_depth)]
|
|
214
|
-
) for _ in range(self.n_windows)
|
|
215
|
-
])
|
|
182
|
+
self.attention_blocks = nn.ModuleList(
|
|
183
|
+
[
|
|
184
|
+
_AttentionBlock(
|
|
185
|
+
in_shape=self.F2,
|
|
186
|
+
head_dim=self.att_head_dim,
|
|
187
|
+
num_heads=att_num_heads,
|
|
188
|
+
dropout=att_drop_prob,
|
|
189
|
+
)
|
|
190
|
+
for _ in range(self.n_windows)
|
|
191
|
+
]
|
|
192
|
+
)
|
|
216
193
|
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
194
|
+
self.temporal_conv_nets = nn.ModuleList(
|
|
195
|
+
[
|
|
196
|
+
nn.Sequential(
|
|
197
|
+
*[
|
|
198
|
+
_TCNResidualBlock(
|
|
199
|
+
in_channels=self.F2,
|
|
200
|
+
kernel_size=tcn_kernel_size,
|
|
201
|
+
n_filters=tcn_n_filters,
|
|
202
|
+
dropout=tcn_drop_prob,
|
|
203
|
+
activation=tcn_activation,
|
|
204
|
+
dilation=2**i,
|
|
205
|
+
)
|
|
206
|
+
for i in range(tcn_depth)
|
|
207
|
+
]
|
|
223
208
|
)
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
|
|
209
|
+
for _ in range(self.n_windows)
|
|
210
|
+
]
|
|
211
|
+
)
|
|
212
|
+
|
|
213
|
+
if self.concat:
|
|
214
|
+
self.final_layer = nn.ModuleList(
|
|
215
|
+
[
|
|
216
|
+
MaxNormLinear(
|
|
217
|
+
in_features=self.F2 * self.n_windows,
|
|
218
|
+
out_features=self.n_outputs,
|
|
219
|
+
max_norm_val=self.max_norm_const,
|
|
220
|
+
)
|
|
221
|
+
]
|
|
222
|
+
)
|
|
236
223
|
else:
|
|
237
|
-
self.
|
|
224
|
+
self.final_layer = nn.ModuleList(
|
|
225
|
+
[
|
|
226
|
+
MaxNormLinear(
|
|
227
|
+
in_features=self.F2,
|
|
228
|
+
out_features=self.n_outputs,
|
|
229
|
+
max_norm_val=self.max_norm_const,
|
|
230
|
+
)
|
|
231
|
+
for _ in range(self.n_windows)
|
|
232
|
+
]
|
|
233
|
+
)
|
|
234
|
+
|
|
235
|
+
self.out_fun = nn.Identity()
|
|
238
236
|
|
|
239
237
|
def forward(self, X):
|
|
240
238
|
# Dimension: (batch_size, C, T)
|
|
@@ -250,43 +248,46 @@ class ATCNet(EEGModuleMixin, nn.Module):
|
|
|
250
248
|
# Dimension: (batch_size, F2, Tc)
|
|
251
249
|
|
|
252
250
|
# ----- Sliding window -----
|
|
253
|
-
sw_concat = [] # to store sliding window outputs
|
|
254
|
-
for w in range(self.n_windows):
|
|
255
|
-
|
|
251
|
+
sw_concat: list[torch.Tensor] = [] # to store sliding window outputs
|
|
252
|
+
# for w in range(self.n_windows):
|
|
253
|
+
for idx, (attention, tcn_module, final_layer) in enumerate(
|
|
254
|
+
zip(self.attention_blocks, self.temporal_conv_nets, self.final_layer)
|
|
255
|
+
):
|
|
256
|
+
conv_feat_w = conv_feat[..., idx : idx + self.Tw]
|
|
256
257
|
# Dimension: (batch_size, F2, Tw)
|
|
257
258
|
|
|
258
259
|
# ----- Attention block -----
|
|
259
|
-
att_feat =
|
|
260
|
+
att_feat = attention(conv_feat_w)
|
|
260
261
|
# Dimension: (batch_size, F2, Tw)
|
|
261
262
|
|
|
262
263
|
# ----- Temporal convolutional network (TCN) -----
|
|
263
|
-
tcn_feat =
|
|
264
|
+
tcn_feat = tcn_module(att_feat)[..., -1]
|
|
264
265
|
# Dimension: (batch_size, F2)
|
|
265
266
|
|
|
266
267
|
# Outputs of sliding window can be either averaged after being
|
|
267
268
|
# mapped by dense layer or concatenated then mapped by a dense
|
|
268
269
|
# layer
|
|
269
270
|
if not self.concat:
|
|
270
|
-
tcn_feat =
|
|
271
|
+
tcn_feat = final_layer(tcn_feat)
|
|
271
272
|
|
|
272
273
|
sw_concat.append(tcn_feat)
|
|
273
274
|
|
|
274
275
|
# ----- Aggregation and prediction -----
|
|
275
276
|
if self.concat:
|
|
276
|
-
|
|
277
|
-
|
|
277
|
+
sw_concat_agg = torch.cat(sw_concat, dim=1)
|
|
278
|
+
sw_concat_agg = self.final_layer[0](sw_concat_agg)
|
|
278
279
|
else:
|
|
279
280
|
if len(sw_concat) > 1: # more than one window
|
|
280
|
-
|
|
281
|
-
|
|
281
|
+
sw_concat_agg = torch.stack(sw_concat, dim=0)
|
|
282
|
+
sw_concat_agg = torch.mean(sw_concat_agg, dim=0)
|
|
282
283
|
else: # one window (# windows = 1)
|
|
283
|
-
|
|
284
|
+
sw_concat_agg = sw_concat[0]
|
|
284
285
|
|
|
285
|
-
return self.out_fun(
|
|
286
|
+
return self.out_fun(sw_concat_agg)
|
|
286
287
|
|
|
287
288
|
|
|
288
289
|
class _ConvBlock(nn.Module):
|
|
289
|
-
"""
|
|
290
|
+
"""Convolutional block proposed in ATCNet [1]_, inspired by the EEGNet
|
|
290
291
|
architecture [2]_.
|
|
291
292
|
|
|
292
293
|
References
|
|
@@ -303,15 +304,15 @@ class _ConvBlock(nn.Module):
|
|
|
303
304
|
"""
|
|
304
305
|
|
|
305
306
|
def __init__(
|
|
306
|
-
|
|
307
|
-
|
|
308
|
-
|
|
309
|
-
|
|
310
|
-
|
|
311
|
-
|
|
312
|
-
|
|
313
|
-
|
|
314
|
-
|
|
307
|
+
self,
|
|
308
|
+
n_channels,
|
|
309
|
+
n_filters=16,
|
|
310
|
+
kernel_length_1=64,
|
|
311
|
+
kernel_length_2=16,
|
|
312
|
+
pool_size_1=8,
|
|
313
|
+
pool_size_2=7,
|
|
314
|
+
depth_mult=2,
|
|
315
|
+
dropout=0.3,
|
|
315
316
|
):
|
|
316
317
|
super().__init__()
|
|
317
318
|
|
|
@@ -402,11 +403,11 @@ class _AttentionBlock(nn.Module):
|
|
|
402
403
|
"""
|
|
403
404
|
|
|
404
405
|
def __init__(
|
|
405
|
-
|
|
406
|
-
|
|
407
|
-
|
|
408
|
-
|
|
409
|
-
|
|
406
|
+
self,
|
|
407
|
+
in_shape=32,
|
|
408
|
+
head_dim=8,
|
|
409
|
+
num_heads=2,
|
|
410
|
+
dropout=0.5,
|
|
410
411
|
):
|
|
411
412
|
super().__init__()
|
|
412
413
|
self.in_shape = in_shape
|
|
@@ -462,7 +463,7 @@ class _AttentionBlock(nn.Module):
|
|
|
462
463
|
|
|
463
464
|
|
|
464
465
|
class _TCNResidualBlock(nn.Module):
|
|
465
|
-
"""
|
|
466
|
+
"""Modified TCN Residual block as proposed in [1]_. Inspired from
|
|
466
467
|
Temporal Convolutional Networks (TCN) [2]_.
|
|
467
468
|
|
|
468
469
|
References
|
|
@@ -477,16 +478,16 @@ class _TCNResidualBlock(nn.Module):
|
|
|
477
478
|
"""
|
|
478
479
|
|
|
479
480
|
def __init__(
|
|
480
|
-
|
|
481
|
-
|
|
482
|
-
|
|
483
|
-
|
|
484
|
-
|
|
485
|
-
|
|
486
|
-
|
|
481
|
+
self,
|
|
482
|
+
in_channels,
|
|
483
|
+
kernel_size=4,
|
|
484
|
+
n_filters=32,
|
|
485
|
+
dropout=0.3,
|
|
486
|
+
activation: nn.Module = nn.ELU,
|
|
487
|
+
dilation=1,
|
|
487
488
|
):
|
|
488
489
|
super().__init__()
|
|
489
|
-
self.activation = activation
|
|
490
|
+
self.activation = activation()
|
|
490
491
|
self.dilation = dilation
|
|
491
492
|
self.dropout = dropout
|
|
492
493
|
self.n_filters = n_filters
|
|
@@ -522,7 +523,7 @@ class _TCNResidualBlock(nn.Module):
|
|
|
522
523
|
self.reshaping_conv = nn.Conv1d(
|
|
523
524
|
n_filters,
|
|
524
525
|
kernel_size=1,
|
|
525
|
-
padding=
|
|
526
|
+
padding="same",
|
|
526
527
|
)
|
|
527
528
|
else:
|
|
528
529
|
self.reshaping_conv = nn.Identity()
|
|
@@ -550,12 +551,12 @@ class _TCNResidualBlock(nn.Module):
|
|
|
550
551
|
|
|
551
552
|
class _MHA(nn.Module):
|
|
552
553
|
def __init__(
|
|
553
|
-
|
|
554
|
-
|
|
555
|
-
|
|
556
|
-
|
|
557
|
-
|
|
558
|
-
|
|
554
|
+
self,
|
|
555
|
+
input_dim: int,
|
|
556
|
+
head_dim: int,
|
|
557
|
+
output_dim: int,
|
|
558
|
+
num_heads: int,
|
|
559
|
+
dropout: float = 0.0,
|
|
559
560
|
):
|
|
560
561
|
"""Multi-head Attention
|
|
561
562
|
|
|
@@ -598,12 +599,9 @@ class _MHA(nn.Module):
|
|
|
598
599
|
self.dropout = nn.Dropout(dropout)
|
|
599
600
|
|
|
600
601
|
def forward(
|
|
601
|
-
|
|
602
|
-
Q: torch.Tensor,
|
|
603
|
-
K: torch.Tensor,
|
|
604
|
-
V: torch.Tensor
|
|
602
|
+
self, Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor
|
|
605
603
|
) -> torch.Tensor:
|
|
606
|
-
"""
|
|
604
|
+
"""Compute MHA(Q, K, V)
|
|
607
605
|
|
|
608
606
|
Parameters
|
|
609
607
|
----------
|
|
@@ -635,22 +633,18 @@ class _MHA(nn.Module):
|
|
|
635
633
|
# Attention weights of size (num_heads * batch_size, n, m):
|
|
636
634
|
# measures how similar each pair of Q and K is.
|
|
637
635
|
W = torch.softmax(
|
|
638
|
-
Q_.bmm(
|
|
639
|
-
|
|
640
|
-
)
|
|
641
|
-
/ np.sqrt(self.head_dim),
|
|
642
|
-
-1
|
|
636
|
+
Q_.bmm(K_.transpose(-2, -1)) / math.sqrt(self.head_dim),
|
|
637
|
+
-1, # (B', D', S)
|
|
643
638
|
) # (B', N, M)
|
|
644
639
|
|
|
645
640
|
# Multihead output (batch_size, seq_len, dim):
|
|
646
641
|
# weighted sum of V where a value gets more weight if its corresponding
|
|
647
642
|
# key has larger dot product with the query.
|
|
648
643
|
H = torch.cat(
|
|
649
|
-
(
|
|
650
|
-
|
|
651
|
-
|
|
652
|
-
|
|
653
|
-
-1
|
|
644
|
+
(W.bmm(V_)).split( # (B', S, S) # (B', S, D')
|
|
645
|
+
batch_size, 0
|
|
646
|
+
), # [(B, S, D')] * num_heads
|
|
647
|
+
-1,
|
|
654
648
|
) # (B, S, D)
|
|
655
649
|
|
|
656
650
|
out = self.fc_o(H)
|