braindecode 0.8.1__py3-none-any.whl → 1.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of braindecode might be problematic. Click here for more details.

Files changed (108) hide show
  1. braindecode/__init__.py +1 -2
  2. braindecode/augmentation/__init__.py +39 -19
  3. braindecode/augmentation/base.py +25 -28
  4. braindecode/augmentation/functional.py +237 -100
  5. braindecode/augmentation/transforms.py +325 -158
  6. braindecode/classifier.py +26 -24
  7. braindecode/datasets/__init__.py +28 -10
  8. braindecode/datasets/base.py +220 -134
  9. braindecode/datasets/bbci.py +43 -52
  10. braindecode/datasets/bcicomp.py +47 -32
  11. braindecode/datasets/bids.py +245 -0
  12. braindecode/datasets/mne.py +45 -24
  13. braindecode/datasets/moabb.py +87 -27
  14. braindecode/datasets/nmt.py +311 -0
  15. braindecode/datasets/sleep_physio_challe_18.py +412 -0
  16. braindecode/datasets/sleep_physionet.py +43 -26
  17. braindecode/datasets/tuh.py +324 -140
  18. braindecode/datasets/xy.py +27 -12
  19. braindecode/datautil/__init__.py +37 -18
  20. braindecode/datautil/serialization.py +110 -72
  21. braindecode/eegneuralnet.py +63 -47
  22. braindecode/functional/__init__.py +22 -0
  23. braindecode/functional/functions.py +250 -0
  24. braindecode/functional/initialization.py +47 -0
  25. braindecode/models/__init__.py +84 -14
  26. braindecode/models/atcnet.py +193 -164
  27. braindecode/models/attentionbasenet.py +599 -0
  28. braindecode/models/base.py +86 -102
  29. braindecode/models/biot.py +504 -0
  30. braindecode/models/contrawr.py +317 -0
  31. braindecode/models/ctnet.py +536 -0
  32. braindecode/models/deep4.py +116 -77
  33. braindecode/models/deepsleepnet.py +149 -119
  34. braindecode/models/eegconformer.py +112 -173
  35. braindecode/models/eeginception_erp.py +109 -118
  36. braindecode/models/eeginception_mi.py +161 -97
  37. braindecode/models/eegitnet.py +215 -152
  38. braindecode/models/eegminer.py +254 -0
  39. braindecode/models/eegnet.py +228 -161
  40. braindecode/models/eegnex.py +247 -0
  41. braindecode/models/eegresnet.py +234 -152
  42. braindecode/models/eegsimpleconv.py +199 -0
  43. braindecode/models/eegtcnet.py +335 -0
  44. braindecode/models/fbcnet.py +221 -0
  45. braindecode/models/fblightconvnet.py +313 -0
  46. braindecode/models/fbmsnet.py +324 -0
  47. braindecode/models/hybrid.py +52 -71
  48. braindecode/models/ifnet.py +441 -0
  49. braindecode/models/labram.py +1186 -0
  50. braindecode/models/msvtnet.py +375 -0
  51. braindecode/models/sccnet.py +207 -0
  52. braindecode/models/shallow_fbcsp.py +50 -56
  53. braindecode/models/signal_jepa.py +1011 -0
  54. braindecode/models/sinc_shallow.py +337 -0
  55. braindecode/models/sleep_stager_blanco_2020.py +55 -46
  56. braindecode/models/sleep_stager_chambon_2018.py +54 -53
  57. braindecode/models/sleep_stager_eldele_2021.py +247 -141
  58. braindecode/models/sparcnet.py +424 -0
  59. braindecode/models/summary.csv +41 -0
  60. braindecode/models/syncnet.py +232 -0
  61. braindecode/models/tcn.py +158 -88
  62. braindecode/models/tidnet.py +280 -167
  63. braindecode/models/tsinception.py +283 -0
  64. braindecode/models/usleep.py +190 -177
  65. braindecode/models/util.py +109 -145
  66. braindecode/modules/__init__.py +84 -0
  67. braindecode/modules/activation.py +60 -0
  68. braindecode/modules/attention.py +757 -0
  69. braindecode/modules/blocks.py +108 -0
  70. braindecode/modules/convolution.py +274 -0
  71. braindecode/modules/filter.py +628 -0
  72. braindecode/modules/layers.py +131 -0
  73. braindecode/modules/linear.py +49 -0
  74. braindecode/modules/parametrization.py +38 -0
  75. braindecode/modules/stats.py +77 -0
  76. braindecode/modules/util.py +76 -0
  77. braindecode/modules/wrapper.py +73 -0
  78. braindecode/preprocessing/__init__.py +36 -11
  79. braindecode/preprocessing/mne_preprocess.py +13 -7
  80. braindecode/preprocessing/preprocess.py +139 -75
  81. braindecode/preprocessing/windowers.py +576 -187
  82. braindecode/regressor.py +23 -12
  83. braindecode/samplers/__init__.py +16 -8
  84. braindecode/samplers/base.py +146 -32
  85. braindecode/samplers/ssl.py +162 -17
  86. braindecode/training/__init__.py +18 -10
  87. braindecode/training/callbacks.py +2 -4
  88. braindecode/training/losses.py +3 -8
  89. braindecode/training/scoring.py +76 -68
  90. braindecode/util.py +55 -59
  91. braindecode/version.py +1 -1
  92. braindecode/visualization/__init__.py +2 -3
  93. braindecode/visualization/confusion_matrices.py +117 -73
  94. braindecode/visualization/gradients.py +14 -10
  95. {braindecode-0.8.1.dist-info → braindecode-1.1.0.dist-info}/METADATA +42 -58
  96. braindecode-1.1.0.dist-info/RECORD +101 -0
  97. {braindecode-0.8.1.dist-info → braindecode-1.1.0.dist-info}/WHEEL +1 -1
  98. {braindecode-0.8.1.dist-info → braindecode-1.1.0.dist-info/licenses}/LICENSE.txt +1 -1
  99. braindecode-1.1.0.dist-info/licenses/NOTICE.txt +20 -0
  100. braindecode/datautil/mne.py +0 -9
  101. braindecode/datautil/preprocess.py +0 -12
  102. braindecode/datautil/windowers.py +0 -6
  103. braindecode/datautil/xy.py +0 -9
  104. braindecode/models/eeginception.py +0 -317
  105. braindecode/models/functions.py +0 -47
  106. braindecode/models/modules.py +0 -358
  107. braindecode-0.8.1.dist-info/RECORD +0 -68
  108. {braindecode-0.8.1.dist-info → braindecode-1.1.0.dist-info}/top_level.txt +0 -0
@@ -1,21 +1,26 @@
1
1
  # Authors: Cedric Rommel <cedric.rommel@inria.fr>
2
2
  #
3
3
  # License: BSD (3-clause)
4
- import numpy as np
4
+ import math
5
5
 
6
6
  import torch
7
- from torch import nn
8
7
  from einops.layers.torch import Rearrange
8
+ from mne.utils import warn
9
+ from torch import nn
9
10
 
10
- from .modules import Ensure4d, MaxNormLinear, CausalConv1d
11
- from .base import EEGModuleMixin, deprecated_args
11
+ from braindecode.models.base import EEGModuleMixin
12
+ from braindecode.modules import CausalConv1d, Ensure4d, MaxNormLinear
12
13
 
13
14
 
14
15
  class ATCNet(EEGModuleMixin, nn.Module):
15
- """ATCNet model from [1]_
16
+ """ATCNet model from Altaheri et al. (2022) [1]_
16
17
 
17
18
  Pytorch implementation based on official tensorflow code [2]_.
18
19
 
20
+ .. figure:: https://user-images.githubusercontent.com/25565236/185449791-e8539453-d4fa-41e1-865a-2cf7e91f60ef.png
21
+ :align: center
22
+ :alt: ATCNet Architecture
23
+
19
24
  Parameters
20
25
  ----------
21
26
  input_window_seconds : float, optional
@@ -54,7 +59,7 @@ class ATCNet(EEGModuleMixin, nn.Module):
54
59
  table 1 of the paper [1]_. Defaults to 8 as in [1]_.
55
60
  att_num_heads : int
56
61
  Number of attention heads, denoted H in table 1 of the paper [1]_.
57
- Defaults to 2 as in [1_.
62
+ Defaults to 2 as in [1]_.
58
63
  att_dropout : float
59
64
  Dropout probability used in the attention block, denoted pa in table 1
60
65
  of the paper [1]_. Defaults to 0.5 as in [1]_.
@@ -65,9 +70,6 @@ class ATCNet(EEGModuleMixin, nn.Module):
65
70
  tcn_kernel_size : int
66
71
  Temporal kernel size used in TCN block, denoted Kt in table 1 of the
67
72
  paper [1]_. Defaults to 4 as in [1]_.
68
- tcn_n_filters : int
69
- Number of filters used in TCN convolutional layers (Ft). Defaults to
70
- 32 as in [1]_.
71
73
  tcn_dropout : float
72
74
  Dropout probability used in the TCN block, denoted pt in table 1
73
75
  of the paper [1]_. Defaults to 0.3 as in [1]_.
@@ -82,59 +84,44 @@ class ATCNet(EEGModuleMixin, nn.Module):
82
84
  max_norm_const : float
83
85
  Maximum L2-norm constraint imposed on weights of the last
84
86
  fully-connected layer. Defaults to 0.25.
85
- n_channels:
86
- Alias for n_chans.
87
- n_classes:
88
- Alias for n_outputs.
89
- input_size_s:
90
- Alias for input_window_seconds.
87
+
91
88
 
92
89
  References
93
90
  ----------
94
- .. [1] H. Altaheri, G. Muhammad and M. Alsulaiman, "Physics-informed
95
- attention temporal convolutional network for EEG-based motor imagery
96
- classification," in IEEE Transactions on Industrial Informatics,
97
- 2022, doi: 10.1109/TII.2022.3197419.
98
- .. [2] https://github.com/Altaheri/EEG-ATCNet/blob/main/models.py
91
+ .. [1] H. Altaheri, G. Muhammad and M. Alsulaiman,
92
+ Physics-informed attention temporal convolutional network for EEG-based
93
+ motor imagery classification in IEEE Transactions on Industrial Informatics,
94
+ 2022, doi: 10.1109/TII.2022.3197419.
95
+ .. [2] EEE-ATCNet implementation.
96
+ https://github.com/Altaheri/EEG-ATCNet/blob/main/models.py
99
97
  """
100
98
 
101
99
  def __init__(
102
- self,
103
- n_chans=None,
104
- n_outputs=None,
105
- input_window_seconds=4.5,
106
- sfreq=250.,
107
- conv_block_n_filters=16,
108
- conv_block_kernel_length_1=64,
109
- conv_block_kernel_length_2=16,
110
- conv_block_pool_size_1=8,
111
- conv_block_pool_size_2=7,
112
- conv_block_depth_mult=2,
113
- conv_block_dropout=0.3,
114
- n_windows=5,
115
- att_head_dim=8,
116
- att_num_heads=2,
117
- att_dropout=0.5,
118
- tcn_depth=2,
119
- tcn_kernel_size=4,
120
- tcn_n_filters=32,
121
- tcn_dropout=0.3,
122
- tcn_activation=nn.ELU(),
123
- concat=False,
124
- max_norm_const=0.25,
125
- chs_info=None,
126
- n_times=None,
127
- n_channels=None,
128
- n_classes=None,
129
- input_size_s=None,
130
- add_log_softmax=True,
100
+ self,
101
+ n_chans=None,
102
+ n_outputs=None,
103
+ input_window_seconds=None,
104
+ sfreq=250.0,
105
+ conv_block_n_filters=16,
106
+ conv_block_kernel_length_1=64,
107
+ conv_block_kernel_length_2=16,
108
+ conv_block_pool_size_1=8,
109
+ conv_block_pool_size_2=7,
110
+ conv_block_depth_mult=2,
111
+ conv_block_dropout=0.3,
112
+ n_windows=5,
113
+ att_head_dim=8,
114
+ att_num_heads=2,
115
+ att_drop_prob=0.5,
116
+ tcn_depth=2,
117
+ tcn_kernel_size=4,
118
+ tcn_drop_prob=0.3,
119
+ tcn_activation: nn.Module = nn.ELU,
120
+ concat=False,
121
+ max_norm_const=0.25,
122
+ chs_info=None,
123
+ n_times=None,
131
124
  ):
132
- n_chans, n_outputs, input_window_seconds = deprecated_args(
133
- self,
134
- ('n_channels', 'n_chans', n_channels, n_chans),
135
- ('n_classes', 'n_outputs', n_classes, n_outputs),
136
- ('input_size_s', 'input_window_seconds', input_size_s, input_window_seconds),
137
- )
138
125
  super().__init__(
139
126
  n_outputs=n_outputs,
140
127
  n_chans=n_chans,
@@ -142,10 +129,47 @@ class ATCNet(EEGModuleMixin, nn.Module):
142
129
  n_times=n_times,
143
130
  input_window_seconds=input_window_seconds,
144
131
  sfreq=sfreq,
145
- add_log_softmax=add_log_softmax,
146
132
  )
147
133
  del n_outputs, n_chans, chs_info, n_times, input_window_seconds, sfreq
148
- del n_channels, n_classes, input_size_s
134
+
135
+ # Validate and adjust parameters based on input size
136
+
137
+ min_len_tcn = (tcn_kernel_size - 1) * (2 ** (tcn_depth - 1)) + 1
138
+ # Minimum length required to get at least one sliding window
139
+ min_len_sliding = n_windows + min_len_tcn - 1
140
+ # Minimum input size that produces the required feature map length
141
+ min_n_times = min_len_sliding * conv_block_pool_size_1 * conv_block_pool_size_2
142
+
143
+ # 2. If the input is shorter, calculate a scaling factor
144
+ if self.n_times < min_n_times:
145
+ scaling_factor = self.n_times / min_n_times
146
+ warn(
147
+ f"n_times ({self.n_times}) is smaller than the minimum required "
148
+ f"({min_n_times}) for the current model parameters configuration. "
149
+ "Adjusting parameters to ensure compatibility."
150
+ "Reducing the kernel, pooling, and stride sizes accordingly."
151
+ "Scaling factor: {:.2f}".format(scaling_factor),
152
+ UserWarning,
153
+ )
154
+ conv_block_kernel_length_1 = max(
155
+ 1, int(conv_block_kernel_length_1 * scaling_factor)
156
+ )
157
+ conv_block_kernel_length_2 = max(
158
+ 1, int(conv_block_kernel_length_2 * scaling_factor)
159
+ )
160
+ conv_block_pool_size_1 = max(
161
+ 1, int(conv_block_pool_size_1 * scaling_factor)
162
+ )
163
+ conv_block_pool_size_2 = max(
164
+ 1, int(conv_block_pool_size_2 * scaling_factor)
165
+ )
166
+
167
+ # n_windows should be at least 1
168
+ n_windows = max(1, int(n_windows * scaling_factor))
169
+
170
+ # tcn_kernel_size must be at least 2 for dilation to work
171
+ tcn_kernel_size = max(2, int(tcn_kernel_size * scaling_factor))
172
+
149
173
  self.conv_block_n_filters = conv_block_n_filters
150
174
  self.conv_block_kernel_length_1 = conv_block_kernel_length_1
151
175
  self.conv_block_kernel_length_2 = conv_block_kernel_length_2
@@ -156,19 +180,18 @@ class ATCNet(EEGModuleMixin, nn.Module):
156
180
  self.n_windows = n_windows
157
181
  self.att_head_dim = att_head_dim
158
182
  self.att_num_heads = att_num_heads
159
- self.att_dropout = att_dropout
183
+ self.att_dropout = att_drop_prob
160
184
  self.tcn_depth = tcn_depth
161
185
  self.tcn_kernel_size = tcn_kernel_size
162
- self.tcn_n_filters = tcn_n_filters
163
- self.tcn_dropout = tcn_dropout
186
+ self.tcn_dropout = tcn_drop_prob
164
187
  self.tcn_activation = tcn_activation
165
188
  self.concat = concat
166
189
  self.max_norm_const = max_norm_const
167
-
190
+ self.tcn_n_filters = int(self.conv_block_depth_mult * self.conv_block_n_filters)
168
191
  map = dict()
169
192
  for w in range(self.n_windows):
170
- map[f'max_norm_linears.[{w}].weight'] = f'final_layer.[{w}].weight'
171
- map[f'max_norm_linears.[{w}].bias'] = f'final_layer.[{w}].bias'
193
+ map[f"max_norm_linears.[{w}].weight"] = f"final_layer.[{w}].weight"
194
+ map[f"max_norm_linears.[{w}].bias"] = f"final_layer.[{w}].bias"
172
195
  self.mapping = map
173
196
 
174
197
  # Check later if we want to keep the Ensure4d. Not sure if we can
@@ -184,57 +207,67 @@ class ATCNet(EEGModuleMixin, nn.Module):
184
207
  pool_size_1=conv_block_pool_size_1,
185
208
  pool_size_2=conv_block_pool_size_2,
186
209
  depth_mult=conv_block_depth_mult,
187
- dropout=conv_block_dropout
210
+ dropout=conv_block_dropout,
188
211
  )
189
212
 
190
213
  self.F2 = int(conv_block_depth_mult * conv_block_n_filters)
191
- self.Tc = int(self.input_window_seconds * self.sfreq / (
192
- conv_block_pool_size_1 * conv_block_pool_size_2))
214
+ self.Tc = int(self.n_times / (conv_block_pool_size_1 * conv_block_pool_size_2))
193
215
  self.Tw = self.Tc - self.n_windows + 1
194
216
 
195
- self.attention_blocks = nn.ModuleList([
196
- _AttentionBlock(
197
- in_shape=self.F2,
198
- head_dim=self.att_head_dim,
199
- num_heads=att_num_heads,
200
- dropout=att_dropout,
201
- ) for _ in range(self.n_windows)
202
- ])
203
-
204
- self.temporal_conv_nets = nn.ModuleList([
205
- nn.Sequential(
206
- *[_TCNResidualBlock(
207
- in_channels=self.F2,
208
- kernel_size=tcn_kernel_size,
209
- n_filters=tcn_n_filters,
210
- dropout=tcn_dropout,
211
- activation=tcn_activation,
212
- dilation=2 ** i
213
- ) for i in range(tcn_depth)]
214
- ) for _ in range(self.n_windows)
215
- ])
217
+ self.attention_blocks = nn.ModuleList(
218
+ [
219
+ _AttentionBlock(
220
+ in_shape=self.F2,
221
+ head_dim=self.att_head_dim,
222
+ num_heads=att_num_heads,
223
+ dropout=att_drop_prob,
224
+ )
225
+ for _ in range(self.n_windows)
226
+ ]
227
+ )
216
228
 
217
- if self.concat:
218
- self.final_layer = nn.ModuleList([
219
- MaxNormLinear(
220
- in_features=self.F2 * self.n_windows,
221
- out_features=self.n_outputs,
222
- max_norm_val=self.max_norm_const
229
+ self.temporal_conv_nets = nn.ModuleList(
230
+ [
231
+ nn.Sequential(
232
+ *[
233
+ _TCNResidualBlock(
234
+ in_channels=self.F2,
235
+ kernel_size=self.tcn_kernel_size,
236
+ n_filters=self.tcn_n_filters,
237
+ dropout=self.tcn_dropout,
238
+ activation=self.tcn_activation,
239
+ dilation=2**i,
240
+ )
241
+ for i in range(self.tcn_depth)
242
+ ]
223
243
  )
224
- ])
225
- else:
226
- self.final_layer = nn.ModuleList([
227
- MaxNormLinear(
228
- in_features=self.F2,
229
- out_features=self.n_outputs,
230
- max_norm_val=self.max_norm_const
231
- ) for _ in range(self.n_windows)
232
- ])
233
-
234
- if self.add_log_softmax:
235
- self.out_fun = nn.LogSoftmax(dim=1)
244
+ for _ in range(self.n_windows)
245
+ ]
246
+ )
247
+
248
+ if self.concat:
249
+ self.final_layer = nn.ModuleList(
250
+ [
251
+ MaxNormLinear(
252
+ in_features=self.F2 * self.n_windows,
253
+ out_features=self.n_outputs,
254
+ max_norm_val=self.max_norm_const,
255
+ )
256
+ ]
257
+ )
236
258
  else:
237
- self.out_fun = nn.Identity()
259
+ self.final_layer = nn.ModuleList(
260
+ [
261
+ MaxNormLinear(
262
+ in_features=self.F2,
263
+ out_features=self.n_outputs,
264
+ max_norm_val=self.max_norm_const,
265
+ )
266
+ for _ in range(self.n_windows)
267
+ ]
268
+ )
269
+
270
+ self.out_fun = nn.Identity()
238
271
 
239
272
  def forward(self, X):
240
273
  # Dimension: (batch_size, C, T)
@@ -250,43 +283,46 @@ class ATCNet(EEGModuleMixin, nn.Module):
250
283
  # Dimension: (batch_size, F2, Tc)
251
284
 
252
285
  # ----- Sliding window -----
253
- sw_concat = [] # to store sliding window outputs
254
- for w in range(self.n_windows):
255
- conv_feat_w = conv_feat[..., w:w + self.Tw]
286
+ sw_concat: list[torch.Tensor] = [] # to store sliding window outputs
287
+ # for w in range(self.n_windows):
288
+ for idx, (attention, tcn_module, final_layer) in enumerate(
289
+ zip(self.attention_blocks, self.temporal_conv_nets, self.final_layer)
290
+ ):
291
+ conv_feat_w = conv_feat[..., idx : idx + self.Tw]
256
292
  # Dimension: (batch_size, F2, Tw)
257
293
 
258
294
  # ----- Attention block -----
259
- att_feat = self.attention_blocks[w](conv_feat_w)
295
+ att_feat = attention(conv_feat_w)
260
296
  # Dimension: (batch_size, F2, Tw)
261
297
 
262
298
  # ----- Temporal convolutional network (TCN) -----
263
- tcn_feat = self.temporal_conv_nets[w](att_feat)[..., -1]
299
+ tcn_feat = tcn_module(att_feat)[..., -1]
264
300
  # Dimension: (batch_size, F2)
265
301
 
266
302
  # Outputs of sliding window can be either averaged after being
267
303
  # mapped by dense layer or concatenated then mapped by a dense
268
304
  # layer
269
305
  if not self.concat:
270
- tcn_feat = self.final_layer[w](tcn_feat)
306
+ tcn_feat = final_layer(tcn_feat)
271
307
 
272
308
  sw_concat.append(tcn_feat)
273
309
 
274
310
  # ----- Aggregation and prediction -----
275
311
  if self.concat:
276
- sw_concat = torch.cat(sw_concat, dim=1)
277
- sw_concat = self.final_layer[0](sw_concat)
312
+ sw_concat_agg = torch.cat(sw_concat, dim=1)
313
+ sw_concat_agg = self.final_layer[0](sw_concat_agg)
278
314
  else:
279
315
  if len(sw_concat) > 1: # more than one window
280
- sw_concat = torch.stack(sw_concat, dim=0)
281
- sw_concat = torch.mean(sw_concat, dim=0)
316
+ sw_concat_agg = torch.stack(sw_concat, dim=0)
317
+ sw_concat_agg = torch.mean(sw_concat_agg, dim=0)
282
318
  else: # one window (# windows = 1)
283
- sw_concat = sw_concat[0]
319
+ sw_concat_agg = sw_concat[0]
284
320
 
285
- return self.out_fun(sw_concat)
321
+ return self.out_fun(sw_concat_agg)
286
322
 
287
323
 
288
324
  class _ConvBlock(nn.Module):
289
- """ Convolutional block proposed in ATCNet [1]_, inspired by the EEGNet
325
+ """Convolutional block proposed in ATCNet [1]_, inspired by the EEGNet
290
326
  architecture [2]_.
291
327
 
292
328
  References
@@ -303,15 +339,15 @@ class _ConvBlock(nn.Module):
303
339
  """
304
340
 
305
341
  def __init__(
306
- self,
307
- n_channels,
308
- n_filters=16,
309
- kernel_length_1=64,
310
- kernel_length_2=16,
311
- pool_size_1=8,
312
- pool_size_2=7,
313
- depth_mult=2,
314
- dropout=0.3,
342
+ self,
343
+ n_channels,
344
+ n_filters=16,
345
+ kernel_length_1=64,
346
+ kernel_length_2=16,
347
+ pool_size_1=8,
348
+ pool_size_2=7,
349
+ depth_mult=2,
350
+ dropout=0.3,
315
351
  ):
316
352
  super().__init__()
317
353
 
@@ -402,11 +438,11 @@ class _AttentionBlock(nn.Module):
402
438
  """
403
439
 
404
440
  def __init__(
405
- self,
406
- in_shape=32,
407
- head_dim=8,
408
- num_heads=2,
409
- dropout=0.5,
441
+ self,
442
+ in_shape=32,
443
+ head_dim=8,
444
+ num_heads=2,
445
+ dropout=0.5,
410
446
  ):
411
447
  super().__init__()
412
448
  self.in_shape = in_shape
@@ -462,7 +498,7 @@ class _AttentionBlock(nn.Module):
462
498
 
463
499
 
464
500
  class _TCNResidualBlock(nn.Module):
465
- """ Modified TCN Residual block as proposed in [1]_. Inspired from
501
+ """Modified TCN Residual block as proposed in [1]_. Inspired from
466
502
  Temporal Convolutional Networks (TCN) [2]_.
467
503
 
468
504
  References
@@ -477,16 +513,16 @@ class _TCNResidualBlock(nn.Module):
477
513
  """
478
514
 
479
515
  def __init__(
480
- self,
481
- in_channels,
482
- kernel_size=4,
483
- n_filters=32,
484
- dropout=0.3,
485
- activation=nn.ELU(),
486
- dilation=1
516
+ self,
517
+ in_channels,
518
+ kernel_size=4,
519
+ n_filters=32,
520
+ dropout=0.3,
521
+ activation: nn.Module = nn.ELU,
522
+ dilation=1,
487
523
  ):
488
524
  super().__init__()
489
- self.activation = activation
525
+ self.activation = activation()
490
526
  self.dilation = dilation
491
527
  self.dropout = dropout
492
528
  self.n_filters = n_filters
@@ -522,7 +558,7 @@ class _TCNResidualBlock(nn.Module):
522
558
  self.reshaping_conv = nn.Conv1d(
523
559
  n_filters,
524
560
  kernel_size=1,
525
- padding='same',
561
+ padding="same",
526
562
  )
527
563
  else:
528
564
  self.reshaping_conv = nn.Identity()
@@ -550,12 +586,12 @@ class _TCNResidualBlock(nn.Module):
550
586
 
551
587
  class _MHA(nn.Module):
552
588
  def __init__(
553
- self,
554
- input_dim: int,
555
- head_dim: int,
556
- output_dim: int,
557
- num_heads: int,
558
- dropout: float = 0.,
589
+ self,
590
+ input_dim: int,
591
+ head_dim: int,
592
+ output_dim: int,
593
+ num_heads: int,
594
+ dropout: float = 0.0,
559
595
  ):
560
596
  """Multi-head Attention
561
597
 
@@ -598,12 +634,9 @@ class _MHA(nn.Module):
598
634
  self.dropout = nn.Dropout(dropout)
599
635
 
600
636
  def forward(
601
- self,
602
- Q: torch.Tensor,
603
- K: torch.Tensor,
604
- V: torch.Tensor
637
+ self, Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor
605
638
  ) -> torch.Tensor:
606
- """ Compute MHA(Q, K, V)
639
+ """Compute MHA(Q, K, V)
607
640
 
608
641
  Parameters
609
642
  ----------
@@ -635,22 +668,18 @@ class _MHA(nn.Module):
635
668
  # Attention weights of size (num_heads * batch_size, n, m):
636
669
  # measures how similar each pair of Q and K is.
637
670
  W = torch.softmax(
638
- Q_.bmm(
639
- K_.transpose(-2, -1) # (B', D', S)
640
- )
641
- / np.sqrt(self.head_dim),
642
- -1
671
+ Q_.bmm(K_.transpose(-2, -1)) / math.sqrt(self.head_dim),
672
+ -1, # (B', D', S)
643
673
  ) # (B', N, M)
644
674
 
645
675
  # Multihead output (batch_size, seq_len, dim):
646
676
  # weighted sum of V where a value gets more weight if its corresponding
647
677
  # key has larger dot product with the query.
648
678
  H = torch.cat(
649
- (
650
- W # (B', S, S)
651
- .bmm(V_) # (B', S, D')
652
- ).split(batch_size, 0), # [(B, S, D')] * num_heads
653
- -1
679
+ (W.bmm(V_)).split( # (B', S, S) # (B', S, D')
680
+ batch_size, 0
681
+ ), # [(B, S, D')] * num_heads
682
+ -1,
654
683
  ) # (B, S, D)
655
684
 
656
685
  out = self.fc_o(H)