braindecode 0.8.1__py3-none-any.whl → 1.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of braindecode might be problematic. Click here for more details.
- braindecode/__init__.py +1 -2
- braindecode/augmentation/__init__.py +39 -19
- braindecode/augmentation/base.py +25 -28
- braindecode/augmentation/functional.py +237 -100
- braindecode/augmentation/transforms.py +325 -158
- braindecode/classifier.py +26 -24
- braindecode/datasets/__init__.py +28 -10
- braindecode/datasets/base.py +220 -134
- braindecode/datasets/bbci.py +43 -52
- braindecode/datasets/bcicomp.py +47 -32
- braindecode/datasets/bids.py +245 -0
- braindecode/datasets/mne.py +45 -24
- braindecode/datasets/moabb.py +87 -27
- braindecode/datasets/nmt.py +311 -0
- braindecode/datasets/sleep_physio_challe_18.py +412 -0
- braindecode/datasets/sleep_physionet.py +43 -26
- braindecode/datasets/tuh.py +324 -140
- braindecode/datasets/xy.py +27 -12
- braindecode/datautil/__init__.py +37 -18
- braindecode/datautil/serialization.py +110 -72
- braindecode/eegneuralnet.py +63 -47
- braindecode/functional/__init__.py +22 -0
- braindecode/functional/functions.py +250 -0
- braindecode/functional/initialization.py +47 -0
- braindecode/models/__init__.py +84 -14
- braindecode/models/atcnet.py +193 -164
- braindecode/models/attentionbasenet.py +599 -0
- braindecode/models/base.py +86 -102
- braindecode/models/biot.py +504 -0
- braindecode/models/contrawr.py +317 -0
- braindecode/models/ctnet.py +536 -0
- braindecode/models/deep4.py +116 -77
- braindecode/models/deepsleepnet.py +149 -119
- braindecode/models/eegconformer.py +112 -173
- braindecode/models/eeginception_erp.py +109 -118
- braindecode/models/eeginception_mi.py +161 -97
- braindecode/models/eegitnet.py +215 -152
- braindecode/models/eegminer.py +254 -0
- braindecode/models/eegnet.py +228 -161
- braindecode/models/eegnex.py +247 -0
- braindecode/models/eegresnet.py +234 -152
- braindecode/models/eegsimpleconv.py +199 -0
- braindecode/models/eegtcnet.py +335 -0
- braindecode/models/fbcnet.py +221 -0
- braindecode/models/fblightconvnet.py +313 -0
- braindecode/models/fbmsnet.py +324 -0
- braindecode/models/hybrid.py +52 -71
- braindecode/models/ifnet.py +441 -0
- braindecode/models/labram.py +1186 -0
- braindecode/models/msvtnet.py +375 -0
- braindecode/models/sccnet.py +207 -0
- braindecode/models/shallow_fbcsp.py +50 -56
- braindecode/models/signal_jepa.py +1011 -0
- braindecode/models/sinc_shallow.py +337 -0
- braindecode/models/sleep_stager_blanco_2020.py +55 -46
- braindecode/models/sleep_stager_chambon_2018.py +54 -53
- braindecode/models/sleep_stager_eldele_2021.py +247 -141
- braindecode/models/sparcnet.py +424 -0
- braindecode/models/summary.csv +41 -0
- braindecode/models/syncnet.py +232 -0
- braindecode/models/tcn.py +158 -88
- braindecode/models/tidnet.py +280 -167
- braindecode/models/tsinception.py +283 -0
- braindecode/models/usleep.py +190 -177
- braindecode/models/util.py +109 -145
- braindecode/modules/__init__.py +84 -0
- braindecode/modules/activation.py +60 -0
- braindecode/modules/attention.py +757 -0
- braindecode/modules/blocks.py +108 -0
- braindecode/modules/convolution.py +274 -0
- braindecode/modules/filter.py +628 -0
- braindecode/modules/layers.py +131 -0
- braindecode/modules/linear.py +49 -0
- braindecode/modules/parametrization.py +38 -0
- braindecode/modules/stats.py +77 -0
- braindecode/modules/util.py +76 -0
- braindecode/modules/wrapper.py +73 -0
- braindecode/preprocessing/__init__.py +36 -11
- braindecode/preprocessing/mne_preprocess.py +13 -7
- braindecode/preprocessing/preprocess.py +139 -75
- braindecode/preprocessing/windowers.py +576 -187
- braindecode/regressor.py +23 -12
- braindecode/samplers/__init__.py +16 -8
- braindecode/samplers/base.py +146 -32
- braindecode/samplers/ssl.py +162 -17
- braindecode/training/__init__.py +18 -10
- braindecode/training/callbacks.py +2 -4
- braindecode/training/losses.py +3 -8
- braindecode/training/scoring.py +76 -68
- braindecode/util.py +55 -59
- braindecode/version.py +1 -1
- braindecode/visualization/__init__.py +2 -3
- braindecode/visualization/confusion_matrices.py +117 -73
- braindecode/visualization/gradients.py +14 -10
- {braindecode-0.8.1.dist-info → braindecode-1.1.0.dist-info}/METADATA +42 -58
- braindecode-1.1.0.dist-info/RECORD +101 -0
- {braindecode-0.8.1.dist-info → braindecode-1.1.0.dist-info}/WHEEL +1 -1
- {braindecode-0.8.1.dist-info → braindecode-1.1.0.dist-info/licenses}/LICENSE.txt +1 -1
- braindecode-1.1.0.dist-info/licenses/NOTICE.txt +20 -0
- braindecode/datautil/mne.py +0 -9
- braindecode/datautil/preprocess.py +0 -12
- braindecode/datautil/windowers.py +0 -6
- braindecode/datautil/xy.py +0 -9
- braindecode/models/eeginception.py +0 -317
- braindecode/models/functions.py +0 -47
- braindecode/models/modules.py +0 -358
- braindecode-0.8.1.dist-info/RECORD +0 -68
- {braindecode-0.8.1.dist-info → braindecode-1.1.0.dist-info}/top_level.txt +0 -0
braindecode/models/atcnet.py
CHANGED
|
@@ -1,21 +1,26 @@
|
|
|
1
1
|
# Authors: Cedric Rommel <cedric.rommel@inria.fr>
|
|
2
2
|
#
|
|
3
3
|
# License: BSD (3-clause)
|
|
4
|
-
import
|
|
4
|
+
import math
|
|
5
5
|
|
|
6
6
|
import torch
|
|
7
|
-
from torch import nn
|
|
8
7
|
from einops.layers.torch import Rearrange
|
|
8
|
+
from mne.utils import warn
|
|
9
|
+
from torch import nn
|
|
9
10
|
|
|
10
|
-
from .
|
|
11
|
-
from .
|
|
11
|
+
from braindecode.models.base import EEGModuleMixin
|
|
12
|
+
from braindecode.modules import CausalConv1d, Ensure4d, MaxNormLinear
|
|
12
13
|
|
|
13
14
|
|
|
14
15
|
class ATCNet(EEGModuleMixin, nn.Module):
|
|
15
|
-
"""ATCNet model from [1]_
|
|
16
|
+
"""ATCNet model from Altaheri et al. (2022) [1]_
|
|
16
17
|
|
|
17
18
|
Pytorch implementation based on official tensorflow code [2]_.
|
|
18
19
|
|
|
20
|
+
.. figure:: https://user-images.githubusercontent.com/25565236/185449791-e8539453-d4fa-41e1-865a-2cf7e91f60ef.png
|
|
21
|
+
:align: center
|
|
22
|
+
:alt: ATCNet Architecture
|
|
23
|
+
|
|
19
24
|
Parameters
|
|
20
25
|
----------
|
|
21
26
|
input_window_seconds : float, optional
|
|
@@ -54,7 +59,7 @@ class ATCNet(EEGModuleMixin, nn.Module):
|
|
|
54
59
|
table 1 of the paper [1]_. Defaults to 8 as in [1]_.
|
|
55
60
|
att_num_heads : int
|
|
56
61
|
Number of attention heads, denoted H in table 1 of the paper [1]_.
|
|
57
|
-
Defaults to 2 as in [
|
|
62
|
+
Defaults to 2 as in [1]_.
|
|
58
63
|
att_dropout : float
|
|
59
64
|
Dropout probability used in the attention block, denoted pa in table 1
|
|
60
65
|
of the paper [1]_. Defaults to 0.5 as in [1]_.
|
|
@@ -65,9 +70,6 @@ class ATCNet(EEGModuleMixin, nn.Module):
|
|
|
65
70
|
tcn_kernel_size : int
|
|
66
71
|
Temporal kernel size used in TCN block, denoted Kt in table 1 of the
|
|
67
72
|
paper [1]_. Defaults to 4 as in [1]_.
|
|
68
|
-
tcn_n_filters : int
|
|
69
|
-
Number of filters used in TCN convolutional layers (Ft). Defaults to
|
|
70
|
-
32 as in [1]_.
|
|
71
73
|
tcn_dropout : float
|
|
72
74
|
Dropout probability used in the TCN block, denoted pt in table 1
|
|
73
75
|
of the paper [1]_. Defaults to 0.3 as in [1]_.
|
|
@@ -82,59 +84,44 @@ class ATCNet(EEGModuleMixin, nn.Module):
|
|
|
82
84
|
max_norm_const : float
|
|
83
85
|
Maximum L2-norm constraint imposed on weights of the last
|
|
84
86
|
fully-connected layer. Defaults to 0.25.
|
|
85
|
-
|
|
86
|
-
Alias for n_chans.
|
|
87
|
-
n_classes:
|
|
88
|
-
Alias for n_outputs.
|
|
89
|
-
input_size_s:
|
|
90
|
-
Alias for input_window_seconds.
|
|
87
|
+
|
|
91
88
|
|
|
92
89
|
References
|
|
93
90
|
----------
|
|
94
|
-
.. [1] H. Altaheri, G. Muhammad and M. Alsulaiman,
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
.. [2]
|
|
91
|
+
.. [1] H. Altaheri, G. Muhammad and M. Alsulaiman,
|
|
92
|
+
Physics-informed attention temporal convolutional network for EEG-based
|
|
93
|
+
motor imagery classification in IEEE Transactions on Industrial Informatics,
|
|
94
|
+
2022, doi: 10.1109/TII.2022.3197419.
|
|
95
|
+
.. [2] EEE-ATCNet implementation.
|
|
96
|
+
https://github.com/Altaheri/EEG-ATCNet/blob/main/models.py
|
|
99
97
|
"""
|
|
100
98
|
|
|
101
99
|
def __init__(
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
n_times=None,
|
|
127
|
-
n_channels=None,
|
|
128
|
-
n_classes=None,
|
|
129
|
-
input_size_s=None,
|
|
130
|
-
add_log_softmax=True,
|
|
100
|
+
self,
|
|
101
|
+
n_chans=None,
|
|
102
|
+
n_outputs=None,
|
|
103
|
+
input_window_seconds=None,
|
|
104
|
+
sfreq=250.0,
|
|
105
|
+
conv_block_n_filters=16,
|
|
106
|
+
conv_block_kernel_length_1=64,
|
|
107
|
+
conv_block_kernel_length_2=16,
|
|
108
|
+
conv_block_pool_size_1=8,
|
|
109
|
+
conv_block_pool_size_2=7,
|
|
110
|
+
conv_block_depth_mult=2,
|
|
111
|
+
conv_block_dropout=0.3,
|
|
112
|
+
n_windows=5,
|
|
113
|
+
att_head_dim=8,
|
|
114
|
+
att_num_heads=2,
|
|
115
|
+
att_drop_prob=0.5,
|
|
116
|
+
tcn_depth=2,
|
|
117
|
+
tcn_kernel_size=4,
|
|
118
|
+
tcn_drop_prob=0.3,
|
|
119
|
+
tcn_activation: nn.Module = nn.ELU,
|
|
120
|
+
concat=False,
|
|
121
|
+
max_norm_const=0.25,
|
|
122
|
+
chs_info=None,
|
|
123
|
+
n_times=None,
|
|
131
124
|
):
|
|
132
|
-
n_chans, n_outputs, input_window_seconds = deprecated_args(
|
|
133
|
-
self,
|
|
134
|
-
('n_channels', 'n_chans', n_channels, n_chans),
|
|
135
|
-
('n_classes', 'n_outputs', n_classes, n_outputs),
|
|
136
|
-
('input_size_s', 'input_window_seconds', input_size_s, input_window_seconds),
|
|
137
|
-
)
|
|
138
125
|
super().__init__(
|
|
139
126
|
n_outputs=n_outputs,
|
|
140
127
|
n_chans=n_chans,
|
|
@@ -142,10 +129,47 @@ class ATCNet(EEGModuleMixin, nn.Module):
|
|
|
142
129
|
n_times=n_times,
|
|
143
130
|
input_window_seconds=input_window_seconds,
|
|
144
131
|
sfreq=sfreq,
|
|
145
|
-
add_log_softmax=add_log_softmax,
|
|
146
132
|
)
|
|
147
133
|
del n_outputs, n_chans, chs_info, n_times, input_window_seconds, sfreq
|
|
148
|
-
|
|
134
|
+
|
|
135
|
+
# Validate and adjust parameters based on input size
|
|
136
|
+
|
|
137
|
+
min_len_tcn = (tcn_kernel_size - 1) * (2 ** (tcn_depth - 1)) + 1
|
|
138
|
+
# Minimum length required to get at least one sliding window
|
|
139
|
+
min_len_sliding = n_windows + min_len_tcn - 1
|
|
140
|
+
# Minimum input size that produces the required feature map length
|
|
141
|
+
min_n_times = min_len_sliding * conv_block_pool_size_1 * conv_block_pool_size_2
|
|
142
|
+
|
|
143
|
+
# 2. If the input is shorter, calculate a scaling factor
|
|
144
|
+
if self.n_times < min_n_times:
|
|
145
|
+
scaling_factor = self.n_times / min_n_times
|
|
146
|
+
warn(
|
|
147
|
+
f"n_times ({self.n_times}) is smaller than the minimum required "
|
|
148
|
+
f"({min_n_times}) for the current model parameters configuration. "
|
|
149
|
+
"Adjusting parameters to ensure compatibility."
|
|
150
|
+
"Reducing the kernel, pooling, and stride sizes accordingly."
|
|
151
|
+
"Scaling factor: {:.2f}".format(scaling_factor),
|
|
152
|
+
UserWarning,
|
|
153
|
+
)
|
|
154
|
+
conv_block_kernel_length_1 = max(
|
|
155
|
+
1, int(conv_block_kernel_length_1 * scaling_factor)
|
|
156
|
+
)
|
|
157
|
+
conv_block_kernel_length_2 = max(
|
|
158
|
+
1, int(conv_block_kernel_length_2 * scaling_factor)
|
|
159
|
+
)
|
|
160
|
+
conv_block_pool_size_1 = max(
|
|
161
|
+
1, int(conv_block_pool_size_1 * scaling_factor)
|
|
162
|
+
)
|
|
163
|
+
conv_block_pool_size_2 = max(
|
|
164
|
+
1, int(conv_block_pool_size_2 * scaling_factor)
|
|
165
|
+
)
|
|
166
|
+
|
|
167
|
+
# n_windows should be at least 1
|
|
168
|
+
n_windows = max(1, int(n_windows * scaling_factor))
|
|
169
|
+
|
|
170
|
+
# tcn_kernel_size must be at least 2 for dilation to work
|
|
171
|
+
tcn_kernel_size = max(2, int(tcn_kernel_size * scaling_factor))
|
|
172
|
+
|
|
149
173
|
self.conv_block_n_filters = conv_block_n_filters
|
|
150
174
|
self.conv_block_kernel_length_1 = conv_block_kernel_length_1
|
|
151
175
|
self.conv_block_kernel_length_2 = conv_block_kernel_length_2
|
|
@@ -156,19 +180,18 @@ class ATCNet(EEGModuleMixin, nn.Module):
|
|
|
156
180
|
self.n_windows = n_windows
|
|
157
181
|
self.att_head_dim = att_head_dim
|
|
158
182
|
self.att_num_heads = att_num_heads
|
|
159
|
-
self.att_dropout =
|
|
183
|
+
self.att_dropout = att_drop_prob
|
|
160
184
|
self.tcn_depth = tcn_depth
|
|
161
185
|
self.tcn_kernel_size = tcn_kernel_size
|
|
162
|
-
self.
|
|
163
|
-
self.tcn_dropout = tcn_dropout
|
|
186
|
+
self.tcn_dropout = tcn_drop_prob
|
|
164
187
|
self.tcn_activation = tcn_activation
|
|
165
188
|
self.concat = concat
|
|
166
189
|
self.max_norm_const = max_norm_const
|
|
167
|
-
|
|
190
|
+
self.tcn_n_filters = int(self.conv_block_depth_mult * self.conv_block_n_filters)
|
|
168
191
|
map = dict()
|
|
169
192
|
for w in range(self.n_windows):
|
|
170
|
-
map[f
|
|
171
|
-
map[f
|
|
193
|
+
map[f"max_norm_linears.[{w}].weight"] = f"final_layer.[{w}].weight"
|
|
194
|
+
map[f"max_norm_linears.[{w}].bias"] = f"final_layer.[{w}].bias"
|
|
172
195
|
self.mapping = map
|
|
173
196
|
|
|
174
197
|
# Check later if we want to keep the Ensure4d. Not sure if we can
|
|
@@ -184,57 +207,67 @@ class ATCNet(EEGModuleMixin, nn.Module):
|
|
|
184
207
|
pool_size_1=conv_block_pool_size_1,
|
|
185
208
|
pool_size_2=conv_block_pool_size_2,
|
|
186
209
|
depth_mult=conv_block_depth_mult,
|
|
187
|
-
dropout=conv_block_dropout
|
|
210
|
+
dropout=conv_block_dropout,
|
|
188
211
|
)
|
|
189
212
|
|
|
190
213
|
self.F2 = int(conv_block_depth_mult * conv_block_n_filters)
|
|
191
|
-
self.Tc = int(self.
|
|
192
|
-
conv_block_pool_size_1 * conv_block_pool_size_2))
|
|
214
|
+
self.Tc = int(self.n_times / (conv_block_pool_size_1 * conv_block_pool_size_2))
|
|
193
215
|
self.Tw = self.Tc - self.n_windows + 1
|
|
194
216
|
|
|
195
|
-
self.attention_blocks = nn.ModuleList(
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
*[_TCNResidualBlock(
|
|
207
|
-
in_channels=self.F2,
|
|
208
|
-
kernel_size=tcn_kernel_size,
|
|
209
|
-
n_filters=tcn_n_filters,
|
|
210
|
-
dropout=tcn_dropout,
|
|
211
|
-
activation=tcn_activation,
|
|
212
|
-
dilation=2 ** i
|
|
213
|
-
) for i in range(tcn_depth)]
|
|
214
|
-
) for _ in range(self.n_windows)
|
|
215
|
-
])
|
|
217
|
+
self.attention_blocks = nn.ModuleList(
|
|
218
|
+
[
|
|
219
|
+
_AttentionBlock(
|
|
220
|
+
in_shape=self.F2,
|
|
221
|
+
head_dim=self.att_head_dim,
|
|
222
|
+
num_heads=att_num_heads,
|
|
223
|
+
dropout=att_drop_prob,
|
|
224
|
+
)
|
|
225
|
+
for _ in range(self.n_windows)
|
|
226
|
+
]
|
|
227
|
+
)
|
|
216
228
|
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
229
|
+
self.temporal_conv_nets = nn.ModuleList(
|
|
230
|
+
[
|
|
231
|
+
nn.Sequential(
|
|
232
|
+
*[
|
|
233
|
+
_TCNResidualBlock(
|
|
234
|
+
in_channels=self.F2,
|
|
235
|
+
kernel_size=self.tcn_kernel_size,
|
|
236
|
+
n_filters=self.tcn_n_filters,
|
|
237
|
+
dropout=self.tcn_dropout,
|
|
238
|
+
activation=self.tcn_activation,
|
|
239
|
+
dilation=2**i,
|
|
240
|
+
)
|
|
241
|
+
for i in range(self.tcn_depth)
|
|
242
|
+
]
|
|
223
243
|
)
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
|
|
244
|
+
for _ in range(self.n_windows)
|
|
245
|
+
]
|
|
246
|
+
)
|
|
247
|
+
|
|
248
|
+
if self.concat:
|
|
249
|
+
self.final_layer = nn.ModuleList(
|
|
250
|
+
[
|
|
251
|
+
MaxNormLinear(
|
|
252
|
+
in_features=self.F2 * self.n_windows,
|
|
253
|
+
out_features=self.n_outputs,
|
|
254
|
+
max_norm_val=self.max_norm_const,
|
|
255
|
+
)
|
|
256
|
+
]
|
|
257
|
+
)
|
|
236
258
|
else:
|
|
237
|
-
self.
|
|
259
|
+
self.final_layer = nn.ModuleList(
|
|
260
|
+
[
|
|
261
|
+
MaxNormLinear(
|
|
262
|
+
in_features=self.F2,
|
|
263
|
+
out_features=self.n_outputs,
|
|
264
|
+
max_norm_val=self.max_norm_const,
|
|
265
|
+
)
|
|
266
|
+
for _ in range(self.n_windows)
|
|
267
|
+
]
|
|
268
|
+
)
|
|
269
|
+
|
|
270
|
+
self.out_fun = nn.Identity()
|
|
238
271
|
|
|
239
272
|
def forward(self, X):
|
|
240
273
|
# Dimension: (batch_size, C, T)
|
|
@@ -250,43 +283,46 @@ class ATCNet(EEGModuleMixin, nn.Module):
|
|
|
250
283
|
# Dimension: (batch_size, F2, Tc)
|
|
251
284
|
|
|
252
285
|
# ----- Sliding window -----
|
|
253
|
-
sw_concat = [] # to store sliding window outputs
|
|
254
|
-
for w in range(self.n_windows):
|
|
255
|
-
|
|
286
|
+
sw_concat: list[torch.Tensor] = [] # to store sliding window outputs
|
|
287
|
+
# for w in range(self.n_windows):
|
|
288
|
+
for idx, (attention, tcn_module, final_layer) in enumerate(
|
|
289
|
+
zip(self.attention_blocks, self.temporal_conv_nets, self.final_layer)
|
|
290
|
+
):
|
|
291
|
+
conv_feat_w = conv_feat[..., idx : idx + self.Tw]
|
|
256
292
|
# Dimension: (batch_size, F2, Tw)
|
|
257
293
|
|
|
258
294
|
# ----- Attention block -----
|
|
259
|
-
att_feat =
|
|
295
|
+
att_feat = attention(conv_feat_w)
|
|
260
296
|
# Dimension: (batch_size, F2, Tw)
|
|
261
297
|
|
|
262
298
|
# ----- Temporal convolutional network (TCN) -----
|
|
263
|
-
tcn_feat =
|
|
299
|
+
tcn_feat = tcn_module(att_feat)[..., -1]
|
|
264
300
|
# Dimension: (batch_size, F2)
|
|
265
301
|
|
|
266
302
|
# Outputs of sliding window can be either averaged after being
|
|
267
303
|
# mapped by dense layer or concatenated then mapped by a dense
|
|
268
304
|
# layer
|
|
269
305
|
if not self.concat:
|
|
270
|
-
tcn_feat =
|
|
306
|
+
tcn_feat = final_layer(tcn_feat)
|
|
271
307
|
|
|
272
308
|
sw_concat.append(tcn_feat)
|
|
273
309
|
|
|
274
310
|
# ----- Aggregation and prediction -----
|
|
275
311
|
if self.concat:
|
|
276
|
-
|
|
277
|
-
|
|
312
|
+
sw_concat_agg = torch.cat(sw_concat, dim=1)
|
|
313
|
+
sw_concat_agg = self.final_layer[0](sw_concat_agg)
|
|
278
314
|
else:
|
|
279
315
|
if len(sw_concat) > 1: # more than one window
|
|
280
|
-
|
|
281
|
-
|
|
316
|
+
sw_concat_agg = torch.stack(sw_concat, dim=0)
|
|
317
|
+
sw_concat_agg = torch.mean(sw_concat_agg, dim=0)
|
|
282
318
|
else: # one window (# windows = 1)
|
|
283
|
-
|
|
319
|
+
sw_concat_agg = sw_concat[0]
|
|
284
320
|
|
|
285
|
-
return self.out_fun(
|
|
321
|
+
return self.out_fun(sw_concat_agg)
|
|
286
322
|
|
|
287
323
|
|
|
288
324
|
class _ConvBlock(nn.Module):
|
|
289
|
-
"""
|
|
325
|
+
"""Convolutional block proposed in ATCNet [1]_, inspired by the EEGNet
|
|
290
326
|
architecture [2]_.
|
|
291
327
|
|
|
292
328
|
References
|
|
@@ -303,15 +339,15 @@ class _ConvBlock(nn.Module):
|
|
|
303
339
|
"""
|
|
304
340
|
|
|
305
341
|
def __init__(
|
|
306
|
-
|
|
307
|
-
|
|
308
|
-
|
|
309
|
-
|
|
310
|
-
|
|
311
|
-
|
|
312
|
-
|
|
313
|
-
|
|
314
|
-
|
|
342
|
+
self,
|
|
343
|
+
n_channels,
|
|
344
|
+
n_filters=16,
|
|
345
|
+
kernel_length_1=64,
|
|
346
|
+
kernel_length_2=16,
|
|
347
|
+
pool_size_1=8,
|
|
348
|
+
pool_size_2=7,
|
|
349
|
+
depth_mult=2,
|
|
350
|
+
dropout=0.3,
|
|
315
351
|
):
|
|
316
352
|
super().__init__()
|
|
317
353
|
|
|
@@ -402,11 +438,11 @@ class _AttentionBlock(nn.Module):
|
|
|
402
438
|
"""
|
|
403
439
|
|
|
404
440
|
def __init__(
|
|
405
|
-
|
|
406
|
-
|
|
407
|
-
|
|
408
|
-
|
|
409
|
-
|
|
441
|
+
self,
|
|
442
|
+
in_shape=32,
|
|
443
|
+
head_dim=8,
|
|
444
|
+
num_heads=2,
|
|
445
|
+
dropout=0.5,
|
|
410
446
|
):
|
|
411
447
|
super().__init__()
|
|
412
448
|
self.in_shape = in_shape
|
|
@@ -462,7 +498,7 @@ class _AttentionBlock(nn.Module):
|
|
|
462
498
|
|
|
463
499
|
|
|
464
500
|
class _TCNResidualBlock(nn.Module):
|
|
465
|
-
"""
|
|
501
|
+
"""Modified TCN Residual block as proposed in [1]_. Inspired from
|
|
466
502
|
Temporal Convolutional Networks (TCN) [2]_.
|
|
467
503
|
|
|
468
504
|
References
|
|
@@ -477,16 +513,16 @@ class _TCNResidualBlock(nn.Module):
|
|
|
477
513
|
"""
|
|
478
514
|
|
|
479
515
|
def __init__(
|
|
480
|
-
|
|
481
|
-
|
|
482
|
-
|
|
483
|
-
|
|
484
|
-
|
|
485
|
-
|
|
486
|
-
|
|
516
|
+
self,
|
|
517
|
+
in_channels,
|
|
518
|
+
kernel_size=4,
|
|
519
|
+
n_filters=32,
|
|
520
|
+
dropout=0.3,
|
|
521
|
+
activation: nn.Module = nn.ELU,
|
|
522
|
+
dilation=1,
|
|
487
523
|
):
|
|
488
524
|
super().__init__()
|
|
489
|
-
self.activation = activation
|
|
525
|
+
self.activation = activation()
|
|
490
526
|
self.dilation = dilation
|
|
491
527
|
self.dropout = dropout
|
|
492
528
|
self.n_filters = n_filters
|
|
@@ -522,7 +558,7 @@ class _TCNResidualBlock(nn.Module):
|
|
|
522
558
|
self.reshaping_conv = nn.Conv1d(
|
|
523
559
|
n_filters,
|
|
524
560
|
kernel_size=1,
|
|
525
|
-
padding=
|
|
561
|
+
padding="same",
|
|
526
562
|
)
|
|
527
563
|
else:
|
|
528
564
|
self.reshaping_conv = nn.Identity()
|
|
@@ -550,12 +586,12 @@ class _TCNResidualBlock(nn.Module):
|
|
|
550
586
|
|
|
551
587
|
class _MHA(nn.Module):
|
|
552
588
|
def __init__(
|
|
553
|
-
|
|
554
|
-
|
|
555
|
-
|
|
556
|
-
|
|
557
|
-
|
|
558
|
-
|
|
589
|
+
self,
|
|
590
|
+
input_dim: int,
|
|
591
|
+
head_dim: int,
|
|
592
|
+
output_dim: int,
|
|
593
|
+
num_heads: int,
|
|
594
|
+
dropout: float = 0.0,
|
|
559
595
|
):
|
|
560
596
|
"""Multi-head Attention
|
|
561
597
|
|
|
@@ -598,12 +634,9 @@ class _MHA(nn.Module):
|
|
|
598
634
|
self.dropout = nn.Dropout(dropout)
|
|
599
635
|
|
|
600
636
|
def forward(
|
|
601
|
-
|
|
602
|
-
Q: torch.Tensor,
|
|
603
|
-
K: torch.Tensor,
|
|
604
|
-
V: torch.Tensor
|
|
637
|
+
self, Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor
|
|
605
638
|
) -> torch.Tensor:
|
|
606
|
-
"""
|
|
639
|
+
"""Compute MHA(Q, K, V)
|
|
607
640
|
|
|
608
641
|
Parameters
|
|
609
642
|
----------
|
|
@@ -635,22 +668,18 @@ class _MHA(nn.Module):
|
|
|
635
668
|
# Attention weights of size (num_heads * batch_size, n, m):
|
|
636
669
|
# measures how similar each pair of Q and K is.
|
|
637
670
|
W = torch.softmax(
|
|
638
|
-
Q_.bmm(
|
|
639
|
-
|
|
640
|
-
)
|
|
641
|
-
/ np.sqrt(self.head_dim),
|
|
642
|
-
-1
|
|
671
|
+
Q_.bmm(K_.transpose(-2, -1)) / math.sqrt(self.head_dim),
|
|
672
|
+
-1, # (B', D', S)
|
|
643
673
|
) # (B', N, M)
|
|
644
674
|
|
|
645
675
|
# Multihead output (batch_size, seq_len, dim):
|
|
646
676
|
# weighted sum of V where a value gets more weight if its corresponding
|
|
647
677
|
# key has larger dot product with the query.
|
|
648
678
|
H = torch.cat(
|
|
649
|
-
(
|
|
650
|
-
|
|
651
|
-
|
|
652
|
-
|
|
653
|
-
-1
|
|
679
|
+
(W.bmm(V_)).split( # (B', S, S) # (B', S, D')
|
|
680
|
+
batch_size, 0
|
|
681
|
+
), # [(B, S, D')] * num_heads
|
|
682
|
+
-1,
|
|
654
683
|
) # (B, S, D)
|
|
655
684
|
|
|
656
685
|
out = self.fc_o(H)
|