braindecode 0.8.1__py3-none-any.whl → 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of braindecode might be problematic. Click here for more details.

Files changed (108) hide show
  1. braindecode/__init__.py +1 -2
  2. braindecode/augmentation/__init__.py +39 -19
  3. braindecode/augmentation/base.py +25 -28
  4. braindecode/augmentation/functional.py +237 -100
  5. braindecode/augmentation/transforms.py +326 -158
  6. braindecode/classifier.py +26 -24
  7. braindecode/datasets/__init__.py +28 -10
  8. braindecode/datasets/base.py +220 -134
  9. braindecode/datasets/bbci.py +43 -52
  10. braindecode/datasets/bcicomp.py +47 -32
  11. braindecode/datasets/bids.py +245 -0
  12. braindecode/datasets/mne.py +45 -24
  13. braindecode/datasets/moabb.py +87 -27
  14. braindecode/datasets/nmt.py +311 -0
  15. braindecode/datasets/sleep_physio_challe_18.py +412 -0
  16. braindecode/datasets/sleep_physionet.py +43 -26
  17. braindecode/datasets/tuh.py +324 -140
  18. braindecode/datasets/xy.py +27 -12
  19. braindecode/datautil/__init__.py +34 -18
  20. braindecode/datautil/serialization.py +98 -71
  21. braindecode/eegneuralnet.py +63 -47
  22. braindecode/functional/__init__.py +10 -0
  23. braindecode/functional/functions.py +251 -0
  24. braindecode/functional/initialization.py +47 -0
  25. braindecode/models/__init__.py +36 -14
  26. braindecode/models/atcnet.py +153 -159
  27. braindecode/models/attentionbasenet.py +550 -0
  28. braindecode/models/base.py +86 -102
  29. braindecode/models/biot.py +483 -0
  30. braindecode/models/contrawr.py +296 -0
  31. braindecode/models/ctnet.py +450 -0
  32. braindecode/models/deep4.py +64 -75
  33. braindecode/models/deepsleepnet.py +149 -119
  34. braindecode/models/eegconformer.py +111 -171
  35. braindecode/models/eeginception_erp.py +109 -118
  36. braindecode/models/eeginception_mi.py +155 -97
  37. braindecode/models/eegitnet.py +215 -151
  38. braindecode/models/eegminer.py +255 -0
  39. braindecode/models/eegnet.py +229 -161
  40. braindecode/models/eegnex.py +247 -0
  41. braindecode/models/eegresnet.py +234 -152
  42. braindecode/models/eegsimpleconv.py +199 -0
  43. braindecode/models/eegtcnet.py +335 -0
  44. braindecode/models/fbcnet.py +221 -0
  45. braindecode/models/fblightconvnet.py +313 -0
  46. braindecode/models/fbmsnet.py +325 -0
  47. braindecode/models/hybrid.py +52 -71
  48. braindecode/models/ifnet.py +441 -0
  49. braindecode/models/labram.py +1166 -0
  50. braindecode/models/msvtnet.py +375 -0
  51. braindecode/models/sccnet.py +182 -0
  52. braindecode/models/shallow_fbcsp.py +50 -56
  53. braindecode/models/signal_jepa.py +1012 -0
  54. braindecode/models/sinc_shallow.py +337 -0
  55. braindecode/models/sleep_stager_blanco_2020.py +55 -46
  56. braindecode/models/sleep_stager_chambon_2018.py +54 -53
  57. braindecode/models/sleep_stager_eldele_2021.py +248 -141
  58. braindecode/models/sparcnet.py +378 -0
  59. braindecode/models/summary.csv +41 -0
  60. braindecode/models/syncnet.py +232 -0
  61. braindecode/models/tcn.py +158 -88
  62. braindecode/models/tidnet.py +280 -167
  63. braindecode/models/tsinception.py +258 -0
  64. braindecode/models/usleep.py +190 -177
  65. braindecode/models/util.py +109 -141
  66. braindecode/modules/__init__.py +38 -0
  67. braindecode/modules/activation.py +60 -0
  68. braindecode/modules/attention.py +757 -0
  69. braindecode/modules/blocks.py +108 -0
  70. braindecode/modules/convolution.py +274 -0
  71. braindecode/modules/filter.py +632 -0
  72. braindecode/modules/layers.py +133 -0
  73. braindecode/modules/linear.py +50 -0
  74. braindecode/modules/parametrization.py +38 -0
  75. braindecode/modules/stats.py +77 -0
  76. braindecode/modules/util.py +77 -0
  77. braindecode/modules/wrapper.py +75 -0
  78. braindecode/preprocessing/__init__.py +36 -11
  79. braindecode/preprocessing/mne_preprocess.py +13 -7
  80. braindecode/preprocessing/preprocess.py +139 -75
  81. braindecode/preprocessing/windowers.py +576 -187
  82. braindecode/regressor.py +23 -12
  83. braindecode/samplers/__init__.py +16 -8
  84. braindecode/samplers/base.py +148 -32
  85. braindecode/samplers/ssl.py +162 -17
  86. braindecode/training/__init__.py +18 -10
  87. braindecode/training/callbacks.py +2 -4
  88. braindecode/training/losses.py +3 -8
  89. braindecode/training/scoring.py +76 -68
  90. braindecode/util.py +55 -59
  91. braindecode/version.py +1 -1
  92. braindecode/visualization/__init__.py +2 -3
  93. braindecode/visualization/confusion_matrices.py +117 -73
  94. braindecode/visualization/gradients.py +14 -10
  95. {braindecode-0.8.1.dist-info → braindecode-1.0.0.dist-info}/METADATA +39 -55
  96. braindecode-1.0.0.dist-info/RECORD +101 -0
  97. {braindecode-0.8.1.dist-info → braindecode-1.0.0.dist-info}/WHEEL +1 -1
  98. {braindecode-0.8.1.dist-info → braindecode-1.0.0.dist-info/licenses}/LICENSE.txt +1 -1
  99. braindecode-1.0.0.dist-info/licenses/NOTICE.txt +20 -0
  100. braindecode/datautil/mne.py +0 -9
  101. braindecode/datautil/preprocess.py +0 -12
  102. braindecode/datautil/windowers.py +0 -6
  103. braindecode/datautil/xy.py +0 -9
  104. braindecode/models/eeginception.py +0 -317
  105. braindecode/models/functions.py +0 -47
  106. braindecode/models/modules.py +0 -358
  107. braindecode-0.8.1.dist-info/RECORD +0 -68
  108. {braindecode-0.8.1.dist-info → braindecode-1.0.0.dist-info}/top_level.txt +0 -0
@@ -1,21 +1,25 @@
1
1
  # Authors: Cedric Rommel <cedric.rommel@inria.fr>
2
2
  #
3
3
  # License: BSD (3-clause)
4
- import numpy as np
4
+ import math
5
5
 
6
6
  import torch
7
- from torch import nn
8
7
  from einops.layers.torch import Rearrange
8
+ from torch import nn
9
9
 
10
- from .modules import Ensure4d, MaxNormLinear, CausalConv1d
11
- from .base import EEGModuleMixin, deprecated_args
10
+ from braindecode.models.base import EEGModuleMixin
11
+ from braindecode.modules import CausalConv1d, Ensure4d, MaxNormLinear
12
12
 
13
13
 
14
14
  class ATCNet(EEGModuleMixin, nn.Module):
15
- """ATCNet model from [1]_
15
+ """ATCNet model from Altaheri et al. (2022) [1]_
16
16
 
17
17
  Pytorch implementation based on official tensorflow code [2]_.
18
18
 
19
+ .. figure:: https://user-images.githubusercontent.com/25565236/185449791-e8539453-d4fa-41e1-865a-2cf7e91f60ef.png
20
+ :align: center
21
+ :alt: ATCNet Architecture
22
+
19
23
  Parameters
20
24
  ----------
21
25
  input_window_seconds : float, optional
@@ -54,7 +58,7 @@ class ATCNet(EEGModuleMixin, nn.Module):
54
58
  table 1 of the paper [1]_. Defaults to 8 as in [1]_.
55
59
  att_num_heads : int
56
60
  Number of attention heads, denoted H in table 1 of the paper [1]_.
57
- Defaults to 2 as in [1_.
61
+ Defaults to 2 as in [1]_.
58
62
  att_dropout : float
59
63
  Dropout probability used in the attention block, denoted pa in table 1
60
64
  of the paper [1]_. Defaults to 0.5 as in [1]_.
@@ -82,59 +86,45 @@ class ATCNet(EEGModuleMixin, nn.Module):
82
86
  max_norm_const : float
83
87
  Maximum L2-norm constraint imposed on weights of the last
84
88
  fully-connected layer. Defaults to 0.25.
85
- n_channels:
86
- Alias for n_chans.
87
- n_classes:
88
- Alias for n_outputs.
89
- input_size_s:
90
- Alias for input_window_seconds.
89
+
91
90
 
92
91
  References
93
92
  ----------
94
- .. [1] H. Altaheri, G. Muhammad and M. Alsulaiman, "Physics-informed
95
- attention temporal convolutional network for EEG-based motor imagery
96
- classification," in IEEE Transactions on Industrial Informatics,
97
- 2022, doi: 10.1109/TII.2022.3197419.
98
- .. [2] https://github.com/Altaheri/EEG-ATCNet/blob/main/models.py
93
+ .. [1] H. Altaheri, G. Muhammad and M. Alsulaiman,
94
+ Physics-informed attention temporal convolutional network for EEG-based
95
+ motor imagery classification in IEEE Transactions on Industrial Informatics,
96
+ 2022, doi: 10.1109/TII.2022.3197419.
97
+ .. [2] EEE-ATCNet implementation.
98
+ https://github.com/Altaheri/EEG-ATCNet/blob/main/models.py
99
99
  """
100
100
 
101
101
  def __init__(
102
- self,
103
- n_chans=None,
104
- n_outputs=None,
105
- input_window_seconds=4.5,
106
- sfreq=250.,
107
- conv_block_n_filters=16,
108
- conv_block_kernel_length_1=64,
109
- conv_block_kernel_length_2=16,
110
- conv_block_pool_size_1=8,
111
- conv_block_pool_size_2=7,
112
- conv_block_depth_mult=2,
113
- conv_block_dropout=0.3,
114
- n_windows=5,
115
- att_head_dim=8,
116
- att_num_heads=2,
117
- att_dropout=0.5,
118
- tcn_depth=2,
119
- tcn_kernel_size=4,
120
- tcn_n_filters=32,
121
- tcn_dropout=0.3,
122
- tcn_activation=nn.ELU(),
123
- concat=False,
124
- max_norm_const=0.25,
125
- chs_info=None,
126
- n_times=None,
127
- n_channels=None,
128
- n_classes=None,
129
- input_size_s=None,
130
- add_log_softmax=True,
102
+ self,
103
+ n_chans=None,
104
+ n_outputs=None,
105
+ input_window_seconds=None,
106
+ sfreq=250.0,
107
+ conv_block_n_filters=16,
108
+ conv_block_kernel_length_1=64,
109
+ conv_block_kernel_length_2=16,
110
+ conv_block_pool_size_1=8,
111
+ conv_block_pool_size_2=7,
112
+ conv_block_depth_mult=2,
113
+ conv_block_dropout=0.3,
114
+ n_windows=5,
115
+ att_head_dim=8,
116
+ att_num_heads=2,
117
+ att_drop_prob=0.5,
118
+ tcn_depth=2,
119
+ tcn_kernel_size=4,
120
+ tcn_n_filters=32,
121
+ tcn_drop_prob=0.3,
122
+ tcn_activation: nn.Module = nn.ELU,
123
+ concat=False,
124
+ max_norm_const=0.25,
125
+ chs_info=None,
126
+ n_times=None,
131
127
  ):
132
- n_chans, n_outputs, input_window_seconds = deprecated_args(
133
- self,
134
- ('n_channels', 'n_chans', n_channels, n_chans),
135
- ('n_classes', 'n_outputs', n_classes, n_outputs),
136
- ('input_size_s', 'input_window_seconds', input_size_s, input_window_seconds),
137
- )
138
128
  super().__init__(
139
129
  n_outputs=n_outputs,
140
130
  n_chans=n_chans,
@@ -142,10 +132,8 @@ class ATCNet(EEGModuleMixin, nn.Module):
142
132
  n_times=n_times,
143
133
  input_window_seconds=input_window_seconds,
144
134
  sfreq=sfreq,
145
- add_log_softmax=add_log_softmax,
146
135
  )
147
136
  del n_outputs, n_chans, chs_info, n_times, input_window_seconds, sfreq
148
- del n_channels, n_classes, input_size_s
149
137
  self.conv_block_n_filters = conv_block_n_filters
150
138
  self.conv_block_kernel_length_1 = conv_block_kernel_length_1
151
139
  self.conv_block_kernel_length_2 = conv_block_kernel_length_2
@@ -156,19 +144,19 @@ class ATCNet(EEGModuleMixin, nn.Module):
156
144
  self.n_windows = n_windows
157
145
  self.att_head_dim = att_head_dim
158
146
  self.att_num_heads = att_num_heads
159
- self.att_dropout = att_dropout
147
+ self.att_dropout = att_drop_prob
160
148
  self.tcn_depth = tcn_depth
161
149
  self.tcn_kernel_size = tcn_kernel_size
162
150
  self.tcn_n_filters = tcn_n_filters
163
- self.tcn_dropout = tcn_dropout
151
+ self.tcn_dropout = tcn_drop_prob
164
152
  self.tcn_activation = tcn_activation
165
153
  self.concat = concat
166
154
  self.max_norm_const = max_norm_const
167
155
 
168
156
  map = dict()
169
157
  for w in range(self.n_windows):
170
- map[f'max_norm_linears.[{w}].weight'] = f'final_layer.[{w}].weight'
171
- map[f'max_norm_linears.[{w}].bias'] = f'final_layer.[{w}].bias'
158
+ map[f"max_norm_linears.[{w}].weight"] = f"final_layer.[{w}].weight"
159
+ map[f"max_norm_linears.[{w}].bias"] = f"final_layer.[{w}].bias"
172
160
  self.mapping = map
173
161
 
174
162
  # Check later if we want to keep the Ensure4d. Not sure if we can
@@ -184,57 +172,67 @@ class ATCNet(EEGModuleMixin, nn.Module):
184
172
  pool_size_1=conv_block_pool_size_1,
185
173
  pool_size_2=conv_block_pool_size_2,
186
174
  depth_mult=conv_block_depth_mult,
187
- dropout=conv_block_dropout
175
+ dropout=conv_block_dropout,
188
176
  )
189
177
 
190
178
  self.F2 = int(conv_block_depth_mult * conv_block_n_filters)
191
- self.Tc = int(self.input_window_seconds * self.sfreq / (
192
- conv_block_pool_size_1 * conv_block_pool_size_2))
179
+ self.Tc = int(self.n_times / (conv_block_pool_size_1 * conv_block_pool_size_2))
193
180
  self.Tw = self.Tc - self.n_windows + 1
194
181
 
195
- self.attention_blocks = nn.ModuleList([
196
- _AttentionBlock(
197
- in_shape=self.F2,
198
- head_dim=self.att_head_dim,
199
- num_heads=att_num_heads,
200
- dropout=att_dropout,
201
- ) for _ in range(self.n_windows)
202
- ])
203
-
204
- self.temporal_conv_nets = nn.ModuleList([
205
- nn.Sequential(
206
- *[_TCNResidualBlock(
207
- in_channels=self.F2,
208
- kernel_size=tcn_kernel_size,
209
- n_filters=tcn_n_filters,
210
- dropout=tcn_dropout,
211
- activation=tcn_activation,
212
- dilation=2 ** i
213
- ) for i in range(tcn_depth)]
214
- ) for _ in range(self.n_windows)
215
- ])
182
+ self.attention_blocks = nn.ModuleList(
183
+ [
184
+ _AttentionBlock(
185
+ in_shape=self.F2,
186
+ head_dim=self.att_head_dim,
187
+ num_heads=att_num_heads,
188
+ dropout=att_drop_prob,
189
+ )
190
+ for _ in range(self.n_windows)
191
+ ]
192
+ )
216
193
 
217
- if self.concat:
218
- self.final_layer = nn.ModuleList([
219
- MaxNormLinear(
220
- in_features=self.F2 * self.n_windows,
221
- out_features=self.n_outputs,
222
- max_norm_val=self.max_norm_const
194
+ self.temporal_conv_nets = nn.ModuleList(
195
+ [
196
+ nn.Sequential(
197
+ *[
198
+ _TCNResidualBlock(
199
+ in_channels=self.F2,
200
+ kernel_size=tcn_kernel_size,
201
+ n_filters=tcn_n_filters,
202
+ dropout=tcn_drop_prob,
203
+ activation=tcn_activation,
204
+ dilation=2**i,
205
+ )
206
+ for i in range(tcn_depth)
207
+ ]
223
208
  )
224
- ])
225
- else:
226
- self.final_layer = nn.ModuleList([
227
- MaxNormLinear(
228
- in_features=self.F2,
229
- out_features=self.n_outputs,
230
- max_norm_val=self.max_norm_const
231
- ) for _ in range(self.n_windows)
232
- ])
233
-
234
- if self.add_log_softmax:
235
- self.out_fun = nn.LogSoftmax(dim=1)
209
+ for _ in range(self.n_windows)
210
+ ]
211
+ )
212
+
213
+ if self.concat:
214
+ self.final_layer = nn.ModuleList(
215
+ [
216
+ MaxNormLinear(
217
+ in_features=self.F2 * self.n_windows,
218
+ out_features=self.n_outputs,
219
+ max_norm_val=self.max_norm_const,
220
+ )
221
+ ]
222
+ )
236
223
  else:
237
- self.out_fun = nn.Identity()
224
+ self.final_layer = nn.ModuleList(
225
+ [
226
+ MaxNormLinear(
227
+ in_features=self.F2,
228
+ out_features=self.n_outputs,
229
+ max_norm_val=self.max_norm_const,
230
+ )
231
+ for _ in range(self.n_windows)
232
+ ]
233
+ )
234
+
235
+ self.out_fun = nn.Identity()
238
236
 
239
237
  def forward(self, X):
240
238
  # Dimension: (batch_size, C, T)
@@ -250,43 +248,46 @@ class ATCNet(EEGModuleMixin, nn.Module):
250
248
  # Dimension: (batch_size, F2, Tc)
251
249
 
252
250
  # ----- Sliding window -----
253
- sw_concat = [] # to store sliding window outputs
254
- for w in range(self.n_windows):
255
- conv_feat_w = conv_feat[..., w:w + self.Tw]
251
+ sw_concat: list[torch.Tensor] = [] # to store sliding window outputs
252
+ # for w in range(self.n_windows):
253
+ for idx, (attention, tcn_module, final_layer) in enumerate(
254
+ zip(self.attention_blocks, self.temporal_conv_nets, self.final_layer)
255
+ ):
256
+ conv_feat_w = conv_feat[..., idx : idx + self.Tw]
256
257
  # Dimension: (batch_size, F2, Tw)
257
258
 
258
259
  # ----- Attention block -----
259
- att_feat = self.attention_blocks[w](conv_feat_w)
260
+ att_feat = attention(conv_feat_w)
260
261
  # Dimension: (batch_size, F2, Tw)
261
262
 
262
263
  # ----- Temporal convolutional network (TCN) -----
263
- tcn_feat = self.temporal_conv_nets[w](att_feat)[..., -1]
264
+ tcn_feat = tcn_module(att_feat)[..., -1]
264
265
  # Dimension: (batch_size, F2)
265
266
 
266
267
  # Outputs of sliding window can be either averaged after being
267
268
  # mapped by dense layer or concatenated then mapped by a dense
268
269
  # layer
269
270
  if not self.concat:
270
- tcn_feat = self.final_layer[w](tcn_feat)
271
+ tcn_feat = final_layer(tcn_feat)
271
272
 
272
273
  sw_concat.append(tcn_feat)
273
274
 
274
275
  # ----- Aggregation and prediction -----
275
276
  if self.concat:
276
- sw_concat = torch.cat(sw_concat, dim=1)
277
- sw_concat = self.final_layer[0](sw_concat)
277
+ sw_concat_agg = torch.cat(sw_concat, dim=1)
278
+ sw_concat_agg = self.final_layer[0](sw_concat_agg)
278
279
  else:
279
280
  if len(sw_concat) > 1: # more than one window
280
- sw_concat = torch.stack(sw_concat, dim=0)
281
- sw_concat = torch.mean(sw_concat, dim=0)
281
+ sw_concat_agg = torch.stack(sw_concat, dim=0)
282
+ sw_concat_agg = torch.mean(sw_concat_agg, dim=0)
282
283
  else: # one window (# windows = 1)
283
- sw_concat = sw_concat[0]
284
+ sw_concat_agg = sw_concat[0]
284
285
 
285
- return self.out_fun(sw_concat)
286
+ return self.out_fun(sw_concat_agg)
286
287
 
287
288
 
288
289
  class _ConvBlock(nn.Module):
289
- """ Convolutional block proposed in ATCNet [1]_, inspired by the EEGNet
290
+ """Convolutional block proposed in ATCNet [1]_, inspired by the EEGNet
290
291
  architecture [2]_.
291
292
 
292
293
  References
@@ -303,15 +304,15 @@ class _ConvBlock(nn.Module):
303
304
  """
304
305
 
305
306
  def __init__(
306
- self,
307
- n_channels,
308
- n_filters=16,
309
- kernel_length_1=64,
310
- kernel_length_2=16,
311
- pool_size_1=8,
312
- pool_size_2=7,
313
- depth_mult=2,
314
- dropout=0.3,
307
+ self,
308
+ n_channels,
309
+ n_filters=16,
310
+ kernel_length_1=64,
311
+ kernel_length_2=16,
312
+ pool_size_1=8,
313
+ pool_size_2=7,
314
+ depth_mult=2,
315
+ dropout=0.3,
315
316
  ):
316
317
  super().__init__()
317
318
 
@@ -402,11 +403,11 @@ class _AttentionBlock(nn.Module):
402
403
  """
403
404
 
404
405
  def __init__(
405
- self,
406
- in_shape=32,
407
- head_dim=8,
408
- num_heads=2,
409
- dropout=0.5,
406
+ self,
407
+ in_shape=32,
408
+ head_dim=8,
409
+ num_heads=2,
410
+ dropout=0.5,
410
411
  ):
411
412
  super().__init__()
412
413
  self.in_shape = in_shape
@@ -462,7 +463,7 @@ class _AttentionBlock(nn.Module):
462
463
 
463
464
 
464
465
  class _TCNResidualBlock(nn.Module):
465
- """ Modified TCN Residual block as proposed in [1]_. Inspired from
466
+ """Modified TCN Residual block as proposed in [1]_. Inspired from
466
467
  Temporal Convolutional Networks (TCN) [2]_.
467
468
 
468
469
  References
@@ -477,16 +478,16 @@ class _TCNResidualBlock(nn.Module):
477
478
  """
478
479
 
479
480
  def __init__(
480
- self,
481
- in_channels,
482
- kernel_size=4,
483
- n_filters=32,
484
- dropout=0.3,
485
- activation=nn.ELU(),
486
- dilation=1
481
+ self,
482
+ in_channels,
483
+ kernel_size=4,
484
+ n_filters=32,
485
+ dropout=0.3,
486
+ activation: nn.Module = nn.ELU,
487
+ dilation=1,
487
488
  ):
488
489
  super().__init__()
489
- self.activation = activation
490
+ self.activation = activation()
490
491
  self.dilation = dilation
491
492
  self.dropout = dropout
492
493
  self.n_filters = n_filters
@@ -522,7 +523,7 @@ class _TCNResidualBlock(nn.Module):
522
523
  self.reshaping_conv = nn.Conv1d(
523
524
  n_filters,
524
525
  kernel_size=1,
525
- padding='same',
526
+ padding="same",
526
527
  )
527
528
  else:
528
529
  self.reshaping_conv = nn.Identity()
@@ -550,12 +551,12 @@ class _TCNResidualBlock(nn.Module):
550
551
 
551
552
  class _MHA(nn.Module):
552
553
  def __init__(
553
- self,
554
- input_dim: int,
555
- head_dim: int,
556
- output_dim: int,
557
- num_heads: int,
558
- dropout: float = 0.,
554
+ self,
555
+ input_dim: int,
556
+ head_dim: int,
557
+ output_dim: int,
558
+ num_heads: int,
559
+ dropout: float = 0.0,
559
560
  ):
560
561
  """Multi-head Attention
561
562
 
@@ -598,12 +599,9 @@ class _MHA(nn.Module):
598
599
  self.dropout = nn.Dropout(dropout)
599
600
 
600
601
  def forward(
601
- self,
602
- Q: torch.Tensor,
603
- K: torch.Tensor,
604
- V: torch.Tensor
602
+ self, Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor
605
603
  ) -> torch.Tensor:
606
- """ Compute MHA(Q, K, V)
604
+ """Compute MHA(Q, K, V)
607
605
 
608
606
  Parameters
609
607
  ----------
@@ -635,22 +633,18 @@ class _MHA(nn.Module):
635
633
  # Attention weights of size (num_heads * batch_size, n, m):
636
634
  # measures how similar each pair of Q and K is.
637
635
  W = torch.softmax(
638
- Q_.bmm(
639
- K_.transpose(-2, -1) # (B', D', S)
640
- )
641
- / np.sqrt(self.head_dim),
642
- -1
636
+ Q_.bmm(K_.transpose(-2, -1)) / math.sqrt(self.head_dim),
637
+ -1, # (B', D', S)
643
638
  ) # (B', N, M)
644
639
 
645
640
  # Multihead output (batch_size, seq_len, dim):
646
641
  # weighted sum of V where a value gets more weight if its corresponding
647
642
  # key has larger dot product with the query.
648
643
  H = torch.cat(
649
- (
650
- W # (B', S, S)
651
- .bmm(V_) # (B', S, D')
652
- ).split(batch_size, 0), # [(B, S, D')] * num_heads
653
- -1
644
+ (W.bmm(V_)).split( # (B', S, S) # (B', S, D')
645
+ batch_size, 0
646
+ ), # [(B, S, D')] * num_heads
647
+ -1,
654
648
  ) # (B, S, D)
655
649
 
656
650
  out = self.fc_o(H)