braindecode 0.8.1__py3-none-any.whl → 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of braindecode might be problematic. Click here for more details.

Files changed (108) hide show
  1. braindecode/__init__.py +1 -2
  2. braindecode/augmentation/__init__.py +39 -19
  3. braindecode/augmentation/base.py +25 -28
  4. braindecode/augmentation/functional.py +237 -100
  5. braindecode/augmentation/transforms.py +326 -158
  6. braindecode/classifier.py +26 -24
  7. braindecode/datasets/__init__.py +28 -10
  8. braindecode/datasets/base.py +220 -134
  9. braindecode/datasets/bbci.py +43 -52
  10. braindecode/datasets/bcicomp.py +47 -32
  11. braindecode/datasets/bids.py +245 -0
  12. braindecode/datasets/mne.py +45 -24
  13. braindecode/datasets/moabb.py +87 -27
  14. braindecode/datasets/nmt.py +311 -0
  15. braindecode/datasets/sleep_physio_challe_18.py +412 -0
  16. braindecode/datasets/sleep_physionet.py +43 -26
  17. braindecode/datasets/tuh.py +324 -140
  18. braindecode/datasets/xy.py +27 -12
  19. braindecode/datautil/__init__.py +34 -18
  20. braindecode/datautil/serialization.py +98 -71
  21. braindecode/eegneuralnet.py +63 -47
  22. braindecode/functional/__init__.py +10 -0
  23. braindecode/functional/functions.py +251 -0
  24. braindecode/functional/initialization.py +47 -0
  25. braindecode/models/__init__.py +36 -14
  26. braindecode/models/atcnet.py +153 -159
  27. braindecode/models/attentionbasenet.py +550 -0
  28. braindecode/models/base.py +86 -102
  29. braindecode/models/biot.py +483 -0
  30. braindecode/models/contrawr.py +296 -0
  31. braindecode/models/ctnet.py +450 -0
  32. braindecode/models/deep4.py +64 -75
  33. braindecode/models/deepsleepnet.py +149 -119
  34. braindecode/models/eegconformer.py +111 -171
  35. braindecode/models/eeginception_erp.py +109 -118
  36. braindecode/models/eeginception_mi.py +155 -97
  37. braindecode/models/eegitnet.py +215 -151
  38. braindecode/models/eegminer.py +255 -0
  39. braindecode/models/eegnet.py +229 -161
  40. braindecode/models/eegnex.py +247 -0
  41. braindecode/models/eegresnet.py +234 -152
  42. braindecode/models/eegsimpleconv.py +199 -0
  43. braindecode/models/eegtcnet.py +335 -0
  44. braindecode/models/fbcnet.py +221 -0
  45. braindecode/models/fblightconvnet.py +313 -0
  46. braindecode/models/fbmsnet.py +325 -0
  47. braindecode/models/hybrid.py +52 -71
  48. braindecode/models/ifnet.py +441 -0
  49. braindecode/models/labram.py +1166 -0
  50. braindecode/models/msvtnet.py +375 -0
  51. braindecode/models/sccnet.py +182 -0
  52. braindecode/models/shallow_fbcsp.py +50 -56
  53. braindecode/models/signal_jepa.py +1012 -0
  54. braindecode/models/sinc_shallow.py +337 -0
  55. braindecode/models/sleep_stager_blanco_2020.py +55 -46
  56. braindecode/models/sleep_stager_chambon_2018.py +54 -53
  57. braindecode/models/sleep_stager_eldele_2021.py +248 -141
  58. braindecode/models/sparcnet.py +378 -0
  59. braindecode/models/summary.csv +41 -0
  60. braindecode/models/syncnet.py +232 -0
  61. braindecode/models/tcn.py +158 -88
  62. braindecode/models/tidnet.py +280 -167
  63. braindecode/models/tsinception.py +258 -0
  64. braindecode/models/usleep.py +190 -177
  65. braindecode/models/util.py +109 -141
  66. braindecode/modules/__init__.py +38 -0
  67. braindecode/modules/activation.py +60 -0
  68. braindecode/modules/attention.py +757 -0
  69. braindecode/modules/blocks.py +108 -0
  70. braindecode/modules/convolution.py +274 -0
  71. braindecode/modules/filter.py +632 -0
  72. braindecode/modules/layers.py +133 -0
  73. braindecode/modules/linear.py +50 -0
  74. braindecode/modules/parametrization.py +38 -0
  75. braindecode/modules/stats.py +77 -0
  76. braindecode/modules/util.py +77 -0
  77. braindecode/modules/wrapper.py +75 -0
  78. braindecode/preprocessing/__init__.py +36 -11
  79. braindecode/preprocessing/mne_preprocess.py +13 -7
  80. braindecode/preprocessing/preprocess.py +139 -75
  81. braindecode/preprocessing/windowers.py +576 -187
  82. braindecode/regressor.py +23 -12
  83. braindecode/samplers/__init__.py +16 -8
  84. braindecode/samplers/base.py +148 -32
  85. braindecode/samplers/ssl.py +162 -17
  86. braindecode/training/__init__.py +18 -10
  87. braindecode/training/callbacks.py +2 -4
  88. braindecode/training/losses.py +3 -8
  89. braindecode/training/scoring.py +76 -68
  90. braindecode/util.py +55 -59
  91. braindecode/version.py +1 -1
  92. braindecode/visualization/__init__.py +2 -3
  93. braindecode/visualization/confusion_matrices.py +117 -73
  94. braindecode/visualization/gradients.py +14 -10
  95. {braindecode-0.8.1.dist-info → braindecode-1.0.0.dist-info}/METADATA +39 -55
  96. braindecode-1.0.0.dist-info/RECORD +101 -0
  97. {braindecode-0.8.1.dist-info → braindecode-1.0.0.dist-info}/WHEEL +1 -1
  98. {braindecode-0.8.1.dist-info → braindecode-1.0.0.dist-info/licenses}/LICENSE.txt +1 -1
  99. braindecode-1.0.0.dist-info/licenses/NOTICE.txt +20 -0
  100. braindecode/datautil/mne.py +0 -9
  101. braindecode/datautil/preprocess.py +0 -12
  102. braindecode/datautil/windowers.py +0 -6
  103. braindecode/datautil/xy.py +0 -9
  104. braindecode/models/eeginception.py +0 -317
  105. braindecode/models/functions.py +0 -47
  106. braindecode/models/modules.py +0 -358
  107. braindecode-0.8.1.dist-info/RECORD +0 -68
  108. {braindecode-0.8.1.dist-info → braindecode-1.0.0.dist-info}/top_level.txt +0 -0
@@ -3,18 +3,24 @@
3
3
  # License: BSD (3-clause)
4
4
 
5
5
  import math
6
- import copy
7
- from copy import deepcopy
8
6
  import warnings
7
+ from copy import deepcopy
8
+ from typing import Callable, Optional
9
9
 
10
10
  import torch
11
- from torch import nn
12
11
  import torch.nn.functional as F
13
- from .base import EEGModuleMixin, deprecated_args
12
+ from torch import nn
13
+
14
+ from braindecode.models.base import EEGModuleMixin
15
+ from braindecode.modules import CausalConv1d
14
16
 
15
17
 
16
18
  class SleepStagerEldele2021(EEGModuleMixin, nn.Module):
17
- """Sleep Staging Architecture from Eldele et al 2021.
19
+ """Sleep Staging Architecture from Eldele et al. (2021) [Eldele2021]_.
20
+
21
+ .. figure:: https://raw.githubusercontent.com/emadeldeen24/AttnSleep/refs/heads/main/imgs/AttnSleep.png
22
+ :align: center
23
+ :alt: SleepStagerEldele2021 Architecture
18
24
 
19
25
  Attention based Neural Net for sleep staging as described in [Eldele2021]_.
20
26
  The code for the paper and this model is also available at [1]_.
@@ -43,7 +49,7 @@ class SleepStagerEldele2021(EEGModuleMixin, nn.Module):
43
49
  input dimension of the second FC layer in the same.
44
50
  n_attn_heads : int
45
51
  Number of attention heads. It should be a factor of d_model
46
- dropout : float
52
+ drop_prob : float
47
53
  Dropout rate in the PositionWiseFeedforward layer and the TCE layers.
48
54
  after_reduced_cnn_size : int
49
55
  Number of output channels produced by the convolution in the AFR module.
@@ -55,6 +61,13 @@ class SleepStagerEldele2021(EEGModuleMixin, nn.Module):
55
61
  Alias for `n_outputs`.
56
62
  input_size_s : float
57
63
  Alias for `input_window_seconds`.
64
+ activation: nn.Module, default=nn.ReLU
65
+ Activation function class to apply. Should be a PyTorch activation
66
+ module class like ``nn.ReLU`` or ``nn.ELU``. Default is ``nn.ReLU``.
67
+ activation_mrcnn: nn.Module, default=nn.ReLU
68
+ Activation function class to apply in the Mask R-CNN layer.
69
+ Should be a PyTorch activation module class like ``nn.ReLU`` or
70
+ ``nn.GELU``. Default is ``nn.GELU``.
58
71
 
59
72
  References
60
73
  ----------
@@ -68,28 +81,23 @@ class SleepStagerEldele2021(EEGModuleMixin, nn.Module):
68
81
  """
69
82
 
70
83
  def __init__(
71
- self,
72
- sfreq=None,
73
- n_tce=2,
74
- d_model=80,
75
- d_ff=120,
76
- n_attn_heads=5,
77
- dropout=0.1,
78
- input_window_seconds=30,
79
- n_outputs=5,
80
- after_reduced_cnn_size=30,
81
- return_feats=False,
82
- chs_info=None,
83
- n_chans=None,
84
- n_times=None,
85
- n_classes=None,
86
- input_size_s=None,
84
+ self,
85
+ sfreq=None,
86
+ n_tce=2,
87
+ d_model=80,
88
+ d_ff=120,
89
+ n_attn_heads=5,
90
+ drop_prob=0.1,
91
+ activation_mrcnn: nn.Module = nn.GELU,
92
+ activation: nn.Module = nn.ReLU,
93
+ input_window_seconds=None,
94
+ n_outputs=None,
95
+ after_reduced_cnn_size=30,
96
+ return_feats=False,
97
+ chs_info=None,
98
+ n_chans=None,
99
+ n_times=None,
87
100
  ):
88
- n_outputs, input_window_seconds, = deprecated_args(
89
- self,
90
- ("n_classes", "n_outputs", n_classes, n_outputs),
91
- ("input_size_s", "input_window_seconds", input_size_s, input_window_seconds),
92
- )
93
101
  super().__init__(
94
102
  n_outputs=n_outputs,
95
103
  n_chans=n_chans,
@@ -99,19 +107,25 @@ class SleepStagerEldele2021(EEGModuleMixin, nn.Module):
99
107
  sfreq=sfreq,
100
108
  )
101
109
  del n_outputs, n_chans, chs_info, n_times, input_window_seconds, sfreq
102
- del n_classes, input_size_s
103
110
 
104
111
  self.mapping = {
105
112
  "fc.weight": "final_layer.weight",
106
- "fc.bias": "final_layer.bias"
113
+ "fc.bias": "final_layer.bias",
107
114
  }
108
115
 
109
- if not ((self.input_window_seconds == 30 and self.sfreq == 100 and d_model == 80) or
110
- (self.input_window_seconds == 30 and self.sfreq == 125 and d_model == 100)):
111
- warnings.warn("This model was designed originally for input windows of 30sec at 100Hz, "
112
- "with d_model at 80 or at 125Hz, with d_model at 100, to use anything "
113
- "other than this may cause errors or cause the model to perform in "
114
- "other ways than intended", UserWarning)
116
+ if not (
117
+ (self.input_window_seconds == 30 and self.sfreq == 100 and d_model == 80)
118
+ or (
119
+ self.input_window_seconds == 30 and self.sfreq == 125 and d_model == 100
120
+ )
121
+ ):
122
+ warnings.warn(
123
+ "This model was designed originally for input windows of 30sec at 100Hz, "
124
+ "with d_model at 80 or at 125Hz, with d_model at 100, to use anything "
125
+ "other than this may cause errors or cause the model to perform in "
126
+ "other ways than intended",
127
+ UserWarning,
128
+ )
115
129
 
116
130
  # the usual kernel size for the mrcnn, for sfreq 100
117
131
  kernel_size = 7
@@ -119,11 +133,20 @@ class SleepStagerEldele2021(EEGModuleMixin, nn.Module):
119
133
  if self.sfreq == 125:
120
134
  kernel_size = 6
121
135
 
122
- mrcnn = _MRCNN(after_reduced_cnn_size, kernel_size)
136
+ mrcnn = _MRCNN(
137
+ after_reduced_cnn_size,
138
+ kernel_size,
139
+ activation=activation_mrcnn,
140
+ activation_se=activation,
141
+ )
123
142
  attn = _MultiHeadedAttention(n_attn_heads, d_model, after_reduced_cnn_size)
124
- ff = _PositionwiseFeedForward(d_model, d_ff, dropout)
125
- tce = _TCE(_EncoderLayer(d_model, deepcopy(attn), deepcopy(ff), after_reduced_cnn_size,
126
- dropout), n_tce)
143
+ ff = _PositionwiseFeedForward(d_model, d_ff, drop_prob, activation=activation)
144
+ tce = _TCE(
145
+ _EncoderLayer(
146
+ d_model, deepcopy(attn), deepcopy(ff), after_reduced_cnn_size, drop_prob
147
+ ),
148
+ n_tce,
149
+ )
127
150
 
128
151
  self.feature_extractor = nn.Sequential(mrcnn, tce)
129
152
  self.len_last_layer = self._len_last_layer(self.n_times)
@@ -133,7 +156,9 @@ class SleepStagerEldele2021(EEGModuleMixin, nn.Module):
133
156
  """if return_feats:
134
157
  raise ValueError("return_feat == True is not accepted anymore")"""
135
158
  if not return_feats:
136
- self.final_layer = nn.Linear(d_model * after_reduced_cnn_size, self.n_outputs)
159
+ self.final_layer = nn.Linear(
160
+ d_model * after_reduced_cnn_size, self.n_outputs
161
+ )
137
162
 
138
163
  def _len_last_layer(self, input_size):
139
164
  self.feature_extractor.eval()
@@ -142,7 +167,7 @@ class SleepStagerEldele2021(EEGModuleMixin, nn.Module):
142
167
  self.feature_extractor.train()
143
168
  return len(out.flatten())
144
169
 
145
- def forward(self, x):
170
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
146
171
  """
147
172
  Forward pass.
148
173
 
@@ -153,27 +178,41 @@ class SleepStagerEldele2021(EEGModuleMixin, nn.Module):
153
178
  """
154
179
 
155
180
  encoded_features = self.feature_extractor(x)
156
- encoded_features = encoded_features.contiguous().view(encoded_features.shape[0], -1)
181
+ encoded_features = encoded_features.contiguous().view(
182
+ encoded_features.shape[0], -1
183
+ )
157
184
 
158
185
  if self.return_feats:
159
186
  return encoded_features
160
- else:
161
- final_output = self.final_layer(encoded_features)
162
- return final_output
187
+
188
+ return self.final_layer(encoded_features)
163
189
 
164
190
 
165
191
  class _SELayer(nn.Module):
166
- def __init__(self, channel, reduction=16):
192
+ def __init__(self, channel, reduction=16, activation=nn.ReLU):
167
193
  super(_SELayer, self).__init__()
168
194
  self.avg_pool = nn.AdaptiveAvgPool1d(1)
169
195
  self.fc = nn.Sequential(
170
196
  nn.Linear(channel, channel // reduction, bias=False),
171
- nn.ReLU(inplace=True),
197
+ activation(inplace=True),
172
198
  nn.Linear(channel // reduction, channel, bias=False),
173
- nn.Sigmoid()
199
+ nn.Sigmoid(),
174
200
  )
175
201
 
176
- def forward(self, x):
202
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
203
+ """
204
+ Forward pass of the SE layer.
205
+
206
+ Parameters
207
+ ----------
208
+ x : torch.Tensor
209
+ Input tensor of shape (batch_size, channel, length).
210
+
211
+ Returns
212
+ -------
213
+ torch.Tensor
214
+ Output tensor after applying the SE recalibration.
215
+ """
177
216
  b, c, _ = x.size()
178
217
  y = self.avg_pool(x).view(b, c)
179
218
  y = self.fc(y).view(b, c, 1)
@@ -183,22 +222,43 @@ class _SELayer(nn.Module):
183
222
  class _SEBasicBlock(nn.Module):
184
223
  expansion = 1
185
224
 
186
- def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
187
- base_width=64, dilation=1, norm_layer=None,
188
- *, reduction=16):
225
+ def __init__(
226
+ self,
227
+ inplanes,
228
+ planes,
229
+ stride=1,
230
+ downsample=None,
231
+ activation: nn.Module = nn.ReLU,
232
+ *,
233
+ reduction=16,
234
+ ):
189
235
  super(_SEBasicBlock, self).__init__()
190
236
  self.conv1 = nn.Conv1d(inplanes, planes, stride)
191
237
  self.bn1 = nn.BatchNorm1d(planes)
192
- self.relu = nn.ReLU(inplace=True)
238
+ self.relu = activation(inplace=True)
193
239
  self.conv2 = nn.Conv1d(planes, planes, 1)
194
240
  self.bn2 = nn.BatchNorm1d(planes)
195
241
  self.se = _SELayer(planes, reduction)
196
242
  self.downsample = downsample
197
243
  self.stride = stride
198
- self.features = nn.Sequential(self.conv1, self.bn1, self.relu, self.conv2, self.bn2,
199
- self.se)
244
+ self.features = nn.Sequential(
245
+ self.conv1, self.bn1, self.relu, self.conv2, self.bn2, self.se
246
+ )
200
247
 
201
- def forward(self, x):
248
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
249
+ """
250
+ Forward pass of the SE layer.
251
+
252
+ Parameters
253
+ ----------
254
+ x : torch.Tensor
255
+ Input tensor of shape (batch_size, n_chans, n_times).
256
+
257
+ Returns
258
+ -------
259
+ torch.Tensor
260
+ Output tensor after applying the SE recalibration.
261
+ """
202
262
  residual = x
203
263
  out = self.features(x)
204
264
 
@@ -212,26 +272,29 @@ class _SEBasicBlock(nn.Module):
212
272
 
213
273
 
214
274
  class _MRCNN(nn.Module):
215
- def __init__(self, after_reduced_cnn_size, kernel_size=7):
275
+ def __init__(
276
+ self,
277
+ after_reduced_cnn_size,
278
+ kernel_size=7,
279
+ activation: nn.Module = nn.GELU,
280
+ activation_se: nn.Module = nn.ReLU,
281
+ ):
216
282
  super(_MRCNN, self).__init__()
217
283
  drate = 0.5
218
- self.GELU = nn.GELU()
284
+ self.GELU = activation()
219
285
  self.features1 = nn.Sequential(
220
286
  nn.Conv1d(1, 64, kernel_size=50, stride=6, bias=False, padding=24),
221
287
  nn.BatchNorm1d(64),
222
288
  self.GELU,
223
289
  nn.MaxPool1d(kernel_size=8, stride=2, padding=4),
224
290
  nn.Dropout(drate),
225
-
226
291
  nn.Conv1d(64, 128, kernel_size=8, stride=1, bias=False, padding=4),
227
292
  nn.BatchNorm1d(128),
228
293
  self.GELU,
229
-
230
294
  nn.Conv1d(128, 128, kernel_size=8, stride=1, bias=False, padding=4),
231
295
  nn.BatchNorm1d(128),
232
296
  self.GELU,
233
-
234
- nn.MaxPool1d(kernel_size=4, stride=4, padding=2)
297
+ nn.MaxPool1d(kernel_size=4, stride=4, padding=2),
235
298
  )
236
299
 
237
300
  self.features2 = nn.Sequential(
@@ -240,28 +303,38 @@ class _MRCNN(nn.Module):
240
303
  self.GELU,
241
304
  nn.MaxPool1d(kernel_size=4, stride=2, padding=2),
242
305
  nn.Dropout(drate),
243
-
244
- nn.Conv1d(64, 128, kernel_size=kernel_size, stride=1, bias=False, padding=3),
306
+ nn.Conv1d(
307
+ 64, 128, kernel_size=kernel_size, stride=1, bias=False, padding=3
308
+ ),
245
309
  nn.BatchNorm1d(128),
246
310
  self.GELU,
247
-
248
- nn.Conv1d(128, 128, kernel_size=kernel_size, stride=1, bias=False, padding=3),
311
+ nn.Conv1d(
312
+ 128, 128, kernel_size=kernel_size, stride=1, bias=False, padding=3
313
+ ),
249
314
  nn.BatchNorm1d(128),
250
315
  self.GELU,
251
-
252
- nn.MaxPool1d(kernel_size=2, stride=2, padding=1)
316
+ nn.MaxPool1d(kernel_size=2, stride=2, padding=1),
253
317
  )
254
318
 
255
319
  self.dropout = nn.Dropout(drate)
256
320
  self.inplanes = 128
257
- self.AFR = self._make_layer(_SEBasicBlock, after_reduced_cnn_size, 1)
321
+ self.AFR = self._make_layer(
322
+ _SEBasicBlock, after_reduced_cnn_size, 1, activate=activation_se
323
+ )
258
324
 
259
- def _make_layer(self, block, planes, blocks, stride=1): # makes residual SE block
325
+ def _make_layer(
326
+ self, block, planes, blocks, stride=1, activate: nn.Module = nn.ReLU
327
+ ): # makes residual SE block
260
328
  downsample = None
261
329
  if stride != 1 or self.inplanes != planes * block.expansion:
262
330
  downsample = nn.Sequential(
263
- nn.Conv1d(self.inplanes, planes * block.expansion,
264
- kernel_size=1, stride=stride, bias=False),
331
+ nn.Conv1d(
332
+ self.inplanes,
333
+ planes * block.expansion,
334
+ kernel_size=1,
335
+ stride=stride,
336
+ bias=False,
337
+ ),
265
338
  nn.BatchNorm1d(planes * block.expansion),
266
339
  )
267
340
 
@@ -269,11 +342,11 @@ class _MRCNN(nn.Module):
269
342
  layers.append(block(self.inplanes, planes, stride, downsample))
270
343
  self.inplanes = planes * block.expansion
271
344
  for i in range(1, blocks):
272
- layers.append(block(self.inplanes, planes))
345
+ layers.append(block(self.inplanes, planes, activate=activate))
273
346
 
274
347
  return nn.Sequential(*layers)
275
348
 
276
- def forward(self, x):
349
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
277
350
  x1 = self.features1(x)
278
351
  x2 = self.features2(x)
279
352
  x_concat = torch.cat((x1, x2), dim=2)
@@ -285,93 +358,107 @@ class _MRCNN(nn.Module):
285
358
  ##########################################################################################
286
359
 
287
360
 
288
- def _attention(query, key, value, dropout=None):
361
+ def _attention(
362
+ query: torch.Tensor, key: torch.Tensor, value: torch.Tensor
363
+ ) -> tuple[torch.Tensor, torch.Tensor]:
289
364
  """Implementation of Scaled dot product attention"""
290
365
  # d_k - dimension of the query and key vectors
291
366
  d_k = query.size(-1)
292
367
  scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)
293
-
294
- p_attn = F.softmax(scores, dim=-1)
295
- if dropout is not None:
296
- p_attn = dropout(p_attn)
297
- return torch.matmul(p_attn, value), p_attn
298
-
299
-
300
- class _CausalConv1d(torch.nn.Conv1d):
301
- def __init__(self,
302
- in_channels,
303
- out_channels,
304
- kernel_size,
305
- stride=1,
306
- dilation=1,
307
- groups=1,
308
- bias=True):
309
- self.__padding = (kernel_size - 1) * dilation
310
-
311
- super(_CausalConv1d, self).__init__(
312
- in_channels,
313
- out_channels,
314
- kernel_size=kernel_size,
315
- stride=stride,
316
- padding=self.__padding,
317
- dilation=dilation,
318
- groups=groups,
319
- bias=bias)
320
-
321
- def forward(self, input):
322
- result = super(_CausalConv1d, self).forward(input)
323
- if self.__padding != 0:
324
- return result[:, :, :-self.__padding]
325
- return result
368
+ p_attn = F.softmax(scores, dim=-1) # attention weights
369
+ output = torch.matmul(p_attn, value) # (B, h, T, d_k)
370
+ return output, p_attn
326
371
 
327
372
 
328
373
  class _MultiHeadedAttention(nn.Module):
329
374
  def __init__(self, h, d_model, after_reduced_cnn_size, dropout=0.1):
330
375
  """Take in model size and number of heads."""
331
- super(_MultiHeadedAttention, self).__init__()
376
+ super().__init__()
332
377
  assert d_model % h == 0
333
378
  self.d_per_head = d_model // h
334
379
  self.h = h
335
380
 
336
- self.convs = _clones(_CausalConv1d(after_reduced_cnn_size, after_reduced_cnn_size,
337
- kernel_size=7, stride=1), 3)
381
+ base_conv = CausalConv1d(
382
+ in_channels=after_reduced_cnn_size,
383
+ out_channels=after_reduced_cnn_size,
384
+ kernel_size=7,
385
+ stride=1,
386
+ )
387
+ self.convs = nn.ModuleList([deepcopy(base_conv) for _ in range(3)])
388
+
338
389
  self.linear = nn.Linear(d_model, d_model)
339
390
  self.dropout = nn.Dropout(p=dropout)
340
391
 
341
- def forward(self, query, key, value):
392
+ def forward(self, query, key, value: torch.Tensor) -> torch.Tensor:
342
393
  """Implements Multi-head attention"""
343
394
  nbatches = query.size(0)
344
395
 
345
396
  query = query.view(nbatches, -1, self.h, self.d_per_head).transpose(1, 2)
346
- key = self.convs[1](key).view(nbatches, -1, self.h, self.d_per_head).transpose(1, 2)
347
- value = self.convs[2](value).view(nbatches, -1, self.h, self.d_per_head).transpose(1, 2)
397
+ key = (
398
+ self.convs[1](key)
399
+ .view(nbatches, -1, self.h, self.d_per_head)
400
+ .transpose(1, 2)
401
+ )
402
+ value = (
403
+ self.convs[2](value)
404
+ .view(nbatches, -1, self.h, self.d_per_head)
405
+ .transpose(1, 2)
406
+ )
407
+
408
+ x_raw, attn_weights = _attention(query, key, value)
409
+ # apply dropout to the *weights*
410
+ attn = self.dropout(attn_weights)
411
+ # recompute the weighted sum with dropped weights
412
+ x = torch.matmul(attn, value)
348
413
 
349
- x, self.attn = _attention(query, key, value, dropout=self.dropout)
414
+ # stash the pre‑dropout weights if you need them
415
+ self.attn = attn_weights
350
416
 
351
- x = x.transpose(1, 2).contiguous() \
352
- .view(nbatches, -1, self.h * self.d_per_head)
417
+ # merge heads and project
418
+ x = x.transpose(1, 2).contiguous().view(nbatches, -1, self.h * self.d_per_head)
353
419
 
354
420
  return self.linear(x)
355
421
 
356
422
 
357
- class _SublayerOutput(nn.Module):
423
+ class _ResidualLayerNormAttn(nn.Module):
358
424
  """
359
425
  A residual connection followed by a layer norm.
360
426
  """
361
427
 
362
- def __init__(self, size, dropout):
363
- super(_SublayerOutput, self).__init__()
428
+ def __init__(self, size, dropout, fn_attn):
429
+ super().__init__()
430
+ self.norm = nn.LayerNorm(size, eps=1e-6)
431
+ self.dropout = nn.Dropout(dropout)
432
+ self.fn_attn = fn_attn
433
+
434
+ def forward(
435
+ self,
436
+ x: torch.Tensor,
437
+ key: torch.Tensor,
438
+ value: torch.Tensor,
439
+ ) -> torch.Tensor:
440
+ """Apply residual connection to any sublayer with the same size."""
441
+ x_norm = self.norm(x)
442
+
443
+ out = self.fn_attn(x_norm, key, value)
444
+
445
+ return x + self.dropout(out)
446
+
447
+
448
+ class _ResidualLayerNormFF(nn.Module):
449
+ def __init__(self, size, dropout, fn_ff):
450
+ super().__init__()
364
451
  self.norm = nn.LayerNorm(size, eps=1e-6)
365
452
  self.dropout = nn.Dropout(dropout)
453
+ self.fn_ff = fn_ff
366
454
 
367
- def forward(self, x, sublayer):
455
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
368
456
  """Apply residual connection to any sublayer with the same size."""
369
- return x + self.dropout(sublayer(self.norm(x)))
457
+ x_norm = self.norm(x)
370
458
 
459
+ out = self.fn_ff(x_norm)
371
460
 
372
- def _clones(module, n):
373
- """Produce n identical layers."""
374
- return nn.ModuleList([copy.deepcopy(module) for _ in range(n)])
461
+ return x + self.dropout(out)
375
462
 
376
463
 
377
464
  class _TCE(nn.Module):
@@ -381,11 +468,13 @@ class _TCE(nn.Module):
381
468
  """
382
469
 
383
470
  def __init__(self, layer, n):
384
- super(_TCE, self).__init__()
385
- self.layers = _clones(layer, n)
471
+ super().__init__()
472
+
473
+ self.layers = nn.ModuleList([deepcopy(layer) for _ in range(n)])
474
+
386
475
  self.norm = nn.LayerNorm(layer.size, eps=1e-6)
387
476
 
388
- def forward(self, x):
477
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
389
478
  for layer in self.layers:
390
479
  x = layer(x)
391
480
  return self.norm(x)
@@ -395,35 +484,53 @@ class _EncoderLayer(nn.Module):
395
484
  """
396
485
  An encoder layer
397
486
  Made up of self-attention and a feed forward layer.
398
- Each of these sublayers have residual and layer norm, implemented by _SublayerOutput.
487
+ Each of these sublayers have residual and layer norm, implemented by _ResidualLayerNorm.
399
488
  """
400
489
 
401
490
  def __init__(self, size, self_attn, feed_forward, after_reduced_cnn_size, dropout):
402
- super(_EncoderLayer, self).__init__()
491
+ super().__init__()
492
+ self.size = size
403
493
  self.self_attn = self_attn
404
494
  self.feed_forward = feed_forward
405
- self.sublayer_output = _clones(_SublayerOutput(size, dropout), 2)
406
- self.size = size
407
- self.conv = _CausalConv1d(after_reduced_cnn_size, after_reduced_cnn_size, kernel_size=7,
408
- stride=1, dilation=1)
409
495
 
410
- def forward(self, x_in):
496
+ self.residual_self_attn = _ResidualLayerNormAttn(
497
+ size=size,
498
+ dropout=dropout,
499
+ fn_attn=self_attn,
500
+ )
501
+ self.residual_ff = _ResidualLayerNormFF(
502
+ size=size,
503
+ dropout=dropout,
504
+ fn_ff=feed_forward,
505
+ )
506
+
507
+ self.conv = CausalConv1d(
508
+ in_channels=after_reduced_cnn_size,
509
+ out_channels=after_reduced_cnn_size,
510
+ kernel_size=7,
511
+ stride=1,
512
+ dilation=1,
513
+ )
514
+
515
+ def forward(self, x_in: torch.Tensor) -> torch.Tensor:
411
516
  """Transformer Encoder"""
412
517
  query = self.conv(x_in)
413
518
  # Encoder self-attention
414
- x = self.sublayer_output[0](query, lambda x: self.self_attn(query, x_in, x_in))
415
- return self.sublayer_output[1](x, self.feed_forward)
519
+ x = self.residual_self_attn(query, x_in, x_in)
520
+ x_ff = self.residual_ff(x)
521
+ return x_ff
416
522
 
417
523
 
418
524
  class _PositionwiseFeedForward(nn.Module):
419
525
  """Positionwise feed-forward network."""
420
526
 
421
- def __init__(self, d_model, d_ff, dropout=0.1):
422
- super(_PositionwiseFeedForward, self).__init__()
527
+ def __init__(self, d_model, d_ff, dropout=0.1, activation: nn.Module = nn.ReLU):
528
+ super().__init__()
423
529
  self.w_1 = nn.Linear(d_model, d_ff)
424
530
  self.w_2 = nn.Linear(d_ff, d_model)
425
531
  self.dropout = nn.Dropout(dropout)
532
+ self.activate = activation()
426
533
 
427
- def forward(self, x):
534
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
428
535
  """Implements FFN equation."""
429
- return self.w_2(self.dropout(F.relu(self.w_1(x))))
536
+ return self.w_2(self.dropout(self.activate(self.w_1(x))))