braindecode 0.8__py3-none-any.whl → 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of braindecode might be problematic. Click here for more details.

Files changed (102) hide show
  1. braindecode/__init__.py +1 -2
  2. braindecode/augmentation/__init__.py +50 -0
  3. braindecode/augmentation/base.py +222 -0
  4. braindecode/augmentation/functional.py +1096 -0
  5. braindecode/augmentation/transforms.py +1274 -0
  6. braindecode/classifier.py +26 -24
  7. braindecode/datasets/__init__.py +34 -0
  8. braindecode/datasets/base.py +840 -0
  9. braindecode/datasets/bbci.py +694 -0
  10. braindecode/datasets/bcicomp.py +194 -0
  11. braindecode/datasets/bids.py +245 -0
  12. braindecode/datasets/mne.py +172 -0
  13. braindecode/datasets/moabb.py +209 -0
  14. braindecode/datasets/nmt.py +311 -0
  15. braindecode/datasets/sleep_physio_challe_18.py +412 -0
  16. braindecode/datasets/sleep_physionet.py +125 -0
  17. braindecode/datasets/tuh.py +588 -0
  18. braindecode/datasets/xy.py +95 -0
  19. braindecode/datautil/__init__.py +49 -0
  20. braindecode/datautil/serialization.py +342 -0
  21. braindecode/datautil/util.py +41 -0
  22. braindecode/eegneuralnet.py +63 -47
  23. braindecode/functional/__init__.py +10 -0
  24. braindecode/functional/functions.py +251 -0
  25. braindecode/functional/initialization.py +47 -0
  26. braindecode/models/__init__.py +52 -0
  27. braindecode/models/atcnet.py +652 -0
  28. braindecode/models/attentionbasenet.py +550 -0
  29. braindecode/models/base.py +296 -0
  30. braindecode/models/biot.py +483 -0
  31. braindecode/models/contrawr.py +296 -0
  32. braindecode/models/ctnet.py +450 -0
  33. braindecode/models/deep4.py +322 -0
  34. braindecode/models/deepsleepnet.py +295 -0
  35. braindecode/models/eegconformer.py +372 -0
  36. braindecode/models/eeginception_erp.py +304 -0
  37. braindecode/models/eeginception_mi.py +371 -0
  38. braindecode/models/eegitnet.py +301 -0
  39. braindecode/models/eegminer.py +255 -0
  40. braindecode/models/eegnet.py +473 -0
  41. braindecode/models/eegnex.py +247 -0
  42. braindecode/models/eegresnet.py +362 -0
  43. braindecode/models/eegsimpleconv.py +199 -0
  44. braindecode/models/eegtcnet.py +335 -0
  45. braindecode/models/fbcnet.py +221 -0
  46. braindecode/models/fblightconvnet.py +313 -0
  47. braindecode/models/fbmsnet.py +325 -0
  48. braindecode/models/hybrid.py +126 -0
  49. braindecode/models/ifnet.py +441 -0
  50. braindecode/models/labram.py +1166 -0
  51. braindecode/models/msvtnet.py +375 -0
  52. braindecode/models/sccnet.py +182 -0
  53. braindecode/models/shallow_fbcsp.py +208 -0
  54. braindecode/models/signal_jepa.py +1012 -0
  55. braindecode/models/sinc_shallow.py +337 -0
  56. braindecode/models/sleep_stager_blanco_2020.py +167 -0
  57. braindecode/models/sleep_stager_chambon_2018.py +157 -0
  58. braindecode/models/sleep_stager_eldele_2021.py +536 -0
  59. braindecode/models/sparcnet.py +378 -0
  60. braindecode/models/summary.csv +41 -0
  61. braindecode/models/syncnet.py +232 -0
  62. braindecode/models/tcn.py +273 -0
  63. braindecode/models/tidnet.py +395 -0
  64. braindecode/models/tsinception.py +258 -0
  65. braindecode/models/usleep.py +340 -0
  66. braindecode/models/util.py +133 -0
  67. braindecode/modules/__init__.py +38 -0
  68. braindecode/modules/activation.py +60 -0
  69. braindecode/modules/attention.py +757 -0
  70. braindecode/modules/blocks.py +108 -0
  71. braindecode/modules/convolution.py +274 -0
  72. braindecode/modules/filter.py +632 -0
  73. braindecode/modules/layers.py +133 -0
  74. braindecode/modules/linear.py +50 -0
  75. braindecode/modules/parametrization.py +38 -0
  76. braindecode/modules/stats.py +77 -0
  77. braindecode/modules/util.py +77 -0
  78. braindecode/modules/wrapper.py +75 -0
  79. braindecode/preprocessing/__init__.py +37 -0
  80. braindecode/preprocessing/mne_preprocess.py +77 -0
  81. braindecode/preprocessing/preprocess.py +478 -0
  82. braindecode/preprocessing/windowers.py +1031 -0
  83. braindecode/regressor.py +23 -12
  84. braindecode/samplers/__init__.py +18 -0
  85. braindecode/samplers/base.py +401 -0
  86. braindecode/samplers/ssl.py +263 -0
  87. braindecode/training/__init__.py +23 -0
  88. braindecode/training/callbacks.py +23 -0
  89. braindecode/training/losses.py +105 -0
  90. braindecode/training/scoring.py +483 -0
  91. braindecode/util.py +55 -59
  92. braindecode/version.py +1 -1
  93. braindecode/visualization/__init__.py +8 -0
  94. braindecode/visualization/confusion_matrices.py +289 -0
  95. braindecode/visualization/gradients.py +57 -0
  96. {braindecode-0.8.dist-info → braindecode-1.0.0.dist-info}/METADATA +39 -55
  97. braindecode-1.0.0.dist-info/RECORD +101 -0
  98. {braindecode-0.8.dist-info → braindecode-1.0.0.dist-info}/WHEEL +1 -1
  99. {braindecode-0.8.dist-info → braindecode-1.0.0.dist-info/licenses}/LICENSE.txt +1 -1
  100. braindecode-1.0.0.dist-info/licenses/NOTICE.txt +20 -0
  101. braindecode-0.8.dist-info/RECORD +0 -11
  102. {braindecode-0.8.dist-info → braindecode-1.0.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,652 @@
1
+ # Authors: Cedric Rommel <cedric.rommel@inria.fr>
2
+ #
3
+ # License: BSD (3-clause)
4
+ import math
5
+
6
+ import torch
7
+ from einops.layers.torch import Rearrange
8
+ from torch import nn
9
+
10
+ from braindecode.models.base import EEGModuleMixin
11
+ from braindecode.modules import CausalConv1d, Ensure4d, MaxNormLinear
12
+
13
+
14
+ class ATCNet(EEGModuleMixin, nn.Module):
15
+ """ATCNet model from Altaheri et al. (2022) [1]_
16
+
17
+ Pytorch implementation based on official tensorflow code [2]_.
18
+
19
+ .. figure:: https://user-images.githubusercontent.com/25565236/185449791-e8539453-d4fa-41e1-865a-2cf7e91f60ef.png
20
+ :align: center
21
+ :alt: ATCNet Architecture
22
+
23
+ Parameters
24
+ ----------
25
+ input_window_seconds : float, optional
26
+ Time length of inputs, in seconds. Defaults to 4.5 s, as in BCI-IV 2a
27
+ dataset.
28
+ sfreq : int, optional
29
+ Sampling frequency of the inputs, in Hz. Default to 250 Hz, as in
30
+ BCI-IV 2a dataset.
31
+ conv_block_n_filters : int
32
+ Number temporal filters in the first convolutional layer of the
33
+ convolutional block, denoted F1 in figure 2 of the paper [1]_. Defaults
34
+ to 16 as in [1]_.
35
+ conv_block_kernel_length_1 : int
36
+ Length of temporal filters in the first convolutional layer of the
37
+ convolutional block, denoted Kc in table 1 of the paper [1]_. Defaults
38
+ to 64 as in [1]_.
39
+ conv_block_kernel_length_2 : int
40
+ Length of temporal filters in the last convolutional layer of the
41
+ convolutional block. Defaults to 16 as in [1]_.
42
+ conv_block_pool_size_1 : int
43
+ Length of first average pooling kernel in the convolutional block.
44
+ Defaults to 8 as in [1]_.
45
+ conv_block_pool_size_2 : int
46
+ Length of first average pooling kernel in the convolutional block,
47
+ denoted P2 in table 1 of the paper [1]_. Defaults to 7 as in [1]_.
48
+ conv_block_depth_mult : int
49
+ Depth multiplier of depthwise convolution in the convolutional block,
50
+ denoted D in table 1 of the paper [1]_. Defaults to 2 as in [1]_.
51
+ conv_block_dropout : float
52
+ Dropout probability used in the convolution block, denoted pc in
53
+ table 1 of the paper [1]_. Defaults to 0.3 as in [1]_.
54
+ n_windows : int
55
+ Number of sliding windows, denoted n in [1]_. Defaults to 5 as in [1]_.
56
+ att_head_dim : int
57
+ Embedding dimension used in each self-attention head, denoted dh in
58
+ table 1 of the paper [1]_. Defaults to 8 as in [1]_.
59
+ att_num_heads : int
60
+ Number of attention heads, denoted H in table 1 of the paper [1]_.
61
+ Defaults to 2 as in [1]_.
62
+ att_dropout : float
63
+ Dropout probability used in the attention block, denoted pa in table 1
64
+ of the paper [1]_. Defaults to 0.5 as in [1]_.
65
+ tcn_depth : int
66
+ Depth of Temporal Convolutional Network block (i.e. number of TCN
67
+ Residual blocks), denoted L in table 1 of the paper [1]_. Defaults to 2
68
+ as in [1]_.
69
+ tcn_kernel_size : int
70
+ Temporal kernel size used in TCN block, denoted Kt in table 1 of the
71
+ paper [1]_. Defaults to 4 as in [1]_.
72
+ tcn_n_filters : int
73
+ Number of filters used in TCN convolutional layers (Ft). Defaults to
74
+ 32 as in [1]_.
75
+ tcn_dropout : float
76
+ Dropout probability used in the TCN block, denoted pt in table 1
77
+ of the paper [1]_. Defaults to 0.3 as in [1]_.
78
+ tcn_activation : torch.nn.Module
79
+ Nonlinear activation to use. Defaults to nn.ELU().
80
+ concat : bool
81
+ When ``True``, concatenates each slidding window embedding before
82
+ feeding it to a fully-connected layer, as done in [1]_. When ``False``,
83
+ maps each slidding window to `n_outputs` logits and average them.
84
+ Defaults to ``False`` contrary to what is reported in [1]_, but
85
+ matching what the official code does [2]_.
86
+ max_norm_const : float
87
+ Maximum L2-norm constraint imposed on weights of the last
88
+ fully-connected layer. Defaults to 0.25.
89
+
90
+
91
+ References
92
+ ----------
93
+ .. [1] H. Altaheri, G. Muhammad and M. Alsulaiman,
94
+ Physics-informed attention temporal convolutional network for EEG-based
95
+ motor imagery classification in IEEE Transactions on Industrial Informatics,
96
+ 2022, doi: 10.1109/TII.2022.3197419.
97
+ .. [2] EEE-ATCNet implementation.
98
+ https://github.com/Altaheri/EEG-ATCNet/blob/main/models.py
99
+ """
100
+
101
+ def __init__(
102
+ self,
103
+ n_chans=None,
104
+ n_outputs=None,
105
+ input_window_seconds=None,
106
+ sfreq=250.0,
107
+ conv_block_n_filters=16,
108
+ conv_block_kernel_length_1=64,
109
+ conv_block_kernel_length_2=16,
110
+ conv_block_pool_size_1=8,
111
+ conv_block_pool_size_2=7,
112
+ conv_block_depth_mult=2,
113
+ conv_block_dropout=0.3,
114
+ n_windows=5,
115
+ att_head_dim=8,
116
+ att_num_heads=2,
117
+ att_drop_prob=0.5,
118
+ tcn_depth=2,
119
+ tcn_kernel_size=4,
120
+ tcn_n_filters=32,
121
+ tcn_drop_prob=0.3,
122
+ tcn_activation: nn.Module = nn.ELU,
123
+ concat=False,
124
+ max_norm_const=0.25,
125
+ chs_info=None,
126
+ n_times=None,
127
+ ):
128
+ super().__init__(
129
+ n_outputs=n_outputs,
130
+ n_chans=n_chans,
131
+ chs_info=chs_info,
132
+ n_times=n_times,
133
+ input_window_seconds=input_window_seconds,
134
+ sfreq=sfreq,
135
+ )
136
+ del n_outputs, n_chans, chs_info, n_times, input_window_seconds, sfreq
137
+ self.conv_block_n_filters = conv_block_n_filters
138
+ self.conv_block_kernel_length_1 = conv_block_kernel_length_1
139
+ self.conv_block_kernel_length_2 = conv_block_kernel_length_2
140
+ self.conv_block_pool_size_1 = conv_block_pool_size_1
141
+ self.conv_block_pool_size_2 = conv_block_pool_size_2
142
+ self.conv_block_depth_mult = conv_block_depth_mult
143
+ self.conv_block_dropout = conv_block_dropout
144
+ self.n_windows = n_windows
145
+ self.att_head_dim = att_head_dim
146
+ self.att_num_heads = att_num_heads
147
+ self.att_dropout = att_drop_prob
148
+ self.tcn_depth = tcn_depth
149
+ self.tcn_kernel_size = tcn_kernel_size
150
+ self.tcn_n_filters = tcn_n_filters
151
+ self.tcn_dropout = tcn_drop_prob
152
+ self.tcn_activation = tcn_activation
153
+ self.concat = concat
154
+ self.max_norm_const = max_norm_const
155
+
156
+ map = dict()
157
+ for w in range(self.n_windows):
158
+ map[f"max_norm_linears.[{w}].weight"] = f"final_layer.[{w}].weight"
159
+ map[f"max_norm_linears.[{w}].bias"] = f"final_layer.[{w}].bias"
160
+ self.mapping = map
161
+
162
+ # Check later if we want to keep the Ensure4d. Not sure if we can
163
+ # remove it or replace it with eipsum.
164
+ self.ensuredims = Ensure4d()
165
+ self.dimshuffle = Rearrange("batch C T 1 -> batch 1 T C")
166
+
167
+ self.conv_block = _ConvBlock(
168
+ n_channels=self.n_chans, # input shape: (batch_size, 1, T, C)
169
+ n_filters=conv_block_n_filters,
170
+ kernel_length_1=conv_block_kernel_length_1,
171
+ kernel_length_2=conv_block_kernel_length_2,
172
+ pool_size_1=conv_block_pool_size_1,
173
+ pool_size_2=conv_block_pool_size_2,
174
+ depth_mult=conv_block_depth_mult,
175
+ dropout=conv_block_dropout,
176
+ )
177
+
178
+ self.F2 = int(conv_block_depth_mult * conv_block_n_filters)
179
+ self.Tc = int(self.n_times / (conv_block_pool_size_1 * conv_block_pool_size_2))
180
+ self.Tw = self.Tc - self.n_windows + 1
181
+
182
+ self.attention_blocks = nn.ModuleList(
183
+ [
184
+ _AttentionBlock(
185
+ in_shape=self.F2,
186
+ head_dim=self.att_head_dim,
187
+ num_heads=att_num_heads,
188
+ dropout=att_drop_prob,
189
+ )
190
+ for _ in range(self.n_windows)
191
+ ]
192
+ )
193
+
194
+ self.temporal_conv_nets = nn.ModuleList(
195
+ [
196
+ nn.Sequential(
197
+ *[
198
+ _TCNResidualBlock(
199
+ in_channels=self.F2,
200
+ kernel_size=tcn_kernel_size,
201
+ n_filters=tcn_n_filters,
202
+ dropout=tcn_drop_prob,
203
+ activation=tcn_activation,
204
+ dilation=2**i,
205
+ )
206
+ for i in range(tcn_depth)
207
+ ]
208
+ )
209
+ for _ in range(self.n_windows)
210
+ ]
211
+ )
212
+
213
+ if self.concat:
214
+ self.final_layer = nn.ModuleList(
215
+ [
216
+ MaxNormLinear(
217
+ in_features=self.F2 * self.n_windows,
218
+ out_features=self.n_outputs,
219
+ max_norm_val=self.max_norm_const,
220
+ )
221
+ ]
222
+ )
223
+ else:
224
+ self.final_layer = nn.ModuleList(
225
+ [
226
+ MaxNormLinear(
227
+ in_features=self.F2,
228
+ out_features=self.n_outputs,
229
+ max_norm_val=self.max_norm_const,
230
+ )
231
+ for _ in range(self.n_windows)
232
+ ]
233
+ )
234
+
235
+ self.out_fun = nn.Identity()
236
+
237
+ def forward(self, X):
238
+ # Dimension: (batch_size, C, T)
239
+ X = self.ensuredims(X)
240
+ # Dimension: (batch_size, C, T, 1)
241
+ X = self.dimshuffle(X)
242
+ # Dimension: (batch_size, 1, T, C)
243
+
244
+ # ----- Sliding window -----
245
+ conv_feat = self.conv_block(X)
246
+ # Dimension: (batch_size, F2, Tc, 1)
247
+ conv_feat = conv_feat.view(-1, self.F2, self.Tc)
248
+ # Dimension: (batch_size, F2, Tc)
249
+
250
+ # ----- Sliding window -----
251
+ sw_concat: list[torch.Tensor] = [] # to store sliding window outputs
252
+ # for w in range(self.n_windows):
253
+ for idx, (attention, tcn_module, final_layer) in enumerate(
254
+ zip(self.attention_blocks, self.temporal_conv_nets, self.final_layer)
255
+ ):
256
+ conv_feat_w = conv_feat[..., idx : idx + self.Tw]
257
+ # Dimension: (batch_size, F2, Tw)
258
+
259
+ # ----- Attention block -----
260
+ att_feat = attention(conv_feat_w)
261
+ # Dimension: (batch_size, F2, Tw)
262
+
263
+ # ----- Temporal convolutional network (TCN) -----
264
+ tcn_feat = tcn_module(att_feat)[..., -1]
265
+ # Dimension: (batch_size, F2)
266
+
267
+ # Outputs of sliding window can be either averaged after being
268
+ # mapped by dense layer or concatenated then mapped by a dense
269
+ # layer
270
+ if not self.concat:
271
+ tcn_feat = final_layer(tcn_feat)
272
+
273
+ sw_concat.append(tcn_feat)
274
+
275
+ # ----- Aggregation and prediction -----
276
+ if self.concat:
277
+ sw_concat_agg = torch.cat(sw_concat, dim=1)
278
+ sw_concat_agg = self.final_layer[0](sw_concat_agg)
279
+ else:
280
+ if len(sw_concat) > 1: # more than one window
281
+ sw_concat_agg = torch.stack(sw_concat, dim=0)
282
+ sw_concat_agg = torch.mean(sw_concat_agg, dim=0)
283
+ else: # one window (# windows = 1)
284
+ sw_concat_agg = sw_concat[0]
285
+
286
+ return self.out_fun(sw_concat_agg)
287
+
288
+
289
+ class _ConvBlock(nn.Module):
290
+ """Convolutional block proposed in ATCNet [1]_, inspired by the EEGNet
291
+ architecture [2]_.
292
+
293
+ References
294
+ ----------
295
+ .. [1] H. Altaheri, G. Muhammad and M. Alsulaiman, "Physics-informed
296
+ attention temporal convolutional network for EEG-based motor imagery
297
+ classification," in IEEE Transactions on Industrial Informatics,
298
+ 2022, doi: 10.1109/TII.2022.3197419.
299
+ .. [2] Lawhern, V. J., Solon, A. J., Waytowich, N. R., Gordon,
300
+ S. M., Hung, C. P., & Lance, B. J. (2018).
301
+ EEGNet: A Compact Convolutional Network for EEG-based
302
+ Brain-Computer Interfaces.
303
+ arXiv preprint arXiv:1611.08024.
304
+ """
305
+
306
+ def __init__(
307
+ self,
308
+ n_channels,
309
+ n_filters=16,
310
+ kernel_length_1=64,
311
+ kernel_length_2=16,
312
+ pool_size_1=8,
313
+ pool_size_2=7,
314
+ depth_mult=2,
315
+ dropout=0.3,
316
+ ):
317
+ super().__init__()
318
+
319
+ self.conv1 = nn.Conv2d(
320
+ in_channels=1,
321
+ out_channels=n_filters,
322
+ kernel_size=(kernel_length_1, 1),
323
+ padding="same",
324
+ bias=False,
325
+ )
326
+
327
+ self.bn1 = nn.BatchNorm2d(num_features=n_filters, eps=1e-4)
328
+
329
+ n_depth_kernels = n_filters * depth_mult
330
+ self.conv2 = nn.Conv2d(
331
+ in_channels=n_filters,
332
+ out_channels=n_depth_kernels,
333
+ groups=n_filters,
334
+ kernel_size=(1, n_channels),
335
+ padding="valid",
336
+ bias=False,
337
+ )
338
+
339
+ self.bn2 = nn.BatchNorm2d(num_features=n_depth_kernels, eps=1e-4)
340
+
341
+ self.activation2 = nn.ELU()
342
+
343
+ self.pool2 = nn.AvgPool2d(kernel_size=(pool_size_1, 1))
344
+
345
+ self.drop2 = nn.Dropout2d(dropout)
346
+
347
+ self.conv3 = nn.Conv2d(
348
+ in_channels=n_depth_kernels,
349
+ out_channels=n_depth_kernels,
350
+ kernel_size=(kernel_length_2, 1),
351
+ padding="same",
352
+ bias=False,
353
+ )
354
+
355
+ self.bn3 = nn.BatchNorm2d(num_features=n_depth_kernels, eps=1e-4)
356
+
357
+ self.activation3 = nn.ELU()
358
+
359
+ self.pool3 = nn.AvgPool2d(kernel_size=(pool_size_2, 1))
360
+
361
+ self.drop3 = nn.Dropout2d(dropout)
362
+
363
+ def forward(self, X):
364
+ # ----- Temporal convolution -----
365
+ # Dimension: (batch_size, 1, T, C)
366
+ X = self.conv1(X)
367
+ X = self.bn1(X)
368
+ # Dimension: (batch_size, F1, T, C)
369
+
370
+ # ----- Depthwise channels convolution -----
371
+ X = self.conv2(X)
372
+ X = self.bn2(X)
373
+ X = self.activation2(X)
374
+ # Dimension: (batch_size, F1*D, T, 1)
375
+ X = self.pool2(X)
376
+ X = self.drop2(X)
377
+ # Dimension: (batch_size, F1*D, T/P1, 1)
378
+
379
+ # ----- "Spatial" convolution -----
380
+ X = self.conv3(X)
381
+ X = self.bn3(X)
382
+ X = self.activation3(X)
383
+ # Dimension: (batch_size, F1*D, T/P1, 1)
384
+ X = self.pool3(X)
385
+ X = self.drop3(X)
386
+ # Dimension: (batch_size, F1*D, T/(P1*P2), 1)
387
+
388
+ return X
389
+
390
+
391
+ class _AttentionBlock(nn.Module):
392
+ """Multi Head self Attention (MHA) block used in ATCNet [1]_, inspired from
393
+ [2]_.
394
+
395
+ References
396
+ ----------
397
+ .. [1] H. Altaheri, G. Muhammad and M. Alsulaiman, "Physics-informed
398
+ attention temporal convolutional network for EEG-based motor imagery
399
+ classification," in IEEE Transactions on Industrial Informatics,
400
+ 2022, doi: 10.1109/TII.2022.3197419.
401
+ .. [2] Vaswani, A. et al., "Attention is all you need",
402
+ in Advances in neural information processing systems, 2017.
403
+ """
404
+
405
+ def __init__(
406
+ self,
407
+ in_shape=32,
408
+ head_dim=8,
409
+ num_heads=2,
410
+ dropout=0.5,
411
+ ):
412
+ super().__init__()
413
+ self.in_shape = in_shape
414
+ self.head_dim = head_dim
415
+ self.num_heads = num_heads
416
+
417
+ # Puts time dimension at -2 and feature dim at -1
418
+ self.dimshuffle = Rearrange("batch C T -> batch T C")
419
+
420
+ # Layer normalization
421
+ self.ln = nn.LayerNorm(normalized_shape=in_shape, eps=1e-6)
422
+
423
+ # Multi-head self-attention layer
424
+ # (We had to reimplement it since the original code is in tensorflow,
425
+ # where it is possible to have an embedding dimension different than
426
+ # the input and output dimensions, which is not possible in pytorch.)
427
+ self.mha = _MHA(
428
+ input_dim=in_shape,
429
+ head_dim=head_dim,
430
+ output_dim=in_shape,
431
+ num_heads=num_heads,
432
+ dropout=dropout,
433
+ )
434
+
435
+ # XXX: This line in the official code is weird, as there is already
436
+ # dropout in the MultiheadAttention layer. They also don't mention
437
+ # any additional dropout between the attention block and TCN in the
438
+ # paper. We are adding it here however to follo so we are removing this
439
+ # for now.
440
+ self.drop = nn.Dropout(0.3)
441
+
442
+ def forward(self, X):
443
+ # Dimension: (batch_size, F2, Tw)
444
+ X = self.dimshuffle(X)
445
+ # Dimension: (batch_size, Tw, F2)
446
+
447
+ # ----- Layer norm -----
448
+ out = self.ln(X)
449
+
450
+ # ----- Self-Attention -----
451
+ out = self.mha(out, out, out)
452
+ # Dimension: (batch_size, Tw, F2)
453
+
454
+ # XXX In the paper fig. 1, it is drawn that layer normalization is
455
+ # performed before the skip connection, while it is done afterwards
456
+ # in the official code. Here we follow the code.
457
+
458
+ # ----- Skip connection -----
459
+ out = X + self.drop(out)
460
+
461
+ # Move back to shape (batch_size, F2, Tw) from the beginning
462
+ return self.dimshuffle(out)
463
+
464
+
465
+ class _TCNResidualBlock(nn.Module):
466
+ """Modified TCN Residual block as proposed in [1]_. Inspired from
467
+ Temporal Convolutional Networks (TCN) [2]_.
468
+
469
+ References
470
+ ----------
471
+ .. [1] H. Altaheri, G. Muhammad and M. Alsulaiman, "Physics-informed
472
+ attention temporal convolutional network for EEG-based motor imagery
473
+ classification," in IEEE Transactions on Industrial Informatics,
474
+ 2022, doi: 10.1109/TII.2022.3197419.
475
+ .. [2] Bai, S., Kolter, J. Z., & Koltun, V.
476
+ "An empirical evaluation of generic convolutional and recurrent
477
+ networks for sequence modeling", 2018.
478
+ """
479
+
480
+ def __init__(
481
+ self,
482
+ in_channels,
483
+ kernel_size=4,
484
+ n_filters=32,
485
+ dropout=0.3,
486
+ activation: nn.Module = nn.ELU,
487
+ dilation=1,
488
+ ):
489
+ super().__init__()
490
+ self.activation = activation()
491
+ self.dilation = dilation
492
+ self.dropout = dropout
493
+ self.n_filters = n_filters
494
+ self.kernel_size = kernel_size
495
+ self.in_channels = in_channels
496
+
497
+ self.conv1 = CausalConv1d(
498
+ in_channels=in_channels,
499
+ out_channels=n_filters,
500
+ kernel_size=kernel_size,
501
+ dilation=dilation,
502
+ )
503
+ nn.init.kaiming_uniform_(self.conv1.weight)
504
+
505
+ self.bn1 = nn.BatchNorm1d(n_filters)
506
+
507
+ self.drop1 = nn.Dropout(dropout)
508
+
509
+ self.conv2 = CausalConv1d(
510
+ in_channels=n_filters,
511
+ out_channels=n_filters,
512
+ kernel_size=kernel_size,
513
+ dilation=dilation,
514
+ )
515
+ nn.init.kaiming_uniform_(self.conv2.weight)
516
+
517
+ self.bn2 = nn.BatchNorm1d(n_filters)
518
+
519
+ self.drop2 = nn.Dropout(dropout)
520
+
521
+ # Reshape the input for the residual connection when necessary
522
+ if in_channels != n_filters:
523
+ self.reshaping_conv = nn.Conv1d(
524
+ n_filters,
525
+ kernel_size=1,
526
+ padding="same",
527
+ )
528
+ else:
529
+ self.reshaping_conv = nn.Identity()
530
+
531
+ def forward(self, X):
532
+ # Dimension: (batch_size, F2, Tw)
533
+ # ----- Double dilated convolutions -----
534
+ out = self.conv1(X)
535
+ out = self.bn1(out)
536
+ out = self.activation(out)
537
+ out = self.drop1(out)
538
+
539
+ out = self.conv2(out)
540
+ out = self.bn2(out)
541
+ out = self.activation(out)
542
+ out = self.drop2(out)
543
+
544
+ out = self.reshaping_conv(out)
545
+
546
+ # ----- Residual connection -----
547
+ out = X + out
548
+
549
+ return self.activation(out)
550
+
551
+
552
+ class _MHA(nn.Module):
553
+ def __init__(
554
+ self,
555
+ input_dim: int,
556
+ head_dim: int,
557
+ output_dim: int,
558
+ num_heads: int,
559
+ dropout: float = 0.0,
560
+ ):
561
+ """Multi-head Attention
562
+
563
+ The difference between this module and torch.nn.MultiheadAttention is
564
+ that this module supports embedding dimensions different then input
565
+ and output ones. It also does not support sequences of different
566
+ length.
567
+
568
+ Parameters
569
+ ----------
570
+ input_dim : int
571
+ Dimension of query, key and value inputs.
572
+ head_dim : int
573
+ Dimension of embed query, key and value in each head,
574
+ before computing attention.
575
+ output_dim : int
576
+ Output dimension.
577
+ num_heads : int
578
+ Number of heads in the multi-head architecture.
579
+ dropout : float, optional
580
+ Dropout probability on output weights. Default: 0.0 (no dropout).
581
+ """
582
+
583
+ super(_MHA, self).__init__()
584
+
585
+ self.input_dim = input_dim
586
+ self.head_dim = head_dim
587
+ # typical choice for the split dimension of the heads
588
+ self.embed_dim = head_dim * num_heads
589
+
590
+ # embeddings for multi-head projections
591
+ self.fc_q = nn.Linear(input_dim, self.embed_dim)
592
+ self.fc_k = nn.Linear(input_dim, self.embed_dim)
593
+ self.fc_v = nn.Linear(input_dim, self.embed_dim)
594
+
595
+ # output mapping
596
+ self.fc_o = nn.Linear(self.embed_dim, output_dim)
597
+
598
+ # dropout
599
+ self.dropout = nn.Dropout(dropout)
600
+
601
+ def forward(
602
+ self, Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor
603
+ ) -> torch.Tensor:
604
+ """Compute MHA(Q, K, V)
605
+
606
+ Parameters
607
+ ----------
608
+ Q: torch.Tensor of size (batch_size, seq_len, input_dim)
609
+ Input query (Q) sequence.
610
+ K: torch.Tensor of size (batch_size, seq_len, input_dim)
611
+ Input key (K) sequence.
612
+ V: torch.Tensor of size (batch_size, seq_len, input_dim)
613
+ Input value (V) sequence.
614
+
615
+ Returns
616
+ -------
617
+ O: torch.Tensor of size (batch_size, seq_len, output_dim)
618
+ Output MHA(Q, K, V)
619
+ """
620
+ assert Q.shape[-1] == K.shape[-1] == V.shape[-1] == self.input_dim
621
+
622
+ batch_size, _, _ = Q.shape
623
+
624
+ # embedding for multi-head projections (masked or not)
625
+ Q = self.fc_q(Q) # (B, S, D)
626
+ K, V = self.fc_k(K), self.fc_v(V) # (B, S, D)
627
+
628
+ # Split into num_head vectors (num_heads * batch_size, n/m, head_dim)
629
+ Q_ = torch.cat(Q.split(self.head_dim, -1), 0) # (B', S, D')
630
+ K_ = torch.cat(K.split(self.head_dim, -1), 0) # (B', S, D')
631
+ V_ = torch.cat(V.split(self.head_dim, -1), 0) # (B', S, D')
632
+
633
+ # Attention weights of size (num_heads * batch_size, n, m):
634
+ # measures how similar each pair of Q and K is.
635
+ W = torch.softmax(
636
+ Q_.bmm(K_.transpose(-2, -1)) / math.sqrt(self.head_dim),
637
+ -1, # (B', D', S)
638
+ ) # (B', N, M)
639
+
640
+ # Multihead output (batch_size, seq_len, dim):
641
+ # weighted sum of V where a value gets more weight if its corresponding
642
+ # key has larger dot product with the query.
643
+ H = torch.cat(
644
+ (W.bmm(V_)).split( # (B', S, S) # (B', S, D')
645
+ batch_size, 0
646
+ ), # [(B, S, D')] * num_heads
647
+ -1,
648
+ ) # (B, S, D)
649
+
650
+ out = self.fc_o(H)
651
+
652
+ return self.dropout(out)