braindecode 0.8.1__py3-none-any.whl → 1.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of braindecode might be problematic. Click here for more details.

Files changed (108) hide show
  1. braindecode/__init__.py +1 -2
  2. braindecode/augmentation/__init__.py +39 -19
  3. braindecode/augmentation/base.py +25 -28
  4. braindecode/augmentation/functional.py +237 -100
  5. braindecode/augmentation/transforms.py +325 -158
  6. braindecode/classifier.py +26 -24
  7. braindecode/datasets/__init__.py +28 -10
  8. braindecode/datasets/base.py +220 -134
  9. braindecode/datasets/bbci.py +43 -52
  10. braindecode/datasets/bcicomp.py +47 -32
  11. braindecode/datasets/bids.py +245 -0
  12. braindecode/datasets/mne.py +45 -24
  13. braindecode/datasets/moabb.py +87 -27
  14. braindecode/datasets/nmt.py +311 -0
  15. braindecode/datasets/sleep_physio_challe_18.py +412 -0
  16. braindecode/datasets/sleep_physionet.py +43 -26
  17. braindecode/datasets/tuh.py +324 -140
  18. braindecode/datasets/xy.py +27 -12
  19. braindecode/datautil/__init__.py +37 -18
  20. braindecode/datautil/serialization.py +110 -72
  21. braindecode/eegneuralnet.py +63 -47
  22. braindecode/functional/__init__.py +22 -0
  23. braindecode/functional/functions.py +250 -0
  24. braindecode/functional/initialization.py +47 -0
  25. braindecode/models/__init__.py +84 -14
  26. braindecode/models/atcnet.py +193 -164
  27. braindecode/models/attentionbasenet.py +599 -0
  28. braindecode/models/base.py +86 -102
  29. braindecode/models/biot.py +504 -0
  30. braindecode/models/contrawr.py +317 -0
  31. braindecode/models/ctnet.py +536 -0
  32. braindecode/models/deep4.py +116 -77
  33. braindecode/models/deepsleepnet.py +149 -119
  34. braindecode/models/eegconformer.py +112 -173
  35. braindecode/models/eeginception_erp.py +109 -118
  36. braindecode/models/eeginception_mi.py +161 -97
  37. braindecode/models/eegitnet.py +215 -152
  38. braindecode/models/eegminer.py +254 -0
  39. braindecode/models/eegnet.py +228 -161
  40. braindecode/models/eegnex.py +247 -0
  41. braindecode/models/eegresnet.py +234 -152
  42. braindecode/models/eegsimpleconv.py +199 -0
  43. braindecode/models/eegtcnet.py +335 -0
  44. braindecode/models/fbcnet.py +221 -0
  45. braindecode/models/fblightconvnet.py +313 -0
  46. braindecode/models/fbmsnet.py +324 -0
  47. braindecode/models/hybrid.py +52 -71
  48. braindecode/models/ifnet.py +441 -0
  49. braindecode/models/labram.py +1186 -0
  50. braindecode/models/msvtnet.py +375 -0
  51. braindecode/models/sccnet.py +207 -0
  52. braindecode/models/shallow_fbcsp.py +50 -56
  53. braindecode/models/signal_jepa.py +1011 -0
  54. braindecode/models/sinc_shallow.py +337 -0
  55. braindecode/models/sleep_stager_blanco_2020.py +55 -46
  56. braindecode/models/sleep_stager_chambon_2018.py +54 -53
  57. braindecode/models/sleep_stager_eldele_2021.py +247 -141
  58. braindecode/models/sparcnet.py +424 -0
  59. braindecode/models/summary.csv +41 -0
  60. braindecode/models/syncnet.py +232 -0
  61. braindecode/models/tcn.py +158 -88
  62. braindecode/models/tidnet.py +280 -167
  63. braindecode/models/tsinception.py +283 -0
  64. braindecode/models/usleep.py +190 -177
  65. braindecode/models/util.py +109 -145
  66. braindecode/modules/__init__.py +84 -0
  67. braindecode/modules/activation.py +60 -0
  68. braindecode/modules/attention.py +757 -0
  69. braindecode/modules/blocks.py +108 -0
  70. braindecode/modules/convolution.py +274 -0
  71. braindecode/modules/filter.py +628 -0
  72. braindecode/modules/layers.py +131 -0
  73. braindecode/modules/linear.py +49 -0
  74. braindecode/modules/parametrization.py +38 -0
  75. braindecode/modules/stats.py +77 -0
  76. braindecode/modules/util.py +76 -0
  77. braindecode/modules/wrapper.py +73 -0
  78. braindecode/preprocessing/__init__.py +36 -11
  79. braindecode/preprocessing/mne_preprocess.py +13 -7
  80. braindecode/preprocessing/preprocess.py +139 -75
  81. braindecode/preprocessing/windowers.py +576 -187
  82. braindecode/regressor.py +23 -12
  83. braindecode/samplers/__init__.py +16 -8
  84. braindecode/samplers/base.py +146 -32
  85. braindecode/samplers/ssl.py +162 -17
  86. braindecode/training/__init__.py +18 -10
  87. braindecode/training/callbacks.py +2 -4
  88. braindecode/training/losses.py +3 -8
  89. braindecode/training/scoring.py +76 -68
  90. braindecode/util.py +55 -59
  91. braindecode/version.py +1 -1
  92. braindecode/visualization/__init__.py +2 -3
  93. braindecode/visualization/confusion_matrices.py +117 -73
  94. braindecode/visualization/gradients.py +14 -10
  95. {braindecode-0.8.1.dist-info → braindecode-1.1.0.dist-info}/METADATA +42 -58
  96. braindecode-1.1.0.dist-info/RECORD +101 -0
  97. {braindecode-0.8.1.dist-info → braindecode-1.1.0.dist-info}/WHEEL +1 -1
  98. {braindecode-0.8.1.dist-info → braindecode-1.1.0.dist-info/licenses}/LICENSE.txt +1 -1
  99. braindecode-1.1.0.dist-info/licenses/NOTICE.txt +20 -0
  100. braindecode/datautil/mne.py +0 -9
  101. braindecode/datautil/preprocess.py +0 -12
  102. braindecode/datautil/windowers.py +0 -6
  103. braindecode/datautil/xy.py +0 -9
  104. braindecode/models/eeginception.py +0 -317
  105. braindecode/models/functions.py +0 -47
  106. braindecode/models/modules.py +0 -358
  107. braindecode-0.8.1.dist-info/RECORD +0 -68
  108. {braindecode-0.8.1.dist-info → braindecode-1.1.0.dist-info}/top_level.txt +0 -0
@@ -3,18 +3,23 @@
3
3
  # License: BSD (3-clause)
4
4
 
5
5
  import math
6
- import copy
7
- from copy import deepcopy
8
6
  import warnings
7
+ from copy import deepcopy
9
8
 
10
9
  import torch
11
- from torch import nn
12
10
  import torch.nn.functional as F
13
- from .base import EEGModuleMixin, deprecated_args
11
+ from torch import nn
12
+
13
+ from braindecode.models.base import EEGModuleMixin
14
+ from braindecode.modules import CausalConv1d
14
15
 
15
16
 
16
17
  class SleepStagerEldele2021(EEGModuleMixin, nn.Module):
17
- """Sleep Staging Architecture from Eldele et al 2021.
18
+ """Sleep Staging Architecture from Eldele et al. (2021) [Eldele2021]_.
19
+
20
+ .. figure:: https://raw.githubusercontent.com/emadeldeen24/AttnSleep/refs/heads/main/imgs/AttnSleep.png
21
+ :align: center
22
+ :alt: SleepStagerEldele2021 Architecture
18
23
 
19
24
  Attention based Neural Net for sleep staging as described in [Eldele2021]_.
20
25
  The code for the paper and this model is also available at [1]_.
@@ -43,7 +48,7 @@ class SleepStagerEldele2021(EEGModuleMixin, nn.Module):
43
48
  input dimension of the second FC layer in the same.
44
49
  n_attn_heads : int
45
50
  Number of attention heads. It should be a factor of d_model
46
- dropout : float
51
+ drop_prob : float
47
52
  Dropout rate in the PositionWiseFeedforward layer and the TCE layers.
48
53
  after_reduced_cnn_size : int
49
54
  Number of output channels produced by the convolution in the AFR module.
@@ -55,6 +60,13 @@ class SleepStagerEldele2021(EEGModuleMixin, nn.Module):
55
60
  Alias for `n_outputs`.
56
61
  input_size_s : float
57
62
  Alias for `input_window_seconds`.
63
+ activation: nn.Module, default=nn.ReLU
64
+ Activation function class to apply. Should be a PyTorch activation
65
+ module class like ``nn.ReLU`` or ``nn.ELU``. Default is ``nn.ReLU``.
66
+ activation_mrcnn: nn.Module, default=nn.ReLU
67
+ Activation function class to apply in the Mask R-CNN layer.
68
+ Should be a PyTorch activation module class like ``nn.ReLU`` or
69
+ ``nn.GELU``. Default is ``nn.GELU``.
58
70
 
59
71
  References
60
72
  ----------
@@ -68,28 +80,23 @@ class SleepStagerEldele2021(EEGModuleMixin, nn.Module):
68
80
  """
69
81
 
70
82
  def __init__(
71
- self,
72
- sfreq=None,
73
- n_tce=2,
74
- d_model=80,
75
- d_ff=120,
76
- n_attn_heads=5,
77
- dropout=0.1,
78
- input_window_seconds=30,
79
- n_outputs=5,
80
- after_reduced_cnn_size=30,
81
- return_feats=False,
82
- chs_info=None,
83
- n_chans=None,
84
- n_times=None,
85
- n_classes=None,
86
- input_size_s=None,
83
+ self,
84
+ sfreq=None,
85
+ n_tce=2,
86
+ d_model=80,
87
+ d_ff=120,
88
+ n_attn_heads=5,
89
+ drop_prob=0.1,
90
+ activation_mrcnn: nn.Module = nn.GELU,
91
+ activation: nn.Module = nn.ReLU,
92
+ input_window_seconds=None,
93
+ n_outputs=None,
94
+ after_reduced_cnn_size=30,
95
+ return_feats=False,
96
+ chs_info=None,
97
+ n_chans=None,
98
+ n_times=None,
87
99
  ):
88
- n_outputs, input_window_seconds, = deprecated_args(
89
- self,
90
- ("n_classes", "n_outputs", n_classes, n_outputs),
91
- ("input_size_s", "input_window_seconds", input_size_s, input_window_seconds),
92
- )
93
100
  super().__init__(
94
101
  n_outputs=n_outputs,
95
102
  n_chans=n_chans,
@@ -99,19 +106,25 @@ class SleepStagerEldele2021(EEGModuleMixin, nn.Module):
99
106
  sfreq=sfreq,
100
107
  )
101
108
  del n_outputs, n_chans, chs_info, n_times, input_window_seconds, sfreq
102
- del n_classes, input_size_s
103
109
 
104
110
  self.mapping = {
105
111
  "fc.weight": "final_layer.weight",
106
- "fc.bias": "final_layer.bias"
112
+ "fc.bias": "final_layer.bias",
107
113
  }
108
114
 
109
- if not ((self.input_window_seconds == 30 and self.sfreq == 100 and d_model == 80) or
110
- (self.input_window_seconds == 30 and self.sfreq == 125 and d_model == 100)):
111
- warnings.warn("This model was designed originally for input windows of 30sec at 100Hz, "
112
- "with d_model at 80 or at 125Hz, with d_model at 100, to use anything "
113
- "other than this may cause errors or cause the model to perform in "
114
- "other ways than intended", UserWarning)
115
+ if not (
116
+ (self.input_window_seconds == 30 and self.sfreq == 100 and d_model == 80)
117
+ or (
118
+ self.input_window_seconds == 30 and self.sfreq == 125 and d_model == 100
119
+ )
120
+ ):
121
+ warnings.warn(
122
+ "This model was designed originally for input windows of 30sec at 100Hz, "
123
+ "with d_model at 80 or at 125Hz, with d_model at 100, to use anything "
124
+ "other than this may cause errors or cause the model to perform in "
125
+ "other ways than intended",
126
+ UserWarning,
127
+ )
115
128
 
116
129
  # the usual kernel size for the mrcnn, for sfreq 100
117
130
  kernel_size = 7
@@ -119,11 +132,20 @@ class SleepStagerEldele2021(EEGModuleMixin, nn.Module):
119
132
  if self.sfreq == 125:
120
133
  kernel_size = 6
121
134
 
122
- mrcnn = _MRCNN(after_reduced_cnn_size, kernel_size)
135
+ mrcnn = _MRCNN(
136
+ after_reduced_cnn_size,
137
+ kernel_size,
138
+ activation=activation_mrcnn,
139
+ activation_se=activation,
140
+ )
123
141
  attn = _MultiHeadedAttention(n_attn_heads, d_model, after_reduced_cnn_size)
124
- ff = _PositionwiseFeedForward(d_model, d_ff, dropout)
125
- tce = _TCE(_EncoderLayer(d_model, deepcopy(attn), deepcopy(ff), after_reduced_cnn_size,
126
- dropout), n_tce)
142
+ ff = _PositionwiseFeedForward(d_model, d_ff, drop_prob, activation=activation)
143
+ tce = _TCE(
144
+ _EncoderLayer(
145
+ d_model, deepcopy(attn), deepcopy(ff), after_reduced_cnn_size, drop_prob
146
+ ),
147
+ n_tce,
148
+ )
127
149
 
128
150
  self.feature_extractor = nn.Sequential(mrcnn, tce)
129
151
  self.len_last_layer = self._len_last_layer(self.n_times)
@@ -133,7 +155,9 @@ class SleepStagerEldele2021(EEGModuleMixin, nn.Module):
133
155
  """if return_feats:
134
156
  raise ValueError("return_feat == True is not accepted anymore")"""
135
157
  if not return_feats:
136
- self.final_layer = nn.Linear(d_model * after_reduced_cnn_size, self.n_outputs)
158
+ self.final_layer = nn.Linear(
159
+ d_model * after_reduced_cnn_size, self.n_outputs
160
+ )
137
161
 
138
162
  def _len_last_layer(self, input_size):
139
163
  self.feature_extractor.eval()
@@ -142,7 +166,7 @@ class SleepStagerEldele2021(EEGModuleMixin, nn.Module):
142
166
  self.feature_extractor.train()
143
167
  return len(out.flatten())
144
168
 
145
- def forward(self, x):
169
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
146
170
  """
147
171
  Forward pass.
148
172
 
@@ -153,27 +177,41 @@ class SleepStagerEldele2021(EEGModuleMixin, nn.Module):
153
177
  """
154
178
 
155
179
  encoded_features = self.feature_extractor(x)
156
- encoded_features = encoded_features.contiguous().view(encoded_features.shape[0], -1)
180
+ encoded_features = encoded_features.contiguous().view(
181
+ encoded_features.shape[0], -1
182
+ )
157
183
 
158
184
  if self.return_feats:
159
185
  return encoded_features
160
- else:
161
- final_output = self.final_layer(encoded_features)
162
- return final_output
186
+
187
+ return self.final_layer(encoded_features)
163
188
 
164
189
 
165
190
  class _SELayer(nn.Module):
166
- def __init__(self, channel, reduction=16):
191
+ def __init__(self, channel, reduction=16, activation=nn.ReLU):
167
192
  super(_SELayer, self).__init__()
168
193
  self.avg_pool = nn.AdaptiveAvgPool1d(1)
169
194
  self.fc = nn.Sequential(
170
195
  nn.Linear(channel, channel // reduction, bias=False),
171
- nn.ReLU(inplace=True),
196
+ activation(inplace=True),
172
197
  nn.Linear(channel // reduction, channel, bias=False),
173
- nn.Sigmoid()
198
+ nn.Sigmoid(),
174
199
  )
175
200
 
176
- def forward(self, x):
201
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
202
+ """
203
+ Forward pass of the SE layer.
204
+
205
+ Parameters
206
+ ----------
207
+ x : torch.Tensor
208
+ Input tensor of shape (batch_size, channel, length).
209
+
210
+ Returns
211
+ -------
212
+ torch.Tensor
213
+ Output tensor after applying the SE recalibration.
214
+ """
177
215
  b, c, _ = x.size()
178
216
  y = self.avg_pool(x).view(b, c)
179
217
  y = self.fc(y).view(b, c, 1)
@@ -183,22 +221,43 @@ class _SELayer(nn.Module):
183
221
  class _SEBasicBlock(nn.Module):
184
222
  expansion = 1
185
223
 
186
- def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
187
- base_width=64, dilation=1, norm_layer=None,
188
- *, reduction=16):
224
+ def __init__(
225
+ self,
226
+ inplanes,
227
+ planes,
228
+ stride=1,
229
+ downsample=None,
230
+ activation: nn.Module = nn.ReLU,
231
+ *,
232
+ reduction=16,
233
+ ):
189
234
  super(_SEBasicBlock, self).__init__()
190
235
  self.conv1 = nn.Conv1d(inplanes, planes, stride)
191
236
  self.bn1 = nn.BatchNorm1d(planes)
192
- self.relu = nn.ReLU(inplace=True)
237
+ self.relu = activation(inplace=True)
193
238
  self.conv2 = nn.Conv1d(planes, planes, 1)
194
239
  self.bn2 = nn.BatchNorm1d(planes)
195
240
  self.se = _SELayer(planes, reduction)
196
241
  self.downsample = downsample
197
242
  self.stride = stride
198
- self.features = nn.Sequential(self.conv1, self.bn1, self.relu, self.conv2, self.bn2,
199
- self.se)
243
+ self.features = nn.Sequential(
244
+ self.conv1, self.bn1, self.relu, self.conv2, self.bn2, self.se
245
+ )
200
246
 
201
- def forward(self, x):
247
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
248
+ """
249
+ Forward pass of the SE layer.
250
+
251
+ Parameters
252
+ ----------
253
+ x : torch.Tensor
254
+ Input tensor of shape (batch_size, n_chans, n_times).
255
+
256
+ Returns
257
+ -------
258
+ torch.Tensor
259
+ Output tensor after applying the SE recalibration.
260
+ """
202
261
  residual = x
203
262
  out = self.features(x)
204
263
 
@@ -212,26 +271,29 @@ class _SEBasicBlock(nn.Module):
212
271
 
213
272
 
214
273
  class _MRCNN(nn.Module):
215
- def __init__(self, after_reduced_cnn_size, kernel_size=7):
274
+ def __init__(
275
+ self,
276
+ after_reduced_cnn_size,
277
+ kernel_size=7,
278
+ activation: nn.Module = nn.GELU,
279
+ activation_se: nn.Module = nn.ReLU,
280
+ ):
216
281
  super(_MRCNN, self).__init__()
217
282
  drate = 0.5
218
- self.GELU = nn.GELU()
283
+ self.GELU = activation()
219
284
  self.features1 = nn.Sequential(
220
285
  nn.Conv1d(1, 64, kernel_size=50, stride=6, bias=False, padding=24),
221
286
  nn.BatchNorm1d(64),
222
287
  self.GELU,
223
288
  nn.MaxPool1d(kernel_size=8, stride=2, padding=4),
224
289
  nn.Dropout(drate),
225
-
226
290
  nn.Conv1d(64, 128, kernel_size=8, stride=1, bias=False, padding=4),
227
291
  nn.BatchNorm1d(128),
228
292
  self.GELU,
229
-
230
293
  nn.Conv1d(128, 128, kernel_size=8, stride=1, bias=False, padding=4),
231
294
  nn.BatchNorm1d(128),
232
295
  self.GELU,
233
-
234
- nn.MaxPool1d(kernel_size=4, stride=4, padding=2)
296
+ nn.MaxPool1d(kernel_size=4, stride=4, padding=2),
235
297
  )
236
298
 
237
299
  self.features2 = nn.Sequential(
@@ -240,28 +302,38 @@ class _MRCNN(nn.Module):
240
302
  self.GELU,
241
303
  nn.MaxPool1d(kernel_size=4, stride=2, padding=2),
242
304
  nn.Dropout(drate),
243
-
244
- nn.Conv1d(64, 128, kernel_size=kernel_size, stride=1, bias=False, padding=3),
305
+ nn.Conv1d(
306
+ 64, 128, kernel_size=kernel_size, stride=1, bias=False, padding=3
307
+ ),
245
308
  nn.BatchNorm1d(128),
246
309
  self.GELU,
247
-
248
- nn.Conv1d(128, 128, kernel_size=kernel_size, stride=1, bias=False, padding=3),
310
+ nn.Conv1d(
311
+ 128, 128, kernel_size=kernel_size, stride=1, bias=False, padding=3
312
+ ),
249
313
  nn.BatchNorm1d(128),
250
314
  self.GELU,
251
-
252
- nn.MaxPool1d(kernel_size=2, stride=2, padding=1)
315
+ nn.MaxPool1d(kernel_size=2, stride=2, padding=1),
253
316
  )
254
317
 
255
318
  self.dropout = nn.Dropout(drate)
256
319
  self.inplanes = 128
257
- self.AFR = self._make_layer(_SEBasicBlock, after_reduced_cnn_size, 1)
320
+ self.AFR = self._make_layer(
321
+ _SEBasicBlock, after_reduced_cnn_size, 1, activate=activation_se
322
+ )
258
323
 
259
- def _make_layer(self, block, planes, blocks, stride=1): # makes residual SE block
324
+ def _make_layer(
325
+ self, block, planes, blocks, stride=1, activate: nn.Module = nn.ReLU
326
+ ): # makes residual SE block
260
327
  downsample = None
261
328
  if stride != 1 or self.inplanes != planes * block.expansion:
262
329
  downsample = nn.Sequential(
263
- nn.Conv1d(self.inplanes, planes * block.expansion,
264
- kernel_size=1, stride=stride, bias=False),
330
+ nn.Conv1d(
331
+ self.inplanes,
332
+ planes * block.expansion,
333
+ kernel_size=1,
334
+ stride=stride,
335
+ bias=False,
336
+ ),
265
337
  nn.BatchNorm1d(planes * block.expansion),
266
338
  )
267
339
 
@@ -269,11 +341,11 @@ class _MRCNN(nn.Module):
269
341
  layers.append(block(self.inplanes, planes, stride, downsample))
270
342
  self.inplanes = planes * block.expansion
271
343
  for i in range(1, blocks):
272
- layers.append(block(self.inplanes, planes))
344
+ layers.append(block(self.inplanes, planes, activate=activate))
273
345
 
274
346
  return nn.Sequential(*layers)
275
347
 
276
- def forward(self, x):
348
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
277
349
  x1 = self.features1(x)
278
350
  x2 = self.features2(x)
279
351
  x_concat = torch.cat((x1, x2), dim=2)
@@ -285,93 +357,107 @@ class _MRCNN(nn.Module):
285
357
  ##########################################################################################
286
358
 
287
359
 
288
- def _attention(query, key, value, dropout=None):
360
+ def _attention(
361
+ query: torch.Tensor, key: torch.Tensor, value: torch.Tensor
362
+ ) -> tuple[torch.Tensor, torch.Tensor]:
289
363
  """Implementation of Scaled dot product attention"""
290
364
  # d_k - dimension of the query and key vectors
291
365
  d_k = query.size(-1)
292
366
  scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)
293
-
294
- p_attn = F.softmax(scores, dim=-1)
295
- if dropout is not None:
296
- p_attn = dropout(p_attn)
297
- return torch.matmul(p_attn, value), p_attn
298
-
299
-
300
- class _CausalConv1d(torch.nn.Conv1d):
301
- def __init__(self,
302
- in_channels,
303
- out_channels,
304
- kernel_size,
305
- stride=1,
306
- dilation=1,
307
- groups=1,
308
- bias=True):
309
- self.__padding = (kernel_size - 1) * dilation
310
-
311
- super(_CausalConv1d, self).__init__(
312
- in_channels,
313
- out_channels,
314
- kernel_size=kernel_size,
315
- stride=stride,
316
- padding=self.__padding,
317
- dilation=dilation,
318
- groups=groups,
319
- bias=bias)
320
-
321
- def forward(self, input):
322
- result = super(_CausalConv1d, self).forward(input)
323
- if self.__padding != 0:
324
- return result[:, :, :-self.__padding]
325
- return result
367
+ p_attn = F.softmax(scores, dim=-1) # attention weights
368
+ output = torch.matmul(p_attn, value) # (B, h, T, d_k)
369
+ return output, p_attn
326
370
 
327
371
 
328
372
  class _MultiHeadedAttention(nn.Module):
329
373
  def __init__(self, h, d_model, after_reduced_cnn_size, dropout=0.1):
330
374
  """Take in model size and number of heads."""
331
- super(_MultiHeadedAttention, self).__init__()
375
+ super().__init__()
332
376
  assert d_model % h == 0
333
377
  self.d_per_head = d_model // h
334
378
  self.h = h
335
379
 
336
- self.convs = _clones(_CausalConv1d(after_reduced_cnn_size, after_reduced_cnn_size,
337
- kernel_size=7, stride=1), 3)
380
+ base_conv = CausalConv1d(
381
+ in_channels=after_reduced_cnn_size,
382
+ out_channels=after_reduced_cnn_size,
383
+ kernel_size=7,
384
+ stride=1,
385
+ )
386
+ self.convs = nn.ModuleList([deepcopy(base_conv) for _ in range(3)])
387
+
338
388
  self.linear = nn.Linear(d_model, d_model)
339
389
  self.dropout = nn.Dropout(p=dropout)
340
390
 
341
- def forward(self, query, key, value):
391
+ def forward(self, query, key, value: torch.Tensor) -> torch.Tensor:
342
392
  """Implements Multi-head attention"""
343
393
  nbatches = query.size(0)
344
394
 
345
395
  query = query.view(nbatches, -1, self.h, self.d_per_head).transpose(1, 2)
346
- key = self.convs[1](key).view(nbatches, -1, self.h, self.d_per_head).transpose(1, 2)
347
- value = self.convs[2](value).view(nbatches, -1, self.h, self.d_per_head).transpose(1, 2)
396
+ key = (
397
+ self.convs[1](key)
398
+ .view(nbatches, -1, self.h, self.d_per_head)
399
+ .transpose(1, 2)
400
+ )
401
+ value = (
402
+ self.convs[2](value)
403
+ .view(nbatches, -1, self.h, self.d_per_head)
404
+ .transpose(1, 2)
405
+ )
406
+
407
+ x_raw, attn_weights = _attention(query, key, value)
408
+ # apply dropout to the *weights*
409
+ attn = self.dropout(attn_weights)
410
+ # recompute the weighted sum with dropped weights
411
+ x = torch.matmul(attn, value)
348
412
 
349
- x, self.attn = _attention(query, key, value, dropout=self.dropout)
413
+ # stash the pre‑dropout weights if you need them
414
+ self.attn = attn_weights
350
415
 
351
- x = x.transpose(1, 2).contiguous() \
352
- .view(nbatches, -1, self.h * self.d_per_head)
416
+ # merge heads and project
417
+ x = x.transpose(1, 2).contiguous().view(nbatches, -1, self.h * self.d_per_head)
353
418
 
354
419
  return self.linear(x)
355
420
 
356
421
 
357
- class _SublayerOutput(nn.Module):
422
+ class _ResidualLayerNormAttn(nn.Module):
358
423
  """
359
424
  A residual connection followed by a layer norm.
360
425
  """
361
426
 
362
- def __init__(self, size, dropout):
363
- super(_SublayerOutput, self).__init__()
427
+ def __init__(self, size, dropout, fn_attn):
428
+ super().__init__()
429
+ self.norm = nn.LayerNorm(size, eps=1e-6)
430
+ self.dropout = nn.Dropout(dropout)
431
+ self.fn_attn = fn_attn
432
+
433
+ def forward(
434
+ self,
435
+ x: torch.Tensor,
436
+ key: torch.Tensor,
437
+ value: torch.Tensor,
438
+ ) -> torch.Tensor:
439
+ """Apply residual connection to any sublayer with the same size."""
440
+ x_norm = self.norm(x)
441
+
442
+ out = self.fn_attn(x_norm, key, value)
443
+
444
+ return x + self.dropout(out)
445
+
446
+
447
+ class _ResidualLayerNormFF(nn.Module):
448
+ def __init__(self, size, dropout, fn_ff):
449
+ super().__init__()
364
450
  self.norm = nn.LayerNorm(size, eps=1e-6)
365
451
  self.dropout = nn.Dropout(dropout)
452
+ self.fn_ff = fn_ff
366
453
 
367
- def forward(self, x, sublayer):
454
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
368
455
  """Apply residual connection to any sublayer with the same size."""
369
- return x + self.dropout(sublayer(self.norm(x)))
456
+ x_norm = self.norm(x)
370
457
 
458
+ out = self.fn_ff(x_norm)
371
459
 
372
- def _clones(module, n):
373
- """Produce n identical layers."""
374
- return nn.ModuleList([copy.deepcopy(module) for _ in range(n)])
460
+ return x + self.dropout(out)
375
461
 
376
462
 
377
463
  class _TCE(nn.Module):
@@ -381,11 +467,13 @@ class _TCE(nn.Module):
381
467
  """
382
468
 
383
469
  def __init__(self, layer, n):
384
- super(_TCE, self).__init__()
385
- self.layers = _clones(layer, n)
470
+ super().__init__()
471
+
472
+ self.layers = nn.ModuleList([deepcopy(layer) for _ in range(n)])
473
+
386
474
  self.norm = nn.LayerNorm(layer.size, eps=1e-6)
387
475
 
388
- def forward(self, x):
476
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
389
477
  for layer in self.layers:
390
478
  x = layer(x)
391
479
  return self.norm(x)
@@ -395,35 +483,53 @@ class _EncoderLayer(nn.Module):
395
483
  """
396
484
  An encoder layer
397
485
  Made up of self-attention and a feed forward layer.
398
- Each of these sublayers have residual and layer norm, implemented by _SublayerOutput.
486
+ Each of these sublayers have residual and layer norm, implemented by _ResidualLayerNorm.
399
487
  """
400
488
 
401
489
  def __init__(self, size, self_attn, feed_forward, after_reduced_cnn_size, dropout):
402
- super(_EncoderLayer, self).__init__()
490
+ super().__init__()
491
+ self.size = size
403
492
  self.self_attn = self_attn
404
493
  self.feed_forward = feed_forward
405
- self.sublayer_output = _clones(_SublayerOutput(size, dropout), 2)
406
- self.size = size
407
- self.conv = _CausalConv1d(after_reduced_cnn_size, after_reduced_cnn_size, kernel_size=7,
408
- stride=1, dilation=1)
409
494
 
410
- def forward(self, x_in):
495
+ self.residual_self_attn = _ResidualLayerNormAttn(
496
+ size=size,
497
+ dropout=dropout,
498
+ fn_attn=self_attn,
499
+ )
500
+ self.residual_ff = _ResidualLayerNormFF(
501
+ size=size,
502
+ dropout=dropout,
503
+ fn_ff=feed_forward,
504
+ )
505
+
506
+ self.conv = CausalConv1d(
507
+ in_channels=after_reduced_cnn_size,
508
+ out_channels=after_reduced_cnn_size,
509
+ kernel_size=7,
510
+ stride=1,
511
+ dilation=1,
512
+ )
513
+
514
+ def forward(self, x_in: torch.Tensor) -> torch.Tensor:
411
515
  """Transformer Encoder"""
412
516
  query = self.conv(x_in)
413
517
  # Encoder self-attention
414
- x = self.sublayer_output[0](query, lambda x: self.self_attn(query, x_in, x_in))
415
- return self.sublayer_output[1](x, self.feed_forward)
518
+ x = self.residual_self_attn(query, x_in, x_in)
519
+ x_ff = self.residual_ff(x)
520
+ return x_ff
416
521
 
417
522
 
418
523
  class _PositionwiseFeedForward(nn.Module):
419
524
  """Positionwise feed-forward network."""
420
525
 
421
- def __init__(self, d_model, d_ff, dropout=0.1):
422
- super(_PositionwiseFeedForward, self).__init__()
526
+ def __init__(self, d_model, d_ff, dropout=0.1, activation: nn.Module = nn.ReLU):
527
+ super().__init__()
423
528
  self.w_1 = nn.Linear(d_model, d_ff)
424
529
  self.w_2 = nn.Linear(d_ff, d_model)
425
530
  self.dropout = nn.Dropout(dropout)
531
+ self.activate = activation()
426
532
 
427
- def forward(self, x):
533
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
428
534
  """Implements FFN equation."""
429
- return self.w_2(self.dropout(F.relu(self.w_1(x))))
535
+ return self.w_2(self.dropout(self.activate(self.w_1(x))))